Coverage Report

Created: 2025-06-13 06:56

/src/openssl/crypto/ml_kem/ml_kem.c
Line
Count
Source (jump to first uncovered line)
1
/*
2
 * Copyright 2024-2025 The OpenSSL Project Authors. All Rights Reserved.
3
 *
4
 * Licensed under the Apache License 2.0 (the "License").  You may not use
5
 * this file except in compliance with the License.  You can obtain a copy
6
 * in the file LICENSE in the source distribution or at
7
 * https://www.openssl.org/source/license.html
8
 */
9
10
#include <openssl/byteorder.h>
11
#include <openssl/rand.h>
12
#include "crypto/ml_kem.h"
13
#include "internal/common.h"
14
#include "internal/constant_time.h"
15
#include "internal/sha3.h"
16
17
#if defined(OPENSSL_CONSTANT_TIME_VALIDATION)
18
#include <valgrind/memcheck.h>
19
#endif
20
21
#if ML_KEM_SEED_BYTES != ML_KEM_SHARED_SECRET_BYTES + ML_KEM_RANDOM_BYTES
22
# error "ML-KEM keygen seed length != shared secret + random bytes length"
23
#endif
24
#if ML_KEM_SHARED_SECRET_BYTES != ML_KEM_RANDOM_BYTES
25
# error "Invalid unequal lengths of ML-KEM shared secret and random inputs"
26
#endif
27
28
#if UINT_MAX < UINT32_MAX
29
# error "Unsupported compiler: sizeof(unsigned int) < sizeof(uint32_t)"
30
#endif
31
32
/* Handy function-like bit-extraction macros */
33
0
#define bit0(b)     ((b) & 1)
34
0
#define bitn(n, b)  (((b) >> n) & 1)
35
36
/*
37
 * 12 bits are sufficient to losslessly represent values in [0, q-1].
38
 * INVERSE_DEGREE is (n/2)^-1 mod q; used in inverse NTT.
39
 */
40
0
#define DEGREE          ML_KEM_DEGREE
41
#define INVERSE_DEGREE  (ML_KEM_PRIME - 2 * 13)
42
#define LOG2PRIME       12
43
#define BARRETT_SHIFT   (2 * LOG2PRIME)
44
45
#ifdef SHA3_BLOCKSIZE
46
# define SHAKE128_BLOCKSIZE SHA3_BLOCKSIZE(128)
47
#endif
48
49
/*
50
 * Return whether a value that can only be 0 or 1 is non-zero, in constant time
51
 * in practice!  The return value is a mask that is all ones if true, and all
52
 * zeros otherwise (twos-complement arithmentic assumed for unsigned values).
53
 *
54
 * Although this is used in constant-time selects, we omit a value barrier
55
 * here.  Value barriers impede auto-vectorization (likely because it forces
56
 * the value to transit through a general-purpose register). On AArch64, this
57
 * is a difference of 2x.
58
 *
59
 * We usually add value barriers to selects because Clang turns consecutive
60
 * selects with the same condition into a branch instead of CMOV/CSEL. This
61
 * condition does not occur in Kyber, so omitting it seems to be safe so far,
62
 * but see |cbd_2|, |cbd_3|, where reduction needs to be specialised to the
63
 * sign of the input, rather than adding |q| in advance, and using the generic
64
 * |reduce_once|.  (David Benjamin, Chromium)
65
 */
66
#if 0
67
# define constish_time_non_zero(b) (~constant_time_is_zero(b));
68
#else
69
0
# define constish_time_non_zero(b) (0u - (b))
70
#endif
71
72
/*
73
 * The scalar rejection-sampling buffer size needs to be a multiple of 12, but
74
 * is otherwise arbitrary, the preferred block size matches the internal buffer
75
 * size of SHAKE128, avoiding internal buffering and copying in SHAKE128. That
76
 * block size of (1600 - 256)/8 bytes, or 168, just happens to divide by 12!
77
 *
78
 * If the blocksize is unknown, or is not divisible by 12, 168 is used as a
79
 * fallback.
80
 */
81
#if defined(SHAKE128_BLOCKSIZE) && (SHAKE128_BLOCKSIZE) % 12 == 0
82
# define SCALAR_SAMPLING_BUFSIZE (SHAKE128_BLOCKSIZE)
83
#else
84
# define SCALAR_SAMPLING_BUFSIZE 168
85
#endif
86
87
/*
88
 * Structure of keys
89
 */
90
typedef struct ossl_ml_kem_scalar_st {
91
    /* On every function entry and exit, 0 <= c[i] < ML_KEM_PRIME. */
92
    uint16_t c[ML_KEM_DEGREE];
93
} scalar;
94
95
/* Key material allocation layout */
96
#define DECLARE_ML_KEM_KEYDATA(name, rank, private_sz) \
97
    struct name##_alloc { \
98
        /* Public vector |t| */ \
99
        scalar tbuf[(rank)]; \
100
        /* Pre-computed matrix |m| (FIPS 203 |A| transpose) */ \
101
        scalar mbuf[(rank)*(rank)] \
102
        /* optional private key data */ \
103
        private_sz \
104
    }
