Coverage Report

Created: 2025-12-04 06:33

next uncovered line (L), next uncovered region (R), next uncovered branch (B)
/src/openssl36/crypto/ml_kem/ml_kem.c
Line
Count
Source
1
/*
2
 * Copyright 2024-2025 The OpenSSL Project Authors. All Rights Reserved.
3
 *
4
 * Licensed under the Apache License 2.0 (the "License").  You may not use
5
 * this file except in compliance with the License.  You can obtain a copy
6
 * in the file LICENSE in the source distribution or at
7
 * https://www.openssl.org/source/license.html
8
 */
9
10
#include <openssl/byteorder.h>
11
#include <openssl/rand.h>
12
#include <openssl/proverr.h>
13
#include "crypto/ml_kem.h"
14
#include "internal/common.h"
15
#include "internal/constant_time.h"
16
#include "internal/sha3.h"
17
18
#if defined(OPENSSL_CONSTANT_TIME_VALIDATION)
19
#include <valgrind/memcheck.h>
20
#endif
21
22
#if ML_KEM_SEED_BYTES != ML_KEM_SHARED_SECRET_BYTES + ML_KEM_RANDOM_BYTES
23
# error "ML-KEM keygen seed length != shared secret + random bytes length"
24
#endif
25
#if ML_KEM_SHARED_SECRET_BYTES != ML_KEM_RANDOM_BYTES
26
# error "Invalid unequal lengths of ML-KEM shared secret and random inputs"
27
#endif
28
29
#if UINT_MAX < UINT32_MAX
30
# error "Unsupported compiler: sizeof(unsigned int) < sizeof(uint32_t)"
31
#endif
32
33
/* Handy function-like bit-extraction macros */
34
32.6M
#define bit0(b)     ((b) & 1)
35
228M
#define bitn(n, b)  (((b) >> n) & 1)
36
37
/*
38
 * 12 bits are sufficient to losslessly represent values in [0, q-1].
39
 * INVERSE_DEGREE is (n/2)^-1 mod q; used in inverse NTT.
40
 */
41
2.48M
#define DEGREE          ML_KEM_DEGREE
42
#define INVERSE_DEGREE  (ML_KEM_PRIME - 2 * 13)
43
#define LOG2PRIME       12
44
#define BARRETT_SHIFT   (2 * LOG2PRIME)
45
46
#ifdef SHA3_BLOCKSIZE
47
# define SHAKE128_BLOCKSIZE SHA3_BLOCKSIZE(128)
48
#endif
49
50
/*
51
 * Return whether a value that can only be 0 or 1 is non-zero, in constant time
52
 * in practice!  The return value is a mask that is all ones if true, and all
53
 * zeros otherwise (twos-complement arithmentic assumed for unsigned values).
54
 *
55
 * Although this is used in constant-time selects, we omit a value barrier
56
 * here.  Value barriers impede auto-vectorization (likely because it forces
57
 * the value to transit through a general-purpose register). On AArch64, this
58
 * is a difference of 2x.
59
 *
60
 * We usually add value barriers to selects because Clang turns consecutive
61
 * selects with the same condition into a branch instead of CMOV/CSEL. This
62
 * condition does not occur in Kyber, so omitting it seems to be safe so far,
63
 * but see |cbd_2|, |cbd_3|, where reduction needs to be specialised to the
64
 * sign of the input, rather than adding |q| in advance, and using the generic
65
 * |reduce_once|.  (David Benjamin, Chromium)
66
 */
67
#if 0
68
# define constish_time_non_zero(b) (~constant_time_is_zero(b));
69
#else
70
895M
# define constish_time_non_zero(b) (0u - (b))
71
#endif
72
73
/*
74
 * The scalar rejection-sampling buffer size needs to be a multiple of 12, but
75
 * is otherwise arbitrary, the preferred block size matches the internal buffer
76
 * size of SHAKE128, avoiding internal buffering and copying in SHAKE128. That
77
 * block size of (1600 - 256)/8 bytes, or 168, just happens to divide by 12!
78
 *
79
 * If the blocksize is unknown, or is not divisible by 12, 168 is used as a
80
 * fallback.
81
 */
82
#if defined(SHAKE128_BLOCKSIZE) && (SHAKE128_BLOCKSIZE) % 12 == 0
83
# define SCALAR_SAMPLING_BUFSIZE (SHAKE128_BLOCKSIZE)
84
#else
85
# define SCALAR_SAMPLING_BUFSIZE 168
86
#endif
87
88
/*
89
 * Structure of keys
90
 */
91
typedef struct ossl_ml_kem_scalar_st {
92
    /* On every function entry and exit, 0 <= c[i] < ML_KEM_PRIME. */
93
    uint16_t c[ML_KEM_DEGREE];
94
} scalar;
95
96
/* Key material allocation layout */
97
#define DECLARE_ML_KEM_PUBKEYDATA(name, rank) \
98
    struct name##_alloc { \
99
        /* Public vector |t| */ \
100
        scalar tbuf[(rank)]; \
101
        /* Pre-computed matrix |m| (FIPS 203 |A| transpose) */ \
102
        scalar mbuf[(rank)*(rank)]; \
103
    }
104
105
#define DECLARE_ML_KEM_PRVKEYDATA(name, rank) \
106
    struct name##_alloc { \
107
        scalar sbuf[rank]; \
108
        uint8_t zbuf[2 * ML_KEM_RANDOM_BYTES]; \
109
    }
110
111
/* Declare variant-specific public and private storage */
112
#define DECLARE_ML_KEM_VARIANT_KEYDATA(bits) \
113
    DECLARE_ML_KEM_PUBKEYDATA(pubkey_##bits, ML_KEM_##bits##_RANK); \
114
    DECLARE_ML_KEM_PRVKEYDATA(prvkey_##bits, ML_KEM_##bits##_RANK)
115
116
DECLARE_ML_KEM_VARIANT_KEYDATA(512);
117
DECLARE_ML_KEM_VARIANT_KEYDATA(768);
118
DECLARE_ML_KEM_VARIANT_KEYDATA(1024);
119
#undef DECLARE_ML_KEM_VARIANT_KEYDATA
120
#undef DECLARE_ML_KEM_PUBKEYDATA
121
#undef DECLARE_ML_KEM_PRVKEYDATA
122
123
typedef __owur
124
int (*CBD_FUNC)(scalar *out, uint8_t in[ML_KEM_RANDOM_BYTES + 1],
125
                EVP_MD_CTX *mdctx, const ML_KEM_KEY *key);
126
static void scalar_encode(uint8_t *out, const scalar *s, int bits);
127
128
/*
129
 * The wire-form of a losslessly encoded vector uses 12-bits per element.
130
 *
131
 * The wire-form public key consists of the lossless encoding of the public
132
 * vector |t|, followed by the public seed |rho|.
133
 *
134
 * Our serialised private key concatenates serialisations of the private vector
135
 * |s|, the public key, the public key hash, and the failure secret |z|.
136
 */
137
#define VECTOR_BYTES(b)     ((3 * DEGREE / 2) * ML_KEM_##b##_RANK)
138
#define PUBKEY_BYTES(b)     (VECTOR_BYTES(b) + ML_KEM_RANDOM_BYTES)
139
#define PRVKEY_BYTES(b)     (2 * PUBKEY_BYTES(b) + ML_KEM_PKHASH_BYTES)
140
141
/*
142
 * Encapsulation produces a vector "u" and a scalar "v", whose coordinates
143
 * (numbers modulo the ML-KEM prime "q") are lossily encoded using as "du" and
144
 * "dv" bits, respectively.  This encoding is the ciphertext input for
145
 * decapsulation.
146
 */
147
#define U_VECTOR_BYTES(b)   ((DEGREE / 8) * ML_KEM_##b##_DU * ML_KEM_##b##_RANK)
148
#define V_SCALAR_BYTES(b)   ((DEGREE / 8) * ML_KEM_##b##_DV)
149
#define CTEXT_BYTES(b)      (U_VECTOR_BYTES(b) + V_SCALAR_BYTES(b))
150
151
#if defined(OPENSSL_CONSTANT_TIME_VALIDATION)
152
153
/*
154
 * CONSTTIME_SECRET takes a pointer and a number of bytes and marks that region
155
 * of memory as secret. Secret data is tracked as it flows to registers and
156
 * other parts of a memory. If secret data is used as a condition for a branch,
157
 * or as a memory index, it will trigger warnings in valgrind.
158
 */
159
# define CONSTTIME_SECRET(ptr, len) VALGRIND_MAKE_MEM_UNDEFINED(ptr, len)
160
161
/*
162
 * CONSTTIME_DECLASSIFY takes a pointer and a number of bytes and marks that
163
 * region of memory as public. Public data is not subject to constant-time
164
 * rules.
165
 */
166
# define CONSTTIME_DECLASSIFY(ptr, len) VALGRIND_MAKE_MEM_DEFINED(ptr, len)
167
168
#else
169
170
# define CONSTTIME_SECRET(ptr, len)
171
# define CONSTTIME_DECLASSIFY(ptr, len)
172
173
#endif
174
175
/*
176
 * Indices of slots in the vinfo tables below
177
 */
178
53.0k
#define ML_KEM_512_VINFO    0
179
136k
#define ML_KEM_768_VINFO    1
180
52.8k
#define ML_KEM_1024_VINFO   2
181
182
/*
183
 * Per-variant fixed parameters
184
 */
185
static const ML_KEM_VINFO vinfo_map[3] = {
186
    {
187
        "ML-KEM-512",
188
        PRVKEY_BYTES(512),
189
        sizeof(struct prvkey_512_alloc),
190
        PUBKEY_BYTES(512),
191
        sizeof(struct pubkey_512_alloc),
192
        CTEXT_BYTES(512),
193
        VECTOR_BYTES(512),
194
        U_VECTOR_BYTES(512),
195
        EVP_PKEY_ML_KEM_512,
196
        ML_KEM_512_BITS,
197
        ML_KEM_512_RANK,
198
        ML_KEM_512_DU,
199
        ML_KEM_512_DV,
200
        ML_KEM_512_SECBITS,
201
        ML_KEM_512_SECURITY_CATEGORY
202
    },
203
    {
204
        "ML-KEM-768",
205
        PRVKEY_BYTES(768),
206
        sizeof(struct prvkey_768_alloc),
207
        PUBKEY_BYTES(768),
208
        sizeof(struct pubkey_768_alloc),
209
        CTEXT_BYTES(768),
210
        VECTOR_BYTES(768),
211
        U_VECTOR_BYTES(768),
212
        EVP_PKEY_ML_KEM_768,
213
        ML_KEM_768_BITS,
214
        ML_KEM_768_RANK,
215
        ML_KEM_768_DU,
216
        ML_KEM_768_DV,
217
        ML_KEM_768_SECBITS,
218
        ML_KEM_768_SECURITY_CATEGORY
219
    },
220
    {
221
        "ML-KEM-1024",
222
        PRVKEY_BYTES(1024),
223
        sizeof(struct prvkey_1024_alloc),
224
        PUBKEY_BYTES(1024),
225
        sizeof(struct pubkey_1024_alloc),
226
        CTEXT_BYTES(1024),
227
        VECTOR_BYTES(1024),
228
        U_VECTOR_BYTES(1024),
229
        EVP_PKEY_ML_KEM_1024,
230
        ML_KEM_1024_BITS,
231
        ML_KEM_1024_RANK,
232
        ML_KEM_1024_DU,
233
        ML_KEM_1024_DV,
234
        ML_KEM_1024_SECBITS,
235
        ML_KEM_1024_SECURITY_CATEGORY
236
    }
237
};
238
239
/*
240
 * Remainders modulo `kPrime`, for sufficiently small inputs, are computed in
241
 * constant time via Barrett reduction, and a final call to reduce_once(),
242
 * which reduces inputs that are at most 2*kPrime and is also constant-time.
243
 */
244
static const int kPrime = ML_KEM_PRIME;
245
static const unsigned int kBarrettShift = BARRETT_SHIFT;
246
static const size_t   kBarrettMultiplier = (1 << BARRETT_SHIFT) / ML_KEM_PRIME;
247
static const uint16_t kHalfPrime = (ML_KEM_PRIME - 1) / 2;
248
static const uint16_t kInverseDegree = INVERSE_DEGREE;
249
250
/*
251
 * Python helper:
252
 *
253
 * p = 3329
254
 * def bitreverse(i):
255
 *     ret = 0
256
 *     for n in range(7):
257
 *         bit = i & 1
258
 *         ret <<= 1
259
 *         ret |= bit
260
 *         i >>= 1
261
 *     return ret
262
 */
263
264
/*-
265
 * First precomputed array from Appendix A of FIPS 203, or else Python:
266
 * kNTTRoots = [pow(17, bitreverse(i), p) for i in range(128)]
267
 */
268
static const uint16_t kNTTRoots[128] = {
269
    1,    1729, 2580, 3289, 2642, 630,  1897, 848,
270
    1062, 1919, 193,  797,  2786, 3260, 569,  1746,
271
    296,  2447, 1339, 1476, 3046, 56,   2240, 1333,
272
    1426, 2094, 535,  2882, 2393, 2879, 1974, 821,
273
    289,  331,  3253, 1756, 1197, 2304, 2277, 2055,
274
    650,  1977, 2513, 632,  2865, 33,   1320, 1915,
275
    2319, 1435, 807,  452,  1438, 2868, 1534, 2402,
276
    2647, 2617, 1481, 648,  2474, 3110, 1227, 910,
277
    17,   2761, 583,  2649, 1637, 723,  2288, 1100,
278
    1409, 2662, 3281, 233,  756,  2156, 3015, 3050,
279
    1703, 1651, 2789, 1789, 1847, 952,  1461, 2687,
280
    939,  2308, 2437, 2388, 733,  2337, 268,  641,
281
    1584, 2298, 2037, 3220, 375,  2549, 2090, 1645,
282
    1063, 319,  2773, 757,  2099, 561,  2466, 2594,
283
    2804, 1092, 403,  1026, 1143, 2150, 2775, 886,
284
    1722, 1212, 1874, 1029, 2110, 2935, 885,  2154,
285
};
286
287
/*
288
 * InverseNTTRoots = [pow(17, -bitreverse(i), p) for i in range(128)]
289
 * Listed in order of use in the inverse NTT loop (index 0 is skipped):
290
 *
291
 *  0, 64, 65, ..., 127, 32, 33, ..., 63, 16, 17, ..., 31, 8, 9, ...
292
 */
293
static const uint16_t kInverseNTTRoots[128] = {
294
    1,    1175, 2444, 394,  1219, 2300, 1455, 2117,
295
    1607, 2443, 554,  1179, 2186, 2303, 2926, 2237,
296
    525,  735,  863,  2768, 1230, 2572, 556,  3010,
297
    2266, 1684, 1239, 780,  2954, 109,  1292, 1031,
298
    1745, 2688, 3061, 992,  2596, 941,  892,  1021,
299
    2390, 642,  1868, 2377, 1482, 1540, 540,  1678,
300
    1626, 279,  314,  1173, 2573, 3096, 48,   667,
301
    1920, 2229, 1041, 2606, 1692, 680,  2746, 568,
302
    3312, 2419, 2102, 219,  855,  2681, 1848, 712,
303
    682,  927,  1795, 461,  1891, 2877, 2522, 1894,
304
    1010, 1414, 2009, 3296, 464,  2697, 816,  1352,
305
    2679, 1274, 1052, 1025, 2132, 1573, 76,   2998,
306
    3040, 2508, 1355, 450,  936,  447,  2794, 1235,
307
    1903, 1996, 1089, 3273, 283,  1853, 1990, 882,
308
    3033, 1583, 2760, 69,   543,  2532, 3136, 1410,
309
    2267, 2481, 1432, 2699, 687,  40,   749,  1600,
310
};
311
312
/*
313
 * Second precomputed array from Appendix A of FIPS 203 (normalised positive),
314
 * or else Python:
315
 * ModRoots = [pow(17, 2*bitreverse(i) + 1, p) for i in range(128)]
316
 */
317
static const uint16_t kModRoots[128] = {
318
    17,   3312, 2761, 568,  583,  2746, 2649, 680,  1637, 1692, 723,  2606,
319
    2288, 1041, 1100, 2229, 1409, 1920, 2662, 667,  3281, 48,   233,  3096,
320
    756,  2573, 2156, 1173, 3015, 314,  3050, 279,  1703, 1626, 1651, 1678,
321
    2789, 540,  1789, 1540, 1847, 1482, 952,  2377, 1461, 1868, 2687, 642,
322
    939,  2390, 2308, 1021, 2437, 892,  2388, 941,  733,  2596, 2337, 992,
323
    268,  3061, 641,  2688, 1584, 1745, 2298, 1031, 2037, 1292, 3220, 109,
324
    375,  2954, 2549, 780,  2090, 1239, 1645, 1684, 1063, 2266, 319,  3010,
325
    2773, 556,  757,  2572, 2099, 1230, 561,  2768, 2466, 863,  2594, 735,
326
    2804, 525,  1092, 2237, 403,  2926, 1026, 2303, 1143, 2186, 2150, 1179,
327
    2775, 554,  886,  2443, 1722, 1607, 1212, 2117, 1874, 1455, 1029, 2300,
328
    2110, 1219, 2935, 394,  885,  2444, 2154, 1175,
329
};
330
331
/*
332
 * single_keccak hashes |inlen| bytes from |in| and writes |outlen| bytes of
333
 * output to |out|. If the |md| specifies a fixed-output function, like
334
 * SHA3-256, then |outlen| must be the correct length for that function.
335
 */
336
static __owur
337
int single_keccak(uint8_t *out, size_t outlen, const uint8_t *in, size_t inlen,
338
                  EVP_MD_CTX *mdctx)
339
296k
{
340
296k
    unsigned int sz = (unsigned int) outlen;
341
342
296k
    if (!EVP_DigestUpdate(mdctx, in, inlen))
343
0
        return 0;
344
296k
    if (EVP_MD_xof(EVP_MD_CTX_get0_md(mdctx)))
345
254k
        return EVP_DigestFinalXOF(mdctx, out, outlen);
346
42.5k
    return EVP_DigestFinal_ex(mdctx, out, &sz)
347
42.5k
        && ossl_assert((size_t) sz == outlen);
348
296k
}
349
350
/*
351
 * FIPS 203, Section 4.1, equation (4.3): PRF. Takes 32+1 input bytes, and uses
352
 * SHAKE256 to produce the input to SamplePolyCBD_eta: FIPS 203, algorithm 8.
353
 */
354
static __owur
355
int prf(uint8_t *out, size_t len, const uint8_t in[ML_KEM_RANDOM_BYTES + 1],
356
        EVP_MD_CTX *mdctx, const ML_KEM_KEY *key)
357
254k
{
358
254k
    return EVP_DigestInit_ex(mdctx, key->shake256_md, NULL)
359
254k
        && single_keccak(out, len, in, ML_KEM_RANDOM_BYTES + 1, mdctx);
360
254k
}
361
362
/*
363
 * FIPS 203, Section 4.1, equation (4.4): H.  SHA3-256 hash of a variable
364
 * length input, producing 32 bytes of output.
365
 */
366
static __owur
367
int hash_h(uint8_t out[ML_KEM_PKHASH_BYTES], const uint8_t *in, size_t len,
368
           EVP_MD_CTX *mdctx, const ML_KEM_KEY *key)
369
208
{
370
208
    return EVP_DigestInit_ex(mdctx, key->sha3_256_md, NULL)
371
208
        && single_keccak(out, ML_KEM_PKHASH_BYTES, in, len, mdctx);
372
208
}
373
374
/* Incremental hash_h of expanded public key */
375
static int
376
hash_h_pubkey(uint8_t pkhash[ML_KEM_PKHASH_BYTES],
377
              EVP_MD_CTX *mdctx, ML_KEM_KEY *key)
378
42.1k
{
379
42.1k
    const ML_KEM_VINFO *vinfo = key->vinfo;
380
42.1k
    const scalar *t = key->t, *end = t + vinfo->rank;
381
42.1k
    unsigned int sz;
382
383
42.1k
    if (!EVP_DigestInit_ex(mdctx, key->sha3_256_md, NULL))
384
0
        return 0;
385
386
126k
    do {
387
126k
        uint8_t buf[3 * DEGREE / 2];
388
389
126k
        scalar_encode(buf, t++, 12);
390
126k
        if (!EVP_DigestUpdate(mdctx, buf, sizeof(buf)))
391
0
            return 0;
392
126k
    } while (t < end);
393
394
42.1k
    if (!EVP_DigestUpdate(mdctx, key->rho, ML_KEM_RANDOM_BYTES))
395
0
        return 0;
396
42.1k
    return EVP_DigestFinal_ex(mdctx, pkhash, &sz)
397
42.1k
        && ossl_assert(sz == ML_KEM_PKHASH_BYTES);
398
42.1k
}
399
400
/*
401
 * FIPS 203, Section 4.1, equation (4.5): G.  SHA3-512 hash of a variable
402
 * length input, producing 64 bytes of output, in particular the seeds
403
 * (d,z) for key generation.
404
 */
405
static __owur
406
int hash_g(uint8_t out[ML_KEM_SEED_BYTES], const uint8_t *in, size_t len,
407
           EVP_MD_CTX *mdctx, const ML_KEM_KEY *key)
408
42.3k
{
409
42.3k
    return EVP_DigestInit_ex(mdctx, key->sha3_512_md, NULL)
410
42.3k
        && single_keccak(out, ML_KEM_SEED_BYTES, in, len, mdctx);
411
42.3k
}
412
413
/*
414
 * FIPS 203, Section 4.1, equation (4.4): J. SHAKE256 taking a variable length
415
 * input to compute a 32-byte implicit rejection shared secret, of the same
416
 * length as the expected shared secret.  (Computed even on success to avoid
417
 * side-channel leaks).
418
 */
419
static __owur
420
int kdf(uint8_t out[ML_KEM_SHARED_SECRET_BYTES],
421
        const uint8_t z[ML_KEM_RANDOM_BYTES],
422
        const uint8_t *ctext, size_t len,
423
        EVP_MD_CTX *mdctx, const ML_KEM_KEY *key)
424
100
{
425
100
    return EVP_DigestInit_ex(mdctx, key->shake256_md, NULL)
426
100
        && EVP_DigestUpdate(mdctx, z, ML_KEM_RANDOM_BYTES)
427
100
        && EVP_DigestUpdate(mdctx, ctext, len)
428
100
        && EVP_DigestFinalXOF(mdctx, out, ML_KEM_SHARED_SECRET_BYTES);
429
100
}
430
431
/*
432
 * FIPS 203, Section 4.2.2, Algorithm 7: "SampleNTT" (steps 3-17, steps 1, 2
433
 * are performed by the caller). Rejection-samples a Keccak stream to get
434
 * uniformly distributed elements in the range [0,q). This is used for matrix
435
 * expansion and only operates on public inputs.
436
 */
437
static __owur
438
int sample_scalar(scalar *out, EVP_MD_CTX *mdctx)
439
380k
{
440
380k
    uint16_t *curr = out->c, *endout = curr + DEGREE;
441
380k
    uint8_t buf[SCALAR_SAMPLING_BUFSIZE], *in;
442
380k
    uint8_t *endin = buf + sizeof(buf);
443
380k
    uint16_t d;
444
380k
    uint8_t b1, b2, b3;
445
446
1.14M
    do {
447
1.14M
        if (!EVP_DigestSqueeze(mdctx, in = buf, sizeof(buf)))
448
0
            return 0;
449
59.6M
        do {
450
59.6M
            b1 = *in++;
451
59.6M
            b2 = *in++;
452
59.6M
            b3 = *in++;
453
454
59.6M
            if (curr >= endout)
455
127k
                break;
456
59.5M
            if ((d = ((b2 & 0x0f) << 8) + b1) < kPrime)
457
49.1M
                *curr++ = d;
458
59.5M
            if (curr >= endout)
459
252k
                break;
460
59.3M
            if ((d = (b3 << 4) + (b2 >> 4)) < kPrime)
461
48.3M
                *curr++ = d;
462
59.3M
        } while (in < endin);
463
1.14M
    } while (curr < endout);
464
380k
    return 1;
465
380k
}
466
467
/*-
468
 * reduce_once reduces 0 <= x < 2*kPrime, mod kPrime.
469
 *
470
 * Subtract |q| if the input is larger, without exposing a side-channel,
471
 * avoiding the "clangover" attack.  See |constish_time_non_zero| for a
472
 * discussion on why the value barrier is by default omitted.
473
 */
474
static __owur uint16_t reduce_once(uint16_t x)
475
830M
{
476
830M
    const uint16_t subtracted = x - kPrime;
477
830M
    uint16_t mask = constish_time_non_zero(subtracted >> 15);
478
479
830M
    return (mask & x) | (~mask & subtracted);
480
830M
}
481
482
/*
483
 * Constant-time reduce x mod kPrime using Barrett reduction. x must be less
484
 * than kPrime + 2 * kPrime^2.  This is sufficient to reduce a product of
485
 * two already reduced u_int16 values, in fact it is sufficient for each
486
 * to be less than 2^12, because (kPrime * (2 * kPrime + 1)) > 2^24.
487
 */
488
static __owur uint16_t reduce(uint32_t x)
489
374M
{
490
374M
    uint64_t product = (uint64_t)x * kBarrettMultiplier;
491
374M
    uint32_t quotient = (uint32_t)(product >> kBarrettShift);
492
374M
    uint32_t remainder = x - quotient * kPrime;
493
494
374M
    return reduce_once(remainder);
495
374M
}
496
497
/* Multiply a scalar by a constant. */
498
static void scalar_mult_const(scalar *s, uint16_t a)
499
926
{
500
926
    uint16_t *curr = s->c, *end = curr + DEGREE, tmp;
501
502
237k
    do {
503
237k
        tmp = reduce(*curr * a);
504
237k
        *curr++ = tmp;
505
237k
    } while (curr < end);
506
926
}
507
508
/*-
509
 * FIPS 203, Section 4.3, Algoritm 9: "NTT".
510
 * In-place number theoretic transform of a given scalar.  Note that ML-KEM's
511
 * kPrime 3329 does not have a 512th root of unity, so this transform leaves
512
 * off the last iteration of the usual FFT code, with the 128 relevant roots of
513
 * unity being stored in NTTRoots.  This means the output should be seen as 128
514
 * elements in GF(3329^2), with the coefficients of the elements being
515
 * consecutive entries in |s->c|.
516
 */
517
static void scalar_ntt(scalar *s)
518
253k
{
519
253k
    const uint16_t *roots = kNTTRoots;
520
253k
    uint16_t *end = s->c + DEGREE;
521
253k
    int offset = DEGREE / 2;
522
523
1.77M
    do {
524
1.77M
        uint16_t *curr = s->c, *peer;
525
526
32.2M
        do {
527
32.2M
            uint16_t *pause = curr + offset, even, odd;
528
32.2M
            uint32_t zeta = *++roots;
529
530
32.2M
            peer = pause;
531
227M
            do {
532
227M
                even = *curr;
533
227M
                odd = reduce(*peer * zeta);
534
227M
                *peer++ = reduce_once(even - odd + kPrime);
535
227M
                *curr++ = reduce_once(odd + even);
536
227M
            } while (curr < pause);
537
32.2M
        } while ((curr = peer) < end);
538
1.77M
    } while ((offset >>= 1) >= 2);
539
253k
}
540
541
/*-
542
 * FIPS 203, Section 4.3, Algoritm 10: "NTT^(-1)".
543
 * In-place inverse number theoretic transform of a given scalar, with pairs of
544
 * entries of s->v being interpreted as elements of GF(3329^2). Just as with
545
 * the number theoretic transform, this leaves off the first step of the normal
546
 * iFFT to account for the fact that 3329 does not have a 512th root of unity,
547
 * using the precomputed 128 roots of unity stored in InverseNTTRoots.
548
 */
549
static void scalar_inverse_ntt(scalar *s)
550
926
{
551
926
    const uint16_t *roots = kInverseNTTRoots;
552
926
    uint16_t *end = s->c + DEGREE;
553
926
    int offset = 2;
554
555
6.48k
    do {
556
6.48k
        uint16_t *curr = s->c, *peer;
557
558
117k
        do {
559
117k
            uint16_t *pause = curr + offset, even, odd;
560
117k
            uint32_t zeta = *++roots;
561
562
117k
            peer = pause;
563
829k
            do {
564
829k
                even = *curr;
565
829k
                odd = *peer;
566
829k
                *peer++ = reduce(zeta * (even - odd + kPrime));
567
829k
                *curr++ = reduce_once(odd + even);
568
829k
            } while (curr < pause);
569
117k
        } while ((curr = peer) < end);
570
6.48k
    } while ((offset <<= 1) < DEGREE);
571
926
    scalar_mult_const(s, kInverseDegree);
572
926
}
573
574
/* Addition updating the LHS scalar in-place. */
575
static void scalar_add(scalar *lhs, const scalar *rhs)
576
826
{
577
826
    int i;
578
579
212k
    for (i = 0; i < DEGREE; i++)
580
211k
        lhs->c[i] = reduce_once(lhs->c[i] + rhs->c[i]);
581
826
}
582
583
/* Subtraction updating the LHS scalar in-place. */
584
static void scalar_sub(scalar *lhs, const scalar *rhs)
585
100
{
586
100
    int i;
587
588
25.7k
    for (i = 0; i < DEGREE; i++)
589
25.6k
        lhs->c[i] = reduce_once(lhs->c[i] - rhs->c[i] + kPrime);
590
100
}
591
592
/*
593
 * Multiplying two scalars in the number theoretically transformed state. Since
594
 * 3329 does not have a 512th root of unity, this means we have to interpret
595
 * the 2*ith and (2*i+1)th entries of the scalar as elements of
596
 * GF(3329)[X]/(X^2 - 17^(2*bitreverse(i)+1)).
597
 *
598
 * The value of 17^(2*bitreverse(i)+1) mod 3329 is stored in the precomputed
599
 * ModRoots table. Note that our Barrett transform only allows us to multipy
600
 * two reduced numbers together, so we need some intermediate reduction steps,
601
 * even if an uint64_t could hold 3 multiplied numbers.
602
 */
603
static void scalar_mult(scalar *out, const scalar *lhs,
604
                        const scalar *rhs)
605
926
{
606
926
    uint16_t *curr = out->c, *end = curr + DEGREE;
607
926
    const uint16_t *lc = lhs->c, *rc = rhs->c;
608
926
    const uint16_t *roots = kModRoots;
609
610
118k
    do {
611
118k
        uint32_t l0 = *lc++, r0 = *rc++;
612
118k
        uint32_t l1 = *lc++, r1 = *rc++;
613
118k
        uint32_t zetapow = *roots++;
614
615
118k
        *curr++ = reduce(l0 * r0 + reduce(l1 * r1) * zetapow);
616
118k
        *curr++ = reduce(l0 * r1 + l1 * r0);
617
118k
    } while (curr < end);
618
926
}
619
620
/* Above, but add the result to an existing scalar */
621
static ossl_inline
622
void scalar_mult_add(scalar *out, const scalar *lhs,
623
                     const scalar *rhs)
624
380k
{
625
380k
    uint16_t *curr = out->c, *end = curr + DEGREE;
626
380k
    const uint16_t *lc = lhs->c, *rc = rhs->c;
627
380k
    const uint16_t *roots = kModRoots;
628
629
48.7M
    do {
630
48.7M
        uint32_t l0 = *lc++, r0 = *rc++;
631
48.7M
        uint32_t l1 = *lc++, r1 = *rc++;
632
48.7M
        uint16_t *c0 = curr++;
633
48.7M
        uint16_t *c1 = curr++;
634
48.7M
        uint32_t zetapow = *roots++;
635
636
48.7M
        *c0 = reduce(*c0 + l0 * r0 + reduce(l1 * r1) * zetapow);
637
48.7M
        *c1 = reduce(*c1 + l0 * r1 + l1 * r0);
638
48.7M
    } while (curr < end);
639
380k
}
640
641
/*-
642
 * FIPS 203, Section 4.2.1, Algorithm 5: "ByteEncode_d", for 2<=d<=12.
643
 * Here |bits| is |d|.  For efficiency, we handle the d=1 case separately.
644
 */
645
static void scalar_encode(uint8_t *out, const scalar *s, int bits)
646
253k
{
647
253k
    const uint16_t *curr = s->c, *end = curr + DEGREE;
648
253k
    uint64_t accum = 0, element;
649
253k
    int used = 0;
650
651
64.8M
    do {
652
64.8M
        element = *curr++;
653
64.8M
        if (used + bits < 64) {
654
52.7M
            accum |= element << used;
655
52.7M
            used += bits;
656
52.7M
        } else if (used + bits > 64) {
657
8.10M
            out = OPENSSL_store_u64_le(out, accum | (element << used));
658
8.10M
            accum = element >> (64 - used);
659
8.10M
            used = (used + bits) - 64;
660
8.10M
        } else {
661
4.04M
            out = OPENSSL_store_u64_le(out, accum | (element << used));
662
4.04M
            accum = 0;
663
4.04M
            used = 0;
664
4.04M
        }
665
64.8M
    } while (curr < end);
666
253k
}
667
668
/*
669
 * scalar_encode_1 is |scalar_encode| specialised for |bits| == 1.
670
 */
671
static void scalar_encode_1(uint8_t out[DEGREE / 8], const scalar *s)
672
100
{
673
100
    int i, j;
674
100
    uint8_t out_byte;
675
676
3.30k
    for (i = 0; i < DEGREE; i += 8) {
677
3.20k
        out_byte = 0;
678
28.8k
        for (j = 0; j < 8; j++)
679
25.6k
            out_byte |= bit0(s->c[i + j]) << j;
680
3.20k
        *out = out_byte;
681
3.20k
        out++;
682
3.20k
    }
683
100
}
684
685
/*-
686
 * FIPS 203, Section 4.2.1, Algorithm 6: "ByteDecode_d", for 2<=d<12.
687
 * Here |bits| is |d|.  For efficiency, we handle the d=1 and d=12 cases
688
 * separately.
689
 *
690
 * scalar_decode parses |DEGREE * bits| bits from |in| into |DEGREE| values in
691
 * |out|.
692
 */
693
static void scalar_decode(scalar *out, const uint8_t *in, int bits)
694
381
{
695
381
    uint16_t *curr = out->c, *end = curr + DEGREE;
696
381
    uint64_t accum = 0;
697
381
    int accum_bits = 0, todo = bits;
698
381
    uint16_t bitmask = (((uint16_t) 1) << bits) - 1, mask = bitmask;
699
381
    uint16_t element = 0;
700
701
107k
    do {
702
107k
        if (accum_bits == 0) {
703
13.3k
            in = OPENSSL_load_u64_le(&accum, in);
704
13.3k
            accum_bits = 64;
705
13.3k
        }
706
107k
        if (todo == bits && accum_bits >= bits) {
707
            /* No partial "element", and all the required bits available */
708
87.3k
            *curr++ = ((uint16_t) accum) & mask;
709
87.3k
            accum >>= bits;
710
87.3k
            accum_bits -= bits;
711
87.3k
        } else if (accum_bits >= todo) {
712
            /* A partial "element", and all the required bits available */
713
10.1k
            *curr++ = element | ((((uint16_t) accum) & mask) << (bits - todo));
714
10.1k
            accum >>= todo;
715
10.1k
            accum_bits -= todo;
716
10.1k
            element = 0;
717
10.1k
            todo = bits;
718
10.1k
            mask = bitmask;
719
10.1k
        } else {
720
            /*
721
             * Only some of the requisite bits accumulated, store |accum_bits|
722
             * of these in |element|.  The accumulated bitcount becomes 0, but
723
             * as soon as we have more bits we'll want to merge accum_bits
724
             * fewer of them into the final |element|.
725
             *
726
             * Note that with a 64-bit accumulator and |bits| always 12 or
727
             * less, if we're here, the previous iteration had all the
728
             * requisite bits, and so there are no kept bits in |element|.
729
             */
730
10.1k
            element = ((uint16_t) accum) & mask;
731
10.1k
            todo -= accum_bits;
732
10.1k
            mask = bitmask >> accum_bits;
733
10.1k
            accum_bits = 0;
734
10.1k
        }
735
107k
    } while (curr < end);
736
381
}
737
738
static __owur
739
int scalar_decode_12(scalar *out, const uint8_t in[3 * DEGREE / 2])
740
773
{
741
773
    int i;
742
773
    uint16_t *c = out->c;
743
744
78.4k
    for (i = 0; i < DEGREE / 2; ++i) {
745
77.8k
        uint8_t b1 = *in++;
746
77.8k
        uint8_t b2 = *in++;
747
77.8k
        uint8_t b3 = *in++;
748
77.8k
        int outOfRange1 = (*c++ = b1 | ((b2 & 0x0f) << 8)) >= kPrime;
749
77.8k
        int outOfRange2 = (*c++ = (b2 >> 4) | (b3 << 4)) >= kPrime;
750
751
77.8k
        if (outOfRange1 | outOfRange2)
752
195
            return 0;
753
77.8k
    }
754
578
    return 1;
755
773
}
756
757
/*-
758
 * scalar_decode_decompress_add is a combination of decoding and decompression
759
 * both specialised for |bits| == 1, with the result added (and sum reduced) to
760
 * the output scalar.
761
 *
762
 * NOTE: this function MUST not leak an input-data-depedennt timing signal.
763
 * A timing leak in a related function in the reference Kyber implementation
764
 * made the "clangover" attack (CVE-2024-37880) possible, giving key recovery
765
 * for ML-KEM-512 in minutes, provided the attacker has access to precise
766
 * timing of a CPU performing chosen-ciphertext decap.  Admittedly this is only
767
 * a risk when private keys are reused (perhaps KEMTLS servers).
768
 */
769
static void
770
scalar_decode_decompress_add(scalar *out, const uint8_t in[DEGREE / 8])
771
216
{
772
216
    static const uint16_t half_q_plus_1 = (ML_KEM_PRIME >> 1) + 1;
773
216
    uint16_t *curr = out->c, *end = curr + DEGREE;
774
216
    uint16_t mask;
775
216
    uint8_t b;
776
777
    /*
778
     * Add |half_q_plus_1| if the bit is set, without exposing a side-channel,
779
     * avoiding the "clangover" attack.  See |constish_time_non_zero| for a
780
     * discussion on why the value barrier is by default omitted.
781
     */
782
216
#define decode_decompress_add_bit                               \
783
55.2k
        mask = constish_time_non_zero(bit0(b));                 \
784
55.2k
        *curr = reduce_once(*curr + (mask & half_q_plus_1));    \
785
55.2k
        curr++;                                                 \
786
55.2k
        b >>= 1
787
788
    /* Unrolled to process each byte in one iteration */
789
6.91k
    do {
790
6.91k
        b = *in++;
791
6.91k
        decode_decompress_add_bit;
792
6.91k
        decode_decompress_add_bit;
793
6.91k
        decode_decompress_add_bit;
794
6.91k
        decode_decompress_add_bit;
795
796
6.91k
        decode_decompress_add_bit;
797
6.91k
        decode_decompress_add_bit;
798
6.91k
        decode_decompress_add_bit;
799
6.91k
        decode_decompress_add_bit;
800
6.91k
    } while (curr < end);
801
216
#undef decode_decompress_add_bit
802
216
}
803
804
/*
805
 * FIPS 203, Section 4.2.1, Equation (4.7): Compress_d.
806
 *
807
 * Compresses (lossily) an input |x| mod 3329 into |bits| many bits by grouping
808
 * numbers close to each other together. The formula used is
809
 * round(2^|bits|/kPrime*x) mod 2^|bits|.
810
 * Uses Barrett reduction to achieve constant time. Since we need both the
811
 * remainder (for rounding) and the quotient (as the result), we cannot use
812
 * |reduce| here, but need to do the Barrett reduction directly.
813
 */
814
static __owur uint16_t compress(uint16_t x, int bits)
815
237k
{
816
237k
    uint32_t shifted = (uint32_t)x << bits;
817
237k
    uint64_t product = (uint64_t)shifted * kBarrettMultiplier;
818
237k
    uint32_t quotient = (uint32_t)(product >> kBarrettShift);
819
237k
    uint32_t remainder = shifted - quotient * kPrime;
820
821
    /*
822
     * Adjust the quotient to round correctly:
823
     *   0 <= remainder <= kHalfPrime round to 0
824
     *   kHalfPrime < remainder <= kPrime + kHalfPrime round to 1
825
     *   kPrime + kHalfPrime < remainder < 2 * kPrime round to 2
826
     */
827
237k
    quotient += 1 & constant_time_lt_32(kHalfPrime, remainder);
828
237k
    quotient += 1 & constant_time_lt_32(kPrime + kHalfPrime, remainder);
829
237k
    return quotient & ((1 << bits) - 1);
830
237k
}
831
832
/*
833
 * FIPS 203, Section 4.2.1, Equation (4.8): Decompress_d.
834
835
 * Decompresses |x| by using a close equi-distant representative. The formula
836
 * is round(kPrime/2^|bits|*x). Note that 2^|bits| being the divisor allows us
837
 * to implement this logic using only bit operations.
838
 */
839
static __owur uint16_t decompress(uint16_t x, int bits)
840
97.5k
{
841
97.5k
    uint32_t product = (uint32_t)x * kPrime;
842
97.5k
    uint32_t power = 1 << bits;
843
    /* This is |product| % power, since |power| is a power of 2. */
844
97.5k
    uint32_t remainder = product & (power - 1);
845
    /* This is |product| / power, since |power| is a power of 2. */
846
97.5k
    uint32_t lower = product >> bits;
847
848
    /*
849
     * The rounding logic works since the first half of numbers mod |power|
850
     * have a 0 as first bit, and the second half has a 1 as first bit, since
851
     * |power| is a power of 2. As a 12 bit number, |remainder| is always
852
     * positive, so we will shift in 0s for a right shift.
853
     */
854
97.5k
    return lower + (remainder >> (bits - 1));
855
97.5k
}
856
857
/*-
858
 * FIPS 203, Section 4.2.1, Equation (4.7): "Compress_d".
859
 * In-place lossy rounding of scalars to 2^d bits.
860
 */
861
static void scalar_compress(scalar *s, int bits)
862
926
{
863
926
    int i;
864
865
237k
    for (i = 0; i < DEGREE; i++)
866
237k
        s->c[i] = compress(s->c[i], bits);
867
926
}
868
869
/*
870
 * FIPS 203, Section 4.2.1, Equation (4.8): "Decompress_d".
871
 * In-place approximate recovery of scalars from 2^d bit compression.
872
 */
873
static void scalar_decompress(scalar *s, int bits)
874
381
{
875
381
    int i;
876
877
97.9k
    for (i = 0; i < DEGREE; i++)
878
97.5k
        s->c[i] = decompress(s->c[i], bits);
879
381
}
880
881
/* Addition updating the LHS vector in-place. */
882
static void vector_add(scalar *lhs, const scalar *rhs, int rank)
883
216
{
884
610
    do {
885
610
        scalar_add(lhs++, rhs++);
886
610
    } while (--rank > 0);
887
216
}
888
889
/*
890
 * Encodes an entire vector into 32*|rank|*|bits| bytes. Note that since 256
891
 * (DEGREE) is divisible by 8, the individual vector entries will always fill a
892
 * whole number of bytes, so we do not need to worry about bit packing here.
893
 */
894
static void vector_encode(uint8_t *out, const scalar *a, int bits, int rank)
895
42.3k
{
896
42.3k
    int stride = bits * DEGREE / 8;
897
898
169k
    for (; rank-- > 0; out += stride)
899
126k
        scalar_encode(out, a++, bits);
900
42.3k
}
901
902
/*
903
 * Decodes 32*|rank|*|bits| bytes from |in| into |out|. It returns early
904
 * if any parsed value is >= |ML_KEM_PRIME|.  The resulting scalars are
905
 * then decompressed and transformed via the NTT.
906
 *
907
 * Note: Used only in decrypt_cpa(), which returns void and so does not check
908
 * the return value of this function.  Side-channels are fine when the input
909
 * ciphertext to decap() is simply syntactically invalid.
910
 */
911
static void
912
vector_decode_decompress_ntt(scalar *out, const uint8_t *in, int bits, int rank)
913
100
{
914
100
    int stride = bits * DEGREE / 8;
915
916
381
    for (; rank-- > 0; in += stride, ++out) {
917
281
        scalar_decode(out, in, bits);
918
281
        scalar_decompress(out, bits);
919
281
        scalar_ntt(out);
920
281
    }
921
100
}
922
923
/* vector_decode(), specialised to bits == 12. */
924
static __owur
925
int vector_decode_12(scalar *out, const uint8_t in[3 * DEGREE / 2], int rank)
926
403
{
927
403
    int stride = 3 * DEGREE / 2;
928
929
981
    for (; rank-- > 0; in += stride)
930
773
        if (!scalar_decode_12(out++, in))
931
195
            return 0;
932
208
    return 1;
933
403
}
934
935
/* In-place compression of each scalar component */
936
static void vector_compress(scalar *a, int bits, int rank)
937
216
{
938
610
    do {
939
610
        scalar_compress(a++, bits);
940
610
    } while (--rank > 0);
941
216
}
942
943
/* The output scalar must not overlap with the inputs */
944
static void inner_product(scalar *out, const scalar *lhs, const scalar *rhs,
945
                          int rank)
946
316
{
947
316
    scalar_mult(out, lhs, rhs);
948
891
    while (--rank > 0)
949
575
        scalar_mult_add(out, ++lhs, ++rhs);
950
316
}
951
952
/*
953
 * Here, the output vector must not overlap with the inputs, the result is
954
 * directly subjected to inverse NTT.
955
 */
956
static void
957
matrix_mult_intt(scalar *out, const scalar *m, const scalar *a, int rank)
958
216
{
959
216
    const scalar *ar;
960
216
    int i, j;
961
962
826
    for (i = rank; i-- > 0; ++out) {
963
610
        scalar_mult(out, m++, ar = a);
964
1.85k
        for (j = rank - 1; j > 0; --j)
965
1.24k
            scalar_mult_add(out, m++, ++ar);
966
610
        scalar_inverse_ntt(out);
967
610
    }
968
216
}
969
970
/* Here, the output vector must not overlap with the inputs */
971
static void
972
matrix_mult_transpose_add(scalar *out, const scalar *m, const scalar *a, int rank)
973
42.1k
{
974
42.1k
    const scalar *mc = m, *mr, *ar;
975
42.1k
    int i, j;
976
977
168k
    for (i = rank; i-- > 0; ++out) {
978
126k
        scalar_mult_add(out, mr = mc++, ar = a);
979
379k
        for (j = rank; --j > 0; )
980
252k
            scalar_mult_add(out, (mr += rank), ++ar);
981
126k
    }
982
42.1k
}
983
984
/*-
985
 * Expands the matrix from a seed for key generation and for encaps-CPA.
986
 * NOTE: FIPS 203 matrix "A" is the transpose of this matrix, computed
987
 * by appending the (i,j) indices to the seed in the opposite order!
988
 *
989
 * Where FIPS 203 computes t = A * s + e, we use the transpose of "m".
990
 */
991
static __owur
992
int matrix_expand(EVP_MD_CTX *mdctx, ML_KEM_KEY *key)
993
42.3k
{
994
42.3k
    scalar *out = key->m;
995
42.3k
    uint8_t input[ML_KEM_RANDOM_BYTES + 2];
996
42.3k
    int rank = key->vinfo->rank;
997
42.3k
    int i, j;
998
999
42.3k
    memcpy(input, key->rho, ML_KEM_RANDOM_BYTES);
1000
169k
    for (i = 0; i < rank; i++) {
1001
507k
        for (j = 0; j < rank; j++) {
1002
380k
            input[ML_KEM_RANDOM_BYTES] = i;
1003
380k
            input[ML_KEM_RANDOM_BYTES + 1] = j;
1004
380k
            if (!EVP_DigestInit_ex(mdctx, key->shake128_md, NULL)
1005
380k
                || !EVP_DigestUpdate(mdctx, input, sizeof(input))
1006
380k
                || !sample_scalar(out++, mdctx))
1007
0
                return 0;
1008
380k
        }
1009
126k
    }
1010
42.3k
    return 1;
1011
42.3k
}
1012
1013
/*
1014
 * Algorithm 7 from the spec, with eta fixed to two and the PRF call
1015
 * included. Creates binominally distributed elements by sampling 2*|eta| bits,
1016
 * and setting the coefficient to the count of the first bits minus the count of
1017
 * the second bits, resulting in a centered binomial distribution. Since eta is
1018
 * two this gives -2/2 with a probability of 1/16, -1/1 with probability 1/4,
1019
 * and 0 with probability 3/8.
1020
 */
1021
static __owur
1022
int cbd_2(scalar *out, uint8_t in[ML_KEM_RANDOM_BYTES + 1],
1023
          EVP_MD_CTX *mdctx, const ML_KEM_KEY *key)
1024
253k
{
1025
253k
    uint16_t *curr = out->c, *end = curr + DEGREE;
1026
253k
    uint8_t randbuf[4 * DEGREE / 8], *r = randbuf;  /* 64 * eta slots */
1027
253k
    uint16_t value, mask;
1028
253k
    uint8_t b;
1029
1030
253k
    if (!prf(randbuf, sizeof(randbuf), in, mdctx, key))
1031
0
        return 0;
1032
1033
32.4M
    do {
1034
32.4M
        b = *r++;
1035
1036
        /*
1037
         * Add |kPrime| if |value| underflowed.  See |constish_time_non_zero|
1038
         * for a discussion on why the value barrier is by default omitted.
1039
         * While this could have been written reduce_once(value + kPrime), this
1040
         * is one extra addition and small range of |value| tempts some
1041
         * versions of Clang to emit a branch.
1042
         */
1043
32.4M
        value = bit0(b) + bitn(1, b);
1044
32.4M
        value -= bitn(2, b) + bitn(3, b);
1045
32.4M
        mask = constish_time_non_zero(value >> 15);
1046
32.4M
        *curr++ = value + (kPrime & mask);
1047
1048
32.4M
        value = bitn(4, b) + bitn(5, b);
1049
32.4M
        value -= bitn(6, b) + bitn(7, b);
1050
32.4M
        mask = constish_time_non_zero(value >> 15);
1051
32.4M
        *curr++ = value + (kPrime & mask);
1052
32.4M
    } while (curr < end);
1053
253k
    return 1;
1054
253k
}
1055
1056
/*
1057
 * Algorithm 7 from the spec, with eta fixed to three and the PRF call
1058
 * included. Creates binominally distributed elements by sampling 3*|eta| bits,
1059
 * and setting the coefficient to the count of the first bits minus the count of
1060
 * the second bits, resulting in a centered binomial distribution.
1061
 */
1062
static __owur
1063
int cbd_3(scalar *out, uint8_t in[ML_KEM_RANDOM_BYTES + 1],
1064
          EVP_MD_CTX *mdctx, const ML_KEM_KEY *key)
1065
1.08k
{
1066
1.08k
    uint16_t *curr = out->c, *end = curr + DEGREE;
1067
1.08k
    uint8_t randbuf[6 * DEGREE / 8], *r = randbuf;  /* 64 * eta slots */
1068
1.08k
    uint8_t b1, b2, b3;
1069
1.08k
    uint16_t value, mask;
1070
1071
1.08k
    if (!prf(randbuf, sizeof(randbuf), in, mdctx, key))
1072
0
        return 0;
1073
1074
69.1k
    do {
1075
69.1k
        b1 = *r++;
1076
69.1k
        b2 = *r++;
1077
69.1k
        b3 = *r++;
1078
1079
        /*
1080
         * Add |kPrime| if |value| underflowed.  See |constish_time_non_zero|
1081
         * for a discussion on why the value barrier is by default omitted.
1082
         * While this could have been written reduce_once(value + kPrime), this
1083
         * is one extra addition and small range of |value| tempts some
1084
         * versions of Clang to emit a branch.
1085
         */
1086
69.1k
        value = bit0(b1) + bitn(1, b1) + bitn(2, b1);
1087
69.1k
        value -= bitn(3, b1)  + bitn(4, b1) + bitn(5, b1);
1088
69.1k
        mask = constish_time_non_zero(value >> 15);
1089
69.1k
        *curr++ = value + (kPrime & mask);
1090
1091
69.1k
        value = bitn(6, b1) + bitn(7, b1) + bit0(b2);
1092
69.1k
        value -= bitn(1, b2) + bitn(2, b2) + bitn(3, b2);
1093
69.1k
        mask = constish_time_non_zero(value >> 15);
1094
69.1k
        *curr++ = value + (kPrime & mask);
1095
1096
69.1k
        value = bitn(4, b2) + bitn(5, b2) + bitn(6, b2);
1097
69.1k
        value -= bitn(7, b2) + bit0(b3) + bitn(1, b3);
1098
69.1k
        mask = constish_time_non_zero(value >> 15);
1099
69.1k
        *curr++ = value + (kPrime & mask);
1100
1101
69.1k
        value = bitn(2, b3) + bitn(3, b3) + bitn(4, b3);
1102
69.1k
        value -= bitn(5, b3) + bitn(6, b3) + bitn(7, b3);
1103
69.1k
        mask = constish_time_non_zero(value >> 15);
1104
69.1k
        *curr++ = value + (kPrime & mask);
1105
69.1k
    } while (curr < end);
1106
1.08k
    return 1;
1107
1.08k
}
1108
1109
/*
1110
 * Generates a secret vector by using |cbd| with the given seed to generate
1111
 * scalar elements and incrementing |counter| for each slot of the vector.
1112
 */
1113
static __owur
1114
int gencbd_vector(scalar *out, CBD_FUNC cbd, uint8_t *counter,
1115
                  const uint8_t seed[ML_KEM_RANDOM_BYTES], int rank,
1116
                  EVP_MD_CTX *mdctx, const ML_KEM_KEY *key)
1117
216
{
1118
216
    uint8_t input[ML_KEM_RANDOM_BYTES + 1];
1119
1120
216
    memcpy(input, seed, ML_KEM_RANDOM_BYTES);
1121
610
    do {
1122
610
        input[ML_KEM_RANDOM_BYTES] = (*counter)++;
1123
610
        if (!cbd(out++, input, mdctx, key))
1124
0
            return 0;
1125
610
    } while (--rank > 0);
1126
216
    return 1;
1127
216
}
1128
1129
/*
1130
 * As above plus NTT transform.
1131
 */
1132
static __owur
1133
int gencbd_vector_ntt(scalar *out, CBD_FUNC cbd, uint8_t *counter,
1134
                      const uint8_t seed[ML_KEM_RANDOM_BYTES], int rank,
1135
                      EVP_MD_CTX *mdctx, const ML_KEM_KEY *key)
1136
84.5k
{
1137
84.5k
    uint8_t input[ML_KEM_RANDOM_BYTES + 1];
1138
1139
84.5k
    memcpy(input, seed, ML_KEM_RANDOM_BYTES);
1140
253k
    do {
1141
253k
        input[ML_KEM_RANDOM_BYTES] = (*counter)++;
1142
253k
        if (!cbd(out, input, mdctx, key))
1143
0
            return 0;
1144
253k
        scalar_ntt(out++);
1145
253k
    } while (--rank > 0);
1146
84.5k
    return 1;
1147
84.5k
}
1148
1149
/* The |ETA1| value for ML-KEM-512 is 3, the rest and all ETA2 values are 2. */
1150
42.3k
#define CBD1(evp_type)  ((evp_type) == EVP_PKEY_ML_KEM_512 ? cbd_3 : cbd_2)
1151
1152
/*
1153
 * FIPS 203, Section 5.2, Algorithm 14: K-PKE.Encrypt.
1154
 *
1155
 * Encrypts a message with given randomness to the ciphertext in |out|. Without
1156
 * applying the Fujisaki-Okamoto transform this would not result in a CCA
1157
 * secure scheme, since lattice schemes are vulnerable to decryption failure
1158
 * oracles.
1159
 *
1160
 * The steps are re-ordered to make more efficient/localised use of storage.
1161
 *
1162
 * Note also that the input public key is assumed to hold a precomputed matrix
1163
 * |A| (our key->m, with the public key holding an expanded (16-bit per scalar
1164
 * coefficient) key->t vector).
1165
 *
1166
 * Caller passes storage in |tmp| for for two temporary vectors.
1167
 */
1168
static __owur
1169
int encrypt_cpa(uint8_t out[ML_KEM_SHARED_SECRET_BYTES],
1170
                const uint8_t message[DEGREE / 8],
1171
                const uint8_t r[ML_KEM_RANDOM_BYTES], scalar *tmp,
1172
                EVP_MD_CTX *mdctx, const ML_KEM_KEY *key)
1173
216
{
1174
216
    const ML_KEM_VINFO *vinfo = key->vinfo;
1175
216
    CBD_FUNC cbd_1 = CBD1(vinfo->evp_type);
1176
216
    int rank = vinfo->rank;
1177
    /* We can use tmp[0..rank-1] as storage for |y|, then |e1|, ... */
1178
216
    scalar *y = &tmp[0], *e1 = y, *e2 = y;
1179
    /* We can use tmp[rank]..tmp[2*rank - 1] for |u| */
1180
216
    scalar *u = &tmp[rank];
1181
216
    scalar v;
1182
216
    uint8_t input[ML_KEM_RANDOM_BYTES + 1];
1183
216
    uint8_t counter = 0;
1184
216
    int du = vinfo->du;
1185
216
    int dv = vinfo->dv;
1186
1187
    /* FIPS 203 "y" vector */
1188
216
    if (!gencbd_vector_ntt(y, cbd_1, &counter, r, rank, mdctx, key))
1189
0
        return 0;
1190
    /* FIPS 203 "v" scalar */
1191
216
    inner_product(&v, key->t, y, rank);
1192
216
    scalar_inverse_ntt(&v);
1193
    /* FIPS 203 "u" vector */
1194
216
    matrix_mult_intt(u, key->m, y, rank);
1195
1196
    /* All done with |y|, now free to reuse tmp[0] for FIPS 203 |e1| */
1197
216
    if (!gencbd_vector(e1, cbd_2, &counter, r, rank, mdctx, key))
1198
0
        return 0;
1199
216
    vector_add(u, e1, rank);
1200
216
    vector_compress(u, du, rank);
1201
216
    vector_encode(out, u, du, rank);
1202
1203
    /* All done with |e1|, now free to reuse tmp[0] for FIPS 203 |e2| */
1204
216
    memcpy(input, r, ML_KEM_RANDOM_BYTES);
1205
216
    input[ML_KEM_RANDOM_BYTES] = counter;
1206
216
    if (!cbd_2(e2, input, mdctx, key))
1207
0
        return 0;
1208
216
    scalar_add(&v, e2);
1209
1210
    /* Combine message with |v| */
1211
216
    scalar_decode_decompress_add(&v, message);
1212
216
    scalar_compress(&v, dv);
1213
216
    scalar_encode(out + vinfo->u_vector_bytes, &v, dv);
1214
216
    return 1;
1215
216
}
1216
1217
/*
1218
 * FIPS 203, Section 5.3, Algorithm 15: K-PKE.Decrypt.
1219
 */
1220
static void
1221
decrypt_cpa(uint8_t out[ML_KEM_SHARED_SECRET_BYTES],
1222
            const uint8_t *ctext, scalar *u, const ML_KEM_KEY *key)
1223
100
{
1224
100
    const ML_KEM_VINFO *vinfo = key->vinfo;
1225
100
    scalar v, mask;
1226
100
    int rank = vinfo->rank;
1227
100
    int du = vinfo->du;
1228
100
    int dv = vinfo->dv;
1229
1230
100
    vector_decode_decompress_ntt(u, ctext, du, rank);
1231
100
    scalar_decode(&v, ctext + vinfo->u_vector_bytes, dv);
1232
100
    scalar_decompress(&v, dv);
1233
100
    inner_product(&mask, key->s, u, rank);
1234
100
    scalar_inverse_ntt(&mask);
1235
100
    scalar_sub(&v, &mask);
1236
100
    scalar_compress(&v, 1);
1237
100
    scalar_encode_1(out, &v);
1238
100
}
1239
1240
/*-
1241
 * FIPS 203, Section 7.1, Algorithm 19: "ML-KEM.KeyGen".
1242
 * FIPS 203, Section 7.2, Algorithm 20: "ML-KEM.Encaps".
1243
 *
1244
 * Fills the |out| buffer with the |ek| output of "ML-KEM.KeyGen", or,
1245
 * equivalently, the |ek| input of "ML-KEM.Encaps", i.e. returns the
1246
 * wire-format of an ML-KEM public key.
1247
 */
1248
static void encode_pubkey(uint8_t *out, const ML_KEM_KEY *key)
1249
42.0k
{
1250
42.0k
    const uint8_t *rho = key->rho;
1251
42.0k
    const ML_KEM_VINFO *vinfo = key->vinfo;
1252
1253
42.0k
    vector_encode(out, key->t, 12, vinfo->rank);
1254
42.0k
    memcpy(out + vinfo->vector_bytes, rho, ML_KEM_RANDOM_BYTES);
1255
42.0k
}
1256
1257
/*-
1258
 * FIPS 203, Section 7.1, Algorithm 19: "ML-KEM.KeyGen".
1259
 *
1260
 * Fills the |out| buffer with the |dk| output of "ML-KEM.KeyGen".
1261
 * This matches the input format of parse_prvkey() below.
1262
 */
1263
static void encode_prvkey(uint8_t *out, const ML_KEM_KEY *key)
1264
63
{
1265
63
    const ML_KEM_VINFO *vinfo = key->vinfo;
1266
1267
63
    vector_encode(out, key->s, 12, vinfo->rank);
1268
63
    out += vinfo->vector_bytes;
1269
63
    encode_pubkey(out, key);
1270
63
    out += vinfo->pubkey_bytes;
1271
63
    memcpy(out, key->pkhash, ML_KEM_PKHASH_BYTES);
1272
63
    out += ML_KEM_PKHASH_BYTES;
1273
63
    memcpy(out, key->z, ML_KEM_RANDOM_BYTES);
1274
63
}
1275
1276
/*-
1277
 * FIPS 203, Section 7.1, Algorithm 19: "ML-KEM.KeyGen".
1278
 * FIPS 203, Section 7.2, Algorithm 20: "ML-KEM.Encaps".
1279
 *
1280
 * This function parses the |in| buffer as the |ek| output of "ML-KEM.KeyGen",
1281
 * or, equivalently, the |ek| input of "ML-KEM.Encaps", i.e. decodes the
1282
 * wire-format of the ML-KEM public key.
1283
 */
1284
static int parse_pubkey(const uint8_t *in, EVP_MD_CTX *mdctx, ML_KEM_KEY *key)
1285
364
{
1286
364
    const ML_KEM_VINFO *vinfo = key->vinfo;
1287
1288
    /* Decode and check |t| */
1289
364
    if (!vector_decode_12(key->t, in, vinfo->rank)) {
1290
156
        ERR_raise_data(ERR_LIB_PROV, PROV_R_INVALID_KEY,
1291
156
                       "%s invalid public 't' vector",
1292
156
                       vinfo->algorithm_name);
1293
156
        return 0;
1294
156
    }
1295
    /* Save the matrix |m| recovery seed |rho| */
1296
208
    memcpy(key->rho, in + vinfo->vector_bytes, ML_KEM_RANDOM_BYTES);
1297
    /*
1298
     * Pre-compute the public key hash, needed for both encap and decap.
1299
     * Also pre-compute the matrix expansion, stored with the public key.
1300
     */
1301
208
    if (!hash_h(key->pkhash, in, vinfo->pubkey_bytes, mdctx, key)
1302
208
        || !matrix_expand(mdctx, key)) {
1303
0
        ERR_raise_data(ERR_LIB_CRYPTO, ERR_R_INTERNAL_ERROR,
1304
0
                       "internal error while parsing %s public key",
1305
0
                       vinfo->algorithm_name);
1306
0
        return 0;
1307
0
    }
1308
208
    return 1;
1309
208
}
1310
1311
/*
1312
 * FIPS 203, Section 7.1, Algorithm 19: "ML-KEM.KeyGen".
1313
 *
1314
 * Parses the |in| buffer as a |dk| output of "ML-KEM.KeyGen".
1315
 * This matches the output format of encode_prvkey() above.
1316
 */
1317
static int parse_prvkey(const uint8_t *in, EVP_MD_CTX *mdctx, ML_KEM_KEY *key)
1318
39
{
1319
39
    const ML_KEM_VINFO *vinfo = key->vinfo;
1320
1321
    /* Decode and check |s|. */
1322
39
    if (!vector_decode_12(key->s, in, vinfo->rank)) {
1323
39
        ERR_raise_data(ERR_LIB_PROV, PROV_R_INVALID_KEY,
1324
39
                       "%s invalid private 's' vector",
1325
39
                       vinfo->algorithm_name);
1326
39
        return 0;
1327
39
    }
1328
0
    in += vinfo->vector_bytes;
1329
1330
0
    if (!parse_pubkey(in, mdctx, key))
1331
0
        return 0;
1332
0
    in += vinfo->pubkey_bytes;
1333
1334
    /* Check public key hash. */
1335
0
    if (memcmp(key->pkhash, in, ML_KEM_PKHASH_BYTES) != 0) {
1336
0
        ERR_raise_data(ERR_LIB_PROV, PROV_R_INVALID_KEY,
1337
0
                       "%s public key hash mismatch",
1338
0
                       vinfo->algorithm_name);
1339
0
        return 0;
1340
0
    }
1341
0
    in += ML_KEM_PKHASH_BYTES;
1342
1343
0
    memcpy(key->z, in, ML_KEM_RANDOM_BYTES);
1344
0
    return 1;
1345
0
}
1346
1347
/*
1348
 * FIPS 203, Section 6.1, Algorithm 16: "ML-KEM.KeyGen_internal".
1349
 *
1350
 * The implementation of Section 5.1, Algorithm 13, "K-PKE.KeyGen(d)" is
1351
 * inlined.
1352
 *
1353
 * The caller MUST pass a pre-allocated digest context that is not shared with
1354
 * any concurrent computation.
1355
 *
1356
 * This function optionally outputs the serialised wire-form |ek| public key
1357
 * into the provided |pubenc| buffer, and generates the content of the |rho|,
1358
 * |pkhash|, |t|, |m|, |s| and |z| components of the private |key| (which must
1359
 * have preallocated space for these).
1360
 *
1361
 * Keys are computed from a 32-byte random |d| plus the 1 byte rank for
1362
 * domain separation.  These are concatenated and hashed to produce a pair of
1363
 * 32-byte seeds public "rho", used to generate the matrix, and private "sigma",
1364
 * used to generate the secret vector |s|.
1365
 *
1366
 * The second random input |z| is copied verbatim into the Fujisaki-Okamoto
1367
 * (FO) transform "implicit-rejection" secret (the |z| component of the private
1368
 * key), which thwarts chosen-ciphertext attacks, provided decap() runs in
1369
 * constant time, with no side channel leaks, on all well-formed (valid length,
1370
 * and correctly encoded) ciphertext inputs.
1371
 */
1372
static __owur
1373
int genkey(const uint8_t seed[ML_KEM_SEED_BYTES],
1374
           EVP_MD_CTX *mdctx, uint8_t *pubenc, ML_KEM_KEY *key)
1375
42.1k
{
1376
42.1k
    uint8_t hashed[2 * ML_KEM_RANDOM_BYTES];
1377
42.1k
    const uint8_t *const sigma = hashed + ML_KEM_RANDOM_BYTES;
1378
42.1k
    uint8_t augmented_seed[ML_KEM_RANDOM_BYTES + 1];
1379
42.1k
    const ML_KEM_VINFO *vinfo = key->vinfo;
1380
42.1k
    CBD_FUNC cbd_1 = CBD1(vinfo->evp_type);
1381
42.1k
    int rank = vinfo->rank;
1382
42.1k
    uint8_t counter = 0;
1383
42.1k
    int ret = 0;
1384
1385
    /*
1386
     * Use the "d" seed salted with the rank to derive the public and private
1387
     * seeds rho and sigma.
1388
     */
1389
42.1k
    memcpy(augmented_seed, seed, ML_KEM_RANDOM_BYTES);
1390
42.1k
    augmented_seed[ML_KEM_RANDOM_BYTES] = (uint8_t) rank;
1391
42.1k
    if (!hash_g(hashed, augmented_seed, sizeof(augmented_seed), mdctx, key))
1392
0
        goto end;
1393
42.1k
    memcpy(key->rho, hashed, ML_KEM_RANDOM_BYTES);
1394
    /* The |rho| matrix seed is public */
1395
42.1k
    CONSTTIME_DECLASSIFY(key->rho, ML_KEM_RANDOM_BYTES);
1396
1397
    /* FIPS 203 |e| vector is initial value of key->t */
1398
42.1k
    if (!matrix_expand(mdctx, key)
1399
42.1k
        || !gencbd_vector_ntt(key->s, cbd_1, &counter, sigma, rank, mdctx, key)
1400
42.1k
        || !gencbd_vector_ntt(key->t, cbd_1, &counter, sigma, rank, mdctx, key))
1401
0
        goto end;
1402
1403
    /* To |e| we now add the product of transpose |m| and |s|, giving |t|. */
1404
42.1k
    matrix_mult_transpose_add(key->t, key->m, key->s, rank);
1405
    /* The |t| vector is public */
1406
42.1k
    CONSTTIME_DECLASSIFY(key->t, vinfo->rank * sizeof(scalar));
1407
1408
42.1k
    if (pubenc == NULL) {
1409
        /* Incremental digest of public key without in-full serialisation. */
1410
42.1k
        if (!hash_h_pubkey(key->pkhash, mdctx, key))
1411
0
            goto end;
1412
42.1k
    } else {
1413
0
        encode_pubkey(pubenc, key);
1414
0
        if (!hash_h(key->pkhash, pubenc, vinfo->pubkey_bytes, mdctx, key))
1415
0
            goto end;
1416
0
    }
1417
1418
    /* Save |z| portion of seed for "implicit rejection" on failure. */
1419
42.1k
    memcpy(key->z, seed + ML_KEM_RANDOM_BYTES, ML_KEM_RANDOM_BYTES);
1420
1421
    /* Optionally save the |d| portion of the seed */
1422
42.1k
    key->d = key->z + ML_KEM_RANDOM_BYTES;
1423
42.1k
    if (key->prov_flags & ML_KEM_KEY_RETAIN_SEED) {
1424
42.1k
        memcpy(key->d, seed, ML_KEM_RANDOM_BYTES);
1425
42.1k
    } else {
1426
0
        OPENSSL_cleanse(key->d, ML_KEM_RANDOM_BYTES);
1427
0
        key->d = NULL;
1428
0
    }
1429
1430
42.1k
    ret = 1;
1431
42.1k
 end:
1432
42.1k
    OPENSSL_cleanse((void *)augmented_seed, ML_KEM_RANDOM_BYTES);
1433
42.1k
    OPENSSL_cleanse((void *)sigma, ML_KEM_RANDOM_BYTES);
1434
42.1k
    if (ret == 0) {
1435
0
        ERR_raise_data(ERR_LIB_CRYPTO, ERR_R_INTERNAL_ERROR,
1436
0
                       "internal error while generating %s private key",
1437
0
                       vinfo->algorithm_name);
1438
0
    }
1439
42.1k
    return ret;
1440
42.1k
}
1441
1442
/*-
1443
 * FIPS 203, Section 6.2, Algorithm 17: "ML-KEM.Encaps_internal".
1444
 * This is the deterministic version with randomness supplied externally.
1445
 *
1446
 * The caller must pass space for two vectors in |tmp|.
1447
 * The |ctext| buffer have space for the ciphertext of the ML-KEM variant
1448
 * of the provided key.
1449
 */
1450
static
1451
int encap(uint8_t *ctext, uint8_t secret[ML_KEM_SHARED_SECRET_BYTES],
1452
          const uint8_t entropy[ML_KEM_RANDOM_BYTES],
1453
          scalar *tmp, EVP_MD_CTX *mdctx, const ML_KEM_KEY *key)
1454
116
{
1455
116
    uint8_t input[ML_KEM_RANDOM_BYTES + ML_KEM_PKHASH_BYTES];
1456
116
    uint8_t Kr[ML_KEM_SHARED_SECRET_BYTES + ML_KEM_RANDOM_BYTES];
1457
116
    uint8_t *r = Kr + ML_KEM_SHARED_SECRET_BYTES;
1458
116
    int ret;
1459
1460
116
    memcpy(input, entropy, ML_KEM_RANDOM_BYTES);
1461
116
    memcpy(input + ML_KEM_RANDOM_BYTES, key->pkhash, ML_KEM_PKHASH_BYTES);
1462
116
    ret = hash_g(Kr, input, sizeof(input), mdctx, key)
1463
116
        && encrypt_cpa(ctext, entropy, r, tmp, mdctx, key);
1464
116
    OPENSSL_cleanse((void *)input, sizeof(input));
1465
1466
116
    if (ret)
1467
116
        memcpy(secret, Kr, ML_KEM_SHARED_SECRET_BYTES);
1468
0
    else
1469
0
        ERR_raise_data(ERR_LIB_CRYPTO, ERR_R_INTERNAL_ERROR,
1470
0
                       "internal error while performing %s encapsulation",
1471
0
                       key->vinfo->algorithm_name);
1472
116
    return ret;
1473
116
}
1474
1475
/*
1476
 * FIPS 203, Section 6.3, Algorithm 18: ML-KEM.Decaps_internal
1477
 *
1478
 * Barring failure of the supporting SHA3/SHAKE primitives, this is fully
1479
 * deterministic, the randomness for the FO transform is extracted during
1480
 * private key generation.
1481
 *
1482
 * The caller must pass space for two vectors in |tmp|.
1483
 * The |ctext| and |tmp_ctext| buffers must each have space for the ciphertext
1484
 * of the key's ML-KEM variant.
1485
 */
1486
static
1487
int decap(uint8_t secret[ML_KEM_SHARED_SECRET_BYTES],
1488
          const uint8_t *ctext, uint8_t *tmp_ctext, scalar *tmp,
1489
          EVP_MD_CTX *mdctx, const ML_KEM_KEY *key)
1490
100
{
1491
100
    uint8_t decrypted[ML_KEM_SHARED_SECRET_BYTES + ML_KEM_PKHASH_BYTES];
1492
100
    uint8_t failure_key[ML_KEM_RANDOM_BYTES];
1493
100
    uint8_t Kr[ML_KEM_SHARED_SECRET_BYTES + ML_KEM_RANDOM_BYTES];
1494
100
    uint8_t *r = Kr + ML_KEM_SHARED_SECRET_BYTES;
1495
100
    const uint8_t *pkhash = key->pkhash;
1496
100
    const ML_KEM_VINFO *vinfo = key->vinfo;
1497
100
    int i;
1498
100
    uint8_t mask;
1499
1500
    /*
1501
     * If our KDF is unavailable, fail early! Otherwise, keep going ignoring
1502
     * any further errors, returning success, and whatever we got for a shared
1503
     * secret.  The decrypt_cpa() function is just arithmetic on secret data,
1504
     * so should not be subject to failure that makes its output predictable.
1505
     *
1506
     * We guard against "should never happen" catastrophic failure of the
1507
     * "pure" function |hash_g| by overwriting the shared secret with the
1508
     * content of the failure key and returning early, if nevertheless hash_g
1509
     * fails.  This is not constant-time, but a failure of |hash_g| already
1510
     * implies loss of side-channel resistance.
1511
     *
1512
     * The same action is taken, if also |encrypt_cpa| should catastrophically
1513
     * fail, due to failure of the |PRF| underlying the CBD functions.
1514
     */
1515
100
    if (!kdf(failure_key, key->z, ctext, vinfo->ctext_bytes, mdctx, key)) {
1516
0
        ERR_raise_data(ERR_LIB_CRYPTO, ERR_R_INTERNAL_ERROR,
1517
0
                       "internal error while performing %s decapsulation",
1518
0
                       vinfo->algorithm_name);
1519
0
        return 0;
1520
0
    }
1521
100
    decrypt_cpa(decrypted, ctext, tmp, key);
1522
100
    memcpy(decrypted + ML_KEM_SHARED_SECRET_BYTES, pkhash, ML_KEM_PKHASH_BYTES);
1523
100
    if (!hash_g(Kr, decrypted, sizeof(decrypted), mdctx, key)
1524
100
        || !encrypt_cpa(tmp_ctext, decrypted, r, tmp, mdctx, key)) {
1525
0
        memcpy(secret, failure_key, ML_KEM_SHARED_SECRET_BYTES);
1526
0
        OPENSSL_cleanse(decrypted, ML_KEM_SHARED_SECRET_BYTES);
1527
0
        return 1;
1528
0
    }
1529
100
    mask = constant_time_eq_int_8(0,
1530
100
        CRYPTO_memcmp(ctext, tmp_ctext, vinfo->ctext_bytes));
1531
3.30k
    for (i = 0; i < ML_KEM_SHARED_SECRET_BYTES; i++)
1532
3.20k
        secret[i] = constant_time_select_8(mask, Kr[i], failure_key[i]);
1533
100
    OPENSSL_cleanse(decrypted, ML_KEM_SHARED_SECRET_BYTES);
1534
100
    OPENSSL_cleanse(Kr, sizeof(Kr));
1535
100
    return 1;
1536
100
}
1537
1538
/*
1539
 * After allocating storage for public or private key data, update the key
1540
 * component pointers to reference that storage.
1541
 *
1542
 * The caller should only store private data in `priv` *after* a successful
1543
 * (non-zero) return from this function.
1544
 */
1545
static __owur
1546
int add_storage(scalar *pub, scalar *priv, int private, ML_KEM_KEY *key)
1547
27.3k
{
1548
27.3k
    int rank = key->vinfo->rank;
1549
1550
27.3k
    if (pub == NULL || (private && priv == NULL)) {
1551
        /*
1552
         * One of these could be allocated correctly. It is legal to call free with a NULL
1553
         * pointer, so always attempt to free both allocations here
1554
         */
1555
0
        OPENSSL_free(pub);
1556
0
        OPENSSL_secure_free(priv);
1557
0
        return 0;
1558
0
    }
1559
1560
    /*
1561
     * We're adding key material, set up rho and pkhash to point to the rho_pkhash buffer
1562
     */
1563
27.3k
    memset(key->rho_pkhash, 0, sizeof(key->rho_pkhash));
1564
27.3k
    key->rho = key->rho_pkhash;
1565
27.3k
    key->pkhash = key->rho_pkhash + ML_KEM_RANDOM_BYTES;
1566
27.3k
    key->d = key->z = NULL;
1567
1568
    /* A public key needs space for |t| and |m| */
1569
27.3k
    key->m = (key->t = pub) + rank;
1570
1571
    /*
1572
     * A private key also needs space for |s| and |z|.
1573
     * The |z| buffer always includes additional space for |d|, but a key's |d|
1574
     * pointer is left NULL when parsed from the NIST format, which omits that
1575
     * information.  Only keys generated from a (d, z) seed pair will have a
1576
     * non-NULL |d| pointer.
1577
     */
1578
27.3k
    if (private)
1579
27.0k
        key->z = (uint8_t *)(rank + (key->s = priv));
1580
27.3k
    return 1;
1581
27.3k
}
1582
1583
/*
1584
 * After freeing the storage associated with a key that failed to be
1585
 * constructed, reset the internal pointers back to NULL.
1586
 */
1587
void
1588
ossl_ml_kem_key_reset(ML_KEM_KEY *key)
1589
27.5k
{
1590
    /*
1591
     * seedbuf can be allocated and contain |z| and |d| if the key is
1592
     * being created from a private key encoding.  Similarly a pending
1593
     * serialised (encoded) private key may be queued up to load.
1594
     * Clear and free that data now.
1595
     */
1596
27.5k
    if (key->seedbuf != NULL)
1597
0
        OPENSSL_secure_clear_free(key->seedbuf, ML_KEM_SEED_BYTES);
1598
27.5k
    if (ossl_ml_kem_have_dkenc(key))
1599
0
        OPENSSL_secure_clear_free(key->encoded_dk, key->vinfo->prvkey_bytes);
1600
1601
    /*-
1602
     * Cleanse any sensitive data:
1603
     * - The private vector |s| is immediately followed by the FO failure
1604
     *   secret |z|, and seed |d|, we can cleanse all three in one call.
1605
     */
1606
27.5k
    if (key->t != NULL) {
1607
27.3k
        if (ossl_ml_kem_have_prvkey(key))
1608
27.0k
            OPENSSL_secure_clear_free(key->s, key->vinfo->prvalloc);
1609
27.3k
        OPENSSL_free(key->t);
1610
27.3k
    }
1611
27.5k
    key->d = key->z = key->seedbuf = key->encoded_dk =
1612
27.5k
        (uint8_t *)(key->s = key->m = key->t = NULL);
1613
27.5k
}
1614
1615
/*
1616
 * ----- API exported to the provider
1617
 *
1618
 * Parameters with an implicit fixed length in the internal static API of each
1619
 * variant have an explicit checked length argument at this layer.
1620
 */
1621
1622
/* Retrieve the parameters of one of the ML-KEM variants */
1623
const ML_KEM_VINFO *ossl_ml_kem_get_vinfo(int evp_type)
1624
242k
{
1625
242k
    switch (evp_type) {
1626
53.0k
    case EVP_PKEY_ML_KEM_512:
1627
53.0k
        return &vinfo_map[ML_KEM_512_VINFO];
1628
136k
    case EVP_PKEY_ML_KEM_768:
1629
136k
        return &vinfo_map[ML_KEM_768_VINFO];
1630
52.8k
    case EVP_PKEY_ML_KEM_1024:
1631
52.8k
        return &vinfo_map[ML_KEM_1024_VINFO];
1632
242k
    }
1633
0
    return NULL;
1634
242k
}
1635
1636
ML_KEM_KEY *ossl_ml_kem_key_new(OSSL_LIB_CTX *libctx, const char *properties,
1637
                                int evp_type)
1638
42.6k
{
1639
42.6k
    const ML_KEM_VINFO *vinfo = ossl_ml_kem_get_vinfo(evp_type);
1640
42.6k
    ML_KEM_KEY *key;
1641
1642
42.6k
    if (vinfo == NULL) {
1643
0
        ERR_raise_data(ERR_LIB_CRYPTO, ERR_R_PASSED_INVALID_ARGUMENT,
1644
0
                       "unsupported ML-KEM key type: %d", evp_type);
1645
0
        return NULL;
1646
0
    }
1647
1648
42.6k
    if ((key = OPENSSL_malloc(sizeof(*key))) == NULL)
1649
0
        return NULL;
1650
1651
42.6k
    key->vinfo = vinfo;
1652
42.6k
    key->libctx = libctx;
1653
42.6k
    key->prov_flags = ML_KEM_KEY_PROV_FLAGS_DEFAULT;
1654
42.6k
    key->shake128_md = EVP_MD_fetch(libctx, "SHAKE128", properties);
1655
42.6k
    key->shake256_md = EVP_MD_fetch(libctx, "SHAKE256", properties);
1656
42.6k
    key->sha3_256_md = EVP_MD_fetch(libctx, "SHA3-256", properties);
1657
42.6k
    key->sha3_512_md = EVP_MD_fetch(libctx, "SHA3-512", properties);
1658
42.6k
    key->d = key->z = key->rho = key->pkhash = key->encoded_dk = key->seedbuf = NULL;
1659
42.6k
    key->s = key->m = key->t = NULL;
1660
1661
42.6k
    if (key->shake128_md != NULL
1662
42.6k
        && key->shake256_md != NULL
1663
42.6k
        && key->sha3_256_md != NULL
1664
42.6k
        && key->sha3_512_md != NULL)
1665
42.6k
        return key;
1666
1667
0
    ossl_ml_kem_key_free(key);
1668
0
    ERR_raise_data(ERR_LIB_CRYPTO, ERR_R_INTERNAL_ERROR,
1669
0
                   "missing SHA3 digest algorithms while creating %s key",
1670
0
                   vinfo->algorithm_name);
1671
0
    return NULL;
1672
42.6k
}
1673
1674
ML_KEM_KEY *ossl_ml_kem_key_dup(const ML_KEM_KEY *key, int selection)
1675
44
{
1676
44
    int ok = 0;
1677
44
    ML_KEM_KEY *ret;
1678
44
    void *tmp_pub;
1679
44
    void *tmp_priv;
1680
1681
44
    if (key == NULL)
1682
0
        return NULL;
1683
    /*
1684
     * Partially decoded keys, not yet imported or loaded, should never be
1685
     * duplicated.
1686
     */
1687
44
    if (ossl_ml_kem_decoded_key(key))
1688
0
        return NULL;
1689
1690
44
    else if ((ret = OPENSSL_memdup(key, sizeof(*key))) == NULL)
1691
0
        return NULL;
1692
1693
44
    ret->d = ret->z = ret->rho = ret->pkhash = NULL;
1694
44
    ret->s = ret->m = ret->t = NULL;
1695
1696
    /* Clear selection bits we can't fulfill */
1697
44
    if (!ossl_ml_kem_have_pubkey(key))
1698
0
        selection = 0;
1699
44
    else if (!ossl_ml_kem_have_prvkey(key))
1700
44
        selection &= ~OSSL_KEYMGMT_SELECT_PRIVATE_KEY;
1701
1702
44
    switch (selection & OSSL_KEYMGMT_SELECT_KEYPAIR) {
1703
0
    case 0:
1704
0
        ok = 1;
1705
0
        break;
1706
44
    case OSSL_KEYMGMT_SELECT_PUBLIC_KEY:
1707
44
        ok = add_storage(OPENSSL_memdup(key->t, key->vinfo->puballoc), NULL, 0, ret);
1708
44
        break;
1709
0
    case OSSL_KEYMGMT_SELECT_PRIVATE_KEY:
1710
0
        tmp_pub = OPENSSL_memdup(key->t, key->vinfo->puballoc);
1711
0
        if (tmp_pub == NULL)
1712
0
            break;
1713
0
        tmp_priv = OPENSSL_secure_malloc(key->vinfo->prvalloc);
1714
0
        if (tmp_priv == NULL) {
1715
0
            OPENSSL_free(tmp_pub);
1716
0
            break;
1717
0
        }
1718
0
        if ((ok = add_storage(tmp_pub, tmp_priv, 1, ret)) != 0)
1719
0
            memcpy(tmp_priv, key->s, key->vinfo->prvalloc);
1720
        /* Duplicated keys retain |d|, if available */
1721
0
        if (key->d != NULL)
1722
0
            ret->d = ret->z + ML_KEM_RANDOM_BYTES;
1723
0
        break;
1724
44
    }
1725
1726
44
    if (!ok) {
1727
0
        OPENSSL_free(ret);
1728
0
        return NULL;
1729
0
    }
1730
1731
44
    EVP_MD_up_ref(ret->shake128_md);
1732
44
    EVP_MD_up_ref(ret->shake256_md);
1733
44
    EVP_MD_up_ref(ret->sha3_256_md);
1734
44
    EVP_MD_up_ref(ret->sha3_512_md);
1735
1736
44
    return ret;
1737
44
}
1738
1739
void ossl_ml_kem_key_free(ML_KEM_KEY *key)
1740
146k
{
1741
146k
    if (key == NULL)
1742
118k
        return;
1743
1744
27.3k
    EVP_MD_free(key->shake128_md);
1745
27.3k
    EVP_MD_free(key->shake256_md);
1746
27.3k
    EVP_MD_free(key->sha3_256_md);
1747
27.3k
    EVP_MD_free(key->sha3_512_md);
1748
1749
27.3k
    ossl_ml_kem_key_reset(key);
1750
27.3k
    OPENSSL_free(key);
1751
27.3k
}
1752
1753
/* Serialise the public component of an ML-KEM key */
1754
int ossl_ml_kem_encode_public_key(uint8_t *out, size_t len,
1755
                                  const ML_KEM_KEY *key)
1756
41.9k
{
1757
41.9k
    if (!ossl_ml_kem_have_pubkey(key)
1758
41.9k
        || len != key->vinfo->pubkey_bytes)
1759
0
        return 0;
1760
41.9k
    encode_pubkey(out, key);
1761
41.9k
    return 1;
1762
41.9k
}
1763
1764
/* Serialise an ML-KEM private key */
1765
int ossl_ml_kem_encode_private_key(uint8_t *out, size_t len,
1766
                                   const ML_KEM_KEY *key)
1767
63
{
1768
63
    if (!ossl_ml_kem_have_prvkey(key)
1769
63
        || len != key->vinfo->prvkey_bytes)
1770
0
        return 0;
1771
63
    encode_prvkey(out, key);
1772
63
    return 1;
1773
63
}
1774
1775
int ossl_ml_kem_encode_seed(uint8_t *out, size_t len,
1776
                            const ML_KEM_KEY *key)
1777
88
{
1778
88
    if (key == NULL || key->d == NULL || len != ML_KEM_SEED_BYTES)
1779
0
        return 0;
1780
    /*
1781
     * Both in the seed buffer, and in the allocated storage, the |d| component
1782
     * of the seed is stored last, so we must copy each separately.
1783
     */
1784
88
    memcpy(out, key->d, ML_KEM_RANDOM_BYTES);
1785
88
    out += ML_KEM_RANDOM_BYTES;
1786
88
    memcpy(out, key->z, ML_KEM_RANDOM_BYTES);
1787
88
    return 1;
1788
88
}
1789
1790
/*
1791
 * Stash the seed without (yet) performing a keygen, used during decoding, to
1792
 * avoid an extra keygen if we're only going to export the key again to load
1793
 * into another provider.
1794
 */
1795
ML_KEM_KEY *ossl_ml_kem_set_seed(const uint8_t *seed, size_t seedlen, ML_KEM_KEY *key)
1796
{
1797
    if (key == NULL
1798
        || ossl_ml_kem_have_pubkey(key)
1799
        || ossl_ml_kem_have_seed(key)
1800
        || seedlen != ML_KEM_SEED_BYTES)
1801
        return NULL;
1802
1803
    if (key->seedbuf == NULL) {
1804
        key->seedbuf = OPENSSL_secure_malloc(seedlen);
1805
        if (key->seedbuf == NULL)
1806
            return NULL;
1807
    }
1808
1809
    key->z = key->seedbuf;
1810
    key->d = key->z + ML_KEM_RANDOM_BYTES;
1811
    memcpy(key->d, seed, ML_KEM_RANDOM_BYTES);
1812
    seed += ML_KEM_RANDOM_BYTES;
1813
    memcpy(key->z, seed, ML_KEM_RANDOM_BYTES);
1814
    return key;
1815
}
1816
1817
/* Parse input as a public key */
1818
int ossl_ml_kem_parse_public_key(const uint8_t *in, size_t len, ML_KEM_KEY *key)
1819
364
{
1820
364
    EVP_MD_CTX *mdctx = NULL;
1821
364
    const ML_KEM_VINFO *vinfo;
1822
364
    int ret = 0;
1823
1824
    /* Keys with key material are immutable */
1825
364
    if (key == NULL
1826
364
        || ossl_ml_kem_have_pubkey(key)
1827
364
        || ossl_ml_kem_have_dkenc(key))
1828
0
        return 0;
1829
364
    vinfo = key->vinfo;
1830
1831
364
    if (len != vinfo->pubkey_bytes
1832
364
        || (mdctx = EVP_MD_CTX_new()) == NULL)
1833
0
        return 0;
1834
1835
364
    if (add_storage(OPENSSL_malloc(vinfo->puballoc), NULL, 0, key))
1836
364
        ret = parse_pubkey(in, mdctx, key);
1837
1838
364
    if (!ret)
1839
156
        ossl_ml_kem_key_reset(key);
1840
364
    EVP_MD_CTX_free(mdctx);
1841
364
    return ret;
1842
364
}
1843
1844
/* Parse input as a new private key */
1845
int ossl_ml_kem_parse_private_key(const uint8_t *in, size_t len,
1846
                                  ML_KEM_KEY *key)
1847
39
{
1848
39
    EVP_MD_CTX *mdctx = NULL;
1849
39
    const ML_KEM_VINFO *vinfo;
1850
39
    int ret = 0;
1851
1852
    /* Keys with key material are immutable */
1853
39
    if (key == NULL
1854
39
        || ossl_ml_kem_have_pubkey(key)
1855
39
        || ossl_ml_kem_have_dkenc(key))
1856
0
        return 0;
1857
39
    vinfo = key->vinfo;
1858
1859
39
    if (len != vinfo->prvkey_bytes
1860
39
        || (mdctx = EVP_MD_CTX_new()) == NULL)
1861
0
        return 0;
1862
1863
39
    if (add_storage(OPENSSL_malloc(vinfo->puballoc),
1864
39
                    OPENSSL_secure_malloc(vinfo->prvalloc), 1, key))
1865
39
        ret = parse_prvkey(in, mdctx, key);
1866
1867
39
    if (!ret)
1868
39
        ossl_ml_kem_key_reset(key);
1869
39
    EVP_MD_CTX_free(mdctx);
1870
39
    return ret;
1871
39
}
1872
1873
/*
1874
 * Generate a new keypair, either from the saved seed (when non-null), or from
1875
 * the RNG.
1876
 */
1877
int ossl_ml_kem_genkey(uint8_t *pubenc, size_t publen, ML_KEM_KEY *key)
1878
42.1k
{
1879
42.1k
    uint8_t seed[ML_KEM_SEED_BYTES];
1880
42.1k
    EVP_MD_CTX *mdctx = NULL;
1881
42.1k
    const ML_KEM_VINFO *vinfo;
1882
42.1k
    int ret = 0;
1883
1884
42.1k
    if (key == NULL
1885
42.1k
        || ossl_ml_kem_have_pubkey(key)
1886
42.1k
        || ossl_ml_kem_have_dkenc(key))
1887
0
        return 0;
1888
42.1k
    vinfo = key->vinfo;
1889
1890
42.1k
    if (pubenc != NULL && publen != vinfo->pubkey_bytes)
1891
0
        return 0;
1892
1893
42.1k
    if (key->seedbuf != NULL) {
1894
25
        if (!ossl_ml_kem_encode_seed(seed, sizeof(seed), key))
1895
0
            return 0;
1896
25
        ossl_ml_kem_key_reset(key);
1897
42.1k
    } else if (RAND_priv_bytes_ex(key->libctx, seed, sizeof(seed),
1898
42.1k
                                  key->vinfo->secbits) <= 0) {
1899
0
        return 0;
1900
0
    }
1901
1902
42.1k
    if ((mdctx = EVP_MD_CTX_new()) == NULL)
1903
0
        return 0;
1904
1905
    /*
1906
     * Data derived from (d, z) defaults secret, and to avoid side-channel
1907
     * leaks should not influence control flow.
1908
     */
1909
42.1k
    CONSTTIME_SECRET(seed, ML_KEM_SEED_BYTES);
1910
1911
42.1k
    if (add_storage(OPENSSL_malloc(vinfo->puballoc),
1912
42.1k
                    OPENSSL_secure_malloc(vinfo->prvalloc), 1, key))
1913
42.1k
        ret = genkey(seed, mdctx, pubenc, key);
1914
42.1k
    OPENSSL_cleanse(seed, sizeof(seed));
1915
1916
    /* Declassify secret inputs and derived outputs before returning control */
1917
42.1k
    CONSTTIME_DECLASSIFY(seed, ML_KEM_SEED_BYTES);
1918
1919
42.1k
    EVP_MD_CTX_free(mdctx);
1920
42.1k
    if (!ret) {
1921
0
        ossl_ml_kem_key_reset(key);
1922
0
        return 0;
1923
0
    }
1924
1925
    /* The public components are already declassified */
1926
42.1k
    CONSTTIME_DECLASSIFY(key->s, vinfo->rank * sizeof(scalar));
1927
42.1k
    CONSTTIME_DECLASSIFY(key->z, 2 * ML_KEM_RANDOM_BYTES);
1928
42.1k
    return 1;
1929
42.1k
}
1930
1931
/*
1932
 * FIPS 203, Section 6.2, Algorithm 17: ML-KEM.Encaps_internal
1933
 * This is the deterministic version with randomness supplied externally.
1934
 */
1935
int ossl_ml_kem_encap_seed(uint8_t *ctext, size_t clen,
1936
                           uint8_t *shared_secret, size_t slen,
1937
                           const uint8_t *entropy, size_t elen,
1938
                           const ML_KEM_KEY *key)
1939
116
{
1940
116
    const ML_KEM_VINFO *vinfo;
1941
116
    EVP_MD_CTX *mdctx;
1942
116
    int ret = 0;
1943
1944
116
    if (key == NULL || !ossl_ml_kem_have_pubkey(key))
1945
0
        return 0;
1946
116
    vinfo = key->vinfo;
1947
1948
116
    if (ctext == NULL || clen != vinfo->ctext_bytes
1949
116
        || shared_secret == NULL || slen != ML_KEM_SHARED_SECRET_BYTES
1950
116
        || entropy == NULL || elen != ML_KEM_RANDOM_BYTES
1951
116
        || (mdctx = EVP_MD_CTX_new()) == NULL)
1952
0
        return 0;
1953
    /*
1954
     * Data derived from the encap entropy defaults secret, and to avoid
1955
     * side-channel leaks should not influence control flow.
1956
     */
1957
116
    CONSTTIME_SECRET(entropy, elen);
1958
1959
    /*-
1960
     * This avoids the need to handle allocation failures for two (max 2KB
1961
     * each) vectors, that are never retained on return from this function.
1962
     * We stack-allocate these.
1963
     */
1964
116
#   define case_encap_seed(bits)                                            \
1965
116
    case EVP_PKEY_ML_KEM_##bits:                                            \
1966
116
        {                                                                   \
1967
116
            scalar tmp[2 * ML_KEM_##bits##_RANK];                           \
1968
116
                                                                            \
1969
116
            ret = encap(ctext, shared_secret, entropy, tmp, mdctx, key);    \
1970
116
            OPENSSL_cleanse((void *)tmp, sizeof(tmp));                      \
1971
116
            break;                                                          \
1972
116
        }
1973
116
    switch (vinfo->evp_type) {
1974
43
    case_encap_seed(512);
1975
49
    case_encap_seed(768);
1976
24
    case_encap_seed(1024);
1977
116
    }
1978
116
#   undef case_encap_seed
1979
1980
    /* Declassify secret inputs and derived outputs before returning control */
1981
116
    CONSTTIME_DECLASSIFY(entropy, elen);
1982
116
    CONSTTIME_DECLASSIFY(ctext, clen);
1983
116
    CONSTTIME_DECLASSIFY(shared_secret, slen);
1984
1985
116
    EVP_MD_CTX_free(mdctx);
1986
116
    return ret;
1987
116
}
1988
1989
int ossl_ml_kem_encap_rand(uint8_t *ctext, size_t clen,
1990
                           uint8_t *shared_secret, size_t slen,
1991
                           const ML_KEM_KEY *key)
1992
116
{
1993
116
    uint8_t r[ML_KEM_RANDOM_BYTES];
1994
1995
116
    if (key == NULL)
1996
0
        return 0;
1997
1998
116
    if (RAND_bytes_ex(key->libctx, r, ML_KEM_RANDOM_BYTES,
1999
116
                      key->vinfo->secbits) < 1)
2000
0
        return 0;
2001
2002
116
    return ossl_ml_kem_encap_seed(ctext, clen, shared_secret, slen,
2003
116
                                  r, sizeof(r), key);
2004
116
}
2005
2006
int ossl_ml_kem_decap(uint8_t *shared_secret, size_t slen,
2007
                      const uint8_t *ctext, size_t clen,
2008
                      const ML_KEM_KEY *key)
2009
100
{
2010
100
    const ML_KEM_VINFO *vinfo;
2011
100
    EVP_MD_CTX *mdctx;
2012
100
    int ret = 0;
2013
#if defined(OPENSSL_CONSTANT_TIME_VALIDATION)
2014
    int classify_bytes;
2015
#endif
2016
2017
    /* Need a private key here */
2018
100
    if (!ossl_ml_kem_have_prvkey(key))
2019
0
        return 0;
2020
100
    vinfo = key->vinfo;
2021
2022
100
    if (shared_secret == NULL || slen != ML_KEM_SHARED_SECRET_BYTES
2023
100
        || ctext == NULL || clen != vinfo->ctext_bytes
2024
100
        || (mdctx = EVP_MD_CTX_new()) == NULL) {
2025
0
        (void)RAND_bytes_ex(key->libctx, shared_secret,
2026
0
                            ML_KEM_SHARED_SECRET_BYTES, vinfo->secbits);
2027
0
        return 0;
2028
0
    }
2029
#if defined(OPENSSL_CONSTANT_TIME_VALIDATION)
2030
    /*
2031
     * Data derived from |s| and |z| defaults secret, and to avoid side-channel
2032
     * leaks should not influence control flow.
2033
     */
2034
    classify_bytes = 2 * sizeof(scalar) + ML_KEM_RANDOM_BYTES;
2035
    CONSTTIME_SECRET(key->s, classify_bytes);
2036
#endif
2037
2038
    /*-
2039
     * This avoids the need to handle allocation failures for two (max 2KB
2040
     * each) vectors and an encoded ciphertext (max 1568 bytes), that are never
2041
     * retained on return from this function.
2042
     * We stack-allocate these.
2043
     */
2044
100
#   define case_decap(bits)                                             \
2045
100
    case EVP_PKEY_ML_KEM_##bits:                                        \
2046
100
        {                                                               \
2047
100
            uint8_t cbuf[CTEXT_BYTES(bits)];                            \
2048
100
            scalar tmp[2 * ML_KEM_##bits##_RANK];                       \
2049
100
                                                                        \
2050
100
            ret = decap(shared_secret, ctext, cbuf, tmp, mdctx, key);   \
2051
100
            OPENSSL_cleanse((void *)tmp, sizeof(tmp));                  \
2052
100
            break;                                                      \
2053
100
        }
2054
100
    switch (vinfo->evp_type) {
2055
43
    case_decap(512);
2056
33
    case_decap(768);
2057
24
    case_decap(1024);
2058
100
    }
2059
2060
    /* Declassify secret inputs and derived outputs before returning control */
2061
100
    CONSTTIME_DECLASSIFY(key->s, classify_bytes);
2062
100
    CONSTTIME_DECLASSIFY(shared_secret, slen);
2063
100
    EVP_MD_CTX_free(mdctx);
2064
2065
100
    return ret;
2066
100
#   undef case_decap
2067
100
}
2068
2069
int ossl_ml_kem_pubkey_cmp(const ML_KEM_KEY *key1, const ML_KEM_KEY *key2)
2070
122
{
2071
    /*
2072
     * This handles any unexpected differences in the ML-KEM variant rank,
2073
     * giving different key component structures, barring SHA3-256 hash
2074
     * collisions, the keys are the same size.
2075
     */
2076
122
    if (ossl_ml_kem_have_pubkey(key1) && ossl_ml_kem_have_pubkey(key2))
2077
122
        return memcmp(key1->pkhash, key2->pkhash, ML_KEM_PKHASH_BYTES) == 0;
2078
2079
    /*
2080
     * No match if just one of the public keys is not available, otherwise both
2081
     * are unavailable, and for now such keys are considered equal.
2082
     */
2083
0
    return (!(ossl_ml_kem_have_pubkey(key1) ^ ossl_ml_kem_have_pubkey(key2)));
2084
122
}