Coverage Report

Created: 2025-11-16 06:40

next uncovered line (L), next uncovered region (R), next uncovered branch (B)
/src/openssl35/crypto/ml_kem/ml_kem.c
Line
Count
Source
1
/*
2
 * Copyright 2024-2025 The OpenSSL Project Authors. All Rights Reserved.
3
 *
4
 * Licensed under the Apache License 2.0 (the "License").  You may not use
5
 * this file except in compliance with the License.  You can obtain a copy
6
 * in the file LICENSE in the source distribution or at
7
 * https://www.openssl.org/source/license.html
8
 */
9
10
#include <openssl/byteorder.h>
11
#include <openssl/rand.h>
12
#include <openssl/proverr.h>
13
#include "crypto/ml_kem.h"
14
#include "internal/common.h"
15
#include "internal/constant_time.h"
16
#include "internal/sha3.h"
17
18
#if defined(OPENSSL_CONSTANT_TIME_VALIDATION)
19
#include <valgrind/memcheck.h>
20
#endif
21
22
#if ML_KEM_SEED_BYTES != ML_KEM_SHARED_SECRET_BYTES + ML_KEM_RANDOM_BYTES
23
# error "ML-KEM keygen seed length != shared secret + random bytes length"
24
#endif
25
#if ML_KEM_SHARED_SECRET_BYTES != ML_KEM_RANDOM_BYTES
26
# error "Invalid unequal lengths of ML-KEM shared secret and random inputs"
27
#endif
28
29
#if UINT_MAX < UINT32_MAX
30
# error "Unsupported compiler: sizeof(unsigned int) < sizeof(uint32_t)"
31
#endif
32
33
/* Handy function-like bit-extraction macros */
34
24.6M
#define bit0(b)     ((b) & 1)
35
172M
#define bitn(n, b)  (((b) >> n) & 1)
36
37
/*
38
 * 12 bits are sufficient to losslessly represent values in [0, q-1].
39
 * INVERSE_DEGREE is (n/2)^-1 mod q; used in inverse NTT.
40
 */
41
1.76M
#define DEGREE          ML_KEM_DEGREE
42
#define INVERSE_DEGREE  (ML_KEM_PRIME - 2 * 13)
43
#define LOG2PRIME       12
44
#define BARRETT_SHIFT   (2 * LOG2PRIME)
45
46
#ifdef SHA3_BLOCKSIZE
47
# define SHAKE128_BLOCKSIZE SHA3_BLOCKSIZE(128)
48
#endif
49
50
/*
51
 * Return whether a value that can only be 0 or 1 is non-zero, in constant time
52
 * in practice!  The return value is a mask that is all ones if true, and all
53
 * zeros otherwise (twos-complement arithmentic assumed for unsigned values).
54
 *
55
 * Although this is used in constant-time selects, we omit a value barrier
56
 * here.  Value barriers impede auto-vectorization (likely because it forces
57
 * the value to transit through a general-purpose register). On AArch64, this
58
 * is a difference of 2x.
59
 *
60
 * We usually add value barriers to selects because Clang turns consecutive
61
 * selects with the same condition into a branch instead of CMOV/CSEL. This
62
 * condition does not occur in Kyber, so omitting it seems to be safe so far,
63
 * but see |cbd_2|, |cbd_3|, where reduction needs to be specialised to the
64
 * sign of the input, rather than adding |q| in advance, and using the generic
65
 * |reduce_once|.  (David Benjamin, Chromium)
66
 */
67
#if 0
68
# define constish_time_non_zero(b) (~constant_time_is_zero(b));
69
#else
70
676M
# define constish_time_non_zero(b) (0u - (b))
71
#endif
72
73
/*
74
 * The scalar rejection-sampling buffer size needs to be a multiple of 12, but
75
 * is otherwise arbitrary, the preferred block size matches the internal buffer
76
 * size of SHAKE128, avoiding internal buffering and copying in SHAKE128. That
77
 * block size of (1600 - 256)/8 bytes, or 168, just happens to divide by 12!
78
 *
79
 * If the blocksize is unknown, or is not divisible by 12, 168 is used as a
80
 * fallback.
81
 */
82
#if defined(SHAKE128_BLOCKSIZE) && (SHAKE128_BLOCKSIZE) % 12 == 0
83
# define SCALAR_SAMPLING_BUFSIZE (SHAKE128_BLOCKSIZE)
84
#else
85
# define SCALAR_SAMPLING_BUFSIZE 168
86
#endif
87
88
/*
89
 * Structure of keys
90
 */
91
typedef struct ossl_ml_kem_scalar_st {
92
    /* On every function entry and exit, 0 <= c[i] < ML_KEM_PRIME. */
93
    uint16_t c[ML_KEM_DEGREE];
94
} scalar;
95
96
/* Key material allocation layout */
97
#define DECLARE_ML_KEM_KEYDATA(name, rank, private_sz) \
98
    struct name##_alloc { \
99
        /* Public vector |t| */ \
100
        scalar tbuf[(rank)]; \
101
        /* Pre-computed matrix |m| (FIPS 203 |A| transpose) */ \
102
        scalar mbuf[(rank)*(rank)] \
103
        /* optional private key data */ \
104
        private_sz \
105
    }