105
106
/* Declare variant-specific public and private storage */
107
#define DECLARE_ML_KEM_VARIANT_KEYDATA(bits) \
108
    DECLARE_ML_KEM_KEYDATA(pubkey_##bits, ML_KEM_##bits##_RANK,;); \
109
    DECLARE_ML_KEM_KEYDATA(prvkey_##bits, ML_KEM_##bits##_RANK,;\
110
        scalar sbuf[ML_KEM_##bits##_RANK]; \
111
        uint8_t zbuf[2 * ML_KEM_RANDOM_BYTES];)
112
DECLARE_ML_KEM_VARIANT_KEYDATA(512);
113
DECLARE_ML_KEM_VARIANT_KEYDATA(768);
114
DECLARE_ML_KEM_VARIANT_KEYDATA(1024);
115
#undef DECLARE_ML_KEM_VARIANT_KEYDATA
116
#undef DECLARE_ML_KEM_KEYDATA
117
118
typedef __owur
119
int (*CBD_FUNC)(scalar *out, uint8_t in[ML_KEM_RANDOM_BYTES + 1],
120
                EVP_MD_CTX *mdctx, const ML_KEM_KEY *key);
121
static void scalar_encode(uint8_t *out, const scalar *s, int bits);
122
123
/*
124
 * The wire-form of a losslessly encoded vector uses 12-bits per element.
125
 *
126
 * The wire-form public key consists of the lossless encoding of the public
127
 * vector |t|, followed by the public seed |rho|.
128
 *
129
 * Our serialised private key concatenates serialisations of the private vector
130
 * |s|, the public key, the public key hash, and the failure secret |z|.
131
 */
132
#define VECTOR_BYTES(b)     ((3 * DEGREE / 2) * ML_KEM_##b##_RANK)
133
#define PUBKEY_BYTES(b)     (VECTOR_BYTES(b) + ML_KEM_RANDOM_BYTES)
134
#define PRVKEY_BYTES(b)     (2 * PUBKEY_BYTES(b) + ML_KEM_PKHASH_BYTES)
135
136
/*
137
 * Encapsulation produces a vector "u" and a scalar "v", whose coordinates
138
 * (numbers modulo the ML-KEM prime "q") are lossily encoded using as "du" and
139
 * "dv" bits, respectively.  This encoding is the ciphertext input for
140
 * decapsulation.
141
 */
142
#define U_VECTOR_BYTES(b)   ((DEGREE / 8) * ML_KEM_##b##_DU * ML_KEM_##b##_RANK)
143
#define V_SCALAR_BYTES(b)   ((DEGREE / 8) * ML_KEM_##b##_DV)
144
#define CTEXT_BYTES(b)      (U_VECTOR_BYTES(b) + V_SCALAR_BYTES(b))
145
146
#if defined(OPENSSL_CONSTANT_TIME_VALIDATION)
147
148
/*
149
 * CONSTTIME_SECRET takes a pointer and a number of bytes and marks that region
150
 * of memory as secret. Secret data is tracked as it flows to registers and
151
 * other parts of a memory. If secret data is used as a condition for a branch,
152
 * or as a memory index, it will trigger warnings in valgrind.
153
 */
154
# define CONSTTIME_SECRET(ptr, len) VALGRIND_MAKE_MEM_UNDEFINED(ptr, len)
155
156
/*
157
 * CONSTTIME_DECLASSIFY takes a pointer and a number of bytes and marks that
158
 * region of memory as public. Public data is not subject to constant-time
159
 * rules.
160
 */
161
# define CONSTTIME_DECLASSIFY(ptr, len) VALGRIND_MAKE_MEM_DEFINED(ptr, len)
162
163
#else
164
165
# define CONSTTIME_SECRET(ptr, len)
166
# define CONSTTIME_DECLASSIFY(ptr, len)
167
168
#endif
169
170
/*
171
 * Indices of slots in the vinfo tables below
172
 */
173
0
#define ML_KEM_512_VINFO    0
174
0
#define ML_KEM_768_VINFO    1
175
0
#define ML_KEM_1024_VINFO   2
176
177
/*
178
 * Per-variant fixed parameters
179
 */
180
static const ML_KEM_VINFO vinfo_map[3] = {
181
    {
182
        "ML-KEM-512",
183
        PRVKEY_BYTES(512),
184
        sizeof(struct prvkey_512_alloc),
185
        PUBKEY_BYTES(512),
186
        sizeof(struct pubkey_512_alloc),
187
        CTEXT_BYTES(512),
188
        VECTOR_BYTES(512),
189
        U_VECTOR_BYTES(512),
190
        EVP_PKEY_ML_KEM_512,
191
        ML_KEM_512_BITS,
192
        ML_KEM_512_RANK,
193
        ML_KEM_512_DU,
194
        ML_KEM_512_DV,
195
        ML_KEM_512_SECBITS,
196
        ML_KEM_512_SECURITY_CATEGORY
197
    },
198
    {
199
        "ML-KEM-768",
200
        PRVKEY_BYTES(768),
201
        sizeof(struct prvkey_768_alloc),
202
        PUBKEY_BYTES(768),
203
        sizeof(struct pubkey_768_alloc),
204
        CTEXT_BYTES(768),
205
        VECTOR_BYTES(768),
206
        U_VECTOR_BYTES(768),
207
        EVP_PKEY_ML_KEM_768,
208
        ML_KEM_768_BITS,
209
        ML_KEM_768_RANK,
210
        ML_KEM_768_DU,
211
        ML_KEM_768_DV,
212
        ML_KEM_768_SECBITS,
213
        ML_KEM_768_SECURITY_CATEGORY
214
    },
215
    {
216
        "ML-KEM-1024",
217
        PRVKEY_BYTES(1024),
218
        sizeof(struct prvkey_1024_alloc),
219
        PUBKEY_BYTES(1024),
220
        sizeof(struct pubkey_1024_alloc),
221
        CTEXT_BYTES(1024),
222
        VECTOR_BYTES(1024),
223
        U_VECTOR_BYTES(1024),
224
        EVP_PKEY_ML_KEM_1024,
225
        ML_KEM_1024_BITS,
226
        ML_KEM_1024_RANK,
227
        ML_KEM_1024_DU,
228
        ML_KEM_1024_DV,
229
        ML_KEM_1024_SECBITS,
230
        ML_KEM_1024_SECURITY_CATEGORY
231
    }
232
};
233
234
/*
235
 * Remainders modulo `kPrime`, for sufficiently small inputs, are computed in
236
 * constant time via Barrett reduction, and a final call to reduce_once(),
237
 * which reduces inputs that are at most 2*kPrime and is also constant-time.
238
 */
239
static const int kPrime = ML_KEM_PRIME;
240
static const unsigned int kBarrettShift = BARRETT_SHIFT;
241
static const size_t   kBarrettMultiplier = (1 << BARRETT_SHIFT) / ML_KEM_PRIME;
242
static const uint16_t kHalfPrime = (ML_KEM_PRIME - 1) / 2;
243
static const uint16_t kInverseDegree = INVERSE_DEGREE;
244
245
/*
246
 * Python helper:
247
 *
248
 * p = 3329
249
 * def bitreverse(i):
250
 *     ret = 0
251
 *     for n in range(7):
252
 *         bit = i & 1
253
 *         ret <<= 1
254
 *         ret |= bit
255
 *         i >>= 1
256
 *     return ret
257
 */
258
259
/*-
260
 * First precomputed array from Appendix A of FIPS 203, or else Python:
261
 * kNTTRoots = [pow(17, bitreverse(i), p) for i in range(128)]
262
 */
263
static const uint16_t kNTTRoots[128] = {
264
    1,    1729, 2580, 3289, 2642, 630,  1897, 848,
265
    1062, 1919, 193,  797,  2786, 3260, 569,  1746,
266
    296,  2447, 1339, 1476, 3046, 56,   2240, 1333,
267
    1426, 2094, 535,  2882, 2393, 2879, 1974, 821,
268
    289,  331,  3253, 1756, 1197, 2304, 2277, 2055,
269
    650,  1977, 2513, 632,  2865, 33,   1320, 1915,
270
    2319, 1435, 807,  452,  1438, 2868, 1534, 2402,
271
    2647, 2617, 1481, 648,  2474, 3110, 1227, 910,
272
    17,   2761, 583,  2649, 1637, 723,  2288, 1100,
273
    1409, 2662, 3281, 233,  756,  2156, 3015, 3050,
274
    1703, 1651, 2789, 1789, 1847, 952,  1461, 2687,
275
    939,  2308, 2437, 2388, 733,  2337, 268,  641,
276
    1584, 2298, 2037, 3220, 375,  2549, 2090, 1645,
277
    1063, 319,  2773, 757,  2099, 561,  2466, 2594,
278
    2804, 1092, 403,  1026, 1143, 2150, 2775, 886,
279
    1722, 1212, 1874, 1029, 2110, 2935, 885,  2154,
280
};
281
282
/*
283
 * InverseNTTRoots = [pow(17, -bitreverse(i), p) for i in range(128)]
284
 * Listed in order of use in the inverse NTT loop (index 0 is skipped):
285
 *
286
 *  0, 64, 65, ..., 127, 32, 33, ..., 63, 16, 17, ..., 31, 8, 9, ...
287
 */
288
static const uint16_t kInverseNTTRoots[128] = {
289
    1,    1175, 2444, 394,  1219, 2300, 1455, 2117,
290
    1607, 2443, 554,  1179, 2186, 2303, 2926, 2237,
291
    525,  735,  863,  2768, 1230, 2572, 556,  3010,
292
    2266, 1684, 1239, 780,  2954, 109,  1292, 1031,
293
    1745, 2688, 3061, 992,  2596, 941,  892,  1021,
294
    2390, 642,  1868, 2377, 1482, 1540, 540,  1678,
295
    1626, 279,  314,  1173, 2573, 3096, 48,   667,
296
    1920, 2229, 1041, 2606, 1692, 680,  2746, 568,
297
    3312, 2419, 2102, 219,  855,  2681, 1848, 712,
298
    682,  927,  1795, 461,  1891, 2877, 2522, 1894,
299
    1010, 1414, 2009, 3296, 464,  2697, 816,  1352,
300
    2679, 1274, 1052, 1025, 2132, 1573, 76,   2998,
301
    3040, 2508, 1355, 450,  936,  447,  2794, 1235,
302
    1903, 1996, 1089, 3273, 283,  1853, 1990, 882,
303
    3033, 1583, 2760, 69,   543,  2532, 3136, 1410,
304
    2267, 2481, 1432, 2699, 687,  40,   749,  1600,
305
};
306
307
/*
308
 * Second precomputed array from Appendix A of FIPS 203 (normalised positive),
309
 * or else Python:
310
 * ModRoots = [pow(17, 2*bitreverse(i) + 1, p) for i in range(128)]
311
 */
312
static const uint16_t kModRoots[128] = {
313
    17,   3312, 2761, 568,  583,  2746, 2649, 680,  1637, 1692, 723,  2606,
314
    2288, 1041, 1100, 2229, 1409, 1920, 2662, 667,  3281, 48,   233,  3096,
315
    756,  2573, 2156, 1173, 3015, 314,  3050, 279,  1703, 1626, 1651, 1678,
316
    2789, 540,  1789, 1540, 1847, 1482, 952,  2377, 1461, 1868, 2687, 642,
317
    939,  2390, 2308, 1021, 2437, 892,  2388, 941,  733,  2596, 2337, 992,
318
    268,  3061, 641,  2688, 1584, 1745, 2298, 1031, 2037, 1292, 3220, 109,
319
    375,  2954, 2549, 780,  2090, 1239, 1645, 1684, 1063, 2266, 319,  3010,
320
    2773, 556,  757,  2572, 2099, 1230, 561,  2768, 2466, 863,  2594, 735,
321
    2804, 525,  1092, 2237, 403,  2926, 1026, 2303, 1143, 2186, 2150, 1179,
322
    2775, 554,  886,  2443, 1722, 1607, 1212, 2117, 1874, 1455, 1029, 2300,
323
    2110, 1219, 2935, 394,  885,  2444, 2154, 1175,
324
};
325
326
/*
327
 * single_keccak hashes |inlen| bytes from |in| and writes |outlen| bytes of
328
 * output to |out|. If the |md| specifies a fixed-output function, like
329
 * SHA3-256, then |outlen| must be the correct length for that function.
330
 */
331
static __owur
332
int single_keccak(uint8_t *out, size_t outlen, const uint8_t *in, size_t inlen,
333
                  EVP_MD_CTX *mdctx)
334
0
{
335
0
    unsigned int sz = (unsigned int) outlen;
336
337
0
    if (!EVP_DigestUpdate(mdctx, in, inlen))
338
0
        return 0;
339
0
    if (EVP_MD_xof(EVP_MD_CTX_get0_md(mdctx)))
340
0
        return EVP_DigestFinalXOF(mdctx, out, outlen);
341
0
    return EVP_DigestFinal_ex(mdctx, out, &sz)
342
0
        && ossl_assert((size_t) sz == outlen);
343
0
}
344
345
/*
346
 * FIPS 203, Section 4.1, equation (4.3): PRF. Takes 32+1 input bytes, and uses
347
 * SHAKE256 to produce the input to SamplePolyCBD_eta: FIPS 203, algorithm 8.
348
 */
349
static __owur
350
int prf(uint8_t *out, size_t len, const uint8_t in[ML_KEM_RANDOM_BYTES + 1],
351
        EVP_MD_CTX *mdctx, const ML_KEM_KEY *key)
352
0
{
353
0
    return EVP_DigestInit_ex(mdctx, key->shake256_md, NULL)
354
0
        && single_keccak(out, len, in, ML_KEM_RANDOM_BYTES + 1, mdctx);
355
0
}
356
357
/*
358
 * FIPS 203, Section 4.1, equation (4.4): H.  SHA3-256 hash of a variable
359
 * length input, producing 32 bytes of output.
360
 */
361
static __owur
362
int hash_h(uint8_t out[ML_KEM_PKHASH_BYTES], const uint8_t *in, size_t len,
363
           EVP_MD_CTX *mdctx, const ML_KEM_KEY *key)
364
0
{
365
0
    return EVP_DigestInit_ex(mdctx, key->sha3_256_md, NULL)
366
0
        && single_keccak(out, ML_KEM_PKHASH_BYTES, in, len, mdctx);
367
0
}
368
369
/* Incremental hash_h of expanded public key */
370
static int
371
hash_h_pubkey(uint8_t pkhash[ML_KEM_PKHASH_BYTES],
372
              EVP_MD_CTX *mdctx, ML_KEM_KEY *key)
373
0
{
374
0
    const ML_KEM_VINFO *vinfo = key->vinfo;
375
0
    const scalar *t = key->t, *end = t + vinfo->rank;
376
0
    unsigned int sz;
377
378
0
    if (!EVP_DigestInit_ex(mdctx, key->sha3_256_md, NULL))
379
0
        return 0;
380
381
0
    do {
382
0
        uint8_t buf[3 * DEGREE / 2];
383
384
0
        scalar_encode(buf, t++, 12);
385
0
        if (!EVP_DigestUpdate(mdctx, buf, sizeof(buf)))
386
0
            return 0;
387
0
    } while (t < end);
388
389
0
    if (!EVP_DigestUpdate(mdctx, key->rho, ML_KEM_RANDOM_BYTES))
390
0
        return 0;
391
0
    return EVP_DigestFinal_ex(mdctx, pkhash, &sz)
392
0
        && ossl_assert(sz == ML_KEM_PKHASH_BYTES);
393
0
}
394
395
/*
396
 * FIPS 203, Section 4.1, equation (4.5): G.  SHA3-512 hash of a variable
397
 * length input, producing 64 bytes of output, in particular the seeds
398
 * (d,z) for key generation.
399
 */
400
static __owur
401
int hash_g(uint8_t out[ML_KEM_SEED_BYTES], const uint8_t *in, size_t len,
402
           EVP_MD_CTX *mdctx, const ML_KEM_KEY *key)
403
0
{
404
0
    return EVP_DigestInit_ex(mdctx, key->sha3_512_md, NULL)
405
0
        && single_keccak(out, ML_KEM_SEED_BYTES, in, len, mdctx);
406
0
}
407
408
/*
409
 * FIPS 203, Section 4.1, equation (4.4): J. SHAKE256 taking a variable length
410
 * input to compute a 32-byte implicit rejection shared secret, of the same
411
 * length as the expected shared secret.  (Computed even on success to avoid
412
 * side-channel leaks).
413
 */
414
static __owur
415
int kdf(uint8_t out[ML_KEM_SHARED_SECRET_BYTES],
416
        const uint8_t z[ML_KEM_RANDOM_BYTES],
417
        const uint8_t *ctext, size_t len,
418
        EVP_MD_CTX *mdctx, const ML_KEM_KEY *key)
419
0
{
420
0
    return EVP_DigestInit_ex(mdctx, key->shake256_md, NULL)
421
0
        && EVP_DigestUpdate(mdctx, z, ML_KEM_RANDOM_BYTES)
422
0
        && EVP_DigestUpdate(mdctx, ctext, len)
423
0
        && EVP_DigestFinalXOF(mdctx, out, ML_KEM_SHARED_SECRET_BYTES);
424
0
}
425
426
/*
427
 * FIPS 203, Section 4.2.2, Algorithm 7: "SampleNTT" (steps 3-17, steps 1, 2
428
 * are performed by the caller). Rejection-samples a Keccak stream to get
429
 * uniformly distributed elements in the range [0,q). This is used for matrix
430
 * expansion and only operates on public inputs.
431
 */
432
static __owur
433
int sample_scalar(scalar *out, EVP_MD_CTX *mdctx)
434
0
{
435
0
    uint16_t *curr = out->c, *endout = curr + DEGREE;
436
0
    uint8_t buf[SCALAR_SAMPLING_BUFSIZE], *in;
437
0
    uint8_t *endin = buf + sizeof(buf);
438
0
    uint16_t d;
439
0
    uint8_t b1, b2, b3;
440
441
0
    do {
442
0
        if (!EVP_DigestSqueeze(mdctx, in = buf, sizeof(buf)))
443
0
            return 0;
444
0
        do {
445
0
            b1 = *in++;
446
0
            b2 = *in++;
447
0
            b3 = *in++;
448
449
0
            if (curr >= endout)
450
0
                break;
451
0
            if ((d = ((b2 & 0x0f) << 8) + b1) < kPrime)
452
0
                *curr++ = d;
453
0
            if (curr >= endout)
454
0
                break;
455
0
            if ((d = (b3 << 4) + (b2 >> 4)) < kPrime)
456
0
                *curr++ = d;
457
0
        } while (in < endin);
458
0
    } while (curr < endout);
459
0
    return 1;
460
0
}
461
462
/*-
463
 * reduce_once reduces 0 <= x < 2*kPrime, mod kPrime.
464
 *
465
 * Subtract |q| if the input is larger, without exposing a side-channel,
466
 * avoiding the "clangover" attack.  See |constish_time_non_zero| for a
467
 * discussion on why the value barrier is by default omitted.
468
 */
469
static __owur uint16_t reduce_once(uint16_t x)
470
0
{
471
0
    const uint16_t subtracted = x - kPrime;
472
0
    uint16_t mask = constish_time_non_zero(subtracted >> 15);
473
474
0
    return (mask & x) | (~mask & subtracted);
475
0
}
476
477
/*
478
 * Constant-time reduce x mod kPrime using Barrett reduction. x must be less
479
 * than kPrime + 2 * kPrime^2.  This is sufficient to reduce a product of
480
 * two already reduced u_int16 values, in fact it is sufficient for each
481
 * to be less than 2^12, because (kPrime * (2 * kPrime + 1)) > 2^24.
482
 */
483
static __owur uint16_t reduce(uint32_t x)
484
0
{
485
0
    uint64_t product = (uint64_t)x * kBarrettMultiplier;
486
0
    uint32_t quotient = (uint32_t)(product >> kBarrettShift);
487
0
    uint32_t remainder = x - quotient * kPrime;
488
489
0
    return reduce_once(remainder);
490
0
}
491
492
/* Multiply a scalar by a constant. */
493
static void scalar_mult_const(scalar *s, uint16_t a)
494
0
{
495
0
    uint16_t *curr = s->c, *end = curr + DEGREE, tmp;
496
497
0
    do {
498
0
        tmp = reduce(*curr * a);
499
0
        *curr++ = tmp;
500
0
    } while (curr < end);
501
0
}
502
503
/*-
504
 * FIPS 203, Section 4.3, Algoritm 9: "NTT".
505
 * In-place number theoretic transform of a given scalar.  Note that ML-KEM's
506
 * kPrime 3329 does not have a 512th root of unity, so this transform leaves
507
 * off the last iteration of the usual FFT code, with the 128 relevant roots of
508
 * unity being stored in NTTRoots.  This means the output should be seen as 128
509
 * elements in GF(3329^2), with the coefficients of the elements being
510
 * consecutive entries in |s->c|.
511
 */
512
static void scalar_ntt(scalar *s)
513
0
{
514
0
    const uint16_t *roots = kNTTRoots;
515
0
    uint16_t *end = s->c + DEGREE;
516
0
    int offset = DEGREE / 2;
517
518
0
    do {
519
0
        uint16_t *curr = s->c, *peer;
520
521
0
        do {
522
0
            uint16_t *pause = curr + offset, even, odd;
523
0
            uint32_t zeta = *++roots;
524
525
0
            peer = pause;
526
0
            do {
527
0
                even = *curr;
528
0
                odd = reduce(*peer * zeta);
529
0
                *peer++ = reduce_once(even - odd + kPrime);
530
0
                *curr++ = reduce_once(odd + even);
531
0
            } while (curr < pause);
532
0
        } while ((curr = peer) < end);
533
0
    } while ((offset >>= 1) >= 2);
534
0
}
535
536
/*-
537
 * FIPS 203, Section 4.3, Algoritm 10: "NTT^(-1)".
538
 * In-place inverse number theoretic transform of a given scalar, with pairs of
539
 * entries of s->v being interpreted as elements of GF(3329^2). Just as with
540
 * the number theoretic transform, this leaves off the first step of the normal
541
 * iFFT to account for the fact that 3329 does not have a 512th root of unity,
542
 * using the precomputed 128 roots of unity stored in InverseNTTRoots.
543
 */
544
static void scalar_inverse_ntt(scalar *s)
545
0
{
546
0
    const uint16_t *roots = kInverseNTTRoots;
547
0
    uint16_t *end = s->c + DEGREE;
548
0
    int offset = 2;
549
550
0
    do {
551
0
        uint16_t *curr = s->c, *peer;
552
553
0
        do {
554
0
            uint16_t *pause = curr + offset, even, odd;
555
0
            uint32_t zeta = *++roots;
556
557
0
            peer = pause;
558
0
            do {
559
0
                even = *curr;
560
0
                odd = *peer;
561
0
                *peer++ = reduce(zeta * (even - odd + kPrime));
562
0
                *curr++ = reduce_once(odd + even);
563
0
            } while (curr < pause);
564
0
        } while ((curr = peer) < end);
565
0
    } while ((offset <<= 1) < DEGREE);
566
0
    scalar_mult_const(s, kInverseDegree);
567
0
}
568
569
/* Addition updating the LHS scalar in-place. */
570
static void scalar_add(scalar *lhs, const scalar *rhs)
571
0
{
572
0
    int i;
573
574
0
    for (i = 0; i < DEGREE; i++)
575
0
        lhs->c[i] = reduce_once(lhs->c[i] + rhs->c[i]);
576
0
}
577
578
/* Subtraction updating the LHS scalar in-place. */
579
static void scalar_sub(scalar *lhs, const scalar *rhs)
580
0
{
581
0
    int i;
582
583
0
    for (i = 0; i < DEGREE; i++)
584
0
        lhs->c[i] = reduce_once(lhs->c[i] - rhs->c[i] + kPrime);
585
0
}
586
587
/*
588
 * Multiplying two scalars in the number theoretically transformed state. Since
589
 * 3329 does not have a 512th root of unity, this means we have to interpret
590
 * the 2*ith and (2*i+1)th entries of the scalar as elements of
591
 * GF(3329)[X]/(X^2 - 17^(2*bitreverse(i)+1)).
592
 *
593
 * The value of 17^(2*bitreverse(i)+1) mod 3329 is stored in the precomputed
594
 * ModRoots table. Note that our Barrett transform only allows us to multipy
595
 * two reduced numbers together, so we need some intermediate reduction steps,
596
 * even if an uint64_t could hold 3 multiplied numbers.
597
 */
598
static void scalar_mult(scalar *out, const scalar *lhs,
599
                        const scalar *rhs)
600
0
{
601
0
    uint16_t *curr = out->c, *end = curr + DEGREE;
602
0
    const uint16_t *lc = lhs->c, *rc = rhs->c;
603
0
    const uint16_t *roots = kModRoots;
604
605
0
    do {
606
0
        uint32_t l0 = *lc++, r0 = *rc++;
607
0
        uint32_t l1 = *lc++, r1 = *rc++;
608
0
        uint32_t zetapow = *roots++;
609
610
0
        *curr++ = reduce(l0 * r0 + reduce(l1 * r1) * zetapow);
611
0
        *curr++ = reduce(l0 * r1 + l1 * r0);
612
0
    } while (curr < end);
613
0
}
614
615
/* Above, but add the result to an existing scalar */
616
static ossl_inline
617
void scalar_mult_add(scalar *out, const scalar *lhs,
618
                     const scalar *rhs)
619
0
{
620
0
    uint16_t *curr = out->c, *end = curr + DEGREE;
621
0
    const uint16_t *lc = lhs->c, *rc = rhs->c;
622
0
    const uint16_t *roots = kModRoots;
623
624
0
    do {
625
0
        uint32_t l0 = *lc++, r0 = *rc++;
626
0
        uint32_t l1 = *lc++, r1 = *rc++;
627
0
        uint16_t *c0 = curr++;
628
0
        uint16_t *c1 = curr++;
629
0
        uint32_t zetapow = *roots++;
630
631
0
        *c0 = reduce(*c0 + l0 * r0 + reduce(l1 * r1) * zetapow);
632
0
        *c1 = reduce(*c1 + l0 * r1 + l1 * r0);
633
0
    } while (curr < end);
634
0
}
635
636
/*-
637
 * FIPS 203, Section 4.2.1, Algorithm 5: "ByteEncode_d", for 2<=d<=12.
638
 * Here |bits| is |d|.  For efficiency, we handle the d=1 case separately.
639
 */
640
static void scalar_encode(uint8_t *out, const scalar *s, int bits)
641
0
{
642
0
    const uint16_t *curr = s->c, *end = curr + DEGREE;
643
0
    uint64_t accum = 0, element;
644
0
    int used = 0;
645
646
0
    do {
647
0
        element = *curr++;
648
0
        if (used + bits < 64) {
649
0
            accum |= element << used;
650
0
            used += bits;
651
0
        } else if (used + bits > 64) {
652
0
            out = OPENSSL_store_u64_le(out, accum | (element << used));
653
0
            accum = element >> (64 - used);
654
0
            used = (used + bits) - 64;
655
0
        } else {
656
0
            out = OPENSSL_store_u64_le(out, accum | (element << used));
657
0
            accum = 0;
658
0
            used = 0;
659
0
        }
660
0
    } while (curr < end);
661
0
}
662
663
/*
664
 * scalar_encode_1 is |scalar_encode| specialised for |bits| == 1.
665
 */
666
static void scalar_encode_1(uint8_t out[DEGREE / 8], const scalar *s)
667
0
{
668
0
    int i, j;
669
0
    uint8_t out_byte;
670
671
0
    for (i = 0; i < DEGREE; i += 8) {
672
0
        out_byte = 0;
673
0
        for (j = 0; j < 8; j++)
674
0
            out_byte |= bit0(s->c[i + j]) << j;
675
0
        *out = out_byte;
676
0
        out++;
677
0
    }
678
0
}
679
680
/*-
681
 * FIPS 203, Section 4.2.1, Algorithm 6: "ByteDecode_d", for 2<=d<12.
682
 * Here |bits| is |d|.  For efficiency, we handle the d=1 and d=12 cases
683
 * separately.
684
 *
685
 * scalar_decode parses |DEGREE * bits| bits from |in| into |DEGREE| values in
686
 * |out|.
687
 */
688
static void scalar_decode(scalar *out, const uint8_t *in, int bits)
689
0
{
690
0
    uint16_t *curr = out->c, *end = curr + DEGREE;
691
0
    uint64_t accum = 0;
692
0
    int accum_bits = 0, todo = bits;
693
0
    uint16_t bitmask = (((uint16_t) 1) << bits) - 1, mask = bitmask;
694
0
    uint16_t element = 0;
695
696
0
    do {
697
0
        if (accum_bits == 0) {
698
0
            in = OPENSSL_load_u64_le(&accum, in);
699
0
            accum_bits = 64;
700
0
        }
701
0
        if (todo == bits && accum_bits >= bits) {
702
            /* No partial "element", and all the required bits available */
703
0
            *curr++ = ((uint16_t) accum) & mask;
704
0
            accum >>= bits;
705
0
            accum_bits -= bits;
706
0
        } else if (accum_bits >= todo) {
707
            /* A partial "element", and all the required bits available */
708
0
            *curr++ = element | ((((uint16_t) accum) & mask) << (bits - todo));
709
0
            accum >>= todo;
710
0
            accum_bits -= todo;
711
0
            element = 0;
712
0
            todo = bits;
713
0
            mask = bitmask;
714
0
        } else {
715
            /*
716
             * Only some of the requisite bits accumulated, store |accum_bits|
717
             * of these in |element|.  The accumulated bitcount becomes 0, but
718
             * as soon as we have more bits we'll want to merge accum_bits
719
             * fewer of them into the final |element|.
720
             *
721
             * Note that with a 64-bit accumulator and |bits| always 12 or
722
             * less, if we're here, the previous iteration had all the
723
             * requisite bits, and so there are no kept bits in |element|.
724
             */
725
0
            element = ((uint16_t) accum) & mask;
726
0
            todo -= accum_bits;
727
0
            mask = bitmask >> accum_bits;
728
0
            accum_bits = 0;
729
0
        }
730
0
    } while (curr < end);
731
0
}
732
733
static __owur
734
int scalar_decode_12(scalar *out, const uint8_t in[3 * DEGREE / 2])
735
0
{
736
0
    int i;
737
0
    uint16_t *c = out->c;
738
739
0
    for (i = 0; i < DEGREE / 2; ++i) {
740
0
        uint8_t b1 = *in++;
741
0
        uint8_t b2 = *in++;
742
0
        uint8_t b3 = *in++;
743
0
        int outOfRange1 = (*c++ = b1 | ((b2 & 0x0f) << 8)) >= kPrime;
744
0
        int outOfRange2 = (*c++ = (b2 >> 4) | (b3 << 4)) >= kPrime;
745
746
0
        if (outOfRange1 | outOfRange2)
747
0
            return 0;
748
0
    }
749
0
    return 1;
750
0
}
751
752
/*-
753
 * scalar_decode_decompress_add is a combination of decoding and decompression
754
 * both specialised for |bits| == 1, with the result added (and sum reduced) to
755
 * the output scalar.
756
 *
757
 * NOTE: this function MUST not leak an input-data-depedennt timing signal.
758
 * A timing leak in a related function in the reference Kyber implementation
759
 * made the "clangover" attack (CVE-2024-37880) possible, giving key recovery
760
 * for ML-KEM-512 in minutes, provided the attacker has access to precise
761
 * timing of a CPU performing chosen-ciphertext decap.  Admittedly this is only
762
 * a risk when private keys are reused (perhaps KEMTLS servers).
763
 */
764
static void
765
scalar_decode_decompress_add(scalar *out, const uint8_t in[DEGREE / 8])
766
0
{
767
0
    static const uint16_t half_q_plus_1 = (ML_KEM_PRIME >> 1) + 1;
768
0
    uint16_t *curr = out->c, *end = curr + DEGREE;
769
0
    uint16_t mask;
770
0
    uint8_t b;
771
772
    /*
773
     * Add |half_q_plus_1| if the bit is set, without exposing a side-channel,
774
     * avoiding the "clangover" attack.  See |constish_time_non_zero| for a
775
     * discussion on why the value barrier is by default omitted.
776
     */
777
0
#define decode_decompress_add_bit                               \
778
0
        mask = constish_time_non_zero(bit0(b));                 \
779
0
        *curr = reduce_once(*curr + (mask & half_q_plus_1));    \
780
0
        curr++;                                                 \
781
0
        b >>= 1
782
783
    /* Unrolled to process each byte in one iteration */
784
0
    do {
785
0
        b = *in++;
786
0
        decode_decompress_add_bit;
787
0
        decode_decompress_add_bit;
788
0
        decode_decompress_add_bit;
789
0
        decode_decompress_add_bit;
790
791
0
        decode_decompress_add_bit;
792
0
        decode_decompress_add_bit;
793
0
        decode_decompress_add_bit;
794
0
        decode_decompress_add_bit;
795
0
    } while (curr < end);
796
0
#undef decode_decompress_add_bit
797
0
}
798
799
/*
800
 * FIPS 203, Section 4.2.1, Equation (4.7): Compress_d.
801
 *
802
 * Compresses (lossily) an input |x| mod 3329 into |bits| many bits by grouping
803
 * numbers close to each other together. The formula used is
804
 * round(2^|bits|/kPrime*x) mod 2^|bits|.
805
 * Uses Barrett reduction to achieve constant time. Since we need both the
806
 * remainder (for rounding) and the quotient (as the result), we cannot use
807
 * |reduce| here, but need to do the Barrett reduction directly.
808
 */
809
static __owur uint16_t compress(uint16_t x, int bits)
810
0
{
811
0
    uint32_t shifted = (uint32_t)x << bits;
812
0
    uint64_t product = (uint64_t)shifted * kBarrettMultiplier;
813
0
    uint32_t quotient = (uint32_t)(product >> kBarrettShift);
814
0
    uint32_t remainder = shifted - quotient * kPrime;
815
816
    /*
817
     * Adjust the quotient to round correctly:
818
     *   0 <= remainder <= kHalfPrime round to 0
819
     *   kHalfPrime < remainder <= kPrime + kHalfPrime round to 1
820
     *   kPrime + kHalfPrime < remainder < 2 * kPrime round to 2
821
     */
822
0
    quotient += 1 & constant_time_lt_32(kHalfPrime, remainder);
823
0
    quotient += 1 & constant_time_lt_32(kPrime + kHalfPrime, remainder);
824
0
    return quotient & ((1 << bits) - 1);
825
0
}
826
827
/*
828
 * FIPS 203, Section 4.2.1, Equation (4.8): Decompress_d.
829
830
 * Decompresses |x| by using a close equi-distant representative. The formula
831
 * is round(kPrime/2^|bits|*x). Note that 2^|bits| being the divisor allows us
832
 * to implement this logic using only bit operations.
833
 */
834
static __owur uint16_t decompress(uint16_t x, int bits)
835
0
{
836
0
    uint32_t product = (uint32_t)x * kPrime;
837
0
    uint32_t power = 1 << bits;
838
    /* This is |product| % power, since |power| is a power of 2. */
839
0
    uint32_t remainder = product & (power - 1);
840
    /* This is |product| / power, since |power| is a power of 2. */
841
0
    uint32_t lower = product >> bits;
842
843
    /*
844
     * The rounding logic works since the first half of numbers mod |power|
845
     * have a 0 as first bit, and the second half has a 1 as first bit, since
846
     * |power| is a power of 2. As a 12 bit number, |remainder| is always
847
     * positive, so we will shift in 0s for a right shift.
848
     */
849
0
    return lower + (remainder >> (bits - 1));
850
0
}
851
852
/*-
853
 * FIPS 203, Section 4.2.1, Equation (4.7): "Compress_d".
854
 * In-place lossy rounding of scalars to 2^d bits.
855
 */
856
static void scalar_compress(scalar *s, int bits)
857
0
{
858
0
    int i;
859
860
0
    for (i = 0; i < DEGREE; i++)
861
0
        s->c[i] = compress(s->c[i], bits);
862
0
}
863
864
/*
865
 * FIPS 203, Section 4.2.1, Equation (4.8): "Decompress_d".
866
 * In-place approximate recovery of scalars from 2^d bit compression.
867
 */
868
static void scalar_decompress(scalar *s, int bits)
869
0
{
870
0
    int i;
871
872
0
    for (i = 0; i < DEGREE; i++)
873
0
        s->c[i] = decompress(s->c[i], bits);
874
0
}
875
876
/* Addition updating the LHS vector in-place. */
877
static void vector_add(scalar *lhs, const scalar *rhs, int rank)
878
0
{
879
0
    do {
880
0
        scalar_add(lhs++, rhs++);
881
0
    } while (--rank > 0);
882
0
}
883
884
/*
885
 * Encodes an entire vector into 32*|rank|*|bits| bytes. Note that since 256
886
 * (DEGREE) is divisible by 8, the individual vector entries will always fill a
887
 * whole number of bytes, so we do not need to worry about bit packing here.
888
 */
889
static void vector_encode(uint8_t *out, const scalar *a, int bits, int rank)
890
0
{
891
0
    int stride = bits * DEGREE / 8;
892
893
0
    for (; rank-- > 0; out += stride)
894
0
        scalar_encode(out, a++, bits);
895
0
}
896
897
/*
898
 * Decodes 32*|rank|*|bits| bytes from |in| into |out|. It returns early
899
 * if any parsed value is >= |ML_KEM_PRIME|.  The resulting scalars are
900
 * then decompressed and transformed via the NTT.
901
 *
902
 * Note: Used only in decrypt_cpa(), which returns void and so does not check
903
 * the return value of this function.  Side-channels are fine when the input
904
 * ciphertext to decap() is simply syntactically invalid.
905
 */
906
static void
907
vector_decode_decompress_ntt(scalar *out, const uint8_t *in, int bits, int rank)
908
0
{
909
0
    int stride = bits * DEGREE / 8;
910
911
0
    for (; rank-- > 0; in += stride, ++out) {
912
0
        scalar_decode(out, in, bits);
913
0
        scalar_decompress(out, bits);
914
0
        scalar_ntt(out);
915
0
    }
916
0
}
917
918
/* vector_decode(), specialised to bits == 12. */
919
static __owur
920
int vector_decode_12(scalar *out, const uint8_t in[3 * DEGREE / 2], int rank)
921
0
{
922
0
    int stride = 3 * DEGREE / 2;
923
924
0
    for (; rank-- > 0; in += stride)
925
0
        if (!scalar_decode_12(out++, in))
926
0
            return 0;
927
0
    return 1;
928
0
}
929
930
/* In-place compression of each scalar component */
931
static void vector_compress(scalar *a, int bits, int rank)
932
0
{
933
0
    do {
934
0
        scalar_compress(a++, bits);
935
0
    } while (--rank > 0);
936
0
}
937
938
/* The output scalar must not overlap with the inputs */
939
static void inner_product(scalar *out, const scalar *lhs, const scalar *rhs,
940
                          int rank)
941
0
{
942
0
    scalar_mult(out, lhs, rhs);
943
0
    while (--rank > 0)
944
0
        scalar_mult_add(out, ++lhs, ++rhs);
945
0
}
946
947
/*
948
 * Here, the output vector must not overlap with the inputs, the result is
949
 * directly subjected to inverse NTT.
950
 */
951
static void
952
matrix_mult_intt(scalar *out, const scalar *m, const scalar *a, int rank)
953
0
{
954
0
    const scalar *ar;
955
0
    int i, j;
956
957
0
    for (i = rank; i-- > 0; ++out) {
958
0
        scalar_mult(out, m++, ar = a);
959
0
        for (j = rank - 1; j > 0; --j)
960
0
            scalar_mult_add(out, m++, ++ar);
961
0
        scalar_inverse_ntt(out);
962
0
    }
963
0
}
964
965
/* Here, the output vector must not overlap with the inputs */
966
static void
967
matrix_mult_transpose_add(scalar *out, const scalar *m, const scalar *a, int rank)
968
0
{
969
0
    const scalar *mc = m, *mr, *ar;
970
0
    int i, j;
971
972
0
    for (i = rank; i-- > 0; ++out) {
973
0
        scalar_mult_add(out, mr = mc++, ar = a);
974
0
        for (j = rank; --j > 0; )
975
0
            scalar_mult_add(out, (mr += rank), ++ar);
976
0
    }
977
0
}
978
979
/*-
980
 * Expands the matrix from a seed for key generation and for encaps-CPA.
981
 * NOTE: FIPS 203 matrix "A" is the transpose of this matrix, computed
982
 * by appending the (i,j) indices to the seed in the opposite order!
983
 *
984
 * Where FIPS 203 computes t = A * s + e, we use the transpose of "m".
985
 */
986
static __owur
987
int matrix_expand(EVP_MD_CTX *mdctx, ML_KEM_KEY *key)
988
0
{
989
0
    scalar *out = key->m;
990
0
    uint8_t input[ML_KEM_RANDOM_BYTES + 2];
991
0
    int rank = key->vinfo->rank;
992
0
    int i, j;
993
994
0
    memcpy(input, key->rho, ML_KEM_RANDOM_BYTES);
995
0
    for (i = 0; i < rank; i++) {
996
0
        for (j = 0; j < rank; j++) {
997
0
            input[ML_KEM_RANDOM_BYTES] = i;
998
0
            input[ML_KEM_RANDOM_BYTES + 1] = j;
999
0
            if (!EVP_DigestInit_ex(mdctx, key->shake128_md, NULL)
1000
0
                || !EVP_DigestUpdate(mdctx, input, sizeof(input))
1001
0
                || !sample_scalar(out++, mdctx))
1002
0
                return 0;
1003
0
        }
1004
0
    }
1005
0
    return 1;
1006
0
}
1007
1008
/*
1009
 * Algorithm 7 from the spec, with eta fixed to two and the PRF call
1010
 * included. Creates binominally distributed elements by sampling 2*|eta| bits,
1011
 * and setting the coefficient to the count of the first bits minus the count of
1012
 * the second bits, resulting in a centered binomial distribution. Since eta is
1013
 * two this gives -2/2 with a probability of 1/16, -1/1 with probability 1/4,
1014
 * and 0 with probability 3/8.
1015
 */
1016
static __owur
1017
int cbd_2(scalar *out, uint8_t in[ML_KEM_RANDOM_BYTES + 1],
1018
          EVP_MD_CTX *mdctx, const ML_KEM_KEY *key)
1019
0
{
1020
0
    uint16_t *curr = out->c, *end = curr + DEGREE;
1021
0
    uint8_t randbuf[4 * DEGREE / 8], *r = randbuf;  /* 64 * eta slots */
1022
0
    uint16_t value, mask;
1023
0
    uint8_t b;
1024
1025
0
    if (!prf(randbuf, sizeof(randbuf), in, mdctx, key))
1026
0
        return 0;
1027
1028
0
    do {
1029
0
        b = *r++;
1030
1031
        /*
1032
         * Add |kPrime| if |value| underflowed.  See |constish_time_non_zero|
1033
         * for a discussion on why the value barrier is by default omitted.
1034
         * While this could have been written reduce_once(value + kPrime), this
1035
         * is one extra addition and small range of |value| tempts some
1036
         * versions of Clang to emit a branch.
1037
         */
1038
0
        value = bit0(b) + bitn(1, b);
1039
0
        value -= bitn(2, b) + bitn(3, b);
1040
0
        mask = constish_time_non_zero(value >> 15);
1041
0
        *curr++ = value + (kPrime & mask);
1042
1043
0
        value = bitn(4, b) + bitn(5, b);
1044
0
        value -= bitn(6, b) + bitn(7, b);
1045
0
        mask = constish_time_non_zero(value >> 15);
1046
0
        *curr++ = value + (kPrime & mask);
1047
0
    } while (curr < end);
1048
0
    return 1;
1049
0
}
1050
1051
/*
1052
 * Algorithm 7 from the spec, with eta fixed to three and the PRF call
1053
 * included. Creates binominally distributed elements by sampling 3*|eta| bits,
1054
 * and setting the coefficient to the count of the first bits minus the count of
1055
 * the second bits, resulting in a centered binomial distribution.
1056
 */
1057
static __owur
1058
int cbd_3(scalar *out, uint8_t in[ML_KEM_RANDOM_BYTES + 1],
1059
          EVP_MD_CTX *mdctx, const ML_KEM_KEY *key)
1060
0
{
1061
0
    uint16_t *curr = out->c, *end = curr + DEGREE;
1062
0
    uint8_t randbuf[6 * DEGREE / 8], *r = randbuf;  /* 64 * eta slots */
1063
0
    uint8_t b1, b2, b3;
1064
0
    uint16_t value, mask;
1065
1066
0
    if (!prf(randbuf, sizeof(randbuf), in, mdctx, key))
1067
0
        return 0;
1068
1069
0
    do {
1070
0
        b1 = *r++;
1071
0
        b2 = *r++;
1072
0
        b3 = *r++;
1073
1074
        /*
1075
         * Add |kPrime| if |value| underflowed.  See |constish_time_non_zero|
1076
         * for a discussion on why the value barrier is by default omitted.
1077
         * While this could have been written reduce_once(value + kPrime), this
1078
         * is one extra addition and small range of |value| tempts some
1079
         * versions of Clang to emit a branch.
1080
         */
1081
0
        value = bit0(b1) + bitn(1, b1) + bitn(2, b1);
1082
0
        value -= bitn(3, b1)  + bitn(4, b1) + bitn(5, b1);
1083
0
        mask = constish_time_non_zero(value >> 15);
1084
0
        *curr++ = value + (kPrime & mask);
1085
1086
0
        value = bitn(6, b1) + bitn(7, b1) + bit0(b2);
1087
0
        value -= bitn(1, b2) + bitn(2, b2) + bitn(3, b2);
1088
0
        mask = constish_time_non_zero(value >> 15);
1089
0
        *curr++ = value + (kPrime & mask);
1090
1091
0
        value = bitn(4, b2) + bitn(5, b2) + bitn(6, b2);
1092
0
        value -= bitn(7, b2) + bit0(b3) + bitn(1, b3);
1093
0
        mask = constish_time_non_zero(value >> 15);
1094
0
        *curr++ = value + (kPrime & mask);
1095
1096
0
        value = bitn(2, b3) + bitn(3, b3) + bitn(4, b3);
1097
0
        value -= bitn(5, b3) + bitn(6, b3) + bitn(7, b3);
1098
0
        mask = constish_time_non_zero(value >> 15);
1099
0
        *curr++ = value + (kPrime & mask);
1100
0
    } while (curr < end);
1101
0
    return 1;
1102
0
}
1103
1104
/*
1105
 * Generates a secret vector by using |cbd| with the given seed to generate
1106
 * scalar elements and incrementing |counter| for each slot of the vector.
1107
 */
1108
static __owur
1109
int gencbd_vector(scalar *out, CBD_FUNC cbd, uint8_t *counter,
1110
                  const uint8_t seed[ML_KEM_RANDOM_BYTES], int rank,
1111
                  EVP_MD_CTX *mdctx, const ML_KEM_KEY *key)
1112
0
{
1113
0
    uint8_t input[ML_KEM_RANDOM_BYTES + 1];
1114
1115
0
    memcpy(input, seed, ML_KEM_RANDOM_BYTES);
1116
0
    do {
1117
0
        input[ML_KEM_RANDOM_BYTES] = (*counter)++;
1118
0
        if (!cbd(out++, input, mdctx, key))
1119
0
            return 0;
1120
0
    } while (--rank > 0);
1121
0
    return 1;
1122
0
}
1123
1124
/*
1125
 * As above plus NTT transform.
1126
 */
1127
static __owur
1128
int gencbd_vector_ntt(scalar *out, CBD_FUNC cbd, uint8_t *counter,
1129
                      const uint8_t seed[ML_KEM_RANDOM_BYTES], int rank,
1130
                      EVP_MD_CTX *mdctx, const ML_KEM_KEY *key)
1131
0
{
1132
0
    uint8_t input[ML_KEM_RANDOM_BYTES + 1];
1133
1134
0
    memcpy(input, seed, ML_KEM_RANDOM_BYTES);
1135
0
    do {
1136
0
        input[ML_KEM_RANDOM_BYTES] = (*counter)++;
1137
0
        if (!cbd(out, input, mdctx, key))
1138
0
            return 0;
1139
0
        scalar_ntt(out++);
1140
0
    } while (--rank > 0);
1141
0
    return 1;
1142
0
}
1143
1144
/* The |ETA1| value for ML-KEM-512 is 3, the rest and all ETA2 values are 2. */
1145
0
#define CBD1(evp_type)  ((evp_type) == EVP_PKEY_ML_KEM_512 ? cbd_3 : cbd_2)
1146
1147
/*
1148
 * FIPS 203, Section 5.2, Algorithm 14: K-PKE.Encrypt.
1149
 *
1150
 * Encrypts a message with given randomness to the ciphertext in |out|. Without
1151
 * applying the Fujisaki-Okamoto transform this would not result in a CCA
1152
 * secure scheme, since lattice schemes are vulnerable to decryption failure
1153
 * oracles.
1154
 *
1155
 * The steps are re-ordered to make more efficient/localised use of storage.
1156
 *
1157
 * Note also that the input public key is assumed to hold a precomputed matrix
1158
 * |A| (our key->m, with the public key holding an expanded (16-bit per scalar
1159
 * coefficient) key->t vector).
1160
 *
1161
 * Caller passes storage in |tmp| for for two temporary vectors.
1162
 */
1163
static __owur
1164
int encrypt_cpa(uint8_t out[ML_KEM_SHARED_SECRET_BYTES],
1165
                const uint8_t message[DEGREE / 8],
1166
                const uint8_t r[ML_KEM_RANDOM_BYTES], scalar *tmp,
1167
                EVP_MD_CTX *mdctx, const ML_KEM_KEY *key)
1168
0
{
1169
0
    const ML_KEM_VINFO *vinfo = key->vinfo;
1170
0
    CBD_FUNC cbd_1 = CBD1(vinfo->evp_type);
1171
0
    int rank = vinfo->rank;
1172
    /* We can use tmp[0..rank-1] as storage for |y|, then |e1|, ... */
1173
0
    scalar *y = &tmp[0], *e1 = y, *e2 = y;
1174
    /* We can use tmp[rank]..tmp[2*rank - 1] for |u| */
1175
0
    scalar *u = &tmp[rank];
1176
0
    scalar v;
1177
0
    uint8_t input[ML_KEM_RANDOM_BYTES + 1];
1178
0
    uint8_t counter = 0;
1179
0
    int du = vinfo->du;
1180
0
    int dv = vinfo->dv;
1181
1182
    /* FIPS 203 "y" vector */
1183
0
    if (!gencbd_vector_ntt(y, cbd_1, &counter, r, rank, mdctx, key))
1184
0
        return 0;
1185
    /* FIPS 203 "v" scalar */
1186
0
    inner_product(&v, key->t, y, rank);
1187
0
    scalar_inverse_ntt(&v);
1188
    /* FIPS 203 "u" vector */
1189
0
    matrix_mult_intt(u, key->m, y, rank);
1190
1191
    /* All done with |y|, now free to reuse tmp[0] for FIPS 203 |e1| */
1192
0
    if (!gencbd_vector(e1, cbd_2, &counter, r, rank, mdctx, key))
1193
0
        return 0;
1194
0
    vector_add(u, e1, rank);
1195
0
    vector_compress(u, du, rank);
1196
0
    vector_encode(out, u, du, rank);
1197
1198
    /* All done with |e1|, now free to reuse tmp[0] for FIPS 203 |e2| */
1199
0
    memcpy(input, r, ML_KEM_RANDOM_BYTES);
1200
0
    input[ML_KEM_RANDOM_BYTES] = counter;
1201
0
    if (!cbd_2(e2, input, mdctx, key))
1202
0
        return 0;
1203
0
    scalar_add(&v, e2);
1204
1205
    /* Combine message with |v| */
1206
0
    scalar_decode_decompress_add(&v, message);
1207
0
    scalar_compress(&v, dv);
1208
0
    scalar_encode(out + vinfo->u_vector_bytes, &v, dv);
1209
0
    return 1;
1210
0
}
1211
1212
/*
1213
 * FIPS 203, Section 5.3, Algorithm 15: K-PKE.Decrypt.
1214
 */
1215
static void
1216
decrypt_cpa(uint8_t out[ML_KEM_SHARED_SECRET_BYTES],
1217
            const uint8_t *ctext, scalar *u, const ML_KEM_KEY *key)
1218
0
{
1219
0
    const ML_KEM_VINFO *vinfo = key->vinfo;
1220
0
    scalar v, mask;
1221
0
    int rank = vinfo->rank;
1222
0
    int du = vinfo->du;
1223
0
    int dv = vinfo->dv;
1224
1225
0
    vector_decode_decompress_ntt(u, ctext, du, rank);
1226
0
    scalar_decode(&v, ctext + vinfo->u_vector_bytes, dv);
1227
0
    scalar_decompress(&v, dv);
1228
0
    inner_product(&mask, key->s, u, rank);
1229
0
    scalar_inverse_ntt(&mask);
1230
0
    scalar_sub(&v, &mask);
1231
0
    scalar_compress(&v, 1);
1232
0
    scalar_encode_1(out, &v);
1233
0
}
1234
1235
/*-
1236
 * FIPS 203, Section 7.1, Algorithm 19: "ML-KEM.KeyGen".
1237
 * FIPS 203, Section 7.2, Algorithm 20: "ML-KEM.Encaps".
1238
 *
1239
 * Fills the |out| buffer with the |ek| output of "ML-KEM.KeyGen", or,
1240
 * equivalently, the |ek| input of "ML-KEM.Encaps", i.e. returns the
1241
 * wire-format of an ML-KEM public key.
1242
 */
1243
static void encode_pubkey(uint8_t *out, const ML_KEM_KEY *key)
1244
0
{
1245
0
    const uint8_t *rho = key->rho;
1246
0
    const ML_KEM_VINFO *vinfo = key->vinfo;
1247
1248
0
    vector_encode(out, key->t, 12, vinfo->rank);
1249
0
    memcpy(out + vinfo->vector_bytes, rho, ML_KEM_RANDOM_BYTES);
1250
0
}
1251
1252
/*-
1253
 * FIPS 203, Section 7.1, Algorithm 19: "ML-KEM.KeyGen".
1254
 *
1255
 * Fills the |out| buffer with the |dk| output of "ML-KEM.KeyGen".
1256
 * This matches the input format of parse_prvkey() below.
1257
 */
1258
static void encode_prvkey(uint8_t *out, const ML_KEM_KEY *key)
1259
0
{
1260
0
    const ML_KEM_VINFO *vinfo = key->vinfo;
1261
1262
0
    vector_encode(out, key->s, 12, vinfo->rank);
1263
0
    out += vinfo->vector_bytes;
1264
0
    encode_pubkey(out, key);
1265
0
    out += vinfo->pubkey_bytes;
1266
0
    memcpy(out, key->pkhash, ML_KEM_PKHASH_BYTES);
1267
0
    out += ML_KEM_PKHASH_BYTES;
1268
0
    memcpy(out, key->z, ML_KEM_RANDOM_BYTES);
1269
0
}
1270
1271
/*-
1272
 * FIPS 203, Section 7.1, Algorithm 19: "ML-KEM.KeyGen".
1273
 * FIPS 203, Section 7.2, Algorithm 20: "ML-KEM.Encaps".
1274
 *
1275
 * This function parses the |in| buffer as the |ek| output of "ML-KEM.KeyGen",
1276
 * or, equivalently, the |ek| input of "ML-KEM.Encaps", i.e. decodes the
1277
 * wire-format of the ML-KEM public key.
1278
 */
1279
static int parse_pubkey(const uint8_t *in, EVP_MD_CTX *mdctx, ML_KEM_KEY *key)
1280
0
{
1281
0
    const ML_KEM_VINFO *vinfo = key->vinfo;
1282
1283
    /* Decode and check |t| */
1284
0
    if (!vector_decode_12(key->t, in, vinfo->rank))
1285
0
        return 0;
1286
    /* Save the matrix |m| recovery seed |rho| */
1287
0
    memcpy(key->rho, in + vinfo->vector_bytes, ML_KEM_RANDOM_BYTES);
1288
    /*
1289
     * Pre-compute the public key hash, needed for both encap and decap.
1290
     * Also pre-compute the matrix expansion, stored with the public key.
1291
     */
1292
0
    return hash_h(key->pkhash, in, vinfo->pubkey_bytes, mdctx, key)
1293
0
        && matrix_expand(mdctx, key);
1294
0
}
1295
1296
/*
1297
 * FIPS 203, Section 7.1, Algorithm 19: "ML-KEM.KeyGen".
1298
 *
1299
 * Parses the |in| buffer as a |dk| output of "ML-KEM.KeyGen".
1300
 * This matches the output format of encode_prvkey() above.
1301
 */
1302
static int parse_prvkey(const uint8_t *in, EVP_MD_CTX *mdctx, ML_KEM_KEY *key)
1303
0
{
1304
0
    const ML_KEM_VINFO *vinfo = key->vinfo;
1305
1306
    /* Decode and check |s|. */
1307
0
    if (!vector_decode_12(key->s, in, vinfo->rank))
1308
0
        return 0;
1309
0
    in += vinfo->vector_bytes;
1310
1311
0
    if (!parse_pubkey(in, mdctx, key))
1312
0
        return 0;
1313
0
    in += vinfo->pubkey_bytes;
1314
1315
    /* Check public key hash. */
1316
0
    if (memcmp(key->pkhash, in, ML_KEM_PKHASH_BYTES) != 0)
1317
0
        return 0;
1318
0
    in += ML_KEM_PKHASH_BYTES;
1319
1320
0
    memcpy(key->z, in, ML_KEM_RANDOM_BYTES);
1321
0
    return 1;
1322
0
}
1323
1324
/*
1325
 * FIPS 203, Section 6.1, Algorithm 16: "ML-KEM.KeyGen_internal".
1326
 *
1327
 * The implementation of Section 5.1, Algorithm 13, "K-PKE.KeyGen(d)" is
1328
 * inlined.
1329
 *
1330
 * The caller MUST pass a pre-allocated digest context that is not shared with
1331
 * any concurrent computation.
1332
 *
1333
 * This function optionally outputs the serialised wire-form |ek| public key
1334
 * into the provided |pubenc| buffer, and generates the content of the |rho|,
1335
 * |pkhash|, |t|, |m|, |s| and |z| components of the private |key| (which must
1336
 * have preallocated space for these).
1337
 *
1338
 * Keys are computed from a 32-byte random |d| plus the 1 byte rank for
1339
 * domain separation.  These are concatenated and hashed to produce a pair of
1340
 * 32-byte seeds public "rho", used to generate the matrix, and private "sigma",
1341
 * used to generate the secret vector |s|.
1342
 *
1343
 * The second random input |z| is copied verbatim into the Fujisaki-Okamoto
1344
 * (FO) transform "implicit-rejection" secret (the |z| component of the private
1345
 * key), which thwarts chosen-ciphertext attacks, provided decap() runs in
1346
 * constant time, with no side channel leaks, on all well-formed (valid length,
1347
 * and correctly encoded) ciphertext inputs.
1348
 */
1349
static __owur
1350
int genkey(const uint8_t seed[ML_KEM_SEED_BYTES],
1351
           EVP_MD_CTX *mdctx, uint8_t *pubenc, ML_KEM_KEY *key)
1352
0
{
1353
0
    uint8_t hashed[2 * ML_KEM_RANDOM_BYTES];
1354
0
    const uint8_t *const sigma = hashed + ML_KEM_RANDOM_BYTES;
1355
0
    uint8_t augmented_seed[ML_KEM_RANDOM_BYTES + 1];
1356
0
    const ML_KEM_VINFO *vinfo = key->vinfo;
1357
0
    CBD_FUNC cbd_1 = CBD1(vinfo->evp_type);
1358
0
    int rank = vinfo->rank;
1359
0
    uint8_t counter = 0;
1360
0
    int ret = 0;
1361
1362
    /*
1363
     * Use the "d" seed salted with the rank to derive the public and private
1364
     * seeds rho and sigma.
1365
     */
1366
0
    memcpy(augmented_seed, seed, ML_KEM_RANDOM_BYTES);
1367
0
    augmented_seed[ML_KEM_RANDOM_BYTES] = (uint8_t) rank;
1368
0
    if (!hash_g(hashed, augmented_seed, sizeof(augmented_seed), mdctx, key))
1369
0
        goto end;
1370
0
    memcpy(key->rho, hashed, ML_KEM_RANDOM_BYTES);
1371
    /* The |rho| matrix seed is public */
1372
0
    CONSTTIME_DECLASSIFY(key->rho, ML_KEM_RANDOM_BYTES);
1373
1374
    /* FIPS 203 |e| vector is initial value of key->t */
1375
0
    if (!matrix_expand(mdctx, key)
1376
0
        || !gencbd_vector_ntt(key->s, cbd_1, &counter, sigma, rank, mdctx, key)
1377
0
        || !gencbd_vector_ntt(key->t, cbd_1, &counter, sigma, rank, mdctx, key))
1378
0
        goto end;
1379
1380
    /* To |e| we now add the product of transpose |m| and |s|, giving |t|. */
1381
0
    matrix_mult_transpose_add(key->t, key->m, key->s, rank);
1382
    /* The |t| vector is public */
1383
0
    CONSTTIME_DECLASSIFY(key->t, vinfo->rank * sizeof(scalar));
1384
1385
0
    if (pubenc == NULL) {
1386
        /* Incremental digest of public key without in-full serialisation. */
1387
0
        if (!hash_h_pubkey(key->pkhash, mdctx, key))
1388
0
            goto end;
1389
0
    } else {
1390
0
        encode_pubkey(pubenc, key);
1391
0
        if (!hash_h(key->pkhash, pubenc, vinfo->pubkey_bytes, mdctx, key))
1392
0
            goto end;
1393
0
    }
1394
1395
    /* Save |z| portion of seed for "implicit rejection" on failure. */
1396
0
    memcpy(key->z, seed + ML_KEM_RANDOM_BYTES, ML_KEM_RANDOM_BYTES);
1397
1398
    /* Optionally save the |d| portion of the seed */
1399
0
    key->d = key->z + ML_KEM_RANDOM_BYTES;
1400
0
    if (key->prov_flags & ML_KEM_KEY_RETAIN_SEED) {
1401
0
        memcpy(key->d, seed, ML_KEM_RANDOM_BYTES);
1402
0
    } else {
1403
0
        OPENSSL_cleanse(key->d, ML_KEM_RANDOM_BYTES);
1404
0
        key->d = NULL;
1405
0
    }
1406
1407
0
    ret = 1;
1408
0
 end:
1409
0
    OPENSSL_cleanse((void *)augmented_seed, ML_KEM_RANDOM_BYTES);
1410
0
    OPENSSL_cleanse((void *)sigma, ML_KEM_RANDOM_BYTES);
1411
0
    return ret;
1412
0
}
1413
1414
/*-
1415
 * FIPS 203, Section 6.2, Algorithm 17: "ML-KEM.Encaps_internal".
1416
 * This is the deterministic version with randomness supplied externally.
1417
 *
1418
 * The caller must pass space for two vectors in |tmp|.
1419
 * The |ctext| buffer have space for the ciphertext of the ML-KEM variant
1420
 * of the provided key.
1421
 */
1422
static
1423
int encap(uint8_t *ctext, uint8_t secret[ML_KEM_SHARED_SECRET_BYTES],
1424
          const uint8_t entropy[ML_KEM_RANDOM_BYTES],
1425
          scalar *tmp, EVP_MD_CTX *mdctx, const ML_KEM_KEY *key)
1426
0
{
1427
0
    uint8_t input[ML_KEM_RANDOM_BYTES + ML_KEM_PKHASH_BYTES];
1428
0
    uint8_t Kr[ML_KEM_SHARED_SECRET_BYTES + ML_KEM_RANDOM_BYTES];
1429
0
    uint8_t *r = Kr + ML_KEM_SHARED_SECRET_BYTES;
1430
0
    int ret;
1431
1432
0
    memcpy(input, entropy, ML_KEM_RANDOM_BYTES);
1433
0
    memcpy(input + ML_KEM_RANDOM_BYTES, key->pkhash, ML_KEM_PKHASH_BYTES);
1434
0
    ret = hash_g(Kr, input, sizeof(input), mdctx, key)
1435
0
        && encrypt_cpa(ctext, entropy, r, tmp, mdctx, key);
1436
1437
0
    if (ret)
1438
0
        memcpy(secret, Kr, ML_KEM_SHARED_SECRET_BYTES);
1439
0
    OPENSSL_cleanse((void *)input, sizeof(input));
1440
0
    return ret;
1441
0
}
1442
1443
/*
1444
 * FIPS 203, Section 6.3, Algorithm 18: ML-KEM.Decaps_internal
1445
 *
1446
 * Barring failure of the supporting SHA3/SHAKE primitives, this is fully
1447
 * deterministic, the randomness for the FO transform is extracted during
1448
 * private key generation.
1449
 *
1450
 * The caller must pass space for two vectors in |tmp|.
1451
 * The |ctext| and |tmp_ctext| buffers must each have space for the ciphertext
1452
 * of the key's ML-KEM variant.
1453
 */
1454
static
1455
int decap(uint8_t secret[ML_KEM_SHARED_SECRET_BYTES],
1456
          const uint8_t *ctext, uint8_t *tmp_ctext, scalar *tmp,
1457
          EVP_MD_CTX *mdctx, const ML_KEM_KEY *key)
1458
0
{
1459
0
    uint8_t decrypted[ML_KEM_SHARED_SECRET_BYTES + ML_KEM_PKHASH_BYTES];
1460
0
    uint8_t failure_key[ML_KEM_RANDOM_BYTES];
1461
0
    uint8_t Kr[ML_KEM_SHARED_SECRET_BYTES + ML_KEM_RANDOM_BYTES];
1462
0
    uint8_t *r = Kr + ML_KEM_SHARED_SECRET_BYTES;
1463
0
    const uint8_t *pkhash = key->pkhash;
1464
0
    const ML_KEM_VINFO *vinfo = key->vinfo;
1465
0
    int i;
1466
0
    uint8_t mask;
1467
1468
    /*
1469
     * If our KDF is unavailable, fail early! Otherwise, keep going ignoring
1470
     * any further errors, returning success, and whatever we got for a shared
1471
     * secret.  The decrypt_cpa() function is just arithmetic on secret data,
1472
     * so should not be subject to failure that makes its output predictable.
1473
     *
1474
     * We guard against "should never happen" catastrophic failure of the
1475
     * "pure" function |hash_g| by overwriting the shared secret with the
1476
     * content of the failure key and returning early, if nevertheless hash_g
1477
     * fails.  This is not constant-time, but a failure of |hash_g| already
1478
     * implies loss of side-channel resistance.
1479
     *
1480
     * The same action is taken, if also |encrypt_cpa| should catastrophically
1481
     * fail, due to failure of the |PRF| underlying the CBD functions.
1482
     */
1483
0
    if (!kdf(failure_key, key->z, ctext, vinfo->ctext_bytes, mdctx, key))
1484
0
        return 0;
1485
0
    decrypt_cpa(decrypted, ctext, tmp, key);
1486
0
    memcpy(decrypted + ML_KEM_SHARED_SECRET_BYTES, pkhash, ML_KEM_PKHASH_BYTES);
1487
0
    if (!hash_g(Kr, decrypted, sizeof(decrypted), mdctx, key)
1488
0
        || !encrypt_cpa(tmp_ctext, decrypted, r, tmp, mdctx, key)) {
1489
0
        memcpy(secret, failure_key, ML_KEM_SHARED_SECRET_BYTES);
1490
0
        OPENSSL_cleanse(decrypted, ML_KEM_SHARED_SECRET_BYTES);
1491
0
        return 1;
1492
0
    }
1493
0
    mask = constant_time_eq_int_8(0,
1494
0
        CRYPTO_memcmp(ctext, tmp_ctext, vinfo->ctext_bytes));
1495
0
    for (i = 0; i < ML_KEM_SHARED_SECRET_BYTES; i++)
1496
0
        secret[i] = constant_time_select_8(mask, Kr[i], failure_key[i]);
1497
0
    OPENSSL_cleanse(decrypted, ML_KEM_SHARED_SECRET_BYTES);
1498
0
    OPENSSL_cleanse(Kr, sizeof(Kr));
1499
0
    return 1;
1500
0
}
1501
1502
/*
1503
 * After allocating storage for public or private key data, update the key
1504
 * component pointers to reference that storage.
1505
 */
1506
static __owur
1507
int add_storage(scalar *p, int private, ML_KEM_KEY *key)
1508
0
{
1509
0
    int rank = key->vinfo->rank;
1510
1511
0
    if (p == NULL)
1512
0
        return 0;
1513
1514
    /*
1515
     * We're adding key material, the seed buffer will now hold |rho| and
1516
     * |pkhash|.
1517
     */
1518
0
    memset(key->seedbuf, 0, sizeof(key->seedbuf));
1519
0
    key->rho = key->seedbuf;
1520
0
    key->pkhash = key->seedbuf + ML_KEM_RANDOM_BYTES;
1521
0
    key->d = key->z = NULL;
1522
1523
    /* A public key needs space for |t| and |m| */
1524
0
    key->m = (key->t = p) + rank;
1525
1526
    /*
1527
     * A private key also needs space for |s| and |z|.
1528
     * The |z| buffer always includes additional space for |d|, but a key's |d|
1529
     * pointer is left NULL when parsed from the NIST format, which omits that
1530
     * information.  Only keys generated from a (d, z) seed pair will have a
1531
     * non-NULL |d| pointer.
1532
     */
1533
0
    if (private)
1534
0
        key->z = (uint8_t *)(rank + (key->s = key->m + rank * rank));
1535
0
    return 1;
1536
0
}
1537
1538
/*
1539
 * After freeing the storage associated with a key that failed to be
1540
 * constructed, reset the internal pointers back to NULL.
1541
 */
1542
void
1543
ossl_ml_kem_key_reset(ML_KEM_KEY *key)
1544
0
{
1545
0
    if (key->t == NULL)
1546
0
        return;
1547
    /*-
1548
     * Cleanse any sensitive data:
1549
     * - The private vector |s| is immediately followed by the FO failure
1550
     *   secret |z|, and seed |d|, we can cleanse all three in one call.
1551
     *
1552
     * - Otherwise, when key->d is set, cleanse the stashed seed.
1553
     */
1554
0
    if (ossl_ml_kem_have_prvkey(key))
1555
0
        OPENSSL_cleanse(key->s,
1556
0
                        key->vinfo->rank * sizeof(scalar) + 2 * ML_KEM_RANDOM_BYTES);
1557
0
    OPENSSL_free(key->t);
1558
0
    key->d = key->z = (uint8_t *)(key->s = key->m = key->t = NULL);
1559
0
}
1560
1561
/*
1562
 * ----- API exported to the provider
1563
 *
1564
 * Parameters with an implicit fixed length in the internal static API of each
1565
 * variant have an explicit checked length argument at this layer.
1566
 */
1567
1568
/* Retrieve the parameters of one of the ML-KEM variants */
1569
const ML_KEM_VINFO *ossl_ml_kem_get_vinfo(int evp_type)
1570
0
{
1571
0
    switch (evp_type) {
1572
0
    case EVP_PKEY_ML_KEM_512:
1573
0
        return &vinfo_map[ML_KEM_512_VINFO];
1574
0
    case EVP_PKEY_ML_KEM_768:
1575
0
        return &vinfo_map[ML_KEM_768_VINFO];
1576
0
    case EVP_PKEY_ML_KEM_1024:
1577
0
        return &vinfo_map[ML_KEM_1024_VINFO];
1578
0
    }
1579
0
    return NULL;
1580
0
}
1581
1582
ML_KEM_KEY *ossl_ml_kem_key_new(OSSL_LIB_CTX *libctx, const char *properties,
1583
                                int evp_type)
1584
0
{
1585
0
    const ML_KEM_VINFO *vinfo = ossl_ml_kem_get_vinfo(evp_type);
1586
0
    ML_KEM_KEY *key;
1587
1588
0
    if (vinfo == NULL)
1589
0
        return NULL;
1590
1591
0
    if ((key = OPENSSL_malloc(sizeof(*key))) == NULL)
1592
0
        return NULL;
1593
1594
0
    key->vinfo = vinfo;
1595
0
    key->libctx = libctx;
1596
0
    key->prov_flags = ML_KEM_KEY_PROV_FLAGS_DEFAULT;
1597
0
    key->shake128_md = EVP_MD_fetch(libctx, "SHAKE128", properties);
1598
0
    key->shake256_md = EVP_MD_fetch(libctx, "SHAKE256", properties);
1599
0
    key->sha3_256_md = EVP_MD_fetch(libctx, "SHA3-256", properties);
1600
0
    key->sha3_512_md = EVP_MD_fetch(libctx, "SHA3-512", properties);
1601
0
    key->d = key->z = key->rho = key->pkhash = key->encoded_dk = NULL;
1602
0
    key->s = key->m = key->t = NULL;
1603
1604
0
    if (key->shake128_md != NULL
1605
0
        && key->shake256_md != NULL
1606
0
        && key->sha3_256_md != NULL
1607
0
        && key->sha3_512_md != NULL)
1608
0
    return key;
1609
1610
0
    ossl_ml_kem_key_free(key);
1611
0
    return NULL;
1612
0
}
1613
1614
ML_KEM_KEY *ossl_ml_kem_key_dup(const ML_KEM_KEY *key, int selection)
1615
0
{
1616
0
    int ok = 0;
1617
0
    ML_KEM_KEY *ret;
1618
1619
    /*
1620
     * Partially decoded keys, not yet imported or loaded, should never be
1621
     * duplicated.
1622
     */
1623
0
    if (ossl_ml_kem_decoded_key(key))
1624
0
        return NULL;
1625
1626
0
    if (key == NULL
1627
0
        || (ret = OPENSSL_memdup(key, sizeof(*key))) == NULL)
1628
0
        return NULL;
1629
0
    ret->d = ret->z = ret->rho = ret->pkhash = NULL;
1630
0
    ret->s = ret->m = ret->t = NULL;
1631
1632
    /* Clear selection bits we can't fulfill */
1633
0
    if (!ossl_ml_kem_have_pubkey(key))
1634
0
        selection = 0;
1635
0
    else if (!ossl_ml_kem_have_prvkey(key))
1636
0
        selection &= ~OSSL_KEYMGMT_SELECT_PRIVATE_KEY;
1637
1638
0
    switch (selection & OSSL_KEYMGMT_SELECT_KEYPAIR) {
1639
0
    case 0:
1640
0
        ok = 1;
1641
0
        break;
1642
0
    case OSSL_KEYMGMT_SELECT_PUBLIC_KEY:
1643
0
        ok = add_storage(OPENSSL_memdup(key->t, key->vinfo->puballoc), 0, ret);
1644
0
        ret->rho = ret->seedbuf;
1645
0
        ret->pkhash = ret->rho + ML_KEM_RANDOM_BYTES;
1646
0
        break;
1647
0
    case OSSL_KEYMGMT_SELECT_PRIVATE_KEY:
1648
0
        ok = add_storage(OPENSSL_memdup(key->t, key->vinfo->prvalloc), 1, ret);
1649
        /* Duplicated keys retain |d|, if available */
1650
0
        if (key->d != NULL)
1651
0
            ret->d = ret->z + ML_KEM_RANDOM_BYTES;
1652
0
        break;
1653
0
    }
1654
1655
0
    if (!ok) {
1656
0
        OPENSSL_free(ret);
1657
0
        return NULL;
1658
0
    }
1659
1660
0
    EVP_MD_up_ref(ret->shake128_md);
1661
0
    EVP_MD_up_ref(ret->shake256_md);
1662
0
    EVP_MD_up_ref(ret->sha3_256_md);
1663
0
    EVP_MD_up_ref(ret->sha3_512_md);
1664
1665
0
    return ret;
1666
0
}
1667
1668
void ossl_ml_kem_key_free(ML_KEM_KEY *key)
1669
0
{
1670
0
    if (key == NULL)
1671
0
        return;
1672
1673
0
    EVP_MD_free(key->shake128_md);
1674
0
    EVP_MD_free(key->shake256_md);
1675
0
    EVP_MD_free(key->sha3_256_md);
1676
0
    EVP_MD_free(key->sha3_512_md);
1677
1678
0
    if (ossl_ml_kem_decoded_key(key)) {
1679
0
        OPENSSL_cleanse(key->seedbuf, sizeof(key->seedbuf));
1680
0
        if (ossl_ml_kem_have_dkenc(key)) {
1681
0
            OPENSSL_cleanse(key->encoded_dk, key->vinfo->prvkey_bytes);
1682
0
            OPENSSL_free(key->encoded_dk);
1683
0
        }
1684
0
    }
1685
0
    ossl_ml_kem_key_reset(key);
1686
0
    OPENSSL_free(key);
1687
0
}
1688
1689
/* Serialise the public component of an ML-KEM key */
1690
int ossl_ml_kem_encode_public_key(uint8_t *out, size_t len,
1691
                                  const ML_KEM_KEY *key)
1692
0
{
1693
0
    if (!ossl_ml_kem_have_pubkey(key)
1694
0
        || len != key->vinfo->pubkey_bytes)
1695
0
        return 0;
1696
0
    encode_pubkey(out, key);
1697
0
    return 1;
1698
0
}
1699
1700
/* Serialise an ML-KEM private key */
1701
int ossl_ml_kem_encode_private_key(uint8_t *out, size_t len,
1702
                                   const ML_KEM_KEY *key)
1703
0
{
1704
0
    if (!ossl_ml_kem_have_prvkey(key)
1705
0
        || len != key->vinfo->prvkey_bytes)
1706
0
        return 0;
1707
0
    encode_prvkey(out, key);
1708
0
    return 1;
1709
0
}
1710
1711
int ossl_ml_kem_encode_seed(uint8_t *out, size_t len,
1712
                            const ML_KEM_KEY *key)
1713
0
{
1714
0
    if (key == NULL || key->d == NULL || len != ML_KEM_SEED_BYTES)
1715
0
        return 0;
1716
    /*
1717
     * Both in the seed buffer, and in the allocated storage, the |d| component
1718
     * of the seed is stored last, so we must copy each separately.
1719
     */
1720
0
    memcpy(out, key->d, ML_KEM_RANDOM_BYTES);
1721
0
    out += ML_KEM_RANDOM_BYTES;
1722
0
    memcpy(out, key->z, ML_KEM_RANDOM_BYTES);
1723
0
    return 1;
1724
0
}
1725
1726
/*
1727
 * Stash the seed without (yet) performing a keygen, used during decoding, to
1728
 * avoid an extra keygen if we're only going to export the key again to load
1729
 * into another provider.
1730
 */
1731
ML_KEM_KEY *ossl_ml_kem_set_seed(const uint8_t *seed, size_t seedlen, ML_KEM_KEY *key)
1732
0
{
1733
0
    if (key == NULL
1734
0
        || ossl_ml_kem_have_pubkey(key)
1735
0
        || ossl_ml_kem_have_seed(key)
1736
0
        || seedlen != ML_KEM_SEED_BYTES)
1737
0
        return NULL;
1738
    /*
1739
     * With no public or private key material on hand, we can use the seed
1740
     * buffer for |z| and |d|, in that order.
1741
     */
1742
0
    key->z = key->seedbuf;
1743
0
    key->d = key->z + ML_KEM_RANDOM_BYTES;
1744
0
    memcpy(key->d, seed, ML_KEM_RANDOM_BYTES);
1745
0
    seed += ML_KEM_RANDOM_BYTES;
1746
0
    memcpy(key->z, seed, ML_KEM_RANDOM_BYTES);
1747
0
    return key;
1748
0
}
1749
1750
/* Parse input as a public key */
1751
int ossl_ml_kem_parse_public_key(const uint8_t *in, size_t len, ML_KEM_KEY *key)
1752
0
{
1753
0
    EVP_MD_CTX *mdctx = NULL;
1754
0
    const ML_KEM_VINFO *vinfo;
1755
0
    int ret = 0;
1756
1757
    /* Keys with key material are immutable */
1758
0
    if (key == NULL
1759
0
        || ossl_ml_kem_have_pubkey(key)
1760
0
        || ossl_ml_kem_have_dkenc(key))
1761
0
        return 0;
1762
0
    vinfo = key->vinfo;
1763
1764
0
    if (len != vinfo->pubkey_bytes
1765
0
        || (mdctx = EVP_MD_CTX_new()) == NULL)
1766
0
        return 0;
1767
1768
0
    if (add_storage(OPENSSL_malloc(vinfo->puballoc), 0, key))
1769
0
        ret = parse_pubkey(in, mdctx, key);
1770
1771
0
    if (!ret)
1772
0
        ossl_ml_kem_key_reset(key);
1773
0
    EVP_MD_CTX_free(mdctx);
1774
0
    return ret;
1775
0
}
1776
1777
/* Parse input as a new private key */
1778
int ossl_ml_kem_parse_private_key(const uint8_t *in, size_t len,
1779
                                  ML_KEM_KEY *key)
1780
0
{
1781
0
    EVP_MD_CTX *mdctx = NULL;
1782
0
    const ML_KEM_VINFO *vinfo;
1783
0
    int ret = 0;
1784
1785
    /* Keys with key material are immutable */
1786
0
    if (key == NULL
1787
0
        || ossl_ml_kem_have_pubkey(key)
1788
0
        || ossl_ml_kem_have_dkenc(key))
1789
0
        return 0;
1790
0
    vinfo = key->vinfo;
1791
1792
0
    if (len != vinfo->prvkey_bytes
1793
0
        || (mdctx = EVP_MD_CTX_new()) == NULL)
1794
0
        return 0;
1795
1796
0
    if (add_storage(OPENSSL_malloc(vinfo->prvalloc), 1, key))
1797
0
        ret = parse_prvkey(in, mdctx, key);
1798
1799
0
    if (!ret)
1800
0
        ossl_ml_kem_key_reset(key);
1801
0
    EVP_MD_CTX_free(mdctx);
1802
0
    return ret;
1803
0
}
1804
1805
/*
1806
 * Generate a new keypair, either from the saved seed (when non-null), or from
1807
 * the RNG.
1808
 */
1809
int ossl_ml_kem_genkey(uint8_t *pubenc, size_t publen, ML_KEM_KEY *key)
1810
0
{
1811
0
    uint8_t seed[ML_KEM_SEED_BYTES];
1812
0
    EVP_MD_CTX *mdctx = NULL;
1813
0
    const ML_KEM_VINFO *vinfo;
1814
0
    int ret = 0;
1815
1816
0
    if (key == NULL
1817
0
        || ossl_ml_kem_have_pubkey(key)
1818
0
        || ossl_ml_kem_have_dkenc(key))
1819
0
        return 0;
1820
0
    vinfo = key->vinfo;
1821
1822
0
    if (pubenc != NULL && publen != vinfo->pubkey_bytes)
1823
0
        return 0;
1824
1825
0
    if (ossl_ml_kem_have_seed(key)) {
1826
0
        if (!ossl_ml_kem_encode_seed(seed, sizeof(seed), key))
1827
0
            return 0;
1828
0
        key->d = key->z = NULL;
1829
0
    } else if (RAND_priv_bytes_ex(key->libctx, seed, sizeof(seed),
1830
0
                                  key->vinfo->secbits) <= 0) {
1831
0
        return 0;
1832
0
    }
1833
1834
0
    if ((mdctx = EVP_MD_CTX_new()) == NULL)
1835
0
        return 0;
1836
1837
    /*
1838
     * Data derived from (d, z) defaults secret, and to avoid side-channel
1839
     * leaks should not influence control flow.
1840
     */
1841
0
    CONSTTIME_SECRET(seed, ML_KEM_SEED_BYTES);
1842
1843
0
    if (add_storage(OPENSSL_malloc(vinfo->prvalloc), 1, key))
1844
0
        ret = genkey(seed, mdctx, pubenc, key);
1845
0
    OPENSSL_cleanse(seed, sizeof(seed));
1846
1847
    /* Declassify secret inputs and derived outputs before returning control */
1848
0
    CONSTTIME_DECLASSIFY(seed, ML_KEM_SEED_BYTES);
1849
1850
0
    EVP_MD_CTX_free(mdctx);
1851
0
    if (!ret) {
1852
0
        ossl_ml_kem_key_reset(key);
1853
0
        return 0;
1854
0
    }
1855
1856
    /* The public components are already declassified */
1857
0
    CONSTTIME_DECLASSIFY(key->s, vinfo->rank * sizeof(scalar));
1858
0
    CONSTTIME_DECLASSIFY(key->z, 2 * ML_KEM_RANDOM_BYTES);
1859
0
    return 1;
1860
0
}
1861
1862
/*
1863
 * FIPS 203, Section 6.2, Algorithm 17: ML-KEM.Encaps_internal
1864
 * This is the deterministic version with randomness supplied externally.
1865
 */
1866
int ossl_ml_kem_encap_seed(uint8_t *ctext, size_t clen,
1867
                           uint8_t *shared_secret, size_t slen,
1868
                           const uint8_t *entropy, size_t elen,
1869
                           const ML_KEM_KEY *key)
1870
0
{
1871
0
    const ML_KEM_VINFO *vinfo;
1872
0
    EVP_MD_CTX *mdctx;
1873
0
    int ret = 0;
1874
1875
0
    if (key == NULL || !ossl_ml_kem_have_pubkey(key))
1876
0
        return 0;
1877
0
    vinfo = key->vinfo;
1878
1879
0
    if (ctext == NULL || clen != vinfo->ctext_bytes
1880
0
        || shared_secret == NULL || slen != ML_KEM_SHARED_SECRET_BYTES
1881
0
        || entropy == NULL || elen != ML_KEM_RANDOM_BYTES
1882
0
        || (mdctx = EVP_MD_CTX_new()) == NULL)
1883
0
        return 0;
1884
    /*
1885
     * Data derived from the encap entropy defaults secret, and to avoid
1886
     * side-channel leaks should not influence control flow.
1887
     */
1888
0
    CONSTTIME_SECRET(entropy, elen);
1889
1890
    /*-
1891
     * This avoids the need to handle allocation failures for two (max 2KB
1892
     * each) vectors, that are never retained on return from this function.
1893
     * We stack-allocate these.
1894
     */
1895
0
#   define case_encap_seed(bits)                                            \
1896
0
    case EVP_PKEY_ML_KEM_##bits:                                            \
1897
0
        {                                                                   \
1898
0
            scalar tmp[2 * ML_KEM_##bits##_RANK];                           \
1899
0
                                                                            \
1900
0
            ret = encap(ctext, shared_secret, entropy, tmp, mdctx, key);    \
1901
0
            OPENSSL_cleanse((void *)tmp, sizeof(tmp));                      \
1902
0
            break;                                                          \
1903
0
        }
1904
0
    switch (vinfo->evp_type) {
1905
0
    case_encap_seed(512);
1906
0
    case_encap_seed(768);
1907
0
    case_encap_seed(1024);
1908
0
    }
1909
0
#   undef case_encap_seed
1910
1911
    /* Declassify secret inputs and derived outputs before returning control */
1912
0
    CONSTTIME_DECLASSIFY(entropy, elen);
1913
0
    CONSTTIME_DECLASSIFY(ctext, clen);
1914
0
    CONSTTIME_DECLASSIFY(shared_secret, slen);
1915
1916
0
    EVP_MD_CTX_free(mdctx);
1917
0
    return ret;
1918
0
}
1919
1920
int ossl_ml_kem_encap_rand(uint8_t *ctext, size_t clen,
1921
                           uint8_t *shared_secret, size_t slen,
1922
                           const ML_KEM_KEY *key)
1923
0
{
1924
0
    uint8_t r[ML_KEM_RANDOM_BYTES];
1925
1926
0
    if (key == NULL)
1927
0
        return 0;
1928
1929
0
    if (RAND_bytes_ex(key->libctx, r, ML_KEM_RANDOM_BYTES,
1930
0
                      key->vinfo->secbits) < 1)
1931
0
        return 0;
1932
1933
0
    return ossl_ml_kem_encap_seed(ctext, clen, shared_secret, slen,
1934
0
                                  r, sizeof(r), key);
1935
0
}
1936
1937
int ossl_ml_kem_decap(uint8_t *shared_secret, size_t slen,
1938
                      const uint8_t *ctext, size_t clen,
1939
                      const ML_KEM_KEY *key)
1940
0
{
1941
0
    const ML_KEM_VINFO *vinfo;
1942
0
    EVP_MD_CTX *mdctx;
1943
0
    int ret = 0;
1944
#if defined(OPENSSL_CONSTANT_TIME_VALIDATION)
1945
    int classify_bytes;
1946
#endif
1947
1948
    /* Need a private key here */
1949
0
    if (!ossl_ml_kem_have_prvkey(key))
1950
0
        return 0;
1951
0
    vinfo = key->vinfo;
1952
1953
0
    if (shared_secret == NULL || slen != ML_KEM_SHARED_SECRET_BYTES
1954
0
        || ctext == NULL || clen != vinfo->ctext_bytes
1955
0
        || (mdctx = EVP_MD_CTX_new()) == NULL) {
1956
0
        (void)RAND_bytes_ex(key->libctx, shared_secret,
1957
0
                            ML_KEM_SHARED_SECRET_BYTES, vinfo->secbits);
1958
0
        return 0;
1959
0
    }
1960
#if defined(OPENSSL_CONSTANT_TIME_VALIDATION)
1961
    /*
1962
     * Data derived from |s| and |z| defaults secret, and to avoid side-channel
1963
     * leaks should not influence control flow.
1964
     */
1965
    classify_bytes = 2 * sizeof(scalar) + ML_KEM_RANDOM_BYTES;
1966
    CONSTTIME_SECRET(key->s, classify_bytes);
1967
#endif
1968
1969
    /*-
1970
     * This avoids the need to handle allocation failures for two (max 2KB
1971
     * each) vectors and an encoded ciphertext (max 1568 bytes), that are never
1972
     * retained on return from this function.
1973
     * We stack-allocate these.
1974
     */
1975
0
#   define case_decap(bits)                                             \
1976
0
    case EVP_PKEY_ML_KEM_##bits:                                        \
1977
0
        {                                                               \
1978
0
            uint8_t cbuf[CTEXT_BYTES(bits)];                            \
1979
0
            scalar tmp[2 * ML_KEM_##bits##_RANK];                       \
1980
0
                                                                        \
1981
0
            ret = decap(shared_secret, ctext, cbuf, tmp, mdctx, key);   \
1982
0
            OPENSSL_cleanse((void *)tmp, sizeof(tmp));                  \
1983
0
            break;                                                      \
1984
0
        }
1985
0
    switch (vinfo->evp_type) {
1986
0
    case_decap(512);
1987
0
    case_decap(768);
1988
0
    case_decap(1024);
1989
0
    }
1990
1991
    /* Declassify secret inputs and derived outputs before returning control */
1992
0
    CONSTTIME_DECLASSIFY(key->s, classify_bytes);
1993
0
    CONSTTIME_DECLASSIFY(shared_secret, slen);
1994
0
    EVP_MD_CTX_free(mdctx);
1995
1996
0
    return ret;
1997
0
#   undef case_decap
1998
0
}
1999
2000
int ossl_ml_kem_pubkey_cmp(const ML_KEM_KEY *key1, const ML_KEM_KEY *key2)
2001
0
{
2002
    /*
2003
     * This handles any unexpected differences in the ML-KEM variant rank,
2004
     * giving different key component structures, barring SHA3-256 hash
2005
     * collisions, the keys are the same size.
2006
     */
2007
0
    if (ossl_ml_kem_have_pubkey(key1) && ossl_ml_kem_have_pubkey(key2))
2008
0
        return memcmp(key1->pkhash, key2->pkhash, ML_KEM_PKHASH_BYTES) == 0;
2009
2010
    /*
2011
     * No match if just one of the public keys is not available, otherwise both
2012
     * are unavailable, and for now such keys are considered equal.
2013
     */
2014
0
    return (ossl_ml_kem_have_pubkey(key1) ^ ossl_ml_kem_have_pubkey(key2));
2015
0
}