Coverage Report

Created: 2025-07-11 06:57

/src/openssl/crypto/ml_kem/ml_kem.c
Line
Count
Source (jump to first uncovered line)
1
/*
2
 * Copyright 2024-2025 The OpenSSL Project Authors. All Rights Reserved.
3
 *
4
 * Licensed under the Apache License 2.0 (the "License").  You may not use
5
 * this file except in compliance with the License.  You can obtain a copy
6
 * in the file LICENSE in the source distribution or at
7
 * https://www.openssl.org/source/license.html
8
 */
9
10
#include <openssl/byteorder.h>
11
#include <openssl/rand.h>
12
#include <openssl/proverr.h>
13
#include "crypto/ml_kem.h"
14
#include "internal/common.h"
15
#include "internal/constant_time.h"
16
#include "internal/sha3.h"
17
18
#if defined(OPENSSL_CONSTANT_TIME_VALIDATION)
19
#include <valgrind/memcheck.h>
20
#endif
21
22
#if ML_KEM_SEED_BYTES != ML_KEM_SHARED_SECRET_BYTES + ML_KEM_RANDOM_BYTES
23
# error "ML-KEM keygen seed length != shared secret + random bytes length"
24
#endif
25
#if ML_KEM_SHARED_SECRET_BYTES != ML_KEM_RANDOM_BYTES
26
# error "Invalid unequal lengths of ML-KEM shared secret and random inputs"
27
#endif
28
29
#if UINT_MAX < UINT32_MAX
30
# error "Unsupported compiler: sizeof(unsigned int) < sizeof(uint32_t)"
31
#endif
32
33
/* Handy function-like bit-extraction macros */
34
0
#define bit0(b)     ((b) & 1)
35
0
#define bitn(n, b)  (((b) >> n) & 1)
36
37
/*
38
 * 12 bits are sufficient to losslessly represent values in [0, q-1].
39
 * INVERSE_DEGREE is (n/2)^-1 mod q; used in inverse NTT.
40
 */
41
0
#define DEGREE          ML_KEM_DEGREE
42
#define INVERSE_DEGREE  (ML_KEM_PRIME - 2 * 13)
43
#define LOG2PRIME       12
44
#define BARRETT_SHIFT   (2 * LOG2PRIME)
45
46
#ifdef SHA3_BLOCKSIZE
47
# define SHAKE128_BLOCKSIZE SHA3_BLOCKSIZE(128)
48
#endif
49
50
/*
51
 * Return whether a value that can only be 0 or 1 is non-zero, in constant time
52
 * in practice!  The return value is a mask that is all ones if true, and all
53
 * zeros otherwise (twos-complement arithmentic assumed for unsigned values).
54
 *
55
 * Although this is used in constant-time selects, we omit a value barrier
56
 * here.  Value barriers impede auto-vectorization (likely because it forces
57
 * the value to transit through a general-purpose register). On AArch64, this
58
 * is a difference of 2x.
59
 *
60
 * We usually add value barriers to selects because Clang turns consecutive
61
 * selects with the same condition into a branch instead of CMOV/CSEL. This
62
 * condition does not occur in Kyber, so omitting it seems to be safe so far,
63
 * but see |cbd_2|, |cbd_3|, where reduction needs to be specialised to the
64
 * sign of the input, rather than adding |q| in advance, and using the generic
65
 * |reduce_once|.  (David Benjamin, Chromium)
66
 */
67
#if 0
68
# define constish_time_non_zero(b) (~constant_time_is_zero(b));
69
#else
70
0
# define constish_time_non_zero(b) (0u - (b))
71
#endif
72
73
/*
74
 * The scalar rejection-sampling buffer size needs to be a multiple of 12, but
75
 * is otherwise arbitrary, the preferred block size matches the internal buffer
76
 * size of SHAKE128, avoiding internal buffering and copying in SHAKE128. That
77
 * block size of (1600 - 256)/8 bytes, or 168, just happens to divide by 12!
78
 *
79
 * If the blocksize is unknown, or is not divisible by 12, 168 is used as a
80
 * fallback.
81
 */
82
#if defined(SHAKE128_BLOCKSIZE) && (SHAKE128_BLOCKSIZE) % 12 == 0
83
# define SCALAR_SAMPLING_BUFSIZE (SHAKE128_BLOCKSIZE)
84
#else
85
# define SCALAR_SAMPLING_BUFSIZE 168
86
#endif
87
88
/*
89
 * Structure of keys
90
 */
91
typedef struct ossl_ml_kem_scalar_st {
92
    /* On every function entry and exit, 0 <= c[i] < ML_KEM_PRIME. */
93
    uint16_t c[ML_KEM_DEGREE];
94
} scalar;
95
96
/* Key material allocation layout */
97
#define DECLARE_ML_KEM_PUBKEYDATA(name, rank) \
98
    struct name##_alloc { \
99
        /* Public vector |t| */ \
100
        scalar tbuf[(rank)]; \
101
        /* Pre-computed matrix |m| (FIPS 203 |A| transpose) */ \
102
        scalar mbuf[(rank)*(rank)]; \
103
    }
104
105
#define DECLARE_ML_KEM_PRVKEYDATA(name, rank) \
106
    struct name##_alloc { \
107
        scalar sbuf[rank]; \
108
        uint8_t zbuf[2 * ML_KEM_RANDOM_BYTES]; \
109
    }
110
111
/* Declare variant-specific public and private storage */
112
#define DECLARE_ML_KEM_VARIANT_KEYDATA(bits) \
113
    DECLARE_ML_KEM_PUBKEYDATA(pubkey_##bits, ML_KEM_##bits##_RANK); \
114
    DECLARE_ML_KEM_PRVKEYDATA(prvkey_##bits, ML_KEM_##bits##_RANK)
115
116
DECLARE_ML_KEM_VARIANT_KEYDATA(512);
117
DECLARE_ML_KEM_VARIANT_KEYDATA(768);
118
DECLARE_ML_KEM_VARIANT_KEYDATA(1024);
119
#undef DECLARE_ML_KEM_VARIANT_KEYDATA
120
#undef DECLARE_ML_KEM_PUBKEYDATA
121
#undef DECLARE_ML_KEM_PRVKEYDATA
122
123
typedef __owur
124
int (*CBD_FUNC)(scalar *out, uint8_t in[ML_KEM_RANDOM_BYTES + 1],
125
                EVP_MD_CTX *mdctx, const ML_KEM_KEY *key);
126
static void scalar_encode(uint8_t *out, const scalar *s, int bits);
127
128
/*
129
 * The wire-form of a losslessly encoded vector uses 12-bits per element.
130
 *
131
 * The wire-form public key consists of the lossless encoding of the public
132
 * vector |t|, followed by the public seed |rho|.
133
 *
134
 * Our serialised private key concatenates serialisations of the private vector
135
 * |s|, the public key, the public key hash, and the failure secret |z|.
136
 */
137
#define VECTOR_BYTES(b)     ((3 * DEGREE / 2) * ML_KEM_##b##_RANK)
138
#define PUBKEY_BYTES(b)     (VECTOR_BYTES(b) + ML_KEM_RANDOM_BYTES)
139
#define PRVKEY_BYTES(b)     (2 * PUBKEY_BYTES(b) + ML_KEM_PKHASH_BYTES)
140
141
/*
142
 * Encapsulation produces a vector "u" and a scalar "v", whose coordinates
143
 * (numbers modulo the ML-KEM prime "q") are lossily encoded using as "du" and
144
 * "dv" bits, respectively.  This encoding is the ciphertext input for
145
 * decapsulation.
146
 */
147
#define U_VECTOR_BYTES(b)   ((DEGREE / 8) * ML_KEM_##b##_DU * ML_KEM_##b##_RANK)
148
#define V_SCALAR_BYTES(b)   ((DEGREE / 8) * ML_KEM_##b##_DV)
149
#define CTEXT_BYTES(b)      (U_VECTOR_BYTES(b) + V_SCALAR_BYTES(b))
150
151
#if defined(OPENSSL_CONSTANT_TIME_VALIDATION)
152
153
/*
154
 * CONSTTIME_SECRET takes a pointer and a number of bytes and marks that region
155
 * of memory as secret. Secret data is tracked as it flows to registers and
156
 * other parts of a memory. If secret data is used as a condition for a branch,
157
 * or as a memory index, it will trigger warnings in valgrind.
158
 */
159
# define CONSTTIME_SECRET(ptr, len) VALGRIND_MAKE_MEM_UNDEFINED(ptr, len)
160
161
/*
162
 * CONSTTIME_DECLASSIFY takes a pointer and a number of bytes and marks that
163
 * region of memory as public. Public data is not subject to constant-time
164
 * rules.
165
 */
166
# define CONSTTIME_DECLASSIFY(ptr, len) VALGRIND_MAKE_MEM_DEFINED(ptr, len)
167
168
#else
169
170
# define CONSTTIME_SECRET(ptr, len)
171
# define CONSTTIME_DECLASSIFY(ptr, len)
172
173
#endif
174
175
/*
176
 * Indices of slots in the vinfo tables below
177
 */
178
0
#define ML_KEM_512_VINFO    0
179
0
#define ML_KEM_768_VINFO    1
180
0
#define ML_KEM_1024_VINFO   2
181
182
/*
183
 * Per-variant fixed parameters
184
 */
185
static const ML_KEM_VINFO vinfo_map[3] = {
186
    {
187
        "ML-KEM-512",
188
        PRVKEY_BYTES(512),
189
        sizeof(struct prvkey_512_alloc),
190
        PUBKEY_BYTES(512),
191
        sizeof(struct pubkey_512_alloc),
192
        CTEXT_BYTES(512),
193
        VECTOR_BYTES(512),
194
        U_VECTOR_BYTES(512),
195
        EVP_PKEY_ML_KEM_512,
196
        ML_KEM_512_BITS,
197
        ML_KEM_512_RANK,
198
        ML_KEM_512_DU,
199
        ML_KEM_512_DV,
200
        ML_KEM_512_SECBITS,
201
        ML_KEM_512_SECURITY_CATEGORY
202
    },
203
    {
204
        "ML-KEM-768",
205
        PRVKEY_BYTES(768),
206
        sizeof(struct prvkey_768_alloc),
207
        PUBKEY_BYTES(768),
208
        sizeof(struct pubkey_768_alloc),
209
        CTEXT_BYTES(768),
210
        VECTOR_BYTES(768),
211
        U_VECTOR_BYTES(768),
212
        EVP_PKEY_ML_KEM_768,
213
        ML_KEM_768_BITS,
214
        ML_KEM_768_RANK,
215
        ML_KEM_768_DU,
216
        ML_KEM_768_DV,
217
        ML_KEM_768_SECBITS,
218
        ML_KEM_768_SECURITY_CATEGORY
219
    },
220
    {
221
        "ML-KEM-1024",
222
        PRVKEY_BYTES(1024),
223
        sizeof(struct prvkey_1024_alloc),
224
        PUBKEY_BYTES(1024),
225
        sizeof(struct pubkey_1024_alloc),
226
        CTEXT_BYTES(1024),
227
        VECTOR_BYTES(1024),
228
        U_VECTOR_BYTES(1024),
229
        EVP_PKEY_ML_KEM_1024,
230
        ML_KEM_1024_BITS,
231
        ML_KEM_1024_RANK,
232
        ML_KEM_1024_DU,
233
        ML_KEM_1024_DV,
234
        ML_KEM_1024_SECBITS,
235
        ML_KEM_1024_SECURITY_CATEGORY
236
    }
237
};
238
239
/*
240
 * Remainders modulo `kPrime`, for sufficiently small inputs, are computed in
241
 * constant time via Barrett reduction, and a final call to reduce_once(),
242
 * which reduces inputs that are at most 2*kPrime and is also constant-time.
243
 */
244
static const int kPrime = ML_KEM_PRIME;
245
static const unsigned int kBarrettShift = BARRETT_SHIFT;
246
static const size_t   kBarrettMultiplier = (1 << BARRETT_SHIFT) / ML_KEM_PRIME;
247
static const uint16_t kHalfPrime = (ML_KEM_PRIME - 1) / 2;
248
static const uint16_t kInverseDegree = INVERSE_DEGREE;
249
250
/*
251
 * Python helper:
252
 *
253
 * p = 3329
254
 * def bitreverse(i):
255
 *     ret = 0
256
 *     for n in range(7):
257
 *         bit = i & 1
258
 *         ret <<= 1
259
 *         ret |= bit
260
 *         i >>= 1
261
 *     return ret
262
 */
263
264
/*-
265
 * First precomputed array from Appendix A of FIPS 203, or else Python:
266
 * kNTTRoots = [pow(17, bitreverse(i), p) for i in range(128)]
267
 */
268
static const uint16_t kNTTRoots[128] = {
269
    1,    1729, 2580, 3289, 2642, 630,  1897, 848,
270
    1062, 1919, 193,  797,  2786, 3260, 569,  1746,
271
    296,  2447, 1339, 1476, 3046, 56,   2240, 1333,
272
    1426, 2094, 535,  2882, 2393, 2879, 1974, 821,
273
    289,  331,  3253, 1756, 1197, 2304, 2277, 2055,
274
    650,  1977, 2513, 632,  2865, 33,   1320, 1915,
275
    2319, 1435, 807,  452,  1438, 2868, 1534, 2402,
276
    2647, 2617, 1481, 648,  2474, 3110, 1227, 910,
277
    17,   2761, 583,  2649, 1637, 723,  2288, 1100,
278
    1409, 2662, 3281, 233,  756,  2156, 3015, 3050,
279
    1703, 1651, 2789, 1789, 1847, 952,  1461, 2687,
280
    939,  2308, 2437, 2388, 733,  2337, 268,  641,
281
    1584, 2298, 2037, 3220, 375,  2549, 2090, 1645,
282
    1063, 319,  2773, 757,  2099, 561,  2466, 2594,
283
    2804, 1092, 403,  1026, 1143, 2150, 2775, 886,
284
    1722, 1212, 1874, 1029, 2110, 2935, 885,  2154,
285
};
286
287
/*
288
 * InverseNTTRoots = [pow(17, -bitreverse(i), p) for i in range(128)]
289
 * Listed in order of use in the inverse NTT loop (index 0 is skipped):
290
 *
291
 *  0, 64, 65, ..., 127, 32, 33, ..., 63, 16, 17, ..., 31, 8, 9, ...
292
 */
293
static const uint16_t kInverseNTTRoots[128] = {
294
    1,    1175, 2444, 394,  1219, 2300, 1455, 2117,
295
    1607, 2443, 554,  1179, 2186, 2303, 2926, 2237,
296
    525,  735,  863,  2768, 1230, 2572, 556,  3010,
297
    2266, 1684, 1239, 780,  2954, 109,  1292, 1031,
298
    1745, 2688, 3061, 992,  2596, 941,  892,  1021,
299
    2390, 642,  1868, 2377, 1482, 1540, 540,  1678,
300
    1626, 279,  314,  1173, 2573, 3096, 48,   667,
301
    1920, 2229, 1041, 2606, 1692, 680,  2746, 568,
302
    3312, 2419, 2102, 219,  855,  2681, 1848, 712,
303
    682,  927,  1795, 461,  1891, 2877, 2522, 1894,
304
    1010, 1414, 2009, 3296, 464,  2697, 816,  1352,
305
    2679, 1274, 1052, 1025, 2132, 1573, 76,   2998,
306
    3040, 2508, 1355, 450,  936,  447,  2794, 1235,
307
    1903, 1996, 1089, 3273, 283,  1853, 1990, 882,
308
    3033, 1583, 2760, 69,   543,  2532, 3136, 1410,
309
    2267, 2481, 1432, 2699, 687,  40,   749,  1600,
310
};
311
312
/*
313
 * Second precomputed array from Appendix A of FIPS 203 (normalised positive),
314
 * or else Python:
315
 * ModRoots = [pow(17, 2*bitreverse(i) + 1, p) for i in range(128)]
316
 */
317
static const uint16_t kModRoots[128] = {
318
    17,   3312, 2761, 568,  583,  2746, 2649, 680,  1637, 1692, 723,  2606,
319
    2288, 1041, 1100, 2229, 1409, 1920, 2662, 667,  3281, 48,   233,  3096,
320
    756,  2573, 2156, 1173, 3015, 314,  3050, 279,  1703, 1626, 1651, 1678,
321
    2789, 540,  1789, 1540, 1847, 1482, 952,  2377, 1461, 1868, 2687, 642,
322
    939,  2390, 2308, 1021, 2437, 892,  2388, 941,  733,  2596, 2337, 992,
323
    268,  3061, 641,  2688, 1584, 1745, 2298, 1031, 2037, 1292, 3220, 109,
324
    375,  2954, 2549, 780,  2090, 1239, 1645, 1684, 1063, 2266, 319,  3010,
325
    2773, 556,  757,  2572, 2099, 1230, 561,  2768, 2466, 863,  2594, 735,
326
    2804, 525,  1092, 2237, 403,  2926, 1026, 2303, 1143, 2186, 2150, 1179,
327
    2775, 554,  886,  2443, 1722, 1607, 1212, 2117, 1874, 1455, 1029, 2300,
328
    2110, 1219, 2935, 394,  885,  2444, 2154, 1175,
329
};
330
331
/*
332
 * single_keccak hashes |inlen| bytes from |in| and writes |outlen| bytes of
333
 * output to |out|. If the |md| specifies a fixed-output function, like
334
 * SHA3-256, then |outlen| must be the correct length for that function.
335
 */
336
static __owur
337
int single_keccak(uint8_t *out, size_t outlen, const uint8_t *in, size_t inlen,
338
                  EVP_MD_CTX *mdctx)
339
0
{
340
0
    unsigned int sz = (unsigned int) outlen;
341
342
0
    if (!EVP_DigestUpdate(mdctx, in, inlen))
343
0
        return 0;
344
0
    if (EVP_MD_xof(EVP_MD_CTX_get0_md(mdctx)))
345
0
        return EVP_DigestFinalXOF(mdctx, out, outlen);
346
0
    return EVP_DigestFinal_ex(mdctx, out, &sz)
347
0
        && ossl_assert((size_t) sz == outlen);
348
0
}
349
350
/*
351
 * FIPS 203, Section 4.1, equation (4.3): PRF. Takes 32+1 input bytes, and uses
352
 * SHAKE256 to produce the input to SamplePolyCBD_eta: FIPS 203, algorithm 8.
353
 */
354
static __owur
355
int prf(uint8_t *out, size_t len, const uint8_t in[ML_KEM_RANDOM_BYTES + 1],
356
        EVP_MD_CTX *mdctx, const ML_KEM_KEY *key)
357
0
{
358
0
    return EVP_DigestInit_ex(mdctx, key->shake256_md, NULL)
359
0
        && single_keccak(out, len, in, ML_KEM_RANDOM_BYTES + 1, mdctx);
360
0
}
361
362
/*
363
 * FIPS 203, Section 4.1, equation (4.4): H.  SHA3-256 hash of a variable
364
 * length input, producing 32 bytes of output.
365
 */
366
static __owur
367
int hash_h(uint8_t out[ML_KEM_PKHASH_BYTES], const uint8_t *in, size_t len,
368
           EVP_MD_CTX *mdctx, const ML_KEM_KEY *key)
369
0
{
370
0
    return EVP_DigestInit_ex(mdctx, key->sha3_256_md, NULL)
371
0
        && single_keccak(out, ML_KEM_PKHASH_BYTES, in, len, mdctx);
372
0
}
373
374
/* Incremental hash_h of expanded public key */
375
static int
376
hash_h_pubkey(uint8_t pkhash[ML_KEM_PKHASH_BYTES],
377
              EVP_MD_CTX *mdctx, ML_KEM_KEY *key)
378
0
{
379
0
    const ML_KEM_VINFO *vinfo = key->vinfo;
380
0
    const scalar *t = key->t, *end = t + vinfo->rank;
381
0
    unsigned int sz;
382
383
0
    if (!EVP_DigestInit_ex(mdctx, key->sha3_256_md, NULL))
384
0
        return 0;
385
386
0
    do {
387
0
        uint8_t buf[3 * DEGREE / 2];
388
389
0
        scalar_encode(buf, t++, 12);
390
0
        if (!EVP_DigestUpdate(mdctx, buf, sizeof(buf)))
391
0
            return 0;
392
0
    } while (t < end);
393
394
0
    if (!EVP_DigestUpdate(mdctx, key->rho, ML_KEM_RANDOM_BYTES))
395
0
        return 0;
396
0
    return EVP_DigestFinal_ex(mdctx, pkhash, &sz)
397
0
        && ossl_assert(sz == ML_KEM_PKHASH_BYTES);
398
0
}
399
400
/*
401
 * FIPS 203, Section 4.1, equation (4.5): G.  SHA3-512 hash of a variable
402
 * length input, producing 64 bytes of output, in particular the seeds
403
 * (d,z) for key generation.
404
 */
405
static __owur
406
int hash_g(uint8_t out[ML_KEM_SEED_BYTES], const uint8_t *in, size_t len,
407
           EVP_MD_CTX *mdctx, const ML_KEM_KEY *key)
408
0
{
409
0
    return EVP_DigestInit_ex(mdctx, key->sha3_512_md, NULL)
410
0
        && single_keccak(out, ML_KEM_SEED_BYTES, in, len, mdctx);
411
0
}
412
413
/*
414
 * FIPS 203, Section 4.1, equation (4.4): J. SHAKE256 taking a variable length
415
 * input to compute a 32-byte implicit rejection shared secret, of the same
416
 * length as the expected shared secret.  (Computed even on success to avoid
417
 * side-channel leaks).
418
 */
419
static __owur
420
int kdf(uint8_t out[ML_KEM_SHARED_SECRET_BYTES],
421
        const uint8_t z[ML_KEM_RANDOM_BYTES],
422
        const uint8_t *ctext, size_t len,
423
        EVP_MD_CTX *mdctx, const ML_KEM_KEY *key)
424
0
{
425
0
    return EVP_DigestInit_ex(mdctx, key->shake256_md, NULL)
426
0
        && EVP_DigestUpdate(mdctx, z, ML_KEM_RANDOM_BYTES)
427
0
        && EVP_DigestUpdate(mdctx, ctext, len)
428
0
        && EVP_DigestFinalXOF(mdctx, out, ML_KEM_SHARED_SECRET_BYTES);
429
0
}
430
431
/*
432
 * FIPS 203, Section 4.2.2, Algorithm 7: "SampleNTT" (steps 3-17, steps 1, 2
433
 * are performed by the caller). Rejection-samples a Keccak stream to get
434
 * uniformly distributed elements in the range [0,q). This is used for matrix
435
 * expansion and only operates on public inputs.
436
 */
437
static __owur
438
int sample_scalar(scalar *out, EVP_MD_CTX *mdctx)
439
0
{
440
0
    uint16_t *curr = out->c, *endout = curr + DEGREE;
441
0
    uint8_t buf[SCALAR_SAMPLING_BUFSIZE], *in;
442
0
    uint8_t *endin = buf + sizeof(buf);
443
0
    uint16_t d;
444
0
    uint8_t b1, b2, b3;
445
446
0
    do {
447
0
        if (!EVP_DigestSqueeze(mdctx, in = buf, sizeof(buf)))
448
0
            return 0;
449
0
        do {
450
0
            b1 = *in++;
451
0
            b2 = *in++;
452
0
            b3 = *in++;
453
454
0
            if (curr >= endout)
455
0
                break;
456
0
            if ((d = ((b2 & 0x0f) << 8) + b1) < kPrime)
457
0
                *curr++ = d;
458
0
            if (curr >= endout)
459
0
                break;
460
0
            if ((d = (b3 << 4) + (b2 >> 4)) < kPrime)
461
0
                *curr++ = d;
462
0
        } while (in < endin);
463
0
    } while (curr < endout);
464
0
    return 1;
465
0
}
466
467
/*-
468
 * reduce_once reduces 0 <= x < 2*kPrime, mod kPrime.
469
 *
470
 * Subtract |q| if the input is larger, without exposing a side-channel,
471
 * avoiding the "clangover" attack.  See |constish_time_non_zero| for a
472
 * discussion on why the value barrier is by default omitted.
473
 */
474
static __owur uint16_t reduce_once(uint16_t x)
475
0
{
476
0
    const uint16_t subtracted = x - kPrime;
477
0
    uint16_t mask = constish_time_non_zero(subtracted >> 15);
478
479
0
    return (mask & x) | (~mask & subtracted);
480
0
}
481
482
/*
483
 * Constant-time reduce x mod kPrime using Barrett reduction. x must be less
484
 * than kPrime + 2 * kPrime^2.  This is sufficient to reduce a product of
485
 * two already reduced u_int16 values, in fact it is sufficient for each
486
 * to be less than 2^12, because (kPrime * (2 * kPrime + 1)) > 2^24.
487
 */
488
static __owur uint16_t reduce(uint32_t x)
489
0
{
490
0
    uint64_t product = (uint64_t)x * kBarrettMultiplier;
491
0
    uint32_t quotient = (uint32_t)(product >> kBarrettShift);
492
0
    uint32_t remainder = x - quotient * kPrime;
493
494
0
    return reduce_once(remainder);
495
0
}
496
497
/* Multiply a scalar by a constant. */
498
static void scalar_mult_const(scalar *s, uint16_t a)
499
0
{
500
0
    uint16_t *curr = s->c, *end = curr + DEGREE, tmp;
501
502
0
    do {
503
0
        tmp = reduce(*curr * a);
504
0
        *curr++ = tmp;
505
0
    } while (curr < end);
506
0
}
507
508
/*-
509
 * FIPS 203, Section 4.3, Algoritm 9: "NTT".
510
 * In-place number theoretic transform of a given scalar.  Note that ML-KEM's
511
 * kPrime 3329 does not have a 512th root of unity, so this transform leaves
512
 * off the last iteration of the usual FFT code, with the 128 relevant roots of
513
 * unity being stored in NTTRoots.  This means the output should be seen as 128
514
 * elements in GF(3329^2), with the coefficients of the elements being
515
 * consecutive entries in |s->c|.
516
 */
517
static void scalar_ntt(scalar *s)
518
0
{
519
0
    const uint16_t *roots = kNTTRoots;
520
0
    uint16_t *end = s->c + DEGREE;
521
0
    int offset = DEGREE / 2;
522
523
0
    do {
524
0
        uint16_t *curr = s->c, *peer;
525
526
0
        do {
527
0
            uint16_t *pause = curr + offset, even, odd;
528
0
            uint32_t zeta = *++roots;
529
530
0
            peer = pause;
531
0
            do {
532
0
                even = *curr;
533
0
                odd = reduce(*peer * zeta);
534
0
                *peer++ = reduce_once(even - odd + kPrime);
535
0
                *curr++ = reduce_once(odd + even);
536
0
            } while (curr < pause);
537
0
        } while ((curr = peer) < end);
538
0
    } while ((offset >>= 1) >= 2);
539
0
}
540
541
/*-
542
 * FIPS 203, Section 4.3, Algoritm 10: "NTT^(-1)".
543
 * In-place inverse number theoretic transform of a given scalar, with pairs of
544
 * entries of s->v being interpreted as elements of GF(3329^2). Just as with
545
 * the number theoretic transform, this leaves off the first step of the normal
546
 * iFFT to account for the fact that 3329 does not have a 512th root of unity,
547
 * using the precomputed 128 roots of unity stored in InverseNTTRoots.
548
 */
549
static void scalar_inverse_ntt(scalar *s)
550
0
{
551
0
    const uint16_t *roots = kInverseNTTRoots;
552
0
    uint16_t *end = s->c + DEGREE;
553
0
    int offset = 2;
554
555
0
    do {
556
0
        uint16_t *curr = s->c, *peer;
557
558
0
        do {
559
0
            uint16_t *pause = curr + offset, even, odd;
560
0
            uint32_t zeta = *++roots;
561
562
0
            peer = pause;
563
0
            do {
564
0
                even = *curr;
565
0
                odd = *peer;
566
0
                *peer++ = reduce(zeta * (even - odd + kPrime));
567
0
                *curr++ = reduce_once(odd + even);
568
0
            } while (curr < pause);
569
0
        } while ((curr = peer) < end);
570
0
    } while ((offset <<= 1) < DEGREE);
571
0
    scalar_mult_const(s, kInverseDegree);
572
0
}
573
574
/* Addition updating the LHS scalar in-place. */
575
static void scalar_add(scalar *lhs, const scalar *rhs)
576
0
{
577
0
    int i;
578
579
0
    for (i = 0; i < DEGREE; i++)
580
0
        lhs->c[i] = reduce_once(lhs->c[i] + rhs->c[i]);
581
0
}
582
583
/* Subtraction updating the LHS scalar in-place. */
584
static void scalar_sub(scalar *lhs, const scalar *rhs)
585
0
{
586
0
    int i;
587
588
0
    for (i = 0; i < DEGREE; i++)
589
0
        lhs->c[i] = reduce_once(lhs->c[i] - rhs->c[i] + kPrime);
590
0
}
591
592
/*
593
 * Multiplying two scalars in the number theoretically transformed state. Since
594
 * 3329 does not have a 512th root of unity, this means we have to interpret
595
 * the 2*ith and (2*i+1)th entries of the scalar as elements of
596
 * GF(3329)[X]/(X^2 - 17^(2*bitreverse(i)+1)).
597
 *
598
 * The value of 17^(2*bitreverse(i)+1) mod 3329 is stored in the precomputed
599
 * ModRoots table. Note that our Barrett transform only allows us to multipy
600
 * two reduced numbers together, so we need some intermediate reduction steps,
601
 * even if an uint64_t could hold 3 multiplied numbers.
602
 */
603
static void scalar_mult(scalar *out, const scalar *lhs,
604
                        const scalar *rhs)
605
0
{
606
0
    uint16_t *curr = out->c, *end = curr + DEGREE;
607
0
    const uint16_t *lc = lhs->c, *rc = rhs->c;
608
0
    const uint16_t *roots = kModRoots;
609
610
0
    do {
611
0
        uint32_t l0 = *lc++, r0 = *rc++;
612
0
        uint32_t l1 = *lc++, r1 = *rc++;
613
0
        uint32_t zetapow = *roots++;
614
615
0
        *curr++ = reduce(l0 * r0 + reduce(l1 * r1) * zetapow);
616
0
        *curr++ = reduce(l0 * r1 + l1 * r0);
617
0
    } while (curr < end);
618
0
}
619
620
/* Above, but add the result to an existing scalar */
621
static ossl_inline
622
void scalar_mult_add(scalar *out, const scalar *lhs,
623
                     const scalar *rhs)
624
0
{
625
0
    uint16_t *curr = out->c, *end = curr + DEGREE;
626
0
    const uint16_t *lc = lhs->c, *rc = rhs->c;
627
0
    const uint16_t *roots = kModRoots;
628
629
0
    do {
630
0
        uint32_t l0 = *lc++, r0 = *rc++;
631
0
        uint32_t l1 = *lc++, r1 = *rc++;
632
0
        uint16_t *c0 = curr++;
633
0
        uint16_t *c1 = curr++;
634
0
        uint32_t zetapow = *roots++;
635
636
0
        *c0 = reduce(*c0 + l0 * r0 + reduce(l1 * r1) * zetapow);
637
0
        *c1 = reduce(*c1 + l0 * r1 + l1 * r0);
638
0
    } while (curr < end);
639
0
}
640
641
/*-
642
 * FIPS 203, Section 4.2.1, Algorithm 5: "ByteEncode_d", for 2<=d<=12.
643
 * Here |bits| is |d|.  For efficiency, we handle the d=1 case separately.
644
 */
645
static void scalar_encode(uint8_t *out, const scalar *s, int bits)
646
0
{
647
0
    const uint16_t *curr = s->c, *end = curr + DEGREE;
648
0
    uint64_t accum = 0, element;
649
0
    int used = 0;
650
651
0
    do {
652
0
        element = *curr++;
653
0
        if (used + bits < 64) {
654
0
            accum |= element << used;
655
0
            used += bits;
656
0
        } else if (used + bits > 64) {
657
0
            out = OPENSSL_store_u64_le(out, accum | (element << used));
658
0
            accum = element >> (64 - used);
659
0
            used = (used + bits) - 64;
660
0
        } else {
661
0
            out = OPENSSL_store_u64_le(out, accum | (element << used));
662
0
            accum = 0;
663
0
            used = 0;
664
0
        }
665
0
    } while (curr < end);
666
0
}
667
668
/*
669
 * scalar_encode_1 is |scalar_encode| specialised for |bits| == 1.
670
 */
671
static void scalar_encode_1(uint8_t out[DEGREE / 8], const scalar *s)
672
0
{
673
0
    int i, j;
674
0
    uint8_t out_byte;
675
676
0
    for (i = 0; i < DEGREE; i += 8) {
677
0
        out_byte = 0;
678
0
        for (j = 0; j < 8; j++)
679
0
            out_byte |= bit0(s->c[i + j]) << j;
680
0
        *out = out_byte;
681
0
        out++;
682
0
    }
683
0
}
684
685
/*-
686
 * FIPS 203, Section 4.2.1, Algorithm 6: "ByteDecode_d", for 2<=d<12.
687
 * Here |bits| is |d|.  For efficiency, we handle the d=1 and d=12 cases
688
 * separately.
689
 *
690
 * scalar_decode parses |DEGREE * bits| bits from |in| into |DEGREE| values in
691
 * |out|.
692
 */
693
static void scalar_decode(scalar *out, const uint8_t *in, int bits)
694
0
{
695
0
    uint16_t *curr = out->c, *end = curr + DEGREE;
696
0
    uint64_t accum = 0;
697
0
    int accum_bits = 0, todo = bits;
698
0
    uint16_t bitmask = (((uint16_t) 1) << bits) - 1, mask = bitmask;
699
0
    uint16_t element = 0;
700
701
0
    do {
702
0
        if (accum_bits == 0) {
703
0
            in = OPENSSL_load_u64_le(&accum, in);
704
0
            accum_bits = 64;
705
0
        }
706
0
        if (todo == bits && accum_bits >= bits) {
707
            /* No partial "element", and all the required bits available */
708
0
            *curr++ = ((uint16_t) accum) & mask;
709
0
            accum >>= bits;
710
0
            accum_bits -= bits;
711
0
        } else if (accum_bits >= todo) {
712
            /* A partial "element", and all the required bits available */
713
0
            *curr++ = element | ((((uint16_t) accum) & mask) << (bits - todo));
714
0
            accum >>= todo;
715
0
            accum_bits -= todo;
716
0
            element = 0;
717
0
            todo = bits;
718
0
            mask = bitmask;
719
0
        } else {
720
            /*
721
             * Only some of the requisite bits accumulated, store |accum_bits|
722
             * of these in |element|.  The accumulated bitcount becomes 0, but
723
             * as soon as we have more bits we'll want to merge accum_bits
724
             * fewer of them into the final |element|.
725
             *
726
             * Note that with a 64-bit accumulator and |bits| always 12 or
727
             * less, if we're here, the previous iteration had all the
728
             * requisite bits, and so there are no kept bits in |element|.
729
             */
730
0
            element = ((uint16_t) accum) & mask;
731
0
            todo -= accum_bits;
732
0
            mask = bitmask >> accum_bits;
733
0
            accum_bits = 0;
734
0
        }
735
0
    } while (curr < end);
736
0
}
737
738
static __owur
739
int scalar_decode_12(scalar *out, const uint8_t in[3 * DEGREE / 2])
740
0
{
741
0
    int i;
742
0
    uint16_t *c = out->c;
743
744
0
    for (i = 0; i < DEGREE / 2; ++i) {
745
0
        uint8_t b1 = *in++;
746
0
        uint8_t b2 = *in++;
747
0
        uint8_t b3 = *in++;
748
0
        int outOfRange1 = (*c++ = b1 | ((b2 & 0x0f) << 8)) >= kPrime;
749
0
        int outOfRange2 = (*c++ = (b2 >> 4) | (b3 << 4)) >= kPrime;
750
751
0
        if (outOfRange1 | outOfRange2)
752
0
            return 0;
753
0
    }
754
0
    return 1;
755
0
}
756
757
/*-
758
 * scalar_decode_decompress_add is a combination of decoding and decompression
759
 * both specialised for |bits| == 1, with the result added (and sum reduced) to
760
 * the output scalar.
761
 *
762
 * NOTE: this function MUST not leak an input-data-depedennt timing signal.
763
 * A timing leak in a related function in the reference Kyber implementation
764
 * made the "clangover" attack (CVE-2024-37880) possible, giving key recovery
765
 * for ML-KEM-512 in minutes, provided the attacker has access to precise
766
 * timing of a CPU performing chosen-ciphertext decap.  Admittedly this is only
767
 * a risk when private keys are reused (perhaps KEMTLS servers).
768
 */
769
static void
770
scalar_decode_decompress_add(scalar *out, const uint8_t in[DEGREE / 8])
771
0
{
772
0
    static const uint16_t half_q_plus_1 = (ML_KEM_PRIME >> 1) + 1;
773
0
    uint16_t *curr = out->c, *end = curr + DEGREE;
774
0
    uint16_t mask;
775
0
    uint8_t b;
776
777
    /*
778
     * Add |half_q_plus_1| if the bit is set, without exposing a side-channel,
779
     * avoiding the "clangover" attack.  See |constish_time_non_zero| for a
780
     * discussion on why the value barrier is by default omitted.
781
     */
782
0
#define decode_decompress_add_bit                               \
783
0
        mask = constish_time_non_zero(bit0(b));                 \
784
0
        *curr = reduce_once(*curr + (mask & half_q_plus_1));    \
785
0
        curr++;                                                 \
786
0
        b >>= 1
787
788
    /* Unrolled to process each byte in one iteration */
789
0
    do {
790
0
        b = *in++;
791
0
        decode_decompress_add_bit;
792
0
        decode_decompress_add_bit;
793
0
        decode_decompress_add_bit;
794
0
        decode_decompress_add_bit;
795
796
0
        decode_decompress_add_bit;
797
0
        decode_decompress_add_bit;
798
0
        decode_decompress_add_bit;
799
0
        decode_decompress_add_bit;
800
0
    } while (curr < end);
801
0
#undef decode_decompress_add_bit
802
0
}
803
804
/*
805
 * FIPS 203, Section 4.2.1, Equation (4.7): Compress_d.
806
 *
807
 * Compresses (lossily) an input |x| mod 3329 into |bits| many bits by grouping
808
 * numbers close to each other together. The formula used is
809
 * round(2^|bits|/kPrime*x) mod 2^|bits|.
810
 * Uses Barrett reduction to achieve constant time. Since we need both the
811
 * remainder (for rounding) and the quotient (as the result), we cannot use
812
 * |reduce| here, but need to do the Barrett reduction directly.
813
 */
814
static __owur uint16_t compress(uint16_t x, int bits)
815
0
{
816
0
    uint32_t shifted = (uint32_t)x << bits;
817
0
    uint64_t product = (uint64_t)shifted * kBarrettMultiplier;
818
0
    uint32_t quotient = (uint32_t)(product >> kBarrettShift);
819
0
    uint32_t remainder = shifted - quotient * kPrime;
820
821
    /*
822
     * Adjust the quotient to round correctly:
823
     *   0 <= remainder <= kHalfPrime round to 0
824
     *   kHalfPrime < remainder <= kPrime + kHalfPrime round to 1
825
     *   kPrime + kHalfPrime < remainder < 2 * kPrime round to 2
826
     */
827
0
    quotient += 1 & constant_time_lt_32(kHalfPrime, remainder);
828
0
    quotient += 1 & constant_time_lt_32(kPrime + kHalfPrime, remainder);
829
0
    return quotient & ((1 << bits) - 1);
830
0
}
831
832
/*
833
 * FIPS 203, Section 4.2.1, Equation (4.8): Decompress_d.
834
835
 * Decompresses |x| by using a close equi-distant representative. The formula
836
 * is round(kPrime/2^|bits|*x). Note that 2^|bits| being the divisor allows us
837
 * to implement this logic using only bit operations.
838
 */
839
static __owur uint16_t decompress(uint16_t x, int bits)
840
0
{
841
0
    uint32_t product = (uint32_t)x * kPrime;
842
0
    uint32_t power = 1 << bits;
843
    /* This is |product| % power, since |power| is a power of 2. */
844
0
    uint32_t remainder = product & (power - 1);
845
    /* This is |product| / power, since |power| is a power of 2. */
846
0
    uint32_t lower = product >> bits;
847
848
    /*
849
     * The rounding logic works since the first half of numbers mod |power|
850
     * have a 0 as first bit, and the second half has a 1 as first bit, since
851
     * |power| is a power of 2. As a 12 bit number, |remainder| is always
852
     * positive, so we will shift in 0s for a right shift.
853
     */
854
0
    return lower + (remainder >> (bits - 1));
855
0
}
856
857
/*-
858
 * FIPS 203, Section 4.2.1, Equation (4.7): "Compress_d".
859
 * In-place lossy rounding of scalars to 2^d bits.
860
 */
861
static void scalar_compress(scalar *s, int bits)
862
0
{
863
0
    int i;
864
865
0
    for (i = 0; i < DEGREE; i++)
866
0
        s->c[i] = compress(s->c[i], bits);
867
0
}
868
869
/*
870
 * FIPS 203, Section 4.2.1, Equation (4.8): "Decompress_d".
871
 * In-place approximate recovery of scalars from 2^d bit compression.
872
 */
873
static void scalar_decompress(scalar *s, int bits)
874
0
{
875
0
    int i;
876
877
0
    for (i = 0; i < DEGREE; i++)
878
0
        s->c[i] = decompress(s->c[i], bits);
879
0
}
880
881
/* Addition updating the LHS vector in-place. */
882
static void vector_add(scalar *lhs, const scalar *rhs, int rank)
883
0
{
884
0
    do {
885
0
        scalar_add(lhs++, rhs++);
886
0
    } while (--rank > 0);
887
0
}
888
889
/*
890
 * Encodes an entire vector into 32*|rank|*|bits| bytes. Note that since 256
891
 * (DEGREE) is divisible by 8, the individual vector entries will always fill a
892
 * whole number of bytes, so we do not need to worry about bit packing here.
893
 */
894
static void vector_encode(uint8_t *out, const scalar *a, int bits, int rank)
895
0
{
896
0
    int stride = bits * DEGREE / 8;
897
898
0
    for (; rank-- > 0; out += stride)
899
0
        scalar_encode(out, a++, bits);
900
0
}
901
902
/*
903
 * Decodes 32*|rank|*|bits| bytes from |in| into |out|. It returns early
904
 * if any parsed value is >= |ML_KEM_PRIME|.  The resulting scalars are
905
 * then decompressed and transformed via the NTT.
906
 *
907
 * Note: Used only in decrypt_cpa(), which returns void and so does not check
908
 * the return value of this function.  Side-channels are fine when the input
909
 * ciphertext to decap() is simply syntactically invalid.
910
 */
911
static void
912
vector_decode_decompress_ntt(scalar *out, const uint8_t *in, int bits, int rank)
913
0
{
914
0
    int stride = bits * DEGREE / 8;
915
916
0
    for (; rank-- > 0; in += stride, ++out) {
917
0
        scalar_decode(out, in, bits);
918
0
        scalar_decompress(out, bits);
919
0
        scalar_ntt(out);
920
0
    }
921
0
}
922
923
/* vector_decode(), specialised to bits == 12. */
924
static __owur
925
int vector_decode_12(scalar *out, const uint8_t in[3 * DEGREE / 2], int rank)
926
0
{
927
0
    int stride = 3 * DEGREE / 2;
928
929
0
    for (; rank-- > 0; in += stride)
930
0
        if (!scalar_decode_12(out++, in))
931
0
            return 0;
932
0
    return 1;
933
0
}
934
935
/* In-place compression of each scalar component */
936
static void vector_compress(scalar *a, int bits, int rank)
937
0
{
938
0
    do {
939
0
        scalar_compress(a++, bits);
940
0
    } while (--rank > 0);
941
0
}
942
943
/* The output scalar must not overlap with the inputs */
944
static void inner_product(scalar *out, const scalar *lhs, const scalar *rhs,
945
                          int rank)
946
0
{
947
0
    scalar_mult(out, lhs, rhs);
948
0
    while (--rank > 0)
949
0
        scalar_mult_add(out, ++lhs, ++rhs);
950
0
}
951
952
/*
953
 * Here, the output vector must not overlap with the inputs, the result is
954
 * directly subjected to inverse NTT.
955
 */
956
static void
957
matrix_mult_intt(scalar *out, const scalar *m, const scalar *a, int rank)
958
0
{
959
0
    const scalar *ar;
960
0
    int i, j;
961
962
0
    for (i = rank; i-- > 0; ++out) {
963
0
        scalar_mult(out, m++, ar = a);
964
0
        for (j = rank - 1; j > 0; --j)
965
0
            scalar_mult_add(out, m++, ++ar);
966
0
        scalar_inverse_ntt(out);
967
0
    }
968
0
}
969
970
/* Here, the output vector must not overlap with the inputs */
971
static void
972
matrix_mult_transpose_add(scalar *out, const scalar *m, const scalar *a, int rank)
973
0
{
974
0
    const scalar *mc = m, *mr, *ar;
975
0
    int i, j;
976
977
0
    for (i = rank; i-- > 0; ++out) {
978
0
        scalar_mult_add(out, mr = mc++, ar = a);
979
0
        for (j = rank; --j > 0; )
980
0
            scalar_mult_add(out, (mr += rank), ++ar);
981
0
    }
982
0
}
983
984
/*-
985
 * Expands the matrix from a seed for key generation and for encaps-CPA.
986
 * NOTE: FIPS 203 matrix "A" is the transpose of this matrix, computed
987
 * by appending the (i,j) indices to the seed in the opposite order!
988
 *
989
 * Where FIPS 203 computes t = A * s + e, we use the transpose of "m".
990
 */
991
static __owur
992
int matrix_expand(EVP_MD_CTX *mdctx, ML_KEM_KEY *key)
993
0
{
994
0
    scalar *out = key->m;
995
0
    uint8_t input[ML_KEM_RANDOM_BYTES + 2];
996
0
    int rank = key->vinfo->rank;
997
0
    int i, j;
998
999
0
    memcpy(input, key->rho, ML_KEM_RANDOM_BYTES);
1000
0
    for (i = 0; i < rank; i++) {
1001
0
        for (j = 0; j < rank; j++) {
1002
0
            input[ML_KEM_RANDOM_BYTES] = i;
1003
0
            input[ML_KEM_RANDOM_BYTES + 1] = j;
1004
0
            if (!EVP_DigestInit_ex(mdctx, key->shake128_md, NULL)
1005
0
                || !EVP_DigestUpdate(mdctx, input, sizeof(input))
1006
0
                || !sample_scalar(out++, mdctx))
1007
0
                return 0;
1008
0
        }
1009
0
    }
1010
0
    return 1;
1011
0
}
1012
1013
/*
1014
 * Algorithm 7 from the spec, with eta fixed to two and the PRF call
1015
 * included. Creates binominally distributed elements by sampling 2*|eta| bits,
1016
 * and setting the coefficient to the count of the first bits minus the count of
1017
 * the second bits, resulting in a centered binomial distribution. Since eta is
1018
 * two this gives -2/2 with a probability of 1/16, -1/1 with probability 1/4,
1019
 * and 0 with probability 3/8.
1020
 */
1021
static __owur
1022
int cbd_2(scalar *out, uint8_t in[ML_KEM_RANDOM_BYTES + 1],
1023
          EVP_MD_CTX *mdctx, const ML_KEM_KEY *key)
1024
0
{
1025
0
    uint16_t *curr = out->c, *end = curr + DEGREE;
1026
0
    uint8_t randbuf[4 * DEGREE / 8], *r = randbuf;  /* 64 * eta slots */
1027
0
    uint16_t value, mask;
1028
0
    uint8_t b;
1029
1030
0
    if (!prf(randbuf, sizeof(randbuf), in, mdctx, key))
1031
0
        return 0;
1032
1033
0
    do {
1034
0
        b = *r++;
1035
1036
        /*
1037
         * Add |kPrime| if |value| underflowed.  See |constish_time_non_zero|
1038
         * for a discussion on why the value barrier is by default omitted.
1039
         * While this could have been written reduce_once(value + kPrime), this
1040
         * is one extra addition and small range of |value| tempts some
1041
         * versions of Clang to emit a branch.
1042
         */
1043
0
        value = bit0(b) + bitn(1, b);
1044
0
        value -= bitn(2, b) + bitn(3, b);
1045
0
        mask = constish_time_non_zero(value >> 15);
1046
0
        *curr++ = value + (kPrime & mask);
1047
1048
0
        value = bitn(4, b) + bitn(5, b);
1049
0
        value -= bitn(6, b) + bitn(7, b);
1050
0
        mask = constish_time_non_zero(value >> 15);
1051
0
        *curr++ = value + (kPrime & mask);
1052
0
    } while (curr < end);
1053
0
    return 1;
1054
0
}
1055
1056
/*
1057
 * Algorithm 7 from the spec, with eta fixed to three and the PRF call
1058
 * included. Creates binominally distributed elements by sampling 3*|eta| bits,
1059
 * and setting the coefficient to the count of the first bits minus the count of
1060
 * the second bits, resulting in a centered binomial distribution.
1061
 */
1062
static __owur
1063
int cbd_3(scalar *out, uint8_t in[ML_KEM_RANDOM_BYTES + 1],
1064
          EVP_MD_CTX *mdctx, const ML_KEM_KEY *key)
1065
0
{
1066
0
    uint16_t *curr = out->c, *end = curr + DEGREE;
1067
0
    uint8_t randbuf[6 * DEGREE / 8], *r = randbuf;  /* 64 * eta slots */
1068
0
    uint8_t b1, b2, b3;
1069
0
    uint16_t value, mask;
1070
1071
0
    if (!prf(randbuf, sizeof(randbuf), in, mdctx, key))
1072
0
        return 0;
1073
1074
0
    do {
1075
0
        b1 = *r++;
1076
0
        b2 = *r++;
1077
0
        b3 = *r++;
1078
1079
        /*
1080
         * Add |kPrime| if |value| underflowed.  See |constish_time_non_zero|
1081
         * for a discussion on why the value barrier is by default omitted.
1082
         * While this could have been written reduce_once(value + kPrime), this
1083
         * is one extra addition and small range of |value| tempts some
1084
         * versions of Clang to emit a branch.
1085
         */
1086
0
        value = bit0(b1) + bitn(1, b1) + bitn(2, b1);
1087
0
        value -= bitn(3, b1)  + bitn(4, b1) + bitn(5, b1);
1088
0
        mask = constish_time_non_zero(value >> 15);
1089
0
        *curr++ = value + (kPrime & mask);
1090
1091
0
        value = bitn(6, b1) + bitn(7, b1) + bit0(b2);
1092
0
        value -= bitn(1, b2) + bitn(2, b2) + bitn(3, b2);
1093
0
        mask = constish_time_non_zero(value >> 15);
1094
0
        *curr++ = value + (kPrime & mask);
1095
1096
0
        value = bitn(4, b2) + bitn(5, b2) + bitn(6, b2);
1097
0
        value -= bitn(7, b2) + bit0(b3) + bitn(1, b3);
1098
0
        mask = constish_time_non_zero(value >> 15);
1099
0
        *curr++ = value + (kPrime & mask);
1100
1101
0
        value = bitn(2, b3) + bitn(3, b3) + bitn(4, b3);
1102
0
        value -= bitn(5, b3) + bitn(6, b3) + bitn(7, b3);
1103
0
        mask = constish_time_non_zero(value >> 15);
1104
0
        *curr++ = value + (kPrime & mask);
1105
0
    } while (curr < end);
1106
0
    return 1;
1107
0
}
1108
1109
/*
1110
 * Generates a secret vector by using |cbd| with the given seed to generate
1111
 * scalar elements and incrementing |counter| for each slot of the vector.
1112
 */
1113
static __owur
1114
int gencbd_vector(scalar *out, CBD_FUNC cbd, uint8_t *counter,
1115
                  const uint8_t seed[ML_KEM_RANDOM_BYTES], int rank,
1116
                  EVP_MD_CTX *mdctx, const ML_KEM_KEY *key)
1117
0
{
1118
0
    uint8_t input[ML_KEM_RANDOM_BYTES + 1];
1119
1120
0
    memcpy(input, seed, ML_KEM_RANDOM_BYTES);
1121
0
    do {
1122
0
        input[ML_KEM_RANDOM_BYTES] = (*counter)++;
1123
0
        if (!cbd(out++, input, mdctx, key))
1124
0
            return 0;
1125
0
    } while (--rank > 0);
1126
0
    return 1;
1127
0
}
1128
1129
/*
1130
 * As above plus NTT transform.
1131
 */
1132
static __owur
1133
int gencbd_vector_ntt(scalar *out, CBD_FUNC cbd, uint8_t *counter,
1134
                      const uint8_t seed[ML_KEM_RANDOM_BYTES], int rank,
1135
                      EVP_MD_CTX *mdctx, const ML_KEM_KEY *key)
1136
0
{
1137
0
    uint8_t input[ML_KEM_RANDOM_BYTES + 1];
1138
1139
0
    memcpy(input, seed, ML_KEM_RANDOM_BYTES);
1140
0
    do {
1141
0
        input[ML_KEM_RANDOM_BYTES] = (*counter)++;
1142
0
        if (!cbd(out, input, mdctx, key))
1143
0
            return 0;
1144
0
        scalar_ntt(out++);
1145
0
    } while (--rank > 0);
1146
0
    return 1;
1147
0
}
1148
1149
/* The |ETA1| value for ML-KEM-512 is 3, the rest and all ETA2 values are 2. */
1150
0
#define CBD1(evp_type)  ((evp_type) == EVP_PKEY_ML_KEM_512 ? cbd_3 : cbd_2)
1151
1152
/*
1153
 * FIPS 203, Section 5.2, Algorithm 14: K-PKE.Encrypt.
1154
 *
1155
 * Encrypts a message with given randomness to the ciphertext in |out|. Without
1156
 * applying the Fujisaki-Okamoto transform this would not result in a CCA
1157
 * secure scheme, since lattice schemes are vulnerable to decryption failure
1158
 * oracles.
1159
 *
1160
 * The steps are re-ordered to make more efficient/localised use of storage.
1161
 *
1162
 * Note also that the input public key is assumed to hold a precomputed matrix
1163
 * |A| (our key->m, with the public key holding an expanded (16-bit per scalar
1164
 * coefficient) key->t vector).
1165
 *
1166
 * Caller passes storage in |tmp| for for two temporary vectors.
1167
 */
1168
static __owur
1169
int encrypt_cpa(uint8_t out[ML_KEM_SHARED_SECRET_BYTES],
1170
                const uint8_t message[DEGREE / 8],
1171
                const uint8_t r[ML_KEM_RANDOM_BYTES], scalar *tmp,
1172
                EVP_MD_CTX *mdctx, const ML_KEM_KEY *key)
1173
0
{
1174
0
    const ML_KEM_VINFO *vinfo = key->vinfo;
1175
0
    CBD_FUNC cbd_1 = CBD1(vinfo->evp_type);
1176
0
    int rank = vinfo->rank;
1177
    /* We can use tmp[0..rank-1] as storage for |y|, then |e1|, ... */
1178
0
    scalar *y = &tmp[0], *e1 = y, *e2 = y;
1179
    /* We can use tmp[rank]..tmp[2*rank - 1] for |u| */
1180
0
    scalar *u = &tmp[rank];
1181
0
    scalar v;
1182
0
    uint8_t input[ML_KEM_RANDOM_BYTES + 1];
1183
0
    uint8_t counter = 0;
1184
0
    int du = vinfo->du;
1185
0
    int dv = vinfo->dv;
1186
1187
    /* FIPS 203 "y" vector */
1188
0
    if (!gencbd_vector_ntt(y, cbd_1, &counter, r, rank, mdctx, key))
1189
0
        return 0;
1190
    /* FIPS 203 "v" scalar */
1191
0
    inner_product(&v, key->t, y, rank);
1192
0
    scalar_inverse_ntt(&v);
1193
    /* FIPS 203 "u" vector */
1194
0
    matrix_mult_intt(u, key->m, y, rank);
1195
1196
    /* All done with |y|, now free to reuse tmp[0] for FIPS 203 |e1| */
1197
0
    if (!gencbd_vector(e1, cbd_2, &counter, r, rank, mdctx, key))
1198
0
        return 0;
1199
0
    vector_add(u, e1, rank);
1200
0
    vector_compress(u, du, rank);
1201
0
    vector_encode(out, u, du, rank);
1202
1203
    /* All done with |e1|, now free to reuse tmp[0] for FIPS 203 |e2| */
1204
0
    memcpy(input, r, ML_KEM_RANDOM_BYTES);
1205
0
    input[ML_KEM_RANDOM_BYTES] = counter;
1206
0
    if (!cbd_2(e2, input, mdctx, key))
1207
0
        return 0;
1208
0
    scalar_add(&v, e2);
1209
1210
    /* Combine message with |v| */
1211
0
    scalar_decode_decompress_add(&v, message);
1212
0
    scalar_compress(&v, dv);
1213
0
    scalar_encode(out + vinfo->u_vector_bytes, &v, dv);
1214
0
    return 1;
1215
0
}
1216
1217
/*
1218
 * FIPS 203, Section 5.3, Algorithm 15: K-PKE.Decrypt.
1219
 */
1220
static void
1221
decrypt_cpa(uint8_t out[ML_KEM_SHARED_SECRET_BYTES],
1222
            const uint8_t *ctext, scalar *u, const ML_KEM_KEY *key)
1223
0
{
1224
0
    const ML_KEM_VINFO *vinfo = key->vinfo;
1225
0
    scalar v, mask;
1226
0
    int rank = vinfo->rank;
1227
0
    int du = vinfo->du;
1228
0
    int dv = vinfo->dv;
1229
1230
0
    vector_decode_decompress_ntt(u, ctext, du, rank);
1231
0
    scalar_decode(&v, ctext + vinfo->u_vector_bytes, dv);
1232
0
    scalar_decompress(&v, dv);
1233
0
    inner_product(&mask, key->s, u, rank);
1234
0
    scalar_inverse_ntt(&mask);
1235
0
    scalar_sub(&v, &mask);
1236
0
    scalar_compress(&v, 1);
1237
0
    scalar_encode_1(out, &v);
1238
0
}
1239
1240
/*-
1241
 * FIPS 203, Section 7.1, Algorithm 19: "ML-KEM.KeyGen".
1242
 * FIPS 203, Section 7.2, Algorithm 20: "ML-KEM.Encaps".
1243
 *
1244
 * Fills the |out| buffer with the |ek| output of "ML-KEM.KeyGen", or,
1245
 * equivalently, the |ek| input of "ML-KEM.Encaps", i.e. returns the
1246
 * wire-format of an ML-KEM public key.
1247
 */
1248
static void encode_pubkey(uint8_t *out, const ML_KEM_KEY *key)
1249
0
{
1250
0
    const uint8_t *rho = key->rho;
1251
0
    const ML_KEM_VINFO *vinfo = key->vinfo;
1252
1253
0
    vector_encode(out, key->t, 12, vinfo->rank);
1254
0
    memcpy(out + vinfo->vector_bytes, rho, ML_KEM_RANDOM_BYTES);
1255
0
}
1256
1257
/*-
1258
 * FIPS 203, Section 7.1, Algorithm 19: "ML-KEM.KeyGen".
1259
 *
1260
 * Fills the |out| buffer with the |dk| output of "ML-KEM.KeyGen".
1261
 * This matches the input format of parse_prvkey() below.
1262
 */
1263
static void encode_prvkey(uint8_t *out, const ML_KEM_KEY *key)
1264
0
{
1265
0
    const ML_KEM_VINFO *vinfo = key->vinfo;
1266
1267
0
    vector_encode(out, key->s, 12, vinfo->rank);
1268
0
    out += vinfo->vector_bytes;
1269
0
    encode_pubkey(out, key);
1270
0
    out += vinfo->pubkey_bytes;
1271
0
    memcpy(out, key->pkhash, ML_KEM_PKHASH_BYTES);
1272
0
    out += ML_KEM_PKHASH_BYTES;
1273
0
    memcpy(out, key->z, ML_KEM_RANDOM_BYTES);
1274
0
}
1275
1276
/*-
1277
 * FIPS 203, Section 7.1, Algorithm 19: "ML-KEM.KeyGen".
1278
 * FIPS 203, Section 7.2, Algorithm 20: "ML-KEM.Encaps".
1279
 *
1280
 * This function parses the |in| buffer as the |ek| output of "ML-KEM.KeyGen",
1281
 * or, equivalently, the |ek| input of "ML-KEM.Encaps", i.e. decodes the
1282
 * wire-format of the ML-KEM public key.
1283
 */
1284
static int parse_pubkey(const uint8_t *in, EVP_MD_CTX *mdctx, ML_KEM_KEY *key)
1285
0
{
1286
0
    const ML_KEM_VINFO *vinfo = key->vinfo;
1287
1288
    /* Decode and check |t| */
1289
0
    if (!vector_decode_12(key->t, in, vinfo->rank)) {
1290
0
        ERR_raise_data(ERR_LIB_PROV, PROV_R_INVALID_KEY,
1291
0
                       "%s invalid public 't' vector",
1292
0
                       vinfo->algorithm_name);
1293
0
        return 0;
1294
0
    }
1295
    /* Save the matrix |m| recovery seed |rho| */
1296
0
    memcpy(key->rho, in + vinfo->vector_bytes, ML_KEM_RANDOM_BYTES);
1297
    /*
1298
     * Pre-compute the public key hash, needed for both encap and decap.
1299
     * Also pre-compute the matrix expansion, stored with the public key.
1300
     */
1301
0
    if (!hash_h(key->pkhash, in, vinfo->pubkey_bytes, mdctx, key)
1302
0
        || !matrix_expand(mdctx, key)) {
1303
0
        ERR_raise_data(ERR_LIB_CRYPTO, ERR_R_INTERNAL_ERROR,
1304
0
                       "internal error while parsing %s public key",
1305
0
                       vinfo->algorithm_name);
1306
0
        return 0;
1307
0
    }
1308
0
    return 1;
1309
0
}
1310
1311
/*
1312
 * FIPS 203, Section 7.1, Algorithm 19: "ML-KEM.KeyGen".
1313
 *
1314
 * Parses the |in| buffer as a |dk| output of "ML-KEM.KeyGen".
1315
 * This matches the output format of encode_prvkey() above.
1316
 */
1317
static int parse_prvkey(const uint8_t *in, EVP_MD_CTX *mdctx, ML_KEM_KEY *key)
1318
0
{
1319
0
    const ML_KEM_VINFO *vinfo = key->vinfo;
1320
1321
    /* Decode and check |s|. */
1322
0
    if (!vector_decode_12(key->s, in, vinfo->rank)) {
1323
0
        ERR_raise_data(ERR_LIB_PROV, PROV_R_INVALID_KEY,
1324
0
                       "%s invalid private 's' vector",
1325
0
                       vinfo->algorithm_name);
1326
0
        return 0;
1327
0
    }
1328
0
    in += vinfo->vector_bytes;
1329
1330
0
    if (!parse_pubkey(in, mdctx, key))
1331
0
        return 0;
1332
0
    in += vinfo->pubkey_bytes;
1333
1334
    /* Check public key hash. */
1335
0
    if (memcmp(key->pkhash, in, ML_KEM_PKHASH_BYTES) != 0) {
1336
0
        ERR_raise_data(ERR_LIB_PROV, PROV_R_INVALID_KEY,
1337
0
                       "%s public key hash mismatch",
1338
0
                       vinfo->algorithm_name);
1339
0
        return 0;
1340
0
    }
1341
0
    in += ML_KEM_PKHASH_BYTES;
1342
1343
0
    memcpy(key->z, in, ML_KEM_RANDOM_BYTES);
1344
0
    return 1;
1345
0
}
1346
1347
/*
1348
 * FIPS 203, Section 6.1, Algorithm 16: "ML-KEM.KeyGen_internal".
1349
 *
1350
 * The implementation of Section 5.1, Algorithm 13, "K-PKE.KeyGen(d)" is
1351
 * inlined.
1352
 *
1353
 * The caller MUST pass a pre-allocated digest context that is not shared with
1354
 * any concurrent computation.
1355
 *
1356
 * This function optionally outputs the serialised wire-form |ek| public key
1357
 * into the provided |pubenc| buffer, and generates the content of the |rho|,
1358
 * |pkhash|, |t|, |m|, |s| and |z| components of the private |key| (which must
1359
 * have preallocated space for these).
1360
 *
1361
 * Keys are computed from a 32-byte random |d| plus the 1 byte rank for
1362
 * domain separation.  These are concatenated and hashed to produce a pair of
1363
 * 32-byte seeds public "rho", used to generate the matrix, and private "sigma",
1364
 * used to generate the secret vector |s|.
1365
 *
1366
 * The second random input |z| is copied verbatim into the Fujisaki-Okamoto
1367
 * (FO) transform "implicit-rejection" secret (the |z| component of the private
1368
 * key), which thwarts chosen-ciphertext attacks, provided decap() runs in
1369
 * constant time, with no side channel leaks, on all well-formed (valid length,
1370
 * and correctly encoded) ciphertext inputs.
1371
 */
1372
static __owur
1373
int genkey(const uint8_t seed[ML_KEM_SEED_BYTES],
1374
           EVP_MD_CTX *mdctx, uint8_t *pubenc, ML_KEM_KEY *key)
1375
0
{
1376
0
    uint8_t hashed[2 * ML_KEM_RANDOM_BYTES];
1377
0
    const uint8_t *const sigma = hashed + ML_KEM_RANDOM_BYTES;
1378
0
    uint8_t augmented_seed[ML_KEM_RANDOM_BYTES + 1];
1379
0
    const ML_KEM_VINFO *vinfo = key->vinfo;
1380
0
    CBD_FUNC cbd_1 = CBD1(vinfo->evp_type);
1381
0
    int rank = vinfo->rank;
1382
0
    uint8_t counter = 0;
1383
0
    int ret = 0;
1384
1385
    /*
1386
     * Use the "d" seed salted with the rank to derive the public and private
1387
     * seeds rho and sigma.
1388
     */
1389
0
    memcpy(augmented_seed, seed, ML_KEM_RANDOM_BYTES);
1390
0
    augmented_seed[ML_KEM_RANDOM_BYTES] = (uint8_t) rank;
1391
0
    if (!hash_g(hashed, augmented_seed, sizeof(augmented_seed), mdctx, key))
1392
0
        goto end;
1393
0
    memcpy(key->rho, hashed, ML_KEM_RANDOM_BYTES);
1394
    /* The |rho| matrix seed is public */
1395
0
    CONSTTIME_DECLASSIFY(key->rho, ML_KEM_RANDOM_BYTES);
1396
1397
    /* FIPS 203 |e| vector is initial value of key->t */
1398
0
    if (!matrix_expand(mdctx, key)
1399
0
        || !gencbd_vector_ntt(key->s, cbd_1, &counter, sigma, rank, mdctx, key)
1400
0
        || !gencbd_vector_ntt(key->t, cbd_1, &counter, sigma, rank, mdctx, key))
1401
0
        goto end;
1402
1403
    /* To |e| we now add the product of transpose |m| and |s|, giving |t|. */
1404
0
    matrix_mult_transpose_add(key->t, key->m, key->s, rank);
1405
    /* The |t| vector is public */
1406
0
    CONSTTIME_DECLASSIFY(key->t, vinfo->rank * sizeof(scalar));
1407
1408
0
    if (pubenc == NULL) {
1409
        /* Incremental digest of public key without in-full serialisation. */
1410
0
        if (!hash_h_pubkey(key->pkhash, mdctx, key))
1411
0
            goto end;
1412
0
    } else {
1413
0
        encode_pubkey(pubenc, key);
1414
0
        if (!hash_h(key->pkhash, pubenc, vinfo->pubkey_bytes, mdctx, key))
1415
0
            goto end;
1416
0
    }
1417
1418
    /* Save |z| portion of seed for "implicit rejection" on failure. */
1419
0
    memcpy(key->z, seed + ML_KEM_RANDOM_BYTES, ML_KEM_RANDOM_BYTES);
1420
1421
    /* Optionally save the |d| portion of the seed */
1422
0
    key->d = key->z + ML_KEM_RANDOM_BYTES;
1423
0
    if (key->prov_flags & ML_KEM_KEY_RETAIN_SEED) {
1424
0
        memcpy(key->d, seed, ML_KEM_RANDOM_BYTES);
1425
0
    } else {
1426
0
        OPENSSL_cleanse(key->d, ML_KEM_RANDOM_BYTES);
1427
0
        key->d = NULL;
1428
0
    }
1429
1430
0
    ret = 1;
1431
0
 end:
1432
0
    OPENSSL_cleanse((void *)augmented_seed, ML_KEM_RANDOM_BYTES);
1433
0
    OPENSSL_cleanse((void *)sigma, ML_KEM_RANDOM_BYTES);
1434
0
    if (ret == 0) {
1435
0
        ERR_raise_data(ERR_LIB_CRYPTO, ERR_R_INTERNAL_ERROR,
1436
0
                       "internal error while generating %s private key",
1437
0
                       vinfo->algorithm_name);
1438
0
    }
1439
0
    return ret;
1440
0
}
1441
1442
/*-
1443
 * FIPS 203, Section 6.2, Algorithm 17: "ML-KEM.Encaps_internal".
1444
 * This is the deterministic version with randomness supplied externally.
1445
 *
1446
 * The caller must pass space for two vectors in |tmp|.
1447
 * The |ctext| buffer have space for the ciphertext of the ML-KEM variant
1448
 * of the provided key.
1449
 */
1450
static
1451
int encap(uint8_t *ctext, uint8_t secret[ML_KEM_SHARED_SECRET_BYTES],
1452
          const uint8_t entropy[ML_KEM_RANDOM_BYTES],
1453
          scalar *tmp, EVP_MD_CTX *mdctx, const ML_KEM_KEY *key)
1454
0
{
1455
0
    uint8_t input[ML_KEM_RANDOM_BYTES + ML_KEM_PKHASH_BYTES];
1456
0
    uint8_t Kr[ML_KEM_SHARED_SECRET_BYTES + ML_KEM_RANDOM_BYTES];
1457
0
    uint8_t *r = Kr + ML_KEM_SHARED_SECRET_BYTES;
1458
0
    int ret;
1459
1460
0
    memcpy(input, entropy, ML_KEM_RANDOM_BYTES);
1461
0
    memcpy(input + ML_KEM_RANDOM_BYTES, key->pkhash, ML_KEM_PKHASH_BYTES);
1462
0
    ret = hash_g(Kr, input, sizeof(input), mdctx, key)
1463
0
        && encrypt_cpa(ctext, entropy, r, tmp, mdctx, key);
1464
0
    OPENSSL_cleanse((void *)input, sizeof(input));
1465
1466
0
    if (ret)
1467
0
        memcpy(secret, Kr, ML_KEM_SHARED_SECRET_BYTES);
1468
0
    else
1469
0
        ERR_raise_data(ERR_LIB_CRYPTO, ERR_R_INTERNAL_ERROR,
1470
0
                       "internal error while performing %s encapsulation",
1471
0
                       key->vinfo->algorithm_name);
1472
0
    return ret;
1473
0
}
1474
1475
/*
1476
 * FIPS 203, Section 6.3, Algorithm 18: ML-KEM.Decaps_internal
1477
 *
1478
 * Barring failure of the supporting SHA3/SHAKE primitives, this is fully
1479
 * deterministic, the randomness for the FO transform is extracted during
1480
 * private key generation.
1481
 *
1482
 * The caller must pass space for two vectors in |tmp|.
1483
 * The |ctext| and |tmp_ctext| buffers must each have space for the ciphertext
1484
 * of the key's ML-KEM variant.
1485
 */
1486
static
1487
int decap(uint8_t secret[ML_KEM_SHARED_SECRET_BYTES],
1488
          const uint8_t *ctext, uint8_t *tmp_ctext, scalar *tmp,
1489
          EVP_MD_CTX *mdctx, const ML_KEM_KEY *key)
1490
0
{
1491
0
    uint8_t decrypted[ML_KEM_SHARED_SECRET_BYTES + ML_KEM_PKHASH_BYTES];
1492
0
    uint8_t failure_key[ML_KEM_RANDOM_BYTES];
1493
0
    uint8_t Kr[ML_KEM_SHARED_SECRET_BYTES + ML_KEM_RANDOM_BYTES];
1494
0
    uint8_t *r = Kr + ML_KEM_SHARED_SECRET_BYTES;
1495
0
    const uint8_t *pkhash = key->pkhash;
1496
0
    const ML_KEM_VINFO *vinfo = key->vinfo;
1497
0
    int i;
1498
0
    uint8_t mask;
1499
1500
    /*
1501
     * If our KDF is unavailable, fail early! Otherwise, keep going ignoring
1502
     * any further errors, returning success, and whatever we got for a shared
1503
     * secret.  The decrypt_cpa() function is just arithmetic on secret data,
1504
     * so should not be subject to failure that makes its output predictable.
1505
     *
1506
     * We guard against "should never happen" catastrophic failure of the
1507
     * "pure" function |hash_g| by overwriting the shared secret with the
1508
     * content of the failure key and returning early, if nevertheless hash_g
1509
     * fails.  This is not constant-time, but a failure of |hash_g| already
1510
     * implies loss of side-channel resistance.
1511
     *
1512
     * The same action is taken, if also |encrypt_cpa| should catastrophically
1513
     * fail, due to failure of the |PRF| underlying the CBD functions.
1514
     */
1515
0
    if (!kdf(failure_key, key->z, ctext, vinfo->ctext_bytes, mdctx, key)) {
1516
0
        ERR_raise_data(ERR_LIB_CRYPTO, ERR_R_INTERNAL_ERROR,
1517
0
                       "internal error while performing %s decapsulation",
1518
0
                       vinfo->algorithm_name);
1519
0
        return 0;
1520
0
    }
1521
0
    decrypt_cpa(decrypted, ctext, tmp, key);
1522
0
    memcpy(decrypted + ML_KEM_SHARED_SECRET_BYTES, pkhash, ML_KEM_PKHASH_BYTES);
1523
0
    if (!hash_g(Kr, decrypted, sizeof(decrypted), mdctx, key)
1524
0
        || !encrypt_cpa(tmp_ctext, decrypted, r, tmp, mdctx, key)) {
1525
0
        memcpy(secret, failure_key, ML_KEM_SHARED_SECRET_BYTES);
1526
0
        OPENSSL_cleanse(decrypted, ML_KEM_SHARED_SECRET_BYTES);
1527
0
        return 1;
1528
0
    }
1529
0
    mask = constant_time_eq_int_8(0,
1530
0
        CRYPTO_memcmp(ctext, tmp_ctext, vinfo->ctext_bytes));
1531
0
    for (i = 0; i < ML_KEM_SHARED_SECRET_BYTES; i++)
1532
0
        secret[i] = constant_time_select_8(mask, Kr[i], failure_key[i]);
1533
0
    OPENSSL_cleanse(decrypted, ML_KEM_SHARED_SECRET_BYTES);
1534
0
    OPENSSL_cleanse(Kr, sizeof(Kr));
1535
0
    return 1;
1536
0
}
1537
1538
/*
1539
 * After allocating storage for public or private key data, update the key
1540
 * component pointers to reference that storage.
1541
 *
1542
 * The caller should only store private data in `priv` *after* a successful
1543
 * (non-zero) return from this function.
1544
 */
1545
static __owur
1546
int add_storage(scalar *pub, scalar *priv, int private, ML_KEM_KEY *key)
1547
0
{
1548
0
    int rank = key->vinfo->rank;
1549
1550
0
    if (pub == NULL || (private && priv == NULL)) {
1551
        /*
1552
         * One of these could be allocated correctly. It is legal to call free with a NULL
1553
         * pointer, so always attempt to free both allocations here
1554
         */
1555
0
        OPENSSL_free(pub);
1556
0
        OPENSSL_secure_free(priv);
1557
0
        return 0;
1558
0
    }
1559
1560
    /*
1561
     * We're adding key material, set up rho and pkhash to point to the rho_pkhash buffer
1562
     */
1563
0
    memset(key->rho_pkhash, 0, sizeof(key->rho_pkhash));
1564
0
    key->rho = key->rho_pkhash;
1565
0
    key->pkhash = key->rho_pkhash + ML_KEM_RANDOM_BYTES;
1566
0
    key->d = key->z = NULL;
1567
1568
    /* A public key needs space for |t| and |m| */
1569
0
    key->m = (key->t = pub) + rank;
1570
1571
    /*
1572
     * A private key also needs space for |s| and |z|.
1573
     * The |z| buffer always includes additional space for |d|, but a key's |d|
1574
     * pointer is left NULL when parsed from the NIST format, which omits that
1575
     * information.  Only keys generated from a (d, z) seed pair will have a
1576
     * non-NULL |d| pointer.
1577
     */
1578
0
    if (private)
1579
0
        key->z = (uint8_t *)(rank + (key->s = priv));
1580
0
    return 1;
1581
0
}
1582
1583
/*
1584
 * After freeing the storage associated with a key that failed to be
1585
 * constructed, reset the internal pointers back to NULL.
1586
 */
1587
void
1588
ossl_ml_kem_key_reset(ML_KEM_KEY *key)
1589
0
{
1590
    /*
1591
     * seedbuf can be allocated and contain |z| and |d| if the key is
1592
     * being created from a private key encoding.  Similarly a pending
1593
     * serialised (encoded) private key may be queued up to load.
1594
     * Clear and free that data now.
1595
     */
1596
0
    if (key->seedbuf != NULL)
1597
0
        OPENSSL_secure_clear_free(key->seedbuf, ML_KEM_SEED_BYTES);
1598
0
    if (ossl_ml_kem_have_dkenc(key))
1599
0
        OPENSSL_secure_clear_free(key->encoded_dk, key->vinfo->prvkey_bytes);
1600
1601
    /*-
1602
     * Cleanse any sensitive data:
1603
     * - The private vector |s| is immediately followed by the FO failure
1604
     *   secret |z|, and seed |d|, we can cleanse all three in one call.
1605
     */
1606
0
    if (key->t != NULL) {
1607
0
        if (ossl_ml_kem_have_prvkey(key))
1608
0
            OPENSSL_secure_clear_free(key->s, key->vinfo->prvalloc);
1609
0
        OPENSSL_free(key->t);
1610
0
    }
1611
0
    key->d = key->z = key->seedbuf = key->encoded_dk =
1612
0
        (uint8_t *)(key->s = key->m = key->t = NULL);
1613
0
}
1614
1615
/*
1616
 * ----- API exported to the provider
1617
 *
1618
 * Parameters with an implicit fixed length in the internal static API of each
1619
 * variant have an explicit checked length argument at this layer.
1620
 */
1621
1622
/* Retrieve the parameters of one of the ML-KEM variants */
1623
const ML_KEM_VINFO *ossl_ml_kem_get_vinfo(int evp_type)
1624
0
{
1625
0
    switch (evp_type) {
1626
0
    case EVP_PKEY_ML_KEM_512:
1627
0
        return &vinfo_map[ML_KEM_512_VINFO];
1628
0
    case EVP_PKEY_ML_KEM_768:
1629
0
        return &vinfo_map[ML_KEM_768_VINFO];
1630
0
    case EVP_PKEY_ML_KEM_1024:
1631
0
        return &vinfo_map[ML_KEM_1024_VINFO];
1632
0
    }
1633
0
    return NULL;
1634
0
}
1635
1636
ML_KEM_KEY *ossl_ml_kem_key_new(OSSL_LIB_CTX *libctx, const char *properties,
1637
                                int evp_type)
1638
0
{
1639
0
    const ML_KEM_VINFO *vinfo = ossl_ml_kem_get_vinfo(evp_type);
1640
0
    ML_KEM_KEY *key;
1641
1642
0
    if (vinfo == NULL) {
1643
0
        ERR_raise_data(ERR_LIB_CRYPTO, ERR_R_PASSED_INVALID_ARGUMENT,
1644
0
                       "unsupported ML-KEM key type: %d", evp_type);
1645
0
        return NULL;
1646
0
    }
1647
1648
0
    if ((key = OPENSSL_malloc(sizeof(*key))) == NULL)
1649
0
        return NULL;
1650
1651
0
    key->vinfo = vinfo;
1652
0
    key->libctx = libctx;
1653
0
    key->prov_flags = ML_KEM_KEY_PROV_FLAGS_DEFAULT;
1654
0
    key->shake128_md = EVP_MD_fetch(libctx, "SHAKE128", properties);
1655
0
    key->shake256_md = EVP_MD_fetch(libctx, "SHAKE256", properties);
1656
0
    key->sha3_256_md = EVP_MD_fetch(libctx, "SHA3-256", properties);
1657
0
    key->sha3_512_md = EVP_MD_fetch(libctx, "SHA3-512", properties);
1658
0
    key->d = key->z = key->rho = key->pkhash = key->encoded_dk = key->seedbuf = NULL;
1659
0
    key->s = key->m = key->t = NULL;
1660
1661
0
    if (key->shake128_md != NULL
1662
0
        && key->shake256_md != NULL
1663
0
        && key->sha3_256_md != NULL
1664
0
        && key->sha3_512_md != NULL)
1665
0
        return key;
1666
1667
0
    ossl_ml_kem_key_free(key);
1668
0
    ERR_raise_data(ERR_LIB_CRYPTO, ERR_R_INTERNAL_ERROR,
1669
0
                   "missing SHA3 digest algorithms while creating %s key",
1670
0
                   vinfo->algorithm_name);
1671
0
    return NULL;
1672
0
}
1673
1674
ML_KEM_KEY *ossl_ml_kem_key_dup(const ML_KEM_KEY *key, int selection)
1675
0
{
1676
0
    int ok = 0;
1677
0
    ML_KEM_KEY *ret;
1678
0
    void *tmp_pub;
1679
0
    void *tmp_priv;
1680
1681
0
    if (key == NULL)
1682
0
        return NULL;
1683
    /*
1684
     * Partially decoded keys, not yet imported or loaded, should never be
1685
     * duplicated.
1686
     */
1687
0
    if (ossl_ml_kem_decoded_key(key))
1688
0
        return NULL;
1689
1690
0
    else if ((ret = OPENSSL_memdup(key, sizeof(*key))) == NULL)
1691
0
        return NULL;
1692
1693
0
    ret->d = ret->z = ret->rho = ret->pkhash = NULL;
1694
0
    ret->s = ret->m = ret->t = NULL;
1695
1696
    /* Clear selection bits we can't fulfill */
1697
0
    if (!ossl_ml_kem_have_pubkey(key))
1698
0
        selection = 0;
1699
0
    else if (!ossl_ml_kem_have_prvkey(key))
1700
0
        selection &= ~OSSL_KEYMGMT_SELECT_PRIVATE_KEY;
1701
1702
0
    switch (selection & OSSL_KEYMGMT_SELECT_KEYPAIR) {
1703
0
    case 0:
1704
0
        ok = 1;
1705
0
        break;
1706
0
    case OSSL_KEYMGMT_SELECT_PUBLIC_KEY:
1707
0
        ok = add_storage(OPENSSL_memdup(key->t, key->vinfo->puballoc), NULL, 0, ret);
1708
0
        break;
1709
0
    case OSSL_KEYMGMT_SELECT_PRIVATE_KEY:
1710
0
        tmp_pub = OPENSSL_memdup(key->t, key->vinfo->puballoc);
1711
0
        if (tmp_pub == NULL)
1712
0
            break;
1713
0
        tmp_priv = OPENSSL_secure_malloc(key->vinfo->prvalloc);
1714
0
        if (tmp_priv == NULL) {
1715
0
            OPENSSL_free(tmp_pub);
1716
0
            break;
1717
0
        }
1718
0
        if ((ok = add_storage(tmp_pub, tmp_priv, 1, ret)) != 0)
1719
0
            memcpy(tmp_priv, key->s, key->vinfo->prvalloc);
1720
        /* Duplicated keys retain |d|, if available */
1721
0
        if (key->d != NULL)
1722
0
            ret->d = ret->z + ML_KEM_RANDOM_BYTES;
1723
0
        break;
1724
0
    }
1725
1726
0
    if (!ok) {
1727
0
        OPENSSL_free(ret);
1728
0
        return NULL;
1729
0
    }
1730
1731
0
    EVP_MD_up_ref(ret->shake128_md);
1732
0
    EVP_MD_up_ref(ret->shake256_md);
1733
0
    EVP_MD_up_ref(ret->sha3_256_md);
1734
0
    EVP_MD_up_ref(ret->sha3_512_md);
1735
1736
0
    return ret;
1737
0
}
1738
1739
void ossl_ml_kem_key_free(ML_KEM_KEY *key)
1740
0
{
1741
0
    if (key == NULL)
1742
0
        return;
1743
1744
0
    EVP_MD_free(key->shake128_md);
1745
0
    EVP_MD_free(key->shake256_md);
1746
0
    EVP_MD_free(key->sha3_256_md);
1747
0
    EVP_MD_free(key->sha3_512_md);
1748
1749
0
    ossl_ml_kem_key_reset(key);
1750
0
    OPENSSL_free(key);
1751
0
}
1752
1753
/* Serialise the public component of an ML-KEM key */
1754
int ossl_ml_kem_encode_public_key(uint8_t *out, size_t len,
1755
                                  const ML_KEM_KEY *key)
1756
0
{
1757
0
    if (!ossl_ml_kem_have_pubkey(key)
1758
0
        || len != key->vinfo->pubkey_bytes)
1759
0
        return 0;
1760
0
    encode_pubkey(out, key);
1761
0
    return 1;
1762
0
}
1763
1764
/* Serialise an ML-KEM private key */
1765
int ossl_ml_kem_encode_private_key(uint8_t *out, size_t len,
1766
                                   const ML_KEM_KEY *key)
1767
0
{
1768
0
    if (!ossl_ml_kem_have_prvkey(key)
1769
0
        || len != key->vinfo->prvkey_bytes)
1770
0
        return 0;
1771
0
    encode_prvkey(out, key);
1772
0
    return 1;
1773
0
}
1774
1775
int ossl_ml_kem_encode_seed(uint8_t *out, size_t len,
1776
                            const ML_KEM_KEY *key)
1777
0
{
1778
0
    if (key == NULL || key->d == NULL || len != ML_KEM_SEED_BYTES)
1779
0
        return 0;
1780
    /*
1781
     * Both in the seed buffer, and in the allocated storage, the |d| component
1782
     * of the seed is stored last, so we must copy each separately.
1783
     */
1784
0
    memcpy(out, key->d, ML_KEM_RANDOM_BYTES);
1785
0
    out += ML_KEM_RANDOM_BYTES;
1786
0
    memcpy(out, key->z, ML_KEM_RANDOM_BYTES);
1787
0
    return 1;
1788
0
}
1789
1790
/*
1791
 * Stash the seed without (yet) performing a keygen, used during decoding, to
1792
 * avoid an extra keygen if we're only going to export the key again to load
1793
 * into another provider.
1794
 */
1795
ML_KEM_KEY *ossl_ml_kem_set_seed(const uint8_t *seed, size_t seedlen, ML_KEM_KEY *key)
1796
0
{
1797
0
    if (key == NULL
1798
0
        || ossl_ml_kem_have_pubkey(key)
1799
0
        || ossl_ml_kem_have_seed(key)
1800
0
        || seedlen != ML_KEM_SEED_BYTES)
1801
0
        return NULL;
1802
1803
0
    if (key->seedbuf == NULL) {
1804
0
        key->seedbuf = OPENSSL_secure_malloc(seedlen);
1805
0
        if (key->seedbuf == NULL)
1806
0
            return NULL;
1807
0
    }
1808
1809
0
    key->z = key->seedbuf;
1810
0
    key->d = key->z + ML_KEM_RANDOM_BYTES;
1811
0
    memcpy(key->d, seed, ML_KEM_RANDOM_BYTES);
1812
0
    seed += ML_KEM_RANDOM_BYTES;
1813
0
    memcpy(key->z, seed, ML_KEM_RANDOM_BYTES);
1814
0
    return key;
1815
0
}
1816
1817
/* Parse input as a public key */
1818
int ossl_ml_kem_parse_public_key(const uint8_t *in, size_t len, ML_KEM_KEY *key)
1819
0
{
1820
0
    EVP_MD_CTX *mdctx = NULL;
1821
0
    const ML_KEM_VINFO *vinfo;
1822
0
    int ret = 0;
1823
1824
    /* Keys with key material are immutable */
1825
0
    if (key == NULL
1826
0
        || ossl_ml_kem_have_pubkey(key)
1827
0
        || ossl_ml_kem_have_dkenc(key))
1828
0
        return 0;
1829
0
    vinfo = key->vinfo;
1830
1831
0
    if (len != vinfo->pubkey_bytes
1832
0
        || (mdctx = EVP_MD_CTX_new()) == NULL)
1833
0
        return 0;
1834
1835
0
    if (add_storage(OPENSSL_malloc(vinfo->puballoc), NULL, 0, key))
1836
0
        ret = parse_pubkey(in, mdctx, key);
1837
1838
0
    if (!ret)
1839
0
        ossl_ml_kem_key_reset(key);
1840
0
    EVP_MD_CTX_free(mdctx);
1841
0
    return ret;
1842
0
}
1843
1844
/* Parse input as a new private key */
1845
int ossl_ml_kem_parse_private_key(const uint8_t *in, size_t len,
1846
                                  ML_KEM_KEY *key)
1847
0
{
1848
0
    EVP_MD_CTX *mdctx = NULL;
1849
0
    const ML_KEM_VINFO *vinfo;
1850
0
    int ret = 0;
1851
1852
    /* Keys with key material are immutable */
1853
0
    if (key == NULL
1854
0
        || ossl_ml_kem_have_pubkey(key)
1855
0
        || ossl_ml_kem_have_dkenc(key))
1856
0
        return 0;
1857
0
    vinfo = key->vinfo;
1858
1859
0
    if (len != vinfo->prvkey_bytes
1860
0
        || (mdctx = EVP_MD_CTX_new()) == NULL)
1861
0
        return 0;
1862
1863
0
    if (add_storage(OPENSSL_malloc(vinfo->puballoc),
1864
0
                    OPENSSL_secure_malloc(vinfo->prvalloc), 1, key))
1865
0
        ret = parse_prvkey(in, mdctx, key);
1866
1867
0
    if (!ret)
1868
0
        ossl_ml_kem_key_reset(key);
1869
0
    EVP_MD_CTX_free(mdctx);
1870
0
    return ret;
1871
0
}
1872
1873
/*
1874
 * Generate a new keypair, either from the saved seed (when non-null), or from
1875
 * the RNG.
1876
 */
1877
int ossl_ml_kem_genkey(uint8_t *pubenc, size_t publen, ML_KEM_KEY *key)
1878
0
{
1879
0
    uint8_t seed[ML_KEM_SEED_BYTES];
1880
0
    EVP_MD_CTX *mdctx = NULL;
1881
0
    const ML_KEM_VINFO *vinfo;
1882
0
    int ret = 0;
1883
1884
0
    if (key == NULL
1885
0
        || ossl_ml_kem_have_pubkey(key)
1886
0
        || ossl_ml_kem_have_dkenc(key))
1887
0
        return 0;
1888
0
    vinfo = key->vinfo;
1889
1890
0
    if (pubenc != NULL && publen != vinfo->pubkey_bytes)
1891
0
        return 0;
1892
1893
0
    if (key->seedbuf != NULL) {
1894
0
        if (!ossl_ml_kem_encode_seed(seed, sizeof(seed), key))
1895
0
            return 0;
1896
0
        ossl_ml_kem_key_reset(key);
1897
0
    } else if (RAND_priv_bytes_ex(key->libctx, seed, sizeof(seed),
1898
0
                                  key->vinfo->secbits) <= 0) {
1899
0
        return 0;
1900
0
    }
1901
1902
0
    if ((mdctx = EVP_MD_CTX_new()) == NULL)
1903
0
        return 0;
1904
1905
    /*
1906
     * Data derived from (d, z) defaults secret, and to avoid side-channel
1907
     * leaks should not influence control flow.
1908
     */
1909
0
    CONSTTIME_SECRET(seed, ML_KEM_SEED_BYTES);
1910
1911
0
    if (add_storage(OPENSSL_malloc(vinfo->puballoc),
1912
0
                    OPENSSL_secure_malloc(vinfo->prvalloc), 1, key))
1913
0
        ret = genkey(seed, mdctx, pubenc, key);
1914
0
    OPENSSL_cleanse(seed, sizeof(seed));
1915
1916
    /* Declassify secret inputs and derived outputs before returning control */
1917
0
    CONSTTIME_DECLASSIFY(seed, ML_KEM_SEED_BYTES);
1918
1919
0
    EVP_MD_CTX_free(mdctx);
1920
0
    if (!ret) {
1921
0
        ossl_ml_kem_key_reset(key);
1922
0
        return 0;
1923
0
    }
1924
1925
    /* The public components are already declassified */
1926
0
    CONSTTIME_DECLASSIFY(key->s, vinfo->rank * sizeof(scalar));
1927
0
    CONSTTIME_DECLASSIFY(key->z, 2 * ML_KEM_RANDOM_BYTES);
1928
0
    return 1;
1929
0
}
1930
1931
/*
1932
 * FIPS 203, Section 6.2, Algorithm 17: ML-KEM.Encaps_internal
1933
 * This is the deterministic version with randomness supplied externally.
1934
 */
1935
int ossl_ml_kem_encap_seed(uint8_t *ctext, size_t clen,
1936
                           uint8_t *shared_secret, size_t slen,
1937
                           const uint8_t *entropy, size_t elen,
1938
                           const ML_KEM_KEY *key)
1939
0
{
1940
0
    const ML_KEM_VINFO *vinfo;
1941
0
    EVP_MD_CTX *mdctx;
1942
0
    int ret = 0;
1943
1944
0
    if (key == NULL || !ossl_ml_kem_have_pubkey(key))
1945
0
        return 0;
1946
0
    vinfo = key->vinfo;
1947
1948
0
    if (ctext == NULL || clen != vinfo->ctext_bytes
1949
0
        || shared_secret == NULL || slen != ML_KEM_SHARED_SECRET_BYTES
1950
0
        || entropy == NULL || elen != ML_KEM_RANDOM_BYTES
1951
0
        || (mdctx = EVP_MD_CTX_new()) == NULL)
1952
0
        return 0;
1953
    /*
1954
     * Data derived from the encap entropy defaults secret, and to avoid
1955
     * side-channel leaks should not influence control flow.
1956
     */
1957
0
    CONSTTIME_SECRET(entropy, elen);
1958
1959
    /*-
1960
     * This avoids the need to handle allocation failures for two (max 2KB
1961
     * each) vectors, that are never retained on return from this function.
1962
     * We stack-allocate these.
1963
     */
1964
0
#   define case_encap_seed(bits)                                            \
1965
0
    case EVP_PKEY_ML_KEM_##bits:                                            \
1966
0
        {                                                                   \
1967
0
            scalar tmp[2 * ML_KEM_##bits##_RANK];                           \
1968
0
                                                                            \
1969
0
            ret = encap(ctext, shared_secret, entropy, tmp, mdctx, key);    \
1970
0
            OPENSSL_cleanse((void *)tmp, sizeof(tmp));                      \
1971
0
            break;                                                          \
1972
0
        }
1973
0
    switch (vinfo->evp_type) {
1974
0
    case_encap_seed(512);
1975
0
    case_encap_seed(768);
1976
0
    case_encap_seed(1024);
1977
0
    }
1978
0
#   undef case_encap_seed
1979
1980
    /* Declassify secret inputs and derived outputs before returning control */
1981
0
    CONSTTIME_DECLASSIFY(entropy, elen);
1982
0
    CONSTTIME_DECLASSIFY(ctext, clen);
1983
0
    CONSTTIME_DECLASSIFY(shared_secret, slen);
1984
1985
0
    EVP_MD_CTX_free(mdctx);
1986
0
    return ret;
1987
0
}
1988
1989
int ossl_ml_kem_encap_rand(uint8_t *ctext, size_t clen,
1990
                           uint8_t *shared_secret, size_t slen,
1991
                           const ML_KEM_KEY *key)
1992
0
{
1993
0
    uint8_t r[ML_KEM_RANDOM_BYTES];
1994
1995
0
    if (key == NULL)
1996
0
        return 0;
1997
1998
0
    if (RAND_bytes_ex(key->libctx, r, ML_KEM_RANDOM_BYTES,
1999
0
                      key->vinfo->secbits) < 1)
2000
0
        return 0;
2001
2002
0
    return ossl_ml_kem_encap_seed(ctext, clen, shared_secret, slen,
2003
0
                                  r, sizeof(r), key);
2004
0
}
2005
2006
int ossl_ml_kem_decap(uint8_t *shared_secret, size_t slen,
2007
                      const uint8_t *ctext, size_t clen,
2008
                      const ML_KEM_KEY *key)
2009
0
{
2010
0
    const ML_KEM_VINFO *vinfo;
2011
0
    EVP_MD_CTX *mdctx;
2012
0
    int ret = 0;
2013
#if defined(OPENSSL_CONSTANT_TIME_VALIDATION)
2014
    int classify_bytes;
2015
#endif
2016
2017
    /* Need a private key here */
2018
0
    if (!ossl_ml_kem_have_prvkey(key))
2019
0
        return 0;
2020
0
    vinfo = key->vinfo;
2021
2022
0
    if (shared_secret == NULL || slen != ML_KEM_SHARED_SECRET_BYTES
2023
0
        || ctext == NULL || clen != vinfo->ctext_bytes
2024
0
        || (mdctx = EVP_MD_CTX_new()) == NULL) {
2025
0
        (void)RAND_bytes_ex(key->libctx, shared_secret,
2026
0
                            ML_KEM_SHARED_SECRET_BYTES, vinfo->secbits);
2027
0
        return 0;
2028
0
    }
2029
#if defined(OPENSSL_CONSTANT_TIME_VALIDATION)
2030
    /*
2031
     * Data derived from |s| and |z| defaults secret, and to avoid side-channel
2032
     * leaks should not influence control flow.
2033
     */
2034
    classify_bytes = 2 * sizeof(scalar) + ML_KEM_RANDOM_BYTES;
2035
    CONSTTIME_SECRET(key->s, classify_bytes);
2036
#endif
2037
2038
    /*-
2039
     * This avoids the need to handle allocation failures for two (max 2KB
2040
     * each) vectors and an encoded ciphertext (max 1568 bytes), that are never
2041
     * retained on return from this function.
2042
     * We stack-allocate these.
2043
     */
2044
0
#   define case_decap(bits)                                             \
2045
0
    case EVP_PKEY_ML_KEM_##bits:                                        \
2046
0
        {                                                               \
2047
0
            uint8_t cbuf[CTEXT_BYTES(bits)];                            \
2048
0
            scalar tmp[2 * ML_KEM_##bits##_RANK];                       \
2049
0
                                                                        \
2050
0
            ret = decap(shared_secret, ctext, cbuf, tmp, mdctx, key);   \
2051
0
            OPENSSL_cleanse((void *)tmp, sizeof(tmp));                  \
2052
0
            break;                                                      \
2053
0
        }
2054
0
    switch (vinfo->evp_type) {
2055
0
    case_decap(512);
2056
0
    case_decap(768);
2057
0
    case_decap(1024);
2058
0
    }
2059
2060
    /* Declassify secret inputs and derived outputs before returning control */
2061
0
    CONSTTIME_DECLASSIFY(key->s, classify_bytes);
2062
0
    CONSTTIME_DECLASSIFY(shared_secret, slen);
2063
0
    EVP_MD_CTX_free(mdctx);
2064
2065
0
    return ret;
2066
0
#   undef case_decap
2067
0
}
2068
2069
int ossl_ml_kem_pubkey_cmp(const ML_KEM_KEY *key1, const ML_KEM_KEY *key2)
2070
0
{
2071
    /*
2072
     * This handles any unexpected differences in the ML-KEM variant rank,
2073
     * giving different key component structures, barring SHA3-256 hash
2074
     * collisions, the keys are the same size.
2075
     */
2076
0
    if (ossl_ml_kem_have_pubkey(key1) && ossl_ml_kem_have_pubkey(key2))
2077
0
        return memcmp(key1->pkhash, key2->pkhash, ML_KEM_PKHASH_BYTES) == 0;
2078
2079
    /*
2080
     * No match if just one of the public keys is not available, otherwise both
2081
     * are unavailable, and for now such keys are considered equal.
2082
     */
2083
0
    return (ossl_ml_kem_have_pubkey(key1) ^ ossl_ml_kem_have_pubkey(key2));
2084
0
}