106
107
/* Declare variant-specific public and private storage */
108
#define DECLARE_ML_KEM_VARIANT_KEYDATA(bits) \
109
    DECLARE_ML_KEM_KEYDATA(pubkey_##bits, ML_KEM_##bits##_RANK,;); \
110
    DECLARE_ML_KEM_KEYDATA(prvkey_##bits, ML_KEM_##bits##_RANK,;\
111
        scalar sbuf[ML_KEM_##bits##_RANK]; \
112
        uint8_t zbuf[2 * ML_KEM_RANDOM_BYTES];)
113
DECLARE_ML_KEM_VARIANT_KEYDATA(512);
114
DECLARE_ML_KEM_VARIANT_KEYDATA(768);
115
DECLARE_ML_KEM_VARIANT_KEYDATA(1024);
116
#undef DECLARE_ML_KEM_VARIANT_KEYDATA
117
#undef DECLARE_ML_KEM_KEYDATA
118
119
typedef __owur
120
int (*CBD_FUNC)(scalar *out, uint8_t in[ML_KEM_RANDOM_BYTES + 1],
121
                EVP_MD_CTX *mdctx, const ML_KEM_KEY *key);
122
static void scalar_encode(uint8_t *out, const scalar *s, int bits);
123
124
/*
125
 * The wire-form of a losslessly encoded vector uses 12-bits per element.
126
 *
127
 * The wire-form public key consists of the lossless encoding of the public
128
 * vector |t|, followed by the public seed |rho|.
129
 *
130
 * Our serialised private key concatenates serialisations of the private vector
131
 * |s|, the public key, the public key hash, and the failure secret |z|.
132
 */
133
#define VECTOR_BYTES(b)     ((3 * DEGREE / 2) * ML_KEM_##b##_RANK)
134
#define PUBKEY_BYTES(b)     (VECTOR_BYTES(b) + ML_KEM_RANDOM_BYTES)
135
#define PRVKEY_BYTES(b)     (2 * PUBKEY_BYTES(b) + ML_KEM_PKHASH_BYTES)
136
137
/*
138
 * Encapsulation produces a vector "u" and a scalar "v", whose coordinates
139
 * (numbers modulo the ML-KEM prime "q") are lossily encoded using as "du" and
140
 * "dv" bits, respectively.  This encoding is the ciphertext input for
141
 * decapsulation.
142
 */
143
#define U_VECTOR_BYTES(b)   ((DEGREE / 8) * ML_KEM_##b##_DU * ML_KEM_##b##_RANK)
144
#define V_SCALAR_BYTES(b)   ((DEGREE / 8) * ML_KEM_##b##_DV)
145
#define CTEXT_BYTES(b)      (U_VECTOR_BYTES(b) + V_SCALAR_BYTES(b))
146
147
#if defined(OPENSSL_CONSTANT_TIME_VALIDATION)
148
149
/*
150
 * CONSTTIME_SECRET takes a pointer and a number of bytes and marks that region
151
 * of memory as secret. Secret data is tracked as it flows to registers and
152
 * other parts of a memory. If secret data is used as a condition for a branch,
153
 * or as a memory index, it will trigger warnings in valgrind.
154
 */
155
# define CONSTTIME_SECRET(ptr, len) VALGRIND_MAKE_MEM_UNDEFINED(ptr, len)
156
157
/*
158
 * CONSTTIME_DECLASSIFY takes a pointer and a number of bytes and marks that
159
 * region of memory as public. Public data is not subject to constant-time
160
 * rules.
161
 */
162
# define CONSTTIME_DECLASSIFY(ptr, len) VALGRIND_MAKE_MEM_DEFINED(ptr, len)
163
164
#else
165
166
# define CONSTTIME_SECRET(ptr, len)
167
# define CONSTTIME_DECLASSIFY(ptr, len)
168
169
#endif
170
171
/*
172
 * Indices of slots in the vinfo tables below
173
 */
174
38.8k
#define ML_KEM_512_VINFO    0
175
101k
#define ML_KEM_768_VINFO    1
176
38.8k
#define ML_KEM_1024_VINFO   2
177
178
/*
179
 * Per-variant fixed parameters
180
 */
181
static const ML_KEM_VINFO vinfo_map[3] = {
182
    {
183
        "ML-KEM-512",
184
        PRVKEY_BYTES(512),
185
        sizeof(struct prvkey_512_alloc),
186
        PUBKEY_BYTES(512),
187
        sizeof(struct pubkey_512_alloc),
188
        CTEXT_BYTES(512),
189
        VECTOR_BYTES(512),
190
        U_VECTOR_BYTES(512),
191
        EVP_PKEY_ML_KEM_512,
192
        ML_KEM_512_BITS,
193
        ML_KEM_512_RANK,
194
        ML_KEM_512_DU,
195
        ML_KEM_512_DV,
196
        ML_KEM_512_SECBITS
197
    },
198
    {
199
        "ML-KEM-768",
200
        PRVKEY_BYTES(768),
201
        sizeof(struct prvkey_768_alloc),
202
        PUBKEY_BYTES(768),
203
        sizeof(struct pubkey_768_alloc),
204
        CTEXT_BYTES(768),
205
        VECTOR_BYTES(768),
206
        U_VECTOR_BYTES(768),
207
        EVP_PKEY_ML_KEM_768,
208
        ML_KEM_768_BITS,
209
        ML_KEM_768_RANK,
210
        ML_KEM_768_DU,
211
        ML_KEM_768_DV,
212
        ML_KEM_768_SECBITS
213
    },
214
    {
215
        "ML-KEM-1024",
216
        PRVKEY_BYTES(1024),
217
        sizeof(struct prvkey_1024_alloc),
218
        PUBKEY_BYTES(1024),
219
        sizeof(struct pubkey_1024_alloc),
220
        CTEXT_BYTES(1024),
221
        VECTOR_BYTES(1024),
222
        U_VECTOR_BYTES(1024),
223
        EVP_PKEY_ML_KEM_1024,
224
        ML_KEM_1024_BITS,
225
        ML_KEM_1024_RANK,
226
        ML_KEM_1024_DU,
227
        ML_KEM_1024_DV,
228
        ML_KEM_1024_SECBITS
229
    }
230
};
231
232
/*
233
 * Remainders modulo `kPrime`, for sufficiently small inputs, are computed in
234
 * constant time via Barrett reduction, and a final call to reduce_once(),
235
 * which reduces inputs that are at most 2*kPrime and is also constant-time.
236
 */
237
static const int kPrime = ML_KEM_PRIME;
238
static const unsigned int kBarrettShift = BARRETT_SHIFT;
239
static const size_t   kBarrettMultiplier = (1 << BARRETT_SHIFT) / ML_KEM_PRIME;
240
static const uint16_t kHalfPrime = (ML_KEM_PRIME - 1) / 2;
241
static const uint16_t kInverseDegree = INVERSE_DEGREE;
242
243
/*
244
 * Python helper:
245
 *
246
 * p = 3329
247
 * def bitreverse(i):
248
 *     ret = 0
249
 *     for n in range(7):
250
 *         bit = i & 1
251
 *         ret <<= 1
252
 *         ret |= bit
253
 *         i >>= 1
254
 *     return ret
255
 */
256
257
/*-
258
 * First precomputed array from Appendix A of FIPS 203, or else Python:
259
 * kNTTRoots = [pow(17, bitreverse(i), p) for i in range(128)]
260
 */
261
static const uint16_t kNTTRoots[128] = {
262
    1,    1729, 2580, 3289, 2642, 630,  1897, 848,
263
    1062, 1919, 193,  797,  2786, 3260, 569,  1746,
264
    296,  2447, 1339, 1476, 3046, 56,   2240, 1333,
265
    1426, 2094, 535,  2882, 2393, 2879, 1974, 821,
266
    289,  331,  3253, 1756, 1197, 2304, 2277, 2055,
267
    650,  1977, 2513, 632,  2865, 33,   1320, 1915,
268
    2319, 1435, 807,  452,  1438, 2868, 1534, 2402,
269
    2647, 2617, 1481, 648,  2474, 3110, 1227, 910,
270
    17,   2761, 583,  2649, 1637, 723,  2288, 1100,
271
    1409, 2662, 3281, 233,  756,  2156, 3015, 3050,
272
    1703, 1651, 2789, 1789, 1847, 952,  1461, 2687,
273
    939,  2308, 2437, 2388, 733,  2337, 268,  641,
274
    1584, 2298, 2037, 3220, 375,  2549, 2090, 1645,
275
    1063, 319,  2773, 757,  2099, 561,  2466, 2594,
276
    2804, 1092, 403,  1026, 1143, 2150, 2775, 886,
277
    1722, 1212, 1874, 1029, 2110, 2935, 885,  2154,
278
};
279
280
/*
281
 * InverseNTTRoots = [pow(17, -bitreverse(i), p) for i in range(128)]
282
 * Listed in order of use in the inverse NTT loop (index 0 is skipped):
283
 *
284
 *  0, 64, 65, ..., 127, 32, 33, ..., 63, 16, 17, ..., 31, 8, 9, ...
285
 */
286
static const uint16_t kInverseNTTRoots[128] = {
287
    1,    1175, 2444, 394,  1219, 2300, 1455, 2117,
288
    1607, 2443, 554,  1179, 2186, 2303, 2926, 2237,
289
    525,  735,  863,  2768, 1230, 2572, 556,  3010,
290
    2266, 1684, 1239, 780,  2954, 109,  1292, 1031,
291
    1745, 2688, 3061, 992,  2596, 941,  892,  1021,
292
    2390, 642,  1868, 2377, 1482, 1540, 540,  1678,
293
    1626, 279,  314,  1173, 2573, 3096, 48,   667,
294
    1920, 2229, 1041, 2606, 1692, 680,  2746, 568,
295
    3312, 2419, 2102, 219,  855,  2681, 1848, 712,
296
    682,  927,  1795, 461,  1891, 2877, 2522, 1894,
297
    1010, 1414, 2009, 3296, 464,  2697, 816,  1352,
298
    2679, 1274, 1052, 1025, 2132, 1573, 76,   2998,
299
    3040, 2508, 1355, 450,  936,  447,  2794, 1235,
300
    1903, 1996, 1089, 3273, 283,  1853, 1990, 882,
301
    3033, 1583, 2760, 69,   543,  2532, 3136, 1410,
302
    2267, 2481, 1432, 2699, 687,  40,   749,  1600,
303
};
304
305
/*
306
 * Second precomputed array from Appendix A of FIPS 203 (normalised positive),
307
 * or else Python:
308
 * ModRoots = [pow(17, 2*bitreverse(i) + 1, p) for i in range(128)]
309
 */
310
static const uint16_t kModRoots[128] = {
311
    17,   3312, 2761, 568,  583,  2746, 2649, 680,  1637, 1692, 723,  2606,
312
    2288, 1041, 1100, 2229, 1409, 1920, 2662, 667,  3281, 48,   233,  3096,
313
    756,  2573, 2156, 1173, 3015, 314,  3050, 279,  1703, 1626, 1651, 1678,
314
    2789, 540,  1789, 1540, 1847, 1482, 952,  2377, 1461, 1868, 2687, 642,
315
    939,  2390, 2308, 1021, 2437, 892,  2388, 941,  733,  2596, 2337, 992,
316
    268,  3061, 641,  2688, 1584, 1745, 2298, 1031, 2037, 1292, 3220, 109,
317
    375,  2954, 2549, 780,  2090, 1239, 1645, 1684, 1063, 2266, 319,  3010,
318
    2773, 556,  757,  2572, 2099, 1230, 561,  2768, 2466, 863,  2594, 735,
319
    2804, 525,  1092, 2237, 403,  2926, 1026, 2303, 1143, 2186, 2150, 1179,
320
    2775, 554,  886,  2443, 1722, 1607, 1212, 2117, 1874, 1455, 1029, 2300,
321
    2110, 1219, 2935, 394,  885,  2444, 2154, 1175,
322
};
323
324
/*
325
 * single_keccak hashes |inlen| bytes from |in| and writes |outlen| bytes of
326
 * output to |out|. If the |md| specifies a fixed-output function, like
327
 * SHA3-256, then |outlen| must be the correct length for that function.
328
 */
329
static __owur
330
int single_keccak(uint8_t *out, size_t outlen, const uint8_t *in, size_t inlen,
331
                  EVP_MD_CTX *mdctx)
332
224k
{
333
224k
    unsigned int sz = (unsigned int) outlen;
334
335
224k
    if (!EVP_DigestUpdate(mdctx, in, inlen))
336
0
        return 0;
337
224k
    if (EVP_MD_xof(EVP_MD_CTX_get0_md(mdctx)))
338
192k
        return EVP_DigestFinalXOF(mdctx, out, outlen);
339
32.2k
    return EVP_DigestFinal_ex(mdctx, out, &sz)
340
32.2k
        && ossl_assert((size_t) sz == outlen);
341
224k
}
342
343
/*
344
 * FIPS 203, Section 4.1, equation (4.3): PRF. Takes 32+1 input bytes, and uses
345
 * SHAKE256 to produce the input to SamplePolyCBD_eta: FIPS 203, algorithm 8.
346
 */
347
static __owur
348
int prf(uint8_t *out, size_t len, const uint8_t in[ML_KEM_RANDOM_BYTES + 1],
349
        EVP_MD_CTX *mdctx, const ML_KEM_KEY *key)
350
192k
{
351
192k
    return EVP_DigestInit_ex(mdctx, key->shake256_md, NULL)
352
192k
        && single_keccak(out, len, in, ML_KEM_RANDOM_BYTES + 1, mdctx);
353
192k
}
354
355
/*
356
 * FIPS 203, Section 4.1, equation (4.4): H.  SHA3-256 hash of a variable
357
 * length input, producing 32 bytes of output.
358
 */
359
static __owur
360
int hash_h(uint8_t out[ML_KEM_PKHASH_BYTES], const uint8_t *in, size_t len,
361
           EVP_MD_CTX *mdctx, const ML_KEM_KEY *key)
362
181
{
363
181
    return EVP_DigestInit_ex(mdctx, key->sha3_256_md, NULL)
364
181
        && single_keccak(out, ML_KEM_PKHASH_BYTES, in, len, mdctx);
365
181
}
366
367
/* Incremental hash_h of expanded public key */
368
static int
369
hash_h_pubkey(uint8_t pkhash[ML_KEM_PKHASH_BYTES],
370
              EVP_MD_CTX *mdctx, ML_KEM_KEY *key)
371
31.8k
{
372
31.8k
    const ML_KEM_VINFO *vinfo = key->vinfo;
373
31.8k
    const scalar *t = key->t, *end = t + vinfo->rank;
374
31.8k
    unsigned int sz;
375
376
31.8k
    if (!EVP_DigestInit_ex(mdctx, key->sha3_256_md, NULL))
377
0
        return 0;
378
379
95.6k
    do {
380
95.6k
        uint8_t buf[3 * DEGREE / 2];
381
382
95.6k
        scalar_encode(buf, t++, 12);
383
95.6k
        if (!EVP_DigestUpdate(mdctx, buf, sizeof(buf)))
384
0
            return 0;
385
95.6k
    } while (t < end);
386
387
31.8k
    if (!EVP_DigestUpdate(mdctx, key->rho, ML_KEM_RANDOM_BYTES))
388
0
        return 0;
389
31.8k
    return EVP_DigestFinal_ex(mdctx, pkhash, &sz)
390
31.8k
        && ossl_assert(sz == ML_KEM_PKHASH_BYTES);
391
31.8k
}
392
393
/*
394
 * FIPS 203, Section 4.1, equation (4.5): G.  SHA3-512 hash of a variable
395
 * length input, producing 64 bytes of output, in particular the seeds
396
 * (d,z) for key generation.
397
 */
398
static __owur
399
int hash_g(uint8_t out[ML_KEM_SEED_BYTES], const uint8_t *in, size_t len,
400
           EVP_MD_CTX *mdctx, const ML_KEM_KEY *key)
401
32.0k
{
402
32.0k
    return EVP_DigestInit_ex(mdctx, key->sha3_512_md, NULL)
403
32.0k
        && single_keccak(out, ML_KEM_SEED_BYTES, in, len, mdctx);
404
32.0k
}
405
406
/*
407
 * FIPS 203, Section 4.1, equation (4.4): J. SHAKE256 taking a variable length
408
 * input to compute a 32-byte implicit rejection shared secret, of the same
409
 * length as the expected shared secret.  (Computed even on success to avoid
410
 * side-channel leaks).
411
 */
412
static __owur
413
int kdf(uint8_t out[ML_KEM_SHARED_SECRET_BYTES],
414
        const uint8_t z[ML_KEM_RANDOM_BYTES],
415
        const uint8_t *ctext, size_t len,
416
        EVP_MD_CTX *mdctx, const ML_KEM_KEY *key)
417
53
{
418
53
    return EVP_DigestInit_ex(mdctx, key->shake256_md, NULL)
419
53
        && EVP_DigestUpdate(mdctx, z, ML_KEM_RANDOM_BYTES)
420
53
        && EVP_DigestUpdate(mdctx, ctext, len)
421
53
        && EVP_DigestFinalXOF(mdctx, out, ML_KEM_SHARED_SECRET_BYTES);
422
53
}
423
424
/*
425
 * FIPS 203, Section 4.2.2, Algorithm 7: "SampleNTT" (steps 3-17, steps 1, 2
426
 * are performed by the caller). Rejection-samples a Keccak stream to get
427
 * uniformly distributed elements in the range [0,q). This is used for matrix
428
 * expansion and only operates on public inputs.
429
 */
430
static __owur
431
int sample_scalar(scalar *out, EVP_MD_CTX *mdctx)
432
288k
{
433
288k
    uint16_t *curr = out->c, *endout = curr + DEGREE;
434
288k
    uint8_t buf[SCALAR_SAMPLING_BUFSIZE], *in;
435
288k
    uint8_t *endin = buf + sizeof(buf);
436
288k
    uint16_t d;
437
288k
    uint8_t b1, b2, b3;
438
439
865k
    do {
440
865k
        if (!EVP_DigestSqueeze(mdctx, in = buf, sizeof(buf)))
441
0
            return 0;
442
45.2M
        do {
443
45.2M
            b1 = *in++;
444
45.2M
            b2 = *in++;
445
45.2M
            b3 = *in++;
446
447
45.2M
            if (curr >= endout)
448
96.8k
                break;
449
45.1M
            if ((d = ((b2 & 0x0f) << 8) + b1) < kPrime)
450
37.2M
                *curr++ = d;
451
45.1M
            if (curr >= endout)
452
191k
                break;
453
44.9M
            if ((d = (b3 << 4) + (b2 >> 4)) < kPrime)
454
36.5M
                *curr++ = d;
455
44.9M
        } while (in < endin);
456
865k
    } while (curr < endout);
457
288k
    return 1;
458
288k
}
459
460
/*-
461
 * reduce_once reduces 0 <= x < 2*kPrime, mod kPrime.
462
 *
463
 * Subtract |q| if the input is larger, without exposing a side-channel,
464
 * avoiding the "clangover" attack.  See |constish_time_non_zero| for a
465
 * discussion on why the value barrier is by default omitted.
466
 */
467
static __owur uint16_t reduce_once(uint16_t x)
468
627M
{
469
627M
    const uint16_t subtracted = x - kPrime;
470
627M
    uint16_t mask = constish_time_non_zero(subtracted >> 15);
471
472
627M
    return (mask & x) | (~mask & subtracted);
473
627M
}
474
475
/*
476
 * Constant-time reduce x mod kPrime using Barrett reduction. x must be less
477
 * than kPrime + 2 * kPrime^2.  This is sufficient to reduce a product of
478
 * two already reduced u_int16 values, in fact it is sufficient for each
479
 * to be less than 2^12, because (kPrime * (2 * kPrime + 1)) > 2^24.
480
 */
481
static __owur uint16_t reduce(uint32_t x)
482
283M
{
483
283M
    uint64_t product = (uint64_t)x * kBarrettMultiplier;
484
283M
    uint32_t quotient = (uint32_t)(product >> kBarrettShift);
485
283M
    uint32_t remainder = x - quotient * kPrime;
486
487
283M
    return reduce_once(remainder);
488
283M
}
489
490
/* Multiply a scalar by a constant. */
491
static void scalar_mult_const(scalar *s, uint16_t a)
492
509
{
493
509
    uint16_t *curr = s->c, *end = curr + DEGREE, tmp;
494
495
130k
    do {
496
130k
        tmp = reduce(*curr * a);
497
130k
        *curr++ = tmp;
498
130k
    } while (curr < end);
499
509
}
500
501
/*-
502
 * FIPS 203, Section 4.3, Algoritm 9: "NTT".
503
 * In-place number theoretic transform of a given scalar.  Note that ML-KEM's
504
 * kPrime 3329 does not have a 512th root of unity, so this transform leaves
505
 * off the last iteration of the usual FFT code, with the 128 relevant roots of
506
 * unity being stored in NTTRoots.  This means the output should be seen as 128
507
 * elements in GF(3329^2), with the coefficients of the elements being
508
 * consecutive entries in |s->c|.
509
 */
510
static void scalar_ntt(scalar *s)
511
191k
{
512
191k
    const uint16_t *roots = kNTTRoots;
513
191k
    uint16_t *end = s->c + DEGREE;
514
191k
    int offset = DEGREE / 2;
515
516
1.34M
    do {
517
1.34M
        uint16_t *curr = s->c, *peer;
518
519
24.3M
        do {
520
24.3M
            uint16_t *pause = curr + offset, even, odd;
521
24.3M
            uint32_t zeta = *++roots;
522
523
24.3M
            peer = pause;
524
171M
            do {
525
171M
                even = *curr;
526
171M
                odd = reduce(*peer * zeta);
527
171M
                *peer++ = reduce_once(even - odd + kPrime);
528
171M
                *curr++ = reduce_once(odd + even);
529
171M
            } while (curr < pause);
530
24.3M
        } while ((curr = peer) < end);
531
1.34M
    } while ((offset >>= 1) >= 2);
532
191k
}
533
534
/*-
535
 * FIPS 203, Section 4.3, Algoritm 10: "NTT^(-1)".
536
 * In-place inverse number theoretic transform of a given scalar, with pairs of
537
 * entries of s->v being interpreted as elements of GF(3329^2). Just as with
538
 * the number theoretic transform, this leaves off the first step of the normal
539
 * iFFT to account for the fact that 3329 does not have a 512th root of unity,
540
 * using the precomputed 128 roots of unity stored in InverseNTTRoots.
541
 */
542
static void scalar_inverse_ntt(scalar *s)
543
509
{
544
509
    const uint16_t *roots = kInverseNTTRoots;
545
509
    uint16_t *end = s->c + DEGREE;
546
509
    int offset = 2;
547
548
3.56k
    do {
549
3.56k
        uint16_t *curr = s->c, *peer;
550
551
64.6k
        do {
552
64.6k
            uint16_t *pause = curr + offset, even, odd;
553
64.6k
            uint32_t zeta = *++roots;
554
555
64.6k
            peer = pause;
556
456k
            do {
557
456k
                even = *curr;
558
456k
                odd = *peer;
559
456k
                *peer++ = reduce(zeta * (even - odd + kPrime));
560
456k
                *curr++ = reduce_once(odd + even);
561
456k
            } while (curr < pause);
562
64.6k
        } while ((curr = peer) < end);
563
3.56k
    } while ((offset <<= 1) < DEGREE);
564
509
    scalar_mult_const(s, kInverseDegree);
565
509
}
566
567
/* Addition updating the LHS scalar in-place. */
568
static void scalar_add(scalar *lhs, const scalar *rhs)
569
456
{
570
456
    int i;
571
572
117k
    for (i = 0; i < DEGREE; i++)
573
116k
        lhs->c[i] = reduce_once(lhs->c[i] + rhs->c[i]);
574
456
}
575
576
/* Subtraction updating the LHS scalar in-place. */
577
static void scalar_sub(scalar *lhs, const scalar *rhs)
578
53
{
579
53
    int i;
580
581
13.6k
    for (i = 0; i < DEGREE; i++)
582
13.5k
        lhs->c[i] = reduce_once(lhs->c[i] - rhs->c[i] + kPrime);
583
53
}
584
585
/*
586
 * Multiplying two scalars in the number theoretically transformed state. Since
587
 * 3329 does not have a 512th root of unity, this means we have to interpret
588
 * the 2*ith and (2*i+1)th entries of the scalar as elements of
589
 * GF(3329)[X]/(X^2 - 17^(2*bitreverse(i)+1)).
590
 *
591
 * The value of 17^(2*bitreverse(i)+1) mod 3329 is stored in the precomputed
592
 * ModRoots table. Note that our Barrett transform only allows us to multipy
593
 * two reduced numbers together, so we need some intermediate reduction steps,
594
 * even if an uint64_t could hold 3 multiplied numbers.
595
 */
596
static void scalar_mult(scalar *out, const scalar *lhs,
597
                        const scalar *rhs)
598
509
{
599
509
    uint16_t *curr = out->c, *end = curr + DEGREE;
600
509
    const uint16_t *lc = lhs->c, *rc = rhs->c;
601
509
    const uint16_t *roots = kModRoots;
602
603
65.1k
    do {
604
65.1k
        uint32_t l0 = *lc++, r0 = *rc++;
605
65.1k
        uint32_t l1 = *lc++, r1 = *rc++;
606
65.1k
        uint32_t zetapow = *roots++;
607
608
65.1k
        *curr++ = reduce(l0 * r0 + reduce(l1 * r1) * zetapow);
609
65.1k
        *curr++ = reduce(l0 * r1 + l1 * r0);
610
65.1k
    } while (curr < end);
611
509
}
612
613
/* Above, but add the result to an existing scalar */
614
static ossl_inline
615
void scalar_mult_add(scalar *out, const scalar *lhs,
616
                     const scalar *rhs)
617
287k
{
618
287k
    uint16_t *curr = out->c, *end = curr + DEGREE;
619
287k
    const uint16_t *lc = lhs->c, *rc = rhs->c;
620
287k
    const uint16_t *roots = kModRoots;
621
622
36.8M
    do {
623
36.8M
        uint32_t l0 = *lc++, r0 = *rc++;
624
36.8M
        uint32_t l1 = *lc++, r1 = *rc++;
625
36.8M
        uint16_t *c0 = curr++;
626
36.8M
        uint16_t *c1 = curr++;
627
36.8M
        uint32_t zetapow = *roots++;
628
629
36.8M
        *c0 = reduce(*c0 + l0 * r0 + reduce(l1 * r1) * zetapow);
630
36.8M
        *c1 = reduce(*c1 + l0 * r1 + l1 * r0);
631
36.8M
    } while (curr < end);
632
287k
}
633
634
/*-
635
 * FIPS 203, Section 4.2.1, Algorithm 5: "ByteEncode_d", for 2<=d<=12.
636
 * Here |bits| is |d|.  For efficiency, we handle the d=1 case separately.
637
 */
638
static void scalar_encode(uint8_t *out, const scalar *s, int bits)
639
191k
{
640
191k
    const uint16_t *curr = s->c, *end = curr + DEGREE;
641
191k
    uint64_t accum = 0, element;
642
191k
    int used = 0;
643
644
49.0M
    do {
645
49.0M
        element = *curr++;
646
49.0M
        if (used + bits < 64) {
647
39.8M
            accum |= element << used;
648
39.8M
            used += bits;
649
39.8M
        } else if (used + bits > 64) {
650
6.13M
            out = OPENSSL_store_u64_le(out, accum | (element << used));
651
6.13M
            accum = element >> (64 - used);
652
6.13M
            used = (used + bits) - 64;
653
6.13M
        } else {
654
3.06M
            out = OPENSSL_store_u64_le(out, accum | (element << used));
655
3.06M
            accum = 0;
656
3.06M
            used = 0;
657
3.06M
        }
658
49.0M
    } while (curr < end);
659
191k
}
660
661
/*
662
 * scalar_encode_1 is |scalar_encode| specialised for |bits| == 1.
663
 */
664
static void scalar_encode_1(uint8_t out[DEGREE / 8], const scalar *s)
665
53
{
666
53
    int i, j;
667
53
    uint8_t out_byte;
668
669
1.74k
    for (i = 0; i < DEGREE; i += 8) {
670
1.69k
        out_byte = 0;
671
15.2k
        for (j = 0; j < 8; j++)
672
13.5k
            out_byte |= bit0(s->c[i + j]) << j;
673
1.69k
        *out = out_byte;
674
1.69k
        out++;
675
1.69k
    }
676
53
}
677
678
/*-
679
 * FIPS 203, Section 4.2.1, Algorithm 6: "ByteDecode_d", for 2<=d<12.
680
 * Here |bits| is |d|.  For efficiency, we handle the d=1 and d=12 cases
681
 * separately.
682
 *
683
 * scalar_decode parses |DEGREE * bits| bits from |in| into |DEGREE| values in
684
 * |out|.
685
 */
686
static void scalar_decode(scalar *out, const uint8_t *in, int bits)
687
190
{
688
190
    uint16_t *curr = out->c, *end = curr + DEGREE;
689
190
    uint64_t accum = 0;
690
190
    int accum_bits = 0, todo = bits;
691
190
    uint16_t bitmask = (((uint16_t) 1) << bits) - 1, mask = bitmask;
692
190
    uint16_t element = 0;
693
694
53.4k
    do {
695
53.4k
        if (accum_bits == 0) {
696
6.50k
            in = OPENSSL_load_u64_le(&accum, in);
697
6.50k
            accum_bits = 64;
698
6.50k
        }
699
53.4k
        if (todo == bits && accum_bits >= bits) {
700
            /* No partial "element", and all the required bits available */
701
43.8k
            *curr++ = ((uint16_t) accum) & mask;
702
43.8k
            accum >>= bits;
703
43.8k
            accum_bits -= bits;
704
43.8k
        } else if (accum_bits >= todo) {
705
            /* A partial "element", and all the required bits available */
706
4.81k
            *curr++ = element | ((((uint16_t) accum) & mask) << (bits - todo));
707
4.81k
            accum >>= todo;
708
4.81k
            accum_bits -= todo;
709
4.81k
            element = 0;
710
4.81k
            todo = bits;
711
4.81k
            mask = bitmask;
712
4.81k
        } else {
713
            /*
714
             * Only some of the requisite bits accumulated, store |accum_bits|
715
             * of these in |element|.  The accumulated bitcount becomes 0, but
716
             * as soon as we have more bits we'll want to merge accum_bits
717
             * fewer of them into the final |element|.
718
             *
719
             * Note that with a 64-bit accumulator and |bits| always 12 or
720
             * less, if we're here, the previous iteration had all the
721
             * requisite bits, and so there are no kept bits in |element|.
722
             */
723
4.81k
            element = ((uint16_t) accum) & mask;
724
4.81k
            todo -= accum_bits;
725
4.81k
            mask = bitmask >> accum_bits;
726
4.81k
            accum_bits = 0;
727
4.81k
        }
728
53.4k
    } while (curr < end);
729
190
}
730
731
static __owur
732
int scalar_decode_12(scalar *out, const uint8_t in[3 * DEGREE / 2])
733
697
{
734
697
    int i;
735
697
    uint16_t *c = out->c;
736
737
72.7k
    for (i = 0; i < DEGREE / 2; ++i) {
738
72.2k
        uint8_t b1 = *in++;
739
72.2k
        uint8_t b2 = *in++;
740
72.2k
        uint8_t b3 = *in++;
741
72.2k
        int outOfRange1 = (*c++ = b1 | ((b2 & 0x0f) << 8)) >= kPrime;
742
72.2k
        int outOfRange2 = (*c++ = (b2 >> 4) | (b3 << 4)) >= kPrime;
743
744
72.2k
        if (outOfRange1 | outOfRange2)
745
158
            return 0;
746
72.2k
    }
747
539
    return 1;
748
697
}
749
750
/*-
751
 * scalar_decode_decompress_add is a combination of decoding and decompression
752
 * both specialised for |bits| == 1, with the result added (and sum reduced) to
753
 * the output scalar.
754
 *
755
 * NOTE: this function MUST not leak an input-data-depedennt timing signal.
756
 * A timing leak in a related function in the reference Kyber implementation
757
 * made the "clangover" attack (CVE-2024-37880) possible, giving key recovery
758
 * for ML-KEM-512 in minutes, provided the attacker has access to precise
759
 * timing of a CPU performing chosen-ciphertext decap.  Admittedly this is only
760
 * a risk when private keys are reused (perhaps KEMTLS servers).
761
 */
762
static void
763
scalar_decode_decompress_add(scalar *out, const uint8_t in[DEGREE / 8])
764
125
{
765
125
    static const uint16_t half_q_plus_1 = (ML_KEM_PRIME >> 1) + 1;
766
125
    uint16_t *curr = out->c, *end = curr + DEGREE;
767
125
    uint16_t mask;
768
125
    uint8_t b;
769
770
    /*
771
     * Add |half_q_plus_1| if the bit is set, without exposing a side-channel,
772
     * avoiding the "clangover" attack.  See |constish_time_non_zero| for a
773
     * discussion on why the value barrier is by default omitted.
774
     */
775
125
#define decode_decompress_add_bit                               \
776
32.0k
        mask = constish_time_non_zero(bit0(b));                 \
777
32.0k
        *curr = reduce_once(*curr + (mask & half_q_plus_1));    \
778
32.0k
        curr++;                                                 \
779
32.0k
        b >>= 1
780
781
    /* Unrolled to process each byte in one iteration */
782
4.00k
    do {
783
4.00k
        b = *in++;
784
4.00k
        decode_decompress_add_bit;
785
4.00k
        decode_decompress_add_bit;
786
4.00k
        decode_decompress_add_bit;
787
4.00k
        decode_decompress_add_bit;
788
789
4.00k
        decode_decompress_add_bit;
790
4.00k
        decode_decompress_add_bit;
791
4.00k
        decode_decompress_add_bit;
792
4.00k
        decode_decompress_add_bit;
793
4.00k
    } while (curr < end);
794
125
#undef decode_decompress_add_bit
795
125
}
796
797
/*
798
 * FIPS 203, Section 4.2.1, Equation (4.7): Compress_d.
799
 *
800
 * Compresses (lossily) an input |x| mod 3329 into |bits| many bits by grouping
801
 * numbers close to each other together. The formula used is
802
 * round(2^|bits|/kPrime*x) mod 2^|bits|.
803
 * Uses Barrett reduction to achieve constant time. Since we need both the
804
 * remainder (for rounding) and the quotient (as the result), we cannot use
805
 * |reduce| here, but need to do the Barrett reduction directly.
806
 */
807
static __owur uint16_t compress(uint16_t x, int bits)
808
130k
{
809
130k
    uint32_t shifted = (uint32_t)x << bits;
810
130k
    uint64_t product = (uint64_t)shifted * kBarrettMultiplier;
811
130k
    uint32_t quotient = (uint32_t)(product >> kBarrettShift);
812
130k
    uint32_t remainder = shifted - quotient * kPrime;
813
814
    /*
815
     * Adjust the quotient to round correctly:
816
     *   0 <= remainder <= kHalfPrime round to 0
817
     *   kHalfPrime < remainder <= kPrime + kHalfPrime round to 1
818
     *   kPrime + kHalfPrime < remainder < 2 * kPrime round to 2
819
     */
820
130k
    quotient += 1 & constant_time_lt_32(kHalfPrime, remainder);
821
130k
    quotient += 1 & constant_time_lt_32(kPrime + kHalfPrime, remainder);
822
130k
    return quotient & ((1 << bits) - 1);
823
130k
}
824
825
/*
826
 * FIPS 203, Section 4.2.1, Equation (4.8): Decompress_d.
827
828
 * Decompresses |x| by using a close equi-distant representative. The formula
829
 * is round(kPrime/2^|bits|*x). Note that 2^|bits| being the divisor allows us
830
 * to implement this logic using only bit operations.
831
 */
832
static __owur uint16_t decompress(uint16_t x, int bits)
833
48.6k
{
834
48.6k
    uint32_t product = (uint32_t)x * kPrime;
835
48.6k
    uint32_t power = 1 << bits;
836
    /* This is |product| % power, since |power| is a power of 2. */
837
48.6k
    uint32_t remainder = product & (power - 1);
838
    /* This is |product| / power, since |power| is a power of 2. */
839
48.6k
    uint32_t lower = product >> bits;
840
841
    /*
842
     * The rounding logic works since the first half of numbers mod |power|
843
     * have a 0 as first bit, and the second half has a 1 as first bit, since
844
     * |power| is a power of 2. As a 12 bit number, |remainder| is always
845
     * positive, so we will shift in 0s for a right shift.
846
     */
847
48.6k
    return lower + (remainder >> (bits - 1));
848
48.6k
}
849
850
/*-
851
 * FIPS 203, Section 4.2.1, Equation (4.7): "Compress_d".
852
 * In-place lossy rounding of scalars to 2^d bits.
853
 */
854
static void scalar_compress(scalar *s, int bits)
855
509
{
856
509
    int i;
857
858
130k
    for (i = 0; i < DEGREE; i++)
859
130k
        s->c[i] = compress(s->c[i], bits);
860
509
}
861
862
/*
863
 * FIPS 203, Section 4.2.1, Equation (4.8): "Decompress_d".
864
 * In-place approximate recovery of scalars from 2^d bit compression.
865
 */
866
static void scalar_decompress(scalar *s, int bits)
867
190
{
868
190
    int i;
869
870
48.8k
    for (i = 0; i < DEGREE; i++)
871
48.6k
        s->c[i] = decompress(s->c[i], bits);
872
190
}
873
874
/* Addition updating the LHS vector in-place. */
875
static void vector_add(scalar *lhs, const scalar *rhs, int rank)
876
125
{
877
331
    do {
878
331
        scalar_add(lhs++, rhs++);
879
331
    } while (--rank > 0);
880
125
}
881
882
/*
883
 * Encodes an entire vector into 32*|rank|*|bits| bytes. Note that since 256
884
 * (DEGREE) is divisible by 8, the individual vector entries will always fill a
885
 * whole number of bytes, so we do not need to worry about bit packing here.
886
 */
887
static void vector_encode(uint8_t *out, const scalar *a, int bits, int rank)
888
32.0k
{
889
32.0k
    int stride = bits * DEGREE / 8;
890
891
127k
    for (; rank-- > 0; out += stride)
892
95.9k
        scalar_encode(out, a++, bits);
893
32.0k
}
894
895
/*
896
 * Decodes 32*|rank|*|bits| bytes from |in| into |out|. It returns early
897
 * if any parsed value is >= |ML_KEM_PRIME|.  The resulting scalars are
898
 * then decompressed and transformed via the NTT.
899
 *
900
 * Note: Used only in decrypt_cpa(), which returns void and so does not check
901
 * the return value of this function.  Side-channels are fine when the input
902
 * ciphertext to decap() is simply syntactically invalid.
903
 */
904
static void
905
vector_decode_decompress_ntt(scalar *out, const uint8_t *in, int bits, int rank)
906
53
{
907
53
    int stride = bits * DEGREE / 8;
908
909
190
    for (; rank-- > 0; in += stride, ++out) {
910
137
        scalar_decode(out, in, bits);
911
137
        scalar_decompress(out, bits);
912
137
        scalar_ntt(out);
913
137
    }
914
53
}
915
916
/* vector_decode(), specialised to bits == 12. */
917
static __owur
918
int vector_decode_12(scalar *out, const uint8_t in[3 * DEGREE / 2], int rank)
919
339
{
920
339
    int stride = 3 * DEGREE / 2;
921
922
878
    for (; rank-- > 0; in += stride)
923
697
        if (!scalar_decode_12(out++, in))
924
158
            return 0;
925
181
    return 1;
926
339
}
927
928
/* In-place compression of each scalar component */
929
static void vector_compress(scalar *a, int bits, int rank)
930
125
{
931
331
    do {
932
331
        scalar_compress(a++, bits);
933
331
    } while (--rank > 0);
934
125
}
935
936
/* The output scalar must not overlap with the inputs */
937
static void inner_product(scalar *out, const scalar *lhs, const scalar *rhs,
938
                          int rank)
939
178
{
940
178
    scalar_mult(out, lhs, rhs);
941
468
    while (--rank > 0)
942
290
        scalar_mult_add(out, ++lhs, ++rhs);
943
178
}
944
945
/*
946
 * Here, the output vector must not overlap with the inputs, the result is
947
 * directly subjected to inverse NTT.
948
 */
949
static void
950
matrix_mult_intt(scalar *out, const scalar *m, const scalar *a, int rank)
951
125
{
952
125
    const scalar *ar;
953
125
    int i, j;
954
955
456
    for (i = rank; i-- > 0; ++out) {
956
331
        scalar_mult(out, m++, ar = a);
957
941
        for (j = rank - 1; j > 0; --j)
958
610
            scalar_mult_add(out, m++, ++ar);
959
331
        scalar_inverse_ntt(out);
960
331
    }
961
125
}
962
963
/* Here, the output vector must not overlap with the inputs */
964
static void
965
matrix_mult_transpose_add(scalar *out, const scalar *m, const scalar *a, int rank)
966
31.8k
{
967
31.8k
    const scalar *mc = m, *mr, *ar;
968
31.8k
    int i, j;
969
970
127k
    for (i = rank; i-- > 0; ++out) {
971
95.6k
        scalar_mult_add(out, mr = mc++, ar = a);
972
286k
        for (j = rank; --j > 0; )
973
191k
            scalar_mult_add(out, (mr += rank), ++ar);
974
95.6k
    }
975
31.8k
}
976
977
/*-
978
 * Expands the matrix from a seed for key generation and for encaps-CPA.
979
 * NOTE: FIPS 203 matrix "A" is the transpose of this matrix, computed
980
 * by appending the (i,j) indices to the seed in the opposite order!
981
 *
982
 * Where FIPS 203 computes t = A * s + e, we use the transpose of "m".
983
 */
984
static __owur
985
int matrix_expand(EVP_MD_CTX *mdctx, ML_KEM_KEY *key)
986
32.0k
{
987
32.0k
    scalar *out = key->m;
988
32.0k
    uint8_t input[ML_KEM_RANDOM_BYTES + 2];
989
32.0k
    int rank = key->vinfo->rank;
990
32.0k
    int i, j;
991
992
32.0k
    memcpy(input, key->rho, ML_KEM_RANDOM_BYTES);
993
128k
    for (i = 0; i < rank; i++) {
994
384k
        for (j = 0; j < rank; j++) {
995
288k
            input[ML_KEM_RANDOM_BYTES] = i;
996
288k
            input[ML_KEM_RANDOM_BYTES + 1] = j;
997
288k
            if (!EVP_DigestInit_ex(mdctx, key->shake128_md, NULL)
998
288k
                || !EVP_DigestUpdate(mdctx, input, sizeof(input))
999
288k
                || !sample_scalar(out++, mdctx))
1000
0
                return 0;
1001
288k
        }
1002
96.1k
    }
1003
32.0k
    return 1;
1004
32.0k
}
1005
1006
/*
1007
 * Algorithm 7 from the spec, with eta fixed to two and the PRF call
1008
 * included. Creates binominally distributed elements by sampling 2*|eta| bits,
1009
 * and setting the coefficient to the count of the first bits minus the count of
1010
 * the second bits, resulting in a centered binomial distribution. Since eta is
1011
 * two this gives -2/2 with a probability of 1/16, -1/1 with probability 1/4,
1012
 * and 0 with probability 3/8.
1013
 */
1014
static __owur
1015
int cbd_2(scalar *out, uint8_t in[ML_KEM_RANDOM_BYTES + 1],
1016
          EVP_MD_CTX *mdctx, const ML_KEM_KEY *key)
1017
191k
{
1018
191k
    uint16_t *curr = out->c, *end = curr + DEGREE;
1019
191k
    uint8_t randbuf[4 * DEGREE / 8], *r = randbuf;  /* 64 * eta slots */
1020
191k
    uint16_t value, mask;
1021
191k
    uint8_t b;
1022
1023
191k
    if (!prf(randbuf, sizeof(randbuf), in, mdctx, key))
1024
0
        return 0;
1025
1026
24.4M
    do {
1027
24.4M
        b = *r++;
1028
1029
        /*
1030
         * Add |kPrime| if |value| underflowed.  See |constish_time_non_zero|
1031
         * for a discussion on why the value barrier is by default omitted.
1032
         * While this could have been written reduce_once(value + kPrime), this
1033
         * is one extra addition and small range of |value| tempts some
1034
         * versions of Clang to emit a branch.
1035
         */
1036
24.4M
        value = bit0(b) + bitn(1, b);
1037
24.4M
        value -= bitn(2, b) + bitn(3, b);
1038
24.4M
        mask = constish_time_non_zero(value >> 15);
1039
24.4M
        *curr++ = value + (kPrime & mask);
1040
1041
24.4M
        value = bitn(4, b) + bitn(5, b);
1042
24.4M
        value -= bitn(6, b) + bitn(7, b);
1043
24.4M
        mask = constish_time_non_zero(value >> 15);
1044
24.4M
        *curr++ = value + (kPrime & mask);
1045
24.4M
    } while (curr < end);
1046
191k
    return 1;
1047
191k
}
1048
1049
/*
1050
 * Algorithm 7 from the spec, with eta fixed to three and the PRF call
1051
 * included. Creates binominally distributed elements by sampling 3*|eta| bits,
1052
 * and setting the coefficient to the count of the first bits minus the count of
1053
 * the second bits, resulting in a centered binomial distribution.
1054
 */
1055
static __owur
1056
int cbd_3(scalar *out, uint8_t in[ML_KEM_RANDOM_BYTES + 1],
1057
          EVP_MD_CTX *mdctx, const ML_KEM_KEY *key)
1058
800
{
1059
800
    uint16_t *curr = out->c, *end = curr + DEGREE;
1060
800
    uint8_t randbuf[6 * DEGREE / 8], *r = randbuf;  /* 64 * eta slots */
1061
800
    uint8_t b1, b2, b3;
1062
800
    uint16_t value, mask;
1063
1064
800
    if (!prf(randbuf, sizeof(randbuf), in, mdctx, key))
1065
0
        return 0;
1066
1067
51.2k
    do {
1068
51.2k
        b1 = *r++;
1069
51.2k
        b2 = *r++;
1070
51.2k
        b3 = *r++;
1071
1072
        /*
1073
         * Add |kPrime| if |value| underflowed.  See |constish_time_non_zero|
1074
         * for a discussion on why the value barrier is by default omitted.
1075
         * While this could have been written reduce_once(value + kPrime), this
1076
         * is one extra addition and small range of |value| tempts some
1077
         * versions of Clang to emit a branch.
1078
         */
1079
51.2k
        value = bit0(b1) + bitn(1, b1) + bitn(2, b1);
1080
51.2k
        value -= bitn(3, b1)  + bitn(4, b1) + bitn(5, b1);
1081
51.2k
        mask = constish_time_non_zero(value >> 15);
1082
51.2k
        *curr++ = value + (kPrime & mask);
1083
1084
51.2k
        value = bitn(6, b1) + bitn(7, b1) + bit0(b2);
1085
51.2k
        value -= bitn(1, b2) + bitn(2, b2) + bitn(3, b2);
1086
51.2k
        mask = constish_time_non_zero(value >> 15);
1087
51.2k
        *curr++ = value + (kPrime & mask);
1088
1089
51.2k
        value = bitn(4, b2) + bitn(5, b2) + bitn(6, b2);
1090
51.2k
        value -= bitn(7, b2) + bit0(b3) + bitn(1, b3);
1091
51.2k
        mask = constish_time_non_zero(value >> 15);
1092
51.2k
        *curr++ = value + (kPrime & mask);
1093
1094
51.2k
        value = bitn(2, b3) + bitn(3, b3) + bitn(4, b3);
1095
51.2k
        value -= bitn(5, b3) + bitn(6, b3) + bitn(7, b3);
1096
51.2k
        mask = constish_time_non_zero(value >> 15);
1097
51.2k
        *curr++ = value + (kPrime & mask);
1098
51.2k
    } while (curr < end);
1099
800
    return 1;
1100
800
}
1101
1102
/*
1103
 * Generates a secret vector by using |cbd| with the given seed to generate
1104
 * scalar elements and incrementing |counter| for each slot of the vector.
1105
 */
1106
static __owur
1107
int gencbd_vector(scalar *out, CBD_FUNC cbd, uint8_t *counter,
1108
                  const uint8_t seed[ML_KEM_RANDOM_BYTES], int rank,
1109
                  EVP_MD_CTX *mdctx, const ML_KEM_KEY *key)
1110
125
{
1111
125
    uint8_t input[ML_KEM_RANDOM_BYTES + 1];
1112
1113
125
    memcpy(input, seed, ML_KEM_RANDOM_BYTES);
1114
331
    do {
1115
331
        input[ML_KEM_RANDOM_BYTES] = (*counter)++;
1116
331
        if (!cbd(out++, input, mdctx, key))
1117
0
            return 0;
1118
331
    } while (--rank > 0);
1119
125
    return 1;
1120
125
}
1121
1122
/*
1123
 * As above plus NTT transform.
1124
 */
1125
static __owur
1126
int gencbd_vector_ntt(scalar *out, CBD_FUNC cbd, uint8_t *counter,
1127
                      const uint8_t seed[ML_KEM_RANDOM_BYTES], int rank,
1128
                      EVP_MD_CTX *mdctx, const ML_KEM_KEY *key)
1129
63.9k
{
1130
63.9k
    uint8_t input[ML_KEM_RANDOM_BYTES + 1];
1131
1132
63.9k
    memcpy(input, seed, ML_KEM_RANDOM_BYTES);
1133
191k
    do {
1134
191k
        input[ML_KEM_RANDOM_BYTES] = (*counter)++;
1135
191k
        if (!cbd(out, input, mdctx, key))
1136
0
            return 0;
1137
191k
        scalar_ntt(out++);
1138
191k
    } while (--rank > 0);
1139
63.9k
    return 1;
1140
63.9k
}
1141
1142
/* The |ETA1| value for ML-KEM-512 is 3, the rest and all ETA2 values are 2. */
1143
32.0k
#define CBD1(evp_type)  ((evp_type) == EVP_PKEY_ML_KEM_512 ? cbd_3 : cbd_2)
1144
1145
/*
1146
 * FIPS 203, Section 5.2, Algorithm 14: K-PKE.Encrypt.
1147
 *
1148
 * Encrypts a message with given randomness to the ciphertext in |out|. Without
1149
 * applying the Fujisaki-Okamoto transform this would not result in a CCA
1150
 * secure scheme, since lattice schemes are vulnerable to decryption failure
1151
 * oracles.
1152
 *
1153
 * The steps are re-ordered to make more efficient/localised use of storage.
1154
 *
1155
 * Note also that the input public key is assumed to hold a precomputed matrix
1156
 * |A| (our key->m, with the public key holding an expanded (16-bit per scalar
1157
 * coefficient) key->t vector).
1158
 *
1159
 * Caller passes storage in |tmp| for for two temporary vectors.
1160
 */
1161
static __owur
1162
int encrypt_cpa(uint8_t out[ML_KEM_SHARED_SECRET_BYTES],
1163
                const uint8_t message[DEGREE / 8],
1164
                const uint8_t r[ML_KEM_RANDOM_BYTES], scalar *tmp,
1165
                EVP_MD_CTX *mdctx, const ML_KEM_KEY *key)
1166
125
{
1167
125
    const ML_KEM_VINFO *vinfo = key->vinfo;
1168
125
    CBD_FUNC cbd_1 = CBD1(vinfo->evp_type);
1169
125
    int rank = vinfo->rank;
1170
    /* We can use tmp[0..rank-1] as storage for |y|, then |e1|, ... */
1171
125
    scalar *y = &tmp[0], *e1 = y, *e2 = y;
1172
    /* We can use tmp[rank]..tmp[2*rank - 1] for |u| */
1173
125
    scalar *u = &tmp[rank];
1174
125
    scalar v;
1175
125
    uint8_t input[ML_KEM_RANDOM_BYTES + 1];
1176
125
    uint8_t counter = 0;
1177
125
    int du = vinfo->du;
1178
125
    int dv = vinfo->dv;
1179
1180
    /* FIPS 203 "y" vector */
1181
125
    if (!gencbd_vector_ntt(y, cbd_1, &counter, r, rank, mdctx, key))
1182
0
        return 0;
1183
    /* FIPS 203 "v" scalar */
1184
125
    inner_product(&v, key->t, y, rank);
1185
125
    scalar_inverse_ntt(&v);
1186
    /* FIPS 203 "u" vector */
1187
125
    matrix_mult_intt(u, key->m, y, rank);
1188
1189
    /* All done with |y|, now free to reuse tmp[0] for FIPS 203 |e1| */
1190
125
    if (!gencbd_vector(e1, cbd_2, &counter, r, rank, mdctx, key))
1191
0
        return 0;
1192
125
    vector_add(u, e1, rank);
1193
125
    vector_compress(u, du, rank);
1194
125
    vector_encode(out, u, du, rank);
1195
1196
    /* All done with |e1|, now free to reuse tmp[0] for FIPS 203 |e2| */
1197
125
    memcpy(input, r, ML_KEM_RANDOM_BYTES);
1198
125
    input[ML_KEM_RANDOM_BYTES] = counter;
1199
125
    if (!cbd_2(e2, input, mdctx, key))
1200
0
        return 0;
1201
125
    scalar_add(&v, e2);
1202
1203
    /* Combine message with |v| */
1204
125
    scalar_decode_decompress_add(&v, message);
1205
125
    scalar_compress(&v, dv);
1206
125
    scalar_encode(out + vinfo->u_vector_bytes, &v, dv);
1207
125
    return 1;
1208
125
}
1209
1210
/*
1211
 * FIPS 203, Section 5.3, Algorithm 15: K-PKE.Decrypt.
1212
 */
1213
static void
1214
decrypt_cpa(uint8_t out[ML_KEM_SHARED_SECRET_BYTES],
1215
            const uint8_t *ctext, scalar *u, const ML_KEM_KEY *key)
1216
53
{
1217
53
    const ML_KEM_VINFO *vinfo = key->vinfo;
1218
53
    scalar v, mask;
1219
53
    int rank = vinfo->rank;
1220
53
    int du = vinfo->du;
1221
53
    int dv = vinfo->dv;
1222
1223
53
    vector_decode_decompress_ntt(u, ctext, du, rank);
1224
53
    scalar_decode(&v, ctext + vinfo->u_vector_bytes, dv);
1225
53
    scalar_decompress(&v, dv);
1226
53
    inner_product(&mask, key->s, u, rank);
1227
53
    scalar_inverse_ntt(&mask);
1228
53
    scalar_sub(&v, &mask);
1229
53
    scalar_compress(&v, 1);
1230
53
    scalar_encode_1(out, &v);
1231
53
}
1232
1233
/*-
1234
 * FIPS 203, Section 7.1, Algorithm 19: "ML-KEM.KeyGen".
1235
 * FIPS 203, Section 7.2, Algorithm 20: "ML-KEM.Encaps".
1236
 *
1237
 * Fills the |out| buffer with the |ek| output of "ML-KEM.KeyGen", or,
1238
 * equivalently, the |ek| input of "ML-KEM.Encaps", i.e. returns the
1239
 * wire-format of an ML-KEM public key.
1240
 */
1241
static void encode_pubkey(uint8_t *out, const ML_KEM_KEY *key)
1242
31.8k
{
1243
31.8k
    const uint8_t *rho = key->rho;
1244
31.8k
    const ML_KEM_VINFO *vinfo = key->vinfo;
1245
1246
31.8k
    vector_encode(out, key->t, 12, vinfo->rank);
1247
31.8k
    memcpy(out + vinfo->vector_bytes, rho, ML_KEM_RANDOM_BYTES);
1248
31.8k
}
1249
1250
/*-
1251
 * FIPS 203, Section 7.1, Algorithm 19: "ML-KEM.KeyGen".
1252
 *
1253
 * Fills the |out| buffer with the |dk| output of "ML-KEM.KeyGen".
1254
 * This matches the input format of parse_prvkey() below.
1255
 */
1256
static void encode_prvkey(uint8_t *out, const ML_KEM_KEY *key)
1257
66
{
1258
66
    const ML_KEM_VINFO *vinfo = key->vinfo;
1259
1260
66
    vector_encode(out, key->s, 12, vinfo->rank);
1261
66
    out += vinfo->vector_bytes;
1262
66
    encode_pubkey(out, key);
1263
66
    out += vinfo->pubkey_bytes;
1264
66
    memcpy(out, key->pkhash, ML_KEM_PKHASH_BYTES);
1265
66
    out += ML_KEM_PKHASH_BYTES;
1266
66
    memcpy(out, key->z, ML_KEM_RANDOM_BYTES);
1267
66
}
1268
1269
/*-
1270
 * FIPS 203, Section 7.1, Algorithm 19: "ML-KEM.KeyGen".
1271
 * FIPS 203, Section 7.2, Algorithm 20: "ML-KEM.Encaps".
1272
 *
1273
 * This function parses the |in| buffer as the |ek| output of "ML-KEM.KeyGen",
1274
 * or, equivalently, the |ek| input of "ML-KEM.Encaps", i.e. decodes the
1275
 * wire-format of the ML-KEM public key.
1276
 */
1277
static int parse_pubkey(const uint8_t *in, EVP_MD_CTX *mdctx, ML_KEM_KEY *key)
1278
311
{
1279
311
    const ML_KEM_VINFO *vinfo = key->vinfo;
1280
1281
    /* Decode and check |t| */
1282
311
    if (!vector_decode_12(key->t, in, vinfo->rank)) {
1283
130
        ERR_raise_data(ERR_LIB_PROV, PROV_R_INVALID_KEY,
1284
130
                       "%s invalid public 't' vector",
1285
130
                       vinfo->algorithm_name);
1286
130
        return 0;
1287
130
    }
1288
    /* Save the matrix |m| recovery seed |rho| */
1289
181
    memcpy(key->rho, in + vinfo->vector_bytes, ML_KEM_RANDOM_BYTES);
1290
    /*
1291
     * Pre-compute the public key hash, needed for both encap and decap.
1292
     * Also pre-compute the matrix expansion, stored with the public key.
1293
     */
1294
181
    if (!hash_h(key->pkhash, in, vinfo->pubkey_bytes, mdctx, key)
1295
181
        || !matrix_expand(mdctx, key)) {
1296
0
        ERR_raise_data(ERR_LIB_CRYPTO, ERR_R_INTERNAL_ERROR,
1297
0
                       "internal error while parsing %s public key",
1298
0
                       vinfo->algorithm_name);
1299
0
        return 0;
1300
0
    }
1301
181
    return 1;
1302
181
}
1303
1304
/*
1305
 * FIPS 203, Section 7.1, Algorithm 19: "ML-KEM.KeyGen".
1306
 *
1307
 * Parses the |in| buffer as a |dk| output of "ML-KEM.KeyGen".
1308
 * This matches the output format of encode_prvkey() above.
1309
 */
1310
static int parse_prvkey(const uint8_t *in, EVP_MD_CTX *mdctx, ML_KEM_KEY *key)
1311
28
{
1312
28
    const ML_KEM_VINFO *vinfo = key->vinfo;
1313
1314
    /* Decode and check |s|. */
1315
28
    if (!vector_decode_12(key->s, in, vinfo->rank)) {
1316
28
        ERR_raise_data(ERR_LIB_PROV, PROV_R_INVALID_KEY,
1317
28
                       "%s invalid private 's' vector",
1318
28
                       vinfo->algorithm_name);
1319
28
        return 0;
1320
28
    }
1321
0
    in += vinfo->vector_bytes;
1322
1323
0
    if (!parse_pubkey(in, mdctx, key))
1324
0
        return 0;
1325
0
    in += vinfo->pubkey_bytes;
1326
1327
    /* Check public key hash. */
1328
0
    if (memcmp(key->pkhash, in, ML_KEM_PKHASH_BYTES) != 0) {
1329
0
        ERR_raise_data(ERR_LIB_PROV, PROV_R_INVALID_KEY,
1330
0
                       "%s public key hash mismatch",
1331
0
                       vinfo->algorithm_name);
1332
0
        return 0;
1333
0
    }
1334
0
    in += ML_KEM_PKHASH_BYTES;
1335
1336
0
    memcpy(key->z, in, ML_KEM_RANDOM_BYTES);
1337
0
    return 1;
1338
0
}
1339
1340
/*
1341
 * FIPS 203, Section 6.1, Algorithm 16: "ML-KEM.KeyGen_internal".
1342
 *
1343
 * The implementation of Section 5.1, Algorithm 13, "K-PKE.KeyGen(d)" is
1344
 * inlined.
1345
 *
1346
 * The caller MUST pass a pre-allocated digest context that is not shared with
1347
 * any concurrent computation.
1348
 *
1349
 * This function optionally outputs the serialised wire-form |ek| public key
1350
 * into the provided |pubenc| buffer, and generates the content of the |rho|,
1351
 * |pkhash|, |t|, |m|, |s| and |z| components of the private |key| (which must
1352
 * have preallocated space for these).
1353
 *
1354
 * Keys are computed from a 32-byte random |d| plus the 1 byte rank for
1355
 * domain separation.  These are concatenated and hashed to produce a pair of
1356
 * 32-byte seeds public "rho", used to generate the matrix, and private "sigma",
1357
 * used to generate the secret vector |s|.
1358
 *
1359
 * The second random input |z| is copied verbatim into the Fujisaki-Okamoto
1360
 * (FO) transform "implicit-rejection" secret (the |z| component of the private
1361
 * key), which thwarts chosen-ciphertext attacks, provided decap() runs in
1362
 * constant time, with no side channel leaks, on all well-formed (valid length,
1363
 * and correctly encoded) ciphertext inputs.
1364
 */
1365
static __owur
1366
int genkey(const uint8_t seed[ML_KEM_SEED_BYTES],
1367
           EVP_MD_CTX *mdctx, uint8_t *pubenc, ML_KEM_KEY *key)
1368
31.8k
{
1369
31.8k
    uint8_t hashed[2 * ML_KEM_RANDOM_BYTES];
1370
31.8k
    const uint8_t *const sigma = hashed + ML_KEM_RANDOM_BYTES;
1371
31.8k
    uint8_t augmented_seed[ML_KEM_RANDOM_BYTES + 1];
1372
31.8k
    const ML_KEM_VINFO *vinfo = key->vinfo;
1373
31.8k
    CBD_FUNC cbd_1 = CBD1(vinfo->evp_type);
1374
31.8k
    int rank = vinfo->rank;
1375
31.8k
    uint8_t counter = 0;
1376
31.8k
    int ret = 0;
1377
1378
    /*
1379
     * Use the "d" seed salted with the rank to derive the public and private
1380
     * seeds rho and sigma.
1381
     */
1382
31.8k
    memcpy(augmented_seed, seed, ML_KEM_RANDOM_BYTES);
1383
31.8k
    augmented_seed[ML_KEM_RANDOM_BYTES] = (uint8_t) rank;
1384
31.8k
    if (!hash_g(hashed, augmented_seed, sizeof(augmented_seed), mdctx, key))
1385
0
        goto end;
1386
31.8k
    memcpy(key->rho, hashed, ML_KEM_RANDOM_BYTES);
1387
    /* The |rho| matrix seed is public */
1388
31.8k
    CONSTTIME_DECLASSIFY(key->rho, ML_KEM_RANDOM_BYTES);
1389
1390
    /* FIPS 203 |e| vector is initial value of key->t */
1391
31.8k
    if (!matrix_expand(mdctx, key)
1392
31.8k
        || !gencbd_vector_ntt(key->s, cbd_1, &counter, sigma, rank, mdctx, key)
1393
31.8k
        || !gencbd_vector_ntt(key->t, cbd_1, &counter, sigma, rank, mdctx, key))
1394
0
        goto end;
1395
1396
    /* To |e| we now add the product of transpose |m| and |s|, giving |t|. */
1397
31.8k
    matrix_mult_transpose_add(key->t, key->m, key->s, rank);
1398
    /* The |t| vector is public */
1399
31.8k
    CONSTTIME_DECLASSIFY(key->t, vinfo->rank * sizeof(scalar));
1400
1401
31.8k
    if (pubenc == NULL) {
1402
        /* Incremental digest of public key without in-full serialisation. */
1403
31.8k
        if (!hash_h_pubkey(key->pkhash, mdctx, key))
1404
0
            goto end;
1405
31.8k
    } else {
1406
0
        encode_pubkey(pubenc, key);
1407
0
        if (!hash_h(key->pkhash, pubenc, vinfo->pubkey_bytes, mdctx, key))
1408
0
            goto end;
1409
0
    }
1410
1411
    /* Save |z| portion of seed for "implicit rejection" on failure. */
1412
31.8k
    memcpy(key->z, seed + ML_KEM_RANDOM_BYTES, ML_KEM_RANDOM_BYTES);
1413
1414
    /* Optionally save the |d| portion of the seed */
1415
31.8k
    key->d = key->z + ML_KEM_RANDOM_BYTES;
1416
31.8k
    if (key->prov_flags & ML_KEM_KEY_RETAIN_SEED) {
1417
31.8k
        memcpy(key->d, seed, ML_KEM_RANDOM_BYTES);
1418
31.8k
    } else {
1419
0
        OPENSSL_cleanse(key->d, ML_KEM_RANDOM_BYTES);
1420
0
        key->d = NULL;
1421
0
    }
1422
1423
31.8k
    ret = 1;
1424
31.8k
 end:
1425
31.8k
    OPENSSL_cleanse((void *)augmented_seed, ML_KEM_RANDOM_BYTES);
1426
31.8k
    OPENSSL_cleanse((void *)sigma, ML_KEM_RANDOM_BYTES);
1427
31.8k
    if (ret == 0) {
1428
0
        ERR_raise_data(ERR_LIB_CRYPTO, ERR_R_INTERNAL_ERROR,
1429
0
                       "internal error while generating %s private key",
1430
0
                       vinfo->algorithm_name);
1431
0
    }
1432
31.8k
    return ret;
1433
31.8k
}
1434
1435
/*-
1436
 * FIPS 203, Section 6.2, Algorithm 17: "ML-KEM.Encaps_internal".
1437
 * This is the deterministic version with randomness supplied externally.
1438
 *
1439
 * The caller must pass space for two vectors in |tmp|.
1440
 * The |ctext| buffer have space for the ciphertext of the ML-KEM variant
1441
 * of the provided key.
1442
 */
1443
static
1444
int encap(uint8_t *ctext, uint8_t secret[ML_KEM_SHARED_SECRET_BYTES],
1445
          const uint8_t entropy[ML_KEM_RANDOM_BYTES],
1446
          scalar *tmp, EVP_MD_CTX *mdctx, const ML_KEM_KEY *key)
1447
72
{
1448
72
    uint8_t input[ML_KEM_RANDOM_BYTES + ML_KEM_PKHASH_BYTES];
1449
72
    uint8_t Kr[ML_KEM_SHARED_SECRET_BYTES + ML_KEM_RANDOM_BYTES];
1450
72
    uint8_t *r = Kr + ML_KEM_SHARED_SECRET_BYTES;
1451
72
    int ret;
1452
1453
72
    memcpy(input, entropy, ML_KEM_RANDOM_BYTES);
1454
72
    memcpy(input + ML_KEM_RANDOM_BYTES, key->pkhash, ML_KEM_PKHASH_BYTES);
1455
72
    ret = hash_g(Kr, input, sizeof(input), mdctx, key)
1456
72
        && encrypt_cpa(ctext, entropy, r, tmp, mdctx, key);
1457
72
    OPENSSL_cleanse((void *)input, sizeof(input));
1458
1459
72
    if (ret)
1460
72
        memcpy(secret, Kr, ML_KEM_SHARED_SECRET_BYTES);
1461
0
    else
1462
0
        ERR_raise_data(ERR_LIB_CRYPTO, ERR_R_INTERNAL_ERROR,
1463
0
                       "internal error while performing %s encapsulation",
1464
0
                       key->vinfo->algorithm_name);
1465
72
    return ret;
1466
72
}
1467
1468
/*
1469
 * FIPS 203, Section 6.3, Algorithm 18: ML-KEM.Decaps_internal
1470
 *
1471
 * Barring failure of the supporting SHA3/SHAKE primitives, this is fully
1472
 * deterministic, the randomness for the FO transform is extracted during
1473
 * private key generation.
1474
 *
1475
 * The caller must pass space for two vectors in |tmp|.
1476
 * The |ctext| and |tmp_ctext| buffers must each have space for the ciphertext
1477
 * of the key's ML-KEM variant.
1478
 */
1479
static
1480
int decap(uint8_t secret[ML_KEM_SHARED_SECRET_BYTES],
1481
          const uint8_t *ctext, uint8_t *tmp_ctext, scalar *tmp,
1482
          EVP_MD_CTX *mdctx, const ML_KEM_KEY *key)
1483
53
{
1484
53
    uint8_t decrypted[ML_KEM_SHARED_SECRET_BYTES + ML_KEM_PKHASH_BYTES];
1485
53
    uint8_t failure_key[ML_KEM_RANDOM_BYTES];
1486
53
    uint8_t Kr[ML_KEM_SHARED_SECRET_BYTES + ML_KEM_RANDOM_BYTES];
1487
53
    uint8_t *r = Kr + ML_KEM_SHARED_SECRET_BYTES;
1488
53
    const uint8_t *pkhash = key->pkhash;
1489
53
    const ML_KEM_VINFO *vinfo = key->vinfo;
1490
53
    int i;
1491
53
    uint8_t mask;
1492
1493
    /*
1494
     * If our KDF is unavailable, fail early! Otherwise, keep going ignoring
1495
     * any further errors, returning success, and whatever we got for a shared
1496
     * secret.  The decrypt_cpa() function is just arithmetic on secret data,
1497
     * so should not be subject to failure that makes its output predictable.
1498
     *
1499
     * We guard against "should never happen" catastrophic failure of the
1500
     * "pure" function |hash_g| by overwriting the shared secret with the
1501
     * content of the failure key and returning early, if nevertheless hash_g
1502
     * fails.  This is not constant-time, but a failure of |hash_g| already
1503
     * implies loss of side-channel resistance.
1504
     *
1505
     * The same action is taken, if also |encrypt_cpa| should catastrophically
1506
     * fail, due to failure of the |PRF| underlying the CBD functions.
1507
     */
1508
53
    if (!kdf(failure_key, key->z, ctext, vinfo->ctext_bytes, mdctx, key)) {
1509
0
        ERR_raise_data(ERR_LIB_CRYPTO, ERR_R_INTERNAL_ERROR,
1510
0
                       "internal error while performing %s decapsulation",
1511
0
                       vinfo->algorithm_name);
1512
0
        return 0;
1513
0
    }
1514
53
    decrypt_cpa(decrypted, ctext, tmp, key);
1515
53
    memcpy(decrypted + ML_KEM_SHARED_SECRET_BYTES, pkhash, ML_KEM_PKHASH_BYTES);
1516
53
    if (!hash_g(Kr, decrypted, sizeof(decrypted), mdctx, key)
1517
53
        || !encrypt_cpa(tmp_ctext, decrypted, r, tmp, mdctx, key)) {
1518
0
        memcpy(secret, failure_key, ML_KEM_SHARED_SECRET_BYTES);
1519
0
        OPENSSL_cleanse(decrypted, ML_KEM_SHARED_SECRET_BYTES);
1520
0
        return 1;
1521
0
    }
1522
53
    mask = constant_time_eq_int_8(0,
1523
53
        CRYPTO_memcmp(ctext, tmp_ctext, vinfo->ctext_bytes));
1524
1.74k
    for (i = 0; i < ML_KEM_SHARED_SECRET_BYTES; i++)
1525
1.69k
        secret[i] = constant_time_select_8(mask, Kr[i], failure_key[i]);
1526
53
    OPENSSL_cleanse(decrypted, ML_KEM_SHARED_SECRET_BYTES);
1527
53
    OPENSSL_cleanse(Kr, sizeof(Kr));
1528
53
    return 1;
1529
53
}
1530
1531
/*
1532
 * After allocating storage for public or private key data, update the key
1533
 * component pointers to reference that storage.
1534
 */
1535
static __owur
1536
int add_storage(scalar *p, int private, ML_KEM_KEY *key)
1537
15.6k
{
1538
15.6k
    int rank = key->vinfo->rank;
1539
1540
15.6k
    if (p == NULL)
1541
0
        return 0;
1542
1543
    /*
1544
     * We're adding key material, the seed buffer will now hold |rho| and
1545
     * |pkhash|.
1546
     */
1547
15.6k
    memset(key->seedbuf, 0, sizeof(key->seedbuf));
1548
15.6k
    key->rho = key->seedbuf;
1549
15.6k
    key->pkhash = key->seedbuf + ML_KEM_RANDOM_BYTES;
1550
15.6k
    key->d = key->z = NULL;
1551
1552
    /* A public key needs space for |t| and |m| */
1553
15.6k
    key->m = (key->t = p) + rank;
1554
1555
    /*
1556
     * A private key also needs space for |s| and |z|.
1557
     * The |z| buffer always includes additional space for |d|, but a key's |d|
1558
     * pointer is left NULL when parsed from the NIST format, which omits that
1559
     * information.  Only keys generated from a (d, z) seed pair will have a
1560
     * non-NULL |d| pointer.
1561
     */
1562
15.6k
    if (private)
1563
15.5k
        key->z = (uint8_t *)(rank + (key->s = key->m + rank * rank));
1564
15.6k
    return 1;
1565
15.6k
}
1566
1567
/*
1568
 * After freeing the storage associated with a key that failed to be
1569
 * constructed, reset the internal pointers back to NULL.
1570
 */
1571
void
1572
ossl_ml_kem_key_reset(ML_KEM_KEY *key)
1573
15.7k
{
1574
15.7k
    if (key->t == NULL)
1575
86
        return;
1576
    /*-
1577
     * Cleanse any sensitive data:
1578
     * - The private vector |s| is immediately followed by the FO failure
1579
     *   secret |z|, and seed |d|, we can cleanse all three in one call.
1580
     *
1581
     * - Otherwise, when key->d is set, cleanse the stashed seed.
1582
     */
1583
15.6k
    if (ossl_ml_kem_have_prvkey(key))
1584
15.5k
        OPENSSL_cleanse(key->s,
1585
15.5k
                        key->vinfo->rank * sizeof(scalar) + 2 * ML_KEM_RANDOM_BYTES);
1586
15.6k
    OPENSSL_free(key->t);
1587
15.6k
    key->d = key->z = (uint8_t *)(key->s = key->m = key->t = NULL);
1588
15.6k
}
1589
1590
/*
1591
 * ----- API exported to the provider
1592
 *
1593
 * Parameters with an implicit fixed length in the internal static API of each
1594
 * variant have an explicit checked length argument at this layer.
1595
 */
1596
1597
/* Retrieve the parameters of one of the ML-KEM variants */
1598
const ML_KEM_VINFO *ossl_ml_kem_get_vinfo(int evp_type)
1599
179k
{
1600
179k
    switch (evp_type) {
1601
38.8k
    case EVP_PKEY_ML_KEM_512:
1602
38.8k
        return &vinfo_map[ML_KEM_512_VINFO];
1603
101k
    case EVP_PKEY_ML_KEM_768:
1604
101k
        return &vinfo_map[ML_KEM_768_VINFO];
1605
38.8k
    case EVP_PKEY_ML_KEM_1024:
1606
38.8k
        return &vinfo_map[ML_KEM_1024_VINFO];
1607
179k
    }
1608
0
    return NULL;
1609
179k
}
1610
1611
ML_KEM_KEY *ossl_ml_kem_key_new(OSSL_LIB_CTX *libctx, const char *properties,
1612
                                int evp_type)
1613
32.2k
{
1614
32.2k
    const ML_KEM_VINFO *vinfo = ossl_ml_kem_get_vinfo(evp_type);
1615
32.2k
    ML_KEM_KEY *key;
1616
1617
32.2k
    if (vinfo == NULL) {
1618
0
        ERR_raise_data(ERR_LIB_CRYPTO, ERR_R_PASSED_INVALID_ARGUMENT,
1619
0
                       "unsupported ML-KEM key type: %d", evp_type);
1620
0
        return NULL;
1621
0
    }
1622
1623
32.2k
    if ((key = OPENSSL_malloc(sizeof(*key))) == NULL)
1624
0
        return NULL;
1625
1626
32.2k
    key->vinfo = vinfo;
1627
32.2k
    key->libctx = libctx;
1628
32.2k
    key->prov_flags = ML_KEM_KEY_PROV_FLAGS_DEFAULT;
1629
32.2k
    key->shake128_md = EVP_MD_fetch(libctx, "SHAKE128", properties);
1630
32.2k
    key->shake256_md = EVP_MD_fetch(libctx, "SHAKE256", properties);
1631
32.2k
    key->sha3_256_md = EVP_MD_fetch(libctx, "SHA3-256", properties);
1632
32.2k
    key->sha3_512_md = EVP_MD_fetch(libctx, "SHA3-512", properties);
1633
32.2k
    key->d = key->z = key->rho = key->pkhash = key->encoded_dk = NULL;
1634
32.2k
    key->s = key->m = key->t = NULL;
1635
1636
32.2k
    if (key->shake128_md != NULL
1637
32.2k
        && key->shake256_md != NULL
1638
32.2k
        && key->sha3_256_md != NULL
1639
32.2k
        && key->sha3_512_md != NULL)
1640
32.2k
        return key;
1641
1642
0
    ossl_ml_kem_key_free(key);
1643
0
    ERR_raise_data(ERR_LIB_CRYPTO, ERR_R_INTERNAL_ERROR,
1644
0
                   "missing SHA3 digest algorithms while creating %s key",
1645
0
                   vinfo->algorithm_name);
1646
0
    return NULL;
1647
32.2k
}
1648
1649
ML_KEM_KEY *ossl_ml_kem_key_dup(const ML_KEM_KEY *key, int selection)
1650
17
{
1651
17
    int ok = 0;
1652
17
    ML_KEM_KEY *ret;
1653
1654
    /*
1655
     * Partially decoded keys, not yet imported or loaded, should never be
1656
     * duplicated.
1657
     */
1658
17
    if (ossl_ml_kem_decoded_key(key))
1659
0
        return NULL;
1660
1661
17
    if (key == NULL
1662
17
        || (ret = OPENSSL_memdup(key, sizeof(*key))) == NULL)
1663
0
        return NULL;
1664
17
    ret->d = ret->z = ret->rho = ret->pkhash = NULL;
1665
17
    ret->s = ret->m = ret->t = NULL;
1666
1667
    /* Clear selection bits we can't fulfill */
1668
17
    if (!ossl_ml_kem_have_pubkey(key))
1669
0
        selection = 0;
1670
17
    else if (!ossl_ml_kem_have_prvkey(key))
1671
17
        selection &= ~OSSL_KEYMGMT_SELECT_PRIVATE_KEY;
1672
1673
17
    switch (selection & OSSL_KEYMGMT_SELECT_KEYPAIR) {
1674
0
    case 0:
1675
0
        ok = 1;
1676
0
        break;
1677
17
    case OSSL_KEYMGMT_SELECT_PUBLIC_KEY:
1678
17
        ok = add_storage(OPENSSL_memdup(key->t, key->vinfo->puballoc), 0, ret);
1679
17
        ret->rho = ret->seedbuf;
1680
17
        ret->pkhash = ret->rho + ML_KEM_RANDOM_BYTES;
1681
17
        break;
1682
0
    case OSSL_KEYMGMT_SELECT_PRIVATE_KEY:
1683
0
        ok = add_storage(OPENSSL_memdup(key->t, key->vinfo->prvalloc), 1, ret);
1684
        /* Duplicated keys retain |d|, if available */
1685
0
        if (key->d != NULL)
1686
0
            ret->d = ret->z + ML_KEM_RANDOM_BYTES;
1687
0
        break;
1688
17
    }
1689
1690
17
    if (!ok) {
1691
0
        OPENSSL_free(ret);
1692
0
        return NULL;
1693
0
    }
1694
1695
17
    EVP_MD_up_ref(ret->shake128_md);
1696
17
    EVP_MD_up_ref(ret->shake256_md);
1697
17
    EVP_MD_up_ref(ret->sha3_256_md);
1698
17
    EVP_MD_up_ref(ret->sha3_512_md);
1699
1700
17
    return ret;
1701
17
}
1702
1703
void ossl_ml_kem_key_free(ML_KEM_KEY *key)
1704
83.6k
{
1705
83.6k
    if (key == NULL)
1706
67.9k
        return;
1707
1708
15.6k
    EVP_MD_free(key->shake128_md);
1709
15.6k
    EVP_MD_free(key->shake256_md);
1710
15.6k
    EVP_MD_free(key->sha3_256_md);
1711
15.6k
    EVP_MD_free(key->sha3_512_md);
1712
1713
15.6k
    if (ossl_ml_kem_decoded_key(key)) {
1714
0
        OPENSSL_cleanse(key->seedbuf, sizeof(key->seedbuf));
1715
0
        if (ossl_ml_kem_have_dkenc(key)) {
1716
0
            OPENSSL_cleanse(key->encoded_dk, key->vinfo->prvkey_bytes);
1717
0
            OPENSSL_free(key->encoded_dk);
1718
0
        }
1719
0
    }
1720
15.6k
    ossl_ml_kem_key_reset(key);
1721
15.6k
    OPENSSL_free(key);
1722
15.6k
}
1723
1724
/* Serialise the public component of an ML-KEM key */
1725
int ossl_ml_kem_encode_public_key(uint8_t *out, size_t len,
1726
                                  const ML_KEM_KEY *key)
1727
31.7k
{
1728
31.7k
    if (!ossl_ml_kem_have_pubkey(key)
1729
31.7k
        || len != key->vinfo->pubkey_bytes)
1730
0
        return 0;
1731
31.7k
    encode_pubkey(out, key);
1732
31.7k
    return 1;
1733
31.7k
}
1734
1735
/* Serialise an ML-KEM private key */
1736
int ossl_ml_kem_encode_private_key(uint8_t *out, size_t len,
1737
                                   const ML_KEM_KEY *key)
1738
66
{
1739
66
    if (!ossl_ml_kem_have_prvkey(key)
1740
66
        || len != key->vinfo->prvkey_bytes)
1741
0
        return 0;
1742
66
    encode_prvkey(out, key);
1743
66
    return 1;
1744
66
}
1745
1746
int ossl_ml_kem_encode_seed(uint8_t *out, size_t len,
1747
                            const ML_KEM_KEY *key)
1748
94
{
1749
94
    if (key == NULL || key->d == NULL || len != ML_KEM_SEED_BYTES)
1750
0
        return 0;
1751
    /*
1752
     * Both in the seed buffer, and in the allocated storage, the |d| component
1753
     * of the seed is stored last, so we must copy each separately.
1754
     */
1755
94
    memcpy(out, key->d, ML_KEM_RANDOM_BYTES);
1756
94
    out += ML_KEM_RANDOM_BYTES;
1757
94
    memcpy(out, key->z, ML_KEM_RANDOM_BYTES);
1758
94
    return 1;
1759
94
}
1760
1761
/*
1762
 * Stash the seed without (yet) performing a keygen, used during decoding, to
1763
 * avoid an extra keygen if we're only going to export the key again to load
1764
 * into another provider.
1765
 */
1766
ML_KEM_KEY *ossl_ml_kem_set_seed(const uint8_t *seed, size_t seedlen, ML_KEM_KEY *key)
1767
28
{
1768
28
    if (key == NULL
1769
28
        || ossl_ml_kem_have_pubkey(key)
1770
28
        || ossl_ml_kem_have_seed(key)
1771
28
        || seedlen != ML_KEM_SEED_BYTES)
1772
0
        return NULL;
1773
    /*
1774
     * With no public or private key material on hand, we can use the seed
1775
     * buffer for |z| and |d|, in that order.
1776
     */
1777
28
    key->z = key->seedbuf;
1778
28
    key->d = key->z + ML_KEM_RANDOM_BYTES;
1779
28
    memcpy(key->d, seed, ML_KEM_RANDOM_BYTES);
1780
28
    seed += ML_KEM_RANDOM_BYTES;
1781
28
    memcpy(key->z, seed, ML_KEM_RANDOM_BYTES);
1782
28
    return key;
1783
28
}
1784
1785
/* Parse input as a public key */
1786
int ossl_ml_kem_parse_public_key(const uint8_t *in, size_t len, ML_KEM_KEY *key)
1787
311
{
1788
311
    EVP_MD_CTX *mdctx = NULL;
1789
311
    const ML_KEM_VINFO *vinfo;
1790
311
    int ret = 0;
1791
1792
    /* Keys with key material are immutable */
1793
311
    if (key == NULL
1794
311
        || ossl_ml_kem_have_pubkey(key)
1795
311
        || ossl_ml_kem_have_dkenc(key))
1796
0
        return 0;
1797
311
    vinfo = key->vinfo;
1798
1799
311
    if (len != vinfo->pubkey_bytes
1800
311
        || (mdctx = EVP_MD_CTX_new()) == NULL)
1801
0
        return 0;
1802
1803
311
    if (add_storage(OPENSSL_malloc(vinfo->puballoc), 0, key))
1804
311
        ret = parse_pubkey(in, mdctx, key);
1805
1806
311
    if (!ret)
1807
130
        ossl_ml_kem_key_reset(key);
1808
311
    EVP_MD_CTX_free(mdctx);
1809
311
    return ret;
1810
311
}
1811
1812
/* Parse input as a new private key */
1813
int ossl_ml_kem_parse_private_key(const uint8_t *in, size_t len,
1814
                                  ML_KEM_KEY *key)
1815
28
{
1816
28
    EVP_MD_CTX *mdctx = NULL;
1817
28
    const ML_KEM_VINFO *vinfo;
1818
28
    int ret = 0;
1819
1820
    /* Keys with key material are immutable */
1821
28
    if (key == NULL
1822
28
        || ossl_ml_kem_have_pubkey(key)
1823
28
        || ossl_ml_kem_have_dkenc(key))
1824
0
        return 0;
1825
28
    vinfo = key->vinfo;
1826
1827
28
    if (len != vinfo->prvkey_bytes
1828
28
        || (mdctx = EVP_MD_CTX_new()) == NULL)
1829
0
        return 0;
1830
1831
28
    if (add_storage(OPENSSL_malloc(vinfo->prvalloc), 1, key))
1832
28
        ret = parse_prvkey(in, mdctx, key);
1833
1834
28
    if (!ret)
1835
28
        ossl_ml_kem_key_reset(key);
1836
28
    EVP_MD_CTX_free(mdctx);
1837
28
    return ret;
1838
28
}
1839
1840
/*
1841
 * Generate a new keypair, either from the saved seed (when non-null), or from
1842
 * the RNG.
1843
 */
1844
int ossl_ml_kem_genkey(uint8_t *pubenc, size_t publen, ML_KEM_KEY *key)
1845
31.8k
{
1846
31.8k
    uint8_t seed[ML_KEM_SEED_BYTES];
1847
31.8k
    EVP_MD_CTX *mdctx = NULL;
1848
31.8k
    const ML_KEM_VINFO *vinfo;
1849
31.8k
    int ret = 0;
1850
1851
31.8k
    if (key == NULL
1852
31.8k
        || ossl_ml_kem_have_pubkey(key)
1853
31.8k
        || ossl_ml_kem_have_dkenc(key))
1854
0
        return 0;
1855
31.8k
    vinfo = key->vinfo;
1856
1857
31.8k
    if (pubenc != NULL && publen != vinfo->pubkey_bytes)
1858
0
        return 0;
1859
1860
31.8k
    if (ossl_ml_kem_have_seed(key)) {
1861
28
        if (!ossl_ml_kem_encode_seed(seed, sizeof(seed), key))
1862
0
            return 0;
1863
28
        key->d = key->z = NULL;
1864
31.8k
    } else if (RAND_priv_bytes_ex(key->libctx, seed, sizeof(seed),
1865
31.8k
                                  key->vinfo->secbits) <= 0) {
1866
0
        return 0;
1867
0
    }
1868
1869
31.8k
    if ((mdctx = EVP_MD_CTX_new()) == NULL)
1870
0
        return 0;
1871
1872
    /*
1873
     * Data derived from (d, z) defaults secret, and to avoid side-channel
1874
     * leaks should not influence control flow.
1875
     */
1876
31.8k
    CONSTTIME_SECRET(seed, ML_KEM_SEED_BYTES);
1877
1878
31.8k
    if (add_storage(OPENSSL_malloc(vinfo->prvalloc), 1, key))
1879
31.8k
        ret = genkey(seed, mdctx, pubenc, key);
1880
31.8k
    OPENSSL_cleanse(seed, sizeof(seed));
1881
1882
    /* Declassify secret inputs and derived outputs before returning control */
1883
31.8k
    CONSTTIME_DECLASSIFY(seed, ML_KEM_SEED_BYTES);
1884
1885
31.8k
    EVP_MD_CTX_free(mdctx);
1886
31.8k
    if (!ret) {
1887
0
        ossl_ml_kem_key_reset(key);
1888
0
        return 0;
1889
0
    }
1890
1891
    /* The public components are already declassified */
1892
31.8k
    CONSTTIME_DECLASSIFY(key->s, vinfo->rank * sizeof(scalar));
1893
31.8k
    CONSTTIME_DECLASSIFY(key->z, 2 * ML_KEM_RANDOM_BYTES);
1894
31.8k
    return 1;
1895
31.8k
}
1896
1897
/*
1898
 * FIPS 203, Section 6.2, Algorithm 17: ML-KEM.Encaps_internal
1899
 * This is the deterministic version with randomness supplied externally.
1900
 */
1901
int ossl_ml_kem_encap_seed(uint8_t *ctext, size_t clen,
1902
                           uint8_t *shared_secret, size_t slen,
1903
                           const uint8_t *entropy, size_t elen,
1904
                           const ML_KEM_KEY *key)
1905
72
{
1906
72
    const ML_KEM_VINFO *vinfo;
1907
72
    EVP_MD_CTX *mdctx;
1908
72
    int ret = 0;
1909
1910
72
    if (key == NULL || !ossl_ml_kem_have_pubkey(key))
1911
0
        return 0;
1912
72
    vinfo = key->vinfo;
1913
1914
72
    if (ctext == NULL || clen != vinfo->ctext_bytes
1915
72
        || shared_secret == NULL || slen != ML_KEM_SHARED_SECRET_BYTES
1916
72
        || entropy == NULL || elen != ML_KEM_RANDOM_BYTES
1917
72
        || (mdctx = EVP_MD_CTX_new()) == NULL)
1918
0
        return 0;
1919
    /*
1920
     * Data derived from the encap entropy defaults secret, and to avoid
1921
     * side-channel leaks should not influence control flow.
1922
     */
1923
72
    CONSTTIME_SECRET(entropy, elen);
1924
1925
    /*-
1926
     * This avoids the need to handle allocation failures for two (max 2KB
1927
     * each) vectors, that are never retained on return from this function.
1928
     * We stack-allocate these.
1929
     */
1930
72
#   define case_encap_seed(bits)                                            \
1931
72
    case EVP_PKEY_ML_KEM_##bits:                                            \
1932
72
        {                                                                   \
1933
72
            scalar tmp[2 * ML_KEM_##bits##_RANK];                           \
1934
72
                                                                            \
1935
72
            ret = encap(ctext, shared_secret, entropy, tmp, mdctx, key);    \
1936
72
            OPENSSL_cleanse((void *)tmp, sizeof(tmp));                      \
1937
72
            break;                                                          \
1938
72
        }
1939
72
    switch (vinfo->evp_type) {
1940
31
    case_encap_seed(512);
1941
32
    case_encap_seed(768);
1942
9
    case_encap_seed(1024);
1943
72
    }
1944
72
#   undef case_encap_seed
1945
1946
    /* Declassify secret inputs and derived outputs before returning control */
1947
72
    CONSTTIME_DECLASSIFY(entropy, elen);
1948
72
    CONSTTIME_DECLASSIFY(ctext, clen);
1949
72
    CONSTTIME_DECLASSIFY(shared_secret, slen);
1950
1951
72
    EVP_MD_CTX_free(mdctx);
1952
72
    return ret;
1953
72
}
1954
1955
int ossl_ml_kem_encap_rand(uint8_t *ctext, size_t clen,
1956
                           uint8_t *shared_secret, size_t slen,
1957
                           const ML_KEM_KEY *key)
1958
72
{
1959
72
    uint8_t r[ML_KEM_RANDOM_BYTES];
1960
1961
72
    if (key == NULL)
1962
0
        return 0;
1963
1964
72
    if (RAND_bytes_ex(key->libctx, r, ML_KEM_RANDOM_BYTES,
1965
72
                      key->vinfo->secbits) < 1)
1966
0
        return 0;
1967
1968
72
    return ossl_ml_kem_encap_seed(ctext, clen, shared_secret, slen,
1969
72
                                  r, sizeof(r), key);
1970
72
}
1971
1972
int ossl_ml_kem_decap(uint8_t *shared_secret, size_t slen,
1973
                      const uint8_t *ctext, size_t clen,
1974
                      const ML_KEM_KEY *key)
1975
53
{
1976
53
    const ML_KEM_VINFO *vinfo;
1977
53
    EVP_MD_CTX *mdctx;
1978
53
    int ret = 0;
1979
#if defined(OPENSSL_CONSTANT_TIME_VALIDATION)
1980
    int classify_bytes;
1981
#endif
1982
1983
    /* Need a private key here */
1984
53
    if (!ossl_ml_kem_have_prvkey(key))
1985
0
        return 0;
1986
53
    vinfo = key->vinfo;
1987
1988
53
    if (shared_secret == NULL || slen != ML_KEM_SHARED_SECRET_BYTES
1989
53
        || ctext == NULL || clen != vinfo->ctext_bytes
1990
53
        || (mdctx = EVP_MD_CTX_new()) == NULL) {
1991
0
        (void)RAND_bytes_ex(key->libctx, shared_secret,
1992
0
                            ML_KEM_SHARED_SECRET_BYTES, vinfo->secbits);
1993
0
        return 0;
1994
0
    }
1995
#if defined(OPENSSL_CONSTANT_TIME_VALIDATION)
1996
    /*
1997
     * Data derived from |s| and |z| defaults secret, and to avoid side-channel
1998
     * leaks should not influence control flow.
1999
     */
2000
    classify_bytes = 2 * sizeof(scalar) + ML_KEM_RANDOM_BYTES;
2001
    CONSTTIME_SECRET(key->s, classify_bytes);
2002
#endif
2003
2004
    /*-
2005
     * This avoids the need to handle allocation failures for two (max 2KB
2006
     * each) vectors and an encoded ciphertext (max 1568 bytes), that are never
2007
     * retained on return from this function.
2008
     * We stack-allocate these.
2009
     */
2010
53
#   define case_decap(bits)                                             \
2011
53
    case EVP_PKEY_ML_KEM_##bits:                                        \
2012
53
        {                                                               \
2013
53
            uint8_t cbuf[CTEXT_BYTES(bits)];                            \
2014
53
            scalar tmp[2 * ML_KEM_##bits##_RANK];                       \
2015
53
                                                                        \
2016
53
            ret = decap(shared_secret, ctext, cbuf, tmp, mdctx, key);   \
2017
53
            OPENSSL_cleanse((void *)tmp, sizeof(tmp));                  \
2018
53
            break;                                                      \
2019
53
        }
2020
53
    switch (vinfo->evp_type) {
2021
31
    case_decap(512);
2022
13
    case_decap(768);
2023
9
    case_decap(1024);
2024
53
    }
2025
2026
    /* Declassify secret inputs and derived outputs before returning control */
2027
53
    CONSTTIME_DECLASSIFY(key->s, classify_bytes);
2028
53
    CONSTTIME_DECLASSIFY(shared_secret, slen);
2029
53
    EVP_MD_CTX_free(mdctx);
2030
2031
53
    return ret;
2032
53
#   undef case_decap
2033
53
}
2034
2035
int ossl_ml_kem_pubkey_cmp(const ML_KEM_KEY *key1, const ML_KEM_KEY *key2)
2036
74
{
2037
    /*
2038
     * This handles any unexpected differences in the ML-KEM variant rank,
2039
     * giving different key component structures, barring SHA3-256 hash
2040
     * collisions, the keys are the same size.
2041
     */
2042
74
    if (ossl_ml_kem_have_pubkey(key1) && ossl_ml_kem_have_pubkey(key2))
2043
74
        return memcmp(key1->pkhash, key2->pkhash, ML_KEM_PKHASH_BYTES) == 0;
2044
2045
    /*
2046
     * No match if just one of the public keys is not available, otherwise both
2047
     * are unavailable, and for now such keys are considered equal.
2048
     */
2049
0
    return (!(ossl_ml_kem_have_pubkey(key1) ^ ossl_ml_kem_have_pubkey(key2)));
2050
74
}