Coverage Report

Created: 2025-06-22 06:56

/src/openssl/crypto/ml_kem/ml_kem.c
Line
Count
Source (jump to first uncovered line)
1
/*
2
 * Copyright 2024-2025 The OpenSSL Project Authors. All Rights Reserved.
3
 *
4
 * Licensed under the Apache License 2.0 (the "License").  You may not use
5
 * this file except in compliance with the License.  You can obtain a copy
6
 * in the file LICENSE in the source distribution or at
7
 * https://www.openssl.org/source/license.html
8
 */
9
10
#include <openssl/byteorder.h>
11
#include <openssl/rand.h>
12
#include <openssl/proverr.h>
13
#include "crypto/ml_kem.h"
14
#include "internal/common.h"
15
#include "internal/constant_time.h"
16
#include "internal/sha3.h"
17
18
#if defined(OPENSSL_CONSTANT_TIME_VALIDATION)
19
#include <valgrind/memcheck.h>
20
#endif
21
22
#if ML_KEM_SEED_BYTES != ML_KEM_SHARED_SECRET_BYTES + ML_KEM_RANDOM_BYTES
23
# error "ML-KEM keygen seed length != shared secret + random bytes length"
24
#endif
25
#if ML_KEM_SHARED_SECRET_BYTES != ML_KEM_RANDOM_BYTES
26
# error "Invalid unequal lengths of ML-KEM shared secret and random inputs"
27
#endif
28
29
#if UINT_MAX < UINT32_MAX
30
# error "Unsupported compiler: sizeof(unsigned int) < sizeof(uint32_t)"
31
#endif
32
33
/* Handy function-like bit-extraction macros */
34
0
#define bit0(b)     ((b) & 1)
35
0
#define bitn(n, b)  (((b) >> n) & 1)
36
37
/*
38
 * 12 bits are sufficient to losslessly represent values in [0, q-1].
39
 * INVERSE_DEGREE is (n/2)^-1 mod q; used in inverse NTT.
40
 */
41
0
#define DEGREE          ML_KEM_DEGREE
42
#define INVERSE_DEGREE  (ML_KEM_PRIME - 2 * 13)
43
#define LOG2PRIME       12
44
#define BARRETT_SHIFT   (2 * LOG2PRIME)
45
46
#ifdef SHA3_BLOCKSIZE
47
# define SHAKE128_BLOCKSIZE SHA3_BLOCKSIZE(128)
48
#endif
49
50
/*
51
 * Return whether a value that can only be 0 or 1 is non-zero, in constant time
52
 * in practice!  The return value is a mask that is all ones if true, and all
53
 * zeros otherwise (twos-complement arithmentic assumed for unsigned values).
54
 *
55
 * Although this is used in constant-time selects, we omit a value barrier
56
 * here.  Value barriers impede auto-vectorization (likely because it forces
57
 * the value to transit through a general-purpose register). On AArch64, this
58
 * is a difference of 2x.
59
 *
60
 * We usually add value barriers to selects because Clang turns consecutive
61
 * selects with the same condition into a branch instead of CMOV/CSEL. This
62
 * condition does not occur in Kyber, so omitting it seems to be safe so far,
63
 * but see |cbd_2|, |cbd_3|, where reduction needs to be specialised to the
64
 * sign of the input, rather than adding |q| in advance, and using the generic
65
 * |reduce_once|.  (David Benjamin, Chromium)
66
 */
67
#if 0
68
# define constish_time_non_zero(b) (~constant_time_is_zero(b));
69
#else
70
0
# define constish_time_non_zero(b) (0u - (b))
71
#endif
72
73
/*
74
 * The scalar rejection-sampling buffer size needs to be a multiple of 12, but
75
 * is otherwise arbitrary, the preferred block size matches the internal buffer
76
 * size of SHAKE128, avoiding internal buffering and copying in SHAKE128. That
77
 * block size of (1600 - 256)/8 bytes, or 168, just happens to divide by 12!
78
 *
79
 * If the blocksize is unknown, or is not divisible by 12, 168 is used as a
80
 * fallback.
81
 */
82
#if defined(SHAKE128_BLOCKSIZE) && (SHAKE128_BLOCKSIZE) % 12 == 0
83
# define SCALAR_SAMPLING_BUFSIZE (SHAKE128_BLOCKSIZE)
84
#else
85
# define SCALAR_SAMPLING_BUFSIZE 168
86
#endif
87
88
/*
89
 * Structure of keys
90
 */
91
typedef struct ossl_ml_kem_scalar_st {
92
    /* On every function entry and exit, 0 <= c[i] < ML_KEM_PRIME. */
93
    uint16_t c[ML_KEM_DEGREE];
94
} scalar;
95
96
/* Key material allocation layout */
97
#define DECLARE_ML_KEM_KEYDATA(name, rank, private_sz) \
98
    struct name##_alloc { \
99
        /* Public vector |t| */ \
100
        scalar tbuf[(rank)]; \
101
        /* Pre-computed matrix |m| (FIPS 203 |A| transpose) */ \
102
        scalar mbuf[(rank)*(rank)] \
103
        /* optional private key data */ \
104
        private_sz \
105
    }
106
107
/* Declare variant-specific public and private storage */
108
#define DECLARE_ML_KEM_VARIANT_KEYDATA(bits) \
109
    DECLARE_ML_KEM_KEYDATA(pubkey_##bits, ML_KEM_##bits##_RANK,;); \
110
    DECLARE_ML_KEM_KEYDATA(prvkey_##bits, ML_KEM_##bits##_RANK,;\
111
        scalar sbuf[ML_KEM_##bits##_RANK]; \
112
        uint8_t zbuf[2 * ML_KEM_RANDOM_BYTES];)
113
DECLARE_ML_KEM_VARIANT_KEYDATA(512);
114
DECLARE_ML_KEM_VARIANT_KEYDATA(768);
115
DECLARE_ML_KEM_VARIANT_KEYDATA(1024);
116
#undef DECLARE_ML_KEM_VARIANT_KEYDATA
117
#undef DECLARE_ML_KEM_KEYDATA
118
119
typedef __owur
120
int (*CBD_FUNC)(scalar *out, uint8_t in[ML_KEM_RANDOM_BYTES + 1],
121
                EVP_MD_CTX *mdctx, const ML_KEM_KEY *key);
122
static void scalar_encode(uint8_t *out, const scalar *s, int bits);
123
124
/*
125
 * The wire-form of a losslessly encoded vector uses 12-bits per element.
126
 *
127
 * The wire-form public key consists of the lossless encoding of the public
128
 * vector |t|, followed by the public seed |rho|.
129
 *
130
 * Our serialised private key concatenates serialisations of the private vector
131
 * |s|, the public key, the public key hash, and the failure secret |z|.
132
 */
133
#define VECTOR_BYTES(b)     ((3 * DEGREE / 2) * ML_KEM_##b##_RANK)
134
#define PUBKEY_BYTES(b)     (VECTOR_BYTES(b) + ML_KEM_RANDOM_BYTES)
135
#define PRVKEY_BYTES(b)     (2 * PUBKEY_BYTES(b) + ML_KEM_PKHASH_BYTES)
136
137
/*
138
 * Encapsulation produces a vector "u" and a scalar "v", whose coordinates
139
 * (numbers modulo the ML-KEM prime "q") are lossily encoded using as "du" and
140
 * "dv" bits, respectively.  This encoding is the ciphertext input for
141
 * decapsulation.
142
 */
143
#define U_VECTOR_BYTES(b)   ((DEGREE / 8) * ML_KEM_##b##_DU * ML_KEM_##b##_RANK)
144
#define V_SCALAR_BYTES(b)   ((DEGREE / 8) * ML_KEM_##b##_DV)
145
#define CTEXT_BYTES(b)      (U_VECTOR_BYTES(b) + V_SCALAR_BYTES(b))
146
147
#if defined(OPENSSL_CONSTANT_TIME_VALIDATION)
148
149
/*
150
 * CONSTTIME_SECRET takes a pointer and a number of bytes and marks that region
151
 * of memory as secret. Secret data is tracked as it flows to registers and
152
 * other parts of a memory. If secret data is used as a condition for a branch,
153
 * or as a memory index, it will trigger warnings in valgrind.
154
 */
155
# define CONSTTIME_SECRET(ptr, len) VALGRIND_MAKE_MEM_UNDEFINED(ptr, len)
156
157
/*
158
 * CONSTTIME_DECLASSIFY takes a pointer and a number of bytes and marks that
159
 * region of memory as public. Public data is not subject to constant-time
160
 * rules.
161
 */
162
# define CONSTTIME_DECLASSIFY(ptr, len) VALGRIND_MAKE_MEM_DEFINED(ptr, len)
163
164
#else
165
166
# define CONSTTIME_SECRET(ptr, len)
167
# define CONSTTIME_DECLASSIFY(ptr, len)
168
169
#endif
170
171
/*
172
 * Indices of slots in the vinfo tables below
173
 */
174
0
#define ML_KEM_512_VINFO    0
175
0
#define ML_KEM_768_VINFO    1
176
0
#define ML_KEM_1024_VINFO   2
177
178
/*
179
 * Per-variant fixed parameters
180
 */
181
static const ML_KEM_VINFO vinfo_map[3] = {
182
    {
183
        "ML-KEM-512",
184
        PRVKEY_BYTES(512),
185
        sizeof(struct prvkey_512_alloc),
186
        PUBKEY_BYTES(512),
187
        sizeof(struct pubkey_512_alloc),
188
        CTEXT_BYTES(512),
189
        VECTOR_BYTES(512),
190
        U_VECTOR_BYTES(512),
191
        EVP_PKEY_ML_KEM_512,
192
        ML_KEM_512_BITS,
193
        ML_KEM_512_RANK,
194
        ML_KEM_512_DU,
195
        ML_KEM_512_DV,
196
        ML_KEM_512_SECBITS,
197
        ML_KEM_512_SECURITY_CATEGORY
198
    },
199
    {
200
        "ML-KEM-768",
201
        PRVKEY_BYTES(768),
202
        sizeof(struct prvkey_768_alloc),
203
        PUBKEY_BYTES(768),
204
        sizeof(struct pubkey_768_alloc),
205
        CTEXT_BYTES(768),
206
        VECTOR_BYTES(768),
207
        U_VECTOR_BYTES(768),
208
        EVP_PKEY_ML_KEM_768,
209
        ML_KEM_768_BITS,
210
        ML_KEM_768_RANK,
211
        ML_KEM_768_DU,
212
        ML_KEM_768_DV,
213
        ML_KEM_768_SECBITS,
214
        ML_KEM_768_SECURITY_CATEGORY
215
    },
216
    {
217
        "ML-KEM-1024",
218
        PRVKEY_BYTES(1024),
219
        sizeof(struct prvkey_1024_alloc),
220
        PUBKEY_BYTES(1024),
221
        sizeof(struct pubkey_1024_alloc),
222
        CTEXT_BYTES(1024),
223
        VECTOR_BYTES(1024),
224
        U_VECTOR_BYTES(1024),
225
        EVP_PKEY_ML_KEM_1024,
226
        ML_KEM_1024_BITS,
227
        ML_KEM_1024_RANK,
228
        ML_KEM_1024_DU,
229
        ML_KEM_1024_DV,
230
        ML_KEM_1024_SECBITS,
231
        ML_KEM_1024_SECURITY_CATEGORY
232
    }
233
};
234
235
/*
236
 * Remainders modulo `kPrime`, for sufficiently small inputs, are computed in
237
 * constant time via Barrett reduction, and a final call to reduce_once(),
238
 * which reduces inputs that are at most 2*kPrime and is also constant-time.
239
 */
240
static const int kPrime = ML_KEM_PRIME;
241
static const unsigned int kBarrettShift = BARRETT_SHIFT;
242
static const size_t   kBarrettMultiplier = (1 << BARRETT_SHIFT) / ML_KEM_PRIME;
243
static const uint16_t kHalfPrime = (ML_KEM_PRIME - 1) / 2;
244
static const uint16_t kInverseDegree = INVERSE_DEGREE;
245
246
/*
247
 * Python helper:
248
 *
249
 * p = 3329
250
 * def bitreverse(i):
251
 *     ret = 0
252
 *     for n in range(7):
253
 *         bit = i & 1
254
 *         ret <<= 1
255
 *         ret |= bit
256
 *         i >>= 1
257
 *     return ret
258
 */
259
260
/*-
261
 * First precomputed array from Appendix A of FIPS 203, or else Python:
262
 * kNTTRoots = [pow(17, bitreverse(i), p) for i in range(128)]
263
 */
264
static const uint16_t kNTTRoots[128] = {
265
    1,    1729, 2580, 3289, 2642, 630,  1897, 848,
266
    1062, 1919, 193,  797,  2786, 3260, 569,  1746,
267
    296,  2447, 1339, 1476, 3046, 56,   2240, 1333,
268
    1426, 2094, 535,  2882, 2393, 2879, 1974, 821,
269
    289,  331,  3253, 1756, 1197, 2304, 2277, 2055,
270
    650,  1977, 2513, 632,  2865, 33,   1320, 1915,
271
    2319, 1435, 807,  452,  1438, 2868, 1534, 2402,
272
    2647, 2617, 1481, 648,  2474, 3110, 1227, 910,
273
    17,   2761, 583,  2649, 1637, 723,  2288, 1100,
274
    1409, 2662, 3281, 233,  756,  2156, 3015, 3050,
275
    1703, 1651, 2789, 1789, 1847, 952,  1461, 2687,
276
    939,  2308, 2437, 2388, 733,  2337, 268,  641,
277
    1584, 2298, 2037, 3220, 375,  2549, 2090, 1645,
278
    1063, 319,  2773, 757,  2099, 561,  2466, 2594,
279
    2804, 1092, 403,  1026, 1143, 2150, 2775, 886,
280
    1722, 1212, 1874, 1029, 2110, 2935, 885,  2154,
281
};
282
283
/*
284
 * InverseNTTRoots = [pow(17, -bitreverse(i), p) for i in range(128)]
285
 * Listed in order of use in the inverse NTT loop (index 0 is skipped):
286
 *
287
 *  0, 64, 65, ..., 127, 32, 33, ..., 63, 16, 17, ..., 31, 8, 9, ...
288
 */
289
static const uint16_t kInverseNTTRoots[128] = {
290
    1,    1175, 2444, 394,  1219, 2300, 1455, 2117,
291
    1607, 2443, 554,  1179, 2186, 2303, 2926, 2237,
292
    525,  735,  863,  2768, 1230, 2572, 556,  3010,
293
    2266, 1684, 1239, 780,  2954, 109,  1292, 1031,
294
    1745, 2688, 3061, 992,  2596, 941,  892,  1021,
295
    2390, 642,  1868, 2377, 1482, 1540, 540,  1678,
296
    1626, 279,  314,  1173, 2573, 3096, 48,   667,
297
    1920, 2229, 1041, 2606, 1692, 680,  2746, 568,
298
    3312, 2419, 2102, 219,  855,  2681, 1848, 712,
299
    682,  927,  1795, 461,  1891, 2877, 2522, 1894,
300
    1010, 1414, 2009, 3296, 464,  2697, 816,  1352,
301
    2679, 1274, 1052, 1025, 2132, 1573, 76,   2998,
302
    3040, 2508, 1355, 450,  936,  447,  2794, 1235,
303
    1903, 1996, 1089, 3273, 283,  1853, 1990, 882,
304
    3033, 1583, 2760, 69,   543,  2532, 3136, 1410,
305
    2267, 2481, 1432, 2699, 687,  40,   749,  1600,
306
};
307
308
/*
309
 * Second precomputed array from Appendix A of FIPS 203 (normalised positive),
310
 * or else Python:
311
 * ModRoots = [pow(17, 2*bitreverse(i) + 1, p) for i in range(128)]
312
 */
313
static const uint16_t kModRoots[128] = {
314
    17,   3312, 2761, 568,  583,  2746, 2649, 680,  1637, 1692, 723,  2606,
315
    2288, 1041, 1100, 2229, 1409, 1920, 2662, 667,  3281, 48,   233,  3096,
316
    756,  2573, 2156, 1173, 3015, 314,  3050, 279,  1703, 1626, 1651, 1678,
317
    2789, 540,  1789, 1540, 1847, 1482, 952,  2377, 1461, 1868, 2687, 642,
318
    939,  2390, 2308, 1021, 2437, 892,  2388, 941,  733,  2596, 2337, 992,
319
    268,  3061, 641,  2688, 1584, 1745, 2298, 1031, 2037, 1292, 3220, 109,
320
    375,  2954, 2549, 780,  2090, 1239, 1645, 1684, 1063, 2266, 319,  3010,
321
    2773, 556,  757,  2572, 2099, 1230, 561,  2768, 2466, 863,  2594, 735,
322
    2804, 525,  1092, 2237, 403,  2926, 1026, 2303, 1143, 2186, 2150, 1179,
323
    2775, 554,  886,  2443, 1722, 1607, 1212, 2117, 1874, 1455, 1029, 2300,
324
    2110, 1219, 2935, 394,  885,  2444, 2154, 1175,
325
};
326
327
/*
328
 * single_keccak hashes |inlen| bytes from |in| and writes |outlen| bytes of
329
 * output to |out|. If the |md| specifies a fixed-output function, like
330
 * SHA3-256, then |outlen| must be the correct length for that function.
331
 */
332
static __owur
333
int single_keccak(uint8_t *out, size_t outlen, const uint8_t *in, size_t inlen,
334
                  EVP_MD_CTX *mdctx)
335
0
{
336
0
    unsigned int sz = (unsigned int) outlen;
337
338
0
    if (!EVP_DigestUpdate(mdctx, in, inlen))
339
0
        return 0;
340
0
    if (EVP_MD_xof(EVP_MD_CTX_get0_md(mdctx)))
341
0
        return EVP_DigestFinalXOF(mdctx, out, outlen);
342
0
    return EVP_DigestFinal_ex(mdctx, out, &sz)
343
0
        && ossl_assert((size_t) sz == outlen);
344
0
}
345
346
/*
347
 * FIPS 203, Section 4.1, equation (4.3): PRF. Takes 32+1 input bytes, and uses
348
 * SHAKE256 to produce the input to SamplePolyCBD_eta: FIPS 203, algorithm 8.
349
 */
350
static __owur
351
int prf(uint8_t *out, size_t len, const uint8_t in[ML_KEM_RANDOM_BYTES + 1],
352
        EVP_MD_CTX *mdctx, const ML_KEM_KEY *key)
353
0
{
354
0
    return EVP_DigestInit_ex(mdctx, key->shake256_md, NULL)
355
0
        && single_keccak(out, len, in, ML_KEM_RANDOM_BYTES + 1, mdctx);
356
0
}
357
358
/*
359
 * FIPS 203, Section 4.1, equation (4.4): H.  SHA3-256 hash of a variable
360
 * length input, producing 32 bytes of output.
361
 */
362
static __owur
363
int hash_h(uint8_t out[ML_KEM_PKHASH_BYTES], const uint8_t *in, size_t len,
364
           EVP_MD_CTX *mdctx, const ML_KEM_KEY *key)
365
0
{
366
0
    return EVP_DigestInit_ex(mdctx, key->sha3_256_md, NULL)
367
0
        && single_keccak(out, ML_KEM_PKHASH_BYTES, in, len, mdctx);
368
0
}
369
370
/* Incremental hash_h of expanded public key */
371
static int
372
hash_h_pubkey(uint8_t pkhash[ML_KEM_PKHASH_BYTES],
373
              EVP_MD_CTX *mdctx, ML_KEM_KEY *key)
374
0
{
375
0
    const ML_KEM_VINFO *vinfo = key->vinfo;
376
0
    const scalar *t = key->t, *end = t + vinfo->rank;
377
0
    unsigned int sz;
378
379
0
    if (!EVP_DigestInit_ex(mdctx, key->sha3_256_md, NULL))
380
0
        return 0;
381
382
0
    do {
383
0
        uint8_t buf[3 * DEGREE / 2];
384
385
0
        scalar_encode(buf, t++, 12);
386
0
        if (!EVP_DigestUpdate(mdctx, buf, sizeof(buf)))
387
0
            return 0;
388
0
    } while (t < end);
389
390
0
    if (!EVP_DigestUpdate(mdctx, key->rho, ML_KEM_RANDOM_BYTES))
391
0
        return 0;
392
0
    return EVP_DigestFinal_ex(mdctx, pkhash, &sz)
393
0
        && ossl_assert(sz == ML_KEM_PKHASH_BYTES);
394
0
}
395
396
/*
397
 * FIPS 203, Section 4.1, equation (4.5): G.  SHA3-512 hash of a variable
398
 * length input, producing 64 bytes of output, in particular the seeds
399
 * (d,z) for key generation.
400
 */
401
static __owur
402
int hash_g(uint8_t out[ML_KEM_SEED_BYTES], const uint8_t *in, size_t len,
403
           EVP_MD_CTX *mdctx, const ML_KEM_KEY *key)
404
0
{
405
0
    return EVP_DigestInit_ex(mdctx, key->sha3_512_md, NULL)
406
0
        && single_keccak(out, ML_KEM_SEED_BYTES, in, len, mdctx);
407
0
}
408
409
/*
410
 * FIPS 203, Section 4.1, equation (4.4): J. SHAKE256 taking a variable length
411
 * input to compute a 32-byte implicit rejection shared secret, of the same
412
 * length as the expected shared secret.  (Computed even on success to avoid
413
 * side-channel leaks).
414
 */
415
static __owur
416
int kdf(uint8_t out[ML_KEM_SHARED_SECRET_BYTES],
417
        const uint8_t z[ML_KEM_RANDOM_BYTES],
418
        const uint8_t *ctext, size_t len,
419
        EVP_MD_CTX *mdctx, const ML_KEM_KEY *key)
420
0
{
421
0
    return EVP_DigestInit_ex(mdctx, key->shake256_md, NULL)
422
0
        && EVP_DigestUpdate(mdctx, z, ML_KEM_RANDOM_BYTES)
423
0
        && EVP_DigestUpdate(mdctx, ctext, len)
424
0
        && EVP_DigestFinalXOF(mdctx, out, ML_KEM_SHARED_SECRET_BYTES);
425
0
}
426
427
/*
428
 * FIPS 203, Section 4.2.2, Algorithm 7: "SampleNTT" (steps 3-17, steps 1, 2
429
 * are performed by the caller). Rejection-samples a Keccak stream to get
430
 * uniformly distributed elements in the range [0,q). This is used for matrix
431
 * expansion and only operates on public inputs.
432
 */
433
static __owur
434
int sample_scalar(scalar *out, EVP_MD_CTX *mdctx)
435
0
{
436
0
    uint16_t *curr = out->c, *endout = curr + DEGREE;
437
0
    uint8_t buf[SCALAR_SAMPLING_BUFSIZE], *in;
438
0
    uint8_t *endin = buf + sizeof(buf);
439
0
    uint16_t d;
440
0
    uint8_t b1, b2, b3;
441
442
0
    do {
443
0
        if (!EVP_DigestSqueeze(mdctx, in = buf, sizeof(buf)))
444
0
            return 0;
445
0
        do {
446
0
            b1 = *in++;
447
0
            b2 = *in++;
448
0
            b3 = *in++;
449
450
0
            if (curr >= endout)
451
0
                break;
452
0
            if ((d = ((b2 & 0x0f) << 8) + b1) < kPrime)
453
0
                *curr++ = d;
454
0
            if (curr >= endout)
455
0
                break;
456
0
            if ((d = (b3 << 4) + (b2 >> 4)) < kPrime)
457
0
                *curr++ = d;
458
0
        } while (in < endin);
459
0
    } while (curr < endout);
460
0
    return 1;
461
0
}
462
463
/*-
464
 * reduce_once reduces 0 <= x < 2*kPrime, mod kPrime.
465
 *
466
 * Subtract |q| if the input is larger, without exposing a side-channel,
467
 * avoiding the "clangover" attack.  See |constish_time_non_zero| for a
468
 * discussion on why the value barrier is by default omitted.
469
 */
470
static __owur uint16_t reduce_once(uint16_t x)
471
0
{
472
0
    const uint16_t subtracted = x - kPrime;
473
0
    uint16_t mask = constish_time_non_zero(subtracted >> 15);
474
475
0
    return (mask & x) | (~mask & subtracted);
476
0
}
477
478
/*
479
 * Constant-time reduce x mod kPrime using Barrett reduction. x must be less
480
 * than kPrime + 2 * kPrime^2.  This is sufficient to reduce a product of
481
 * two already reduced u_int16 values, in fact it is sufficient for each
482
 * to be less than 2^12, because (kPrime * (2 * kPrime + 1)) > 2^24.
483
 */
484
static __owur uint16_t reduce(uint32_t x)
485
0
{
486
0
    uint64_t product = (uint64_t)x * kBarrettMultiplier;
487
0
    uint32_t quotient = (uint32_t)(product >> kBarrettShift);
488
0
    uint32_t remainder = x - quotient * kPrime;
489
490
0
    return reduce_once(remainder);
491
0
}
492
493
/* Multiply a scalar by a constant. */
494
static void scalar_mult_const(scalar *s, uint16_t a)
495
0
{
496
0
    uint16_t *curr = s->c, *end = curr + DEGREE, tmp;
497
498
0
    do {
499
0
        tmp = reduce(*curr * a);
500
0
        *curr++ = tmp;
501
0
    } while (curr < end);
502
0
}
503
504
/*-
505
 * FIPS 203, Section 4.3, Algoritm 9: "NTT".
506
 * In-place number theoretic transform of a given scalar.  Note that ML-KEM's
507
 * kPrime 3329 does not have a 512th root of unity, so this transform leaves
508
 * off the last iteration of the usual FFT code, with the 128 relevant roots of
509
 * unity being stored in NTTRoots.  This means the output should be seen as 128
510
 * elements in GF(3329^2), with the coefficients of the elements being
511
 * consecutive entries in |s->c|.
512
 */
513
static void scalar_ntt(scalar *s)
514
0
{
515
0
    const uint16_t *roots = kNTTRoots;
516
0
    uint16_t *end = s->c + DEGREE;
517
0
    int offset = DEGREE / 2;
518
519
0
    do {
520
0
        uint16_t *curr = s->c, *peer;
521
522
0
        do {
523
0
            uint16_t *pause = curr + offset, even, odd;
524
0
            uint32_t zeta = *++roots;
525
526
0
            peer = pause;
527
0
            do {
528
0
                even = *curr;
529
0
                odd = reduce(*peer * zeta);
530
0
                *peer++ = reduce_once(even - odd + kPrime);
531
0
                *curr++ = reduce_once(odd + even);
532
0
            } while (curr < pause);
533
0
        } while ((curr = peer) < end);
534
0
    } while ((offset >>= 1) >= 2);
535
0
}
536
537
/*-
538
 * FIPS 203, Section 4.3, Algoritm 10: "NTT^(-1)".
539
 * In-place inverse number theoretic transform of a given scalar, with pairs of
540
 * entries of s->v being interpreted as elements of GF(3329^2). Just as with
541
 * the number theoretic transform, this leaves off the first step of the normal
542
 * iFFT to account for the fact that 3329 does not have a 512th root of unity,
543
 * using the precomputed 128 roots of unity stored in InverseNTTRoots.
544
 */
545
static void scalar_inverse_ntt(scalar *s)
546
0
{
547
0
    const uint16_t *roots = kInverseNTTRoots;
548
0
    uint16_t *end = s->c + DEGREE;
549
0
    int offset = 2;
550
551
0
    do {
552
0
        uint16_t *curr = s->c, *peer;
553
554
0
        do {
555
0
            uint16_t *pause = curr + offset, even, odd;
556
0
            uint32_t zeta = *++roots;
557
558
0
            peer = pause;
559
0
            do {
560
0
                even = *curr;
561
0
                odd = *peer;
562
0
                *peer++ = reduce(zeta * (even - odd + kPrime));
563
0
                *curr++ = reduce_once(odd + even);
564
0
            } while (curr < pause);
565
0
        } while ((curr = peer) < end);
566
0
    } while ((offset <<= 1) < DEGREE);
567
0
    scalar_mult_const(s, kInverseDegree);
568
0
}
569
570
/* Addition updating the LHS scalar in-place. */
571
static void scalar_add(scalar *lhs, const scalar *rhs)
572
0
{
573
0
    int i;
574
575
0
    for (i = 0; i < DEGREE; i++)
576
0
        lhs->c[i] = reduce_once(lhs->c[i] + rhs->c[i]);
577
0
}
578
579
/* Subtraction updating the LHS scalar in-place. */
580
static void scalar_sub(scalar *lhs, const scalar *rhs)
581
0
{
582
0
    int i;
583
584
0
    for (i = 0; i < DEGREE; i++)
585
0
        lhs->c[i] = reduce_once(lhs->c[i] - rhs->c[i] + kPrime);
586
0
}
587
588
/*
589
 * Multiplying two scalars in the number theoretically transformed state. Since
590
 * 3329 does not have a 512th root of unity, this means we have to interpret
591
 * the 2*ith and (2*i+1)th entries of the scalar as elements of
592
 * GF(3329)[X]/(X^2 - 17^(2*bitreverse(i)+1)).
593
 *
594
 * The value of 17^(2*bitreverse(i)+1) mod 3329 is stored in the precomputed
595
 * ModRoots table. Note that our Barrett transform only allows us to multipy
596
 * two reduced numbers together, so we need some intermediate reduction steps,
597
 * even if an uint64_t could hold 3 multiplied numbers.
598
 */
599
static void scalar_mult(scalar *out, const scalar *lhs,
600
                        const scalar *rhs)
601
0
{
602
0
    uint16_t *curr = out->c, *end = curr + DEGREE;
603
0
    const uint16_t *lc = lhs->c, *rc = rhs->c;
604
0
    const uint16_t *roots = kModRoots;
605
606
0
    do {
607
0
        uint32_t l0 = *lc++, r0 = *rc++;
608
0
        uint32_t l1 = *lc++, r1 = *rc++;
609
0
        uint32_t zetapow = *roots++;
610
611
0
        *curr++ = reduce(l0 * r0 + reduce(l1 * r1) * zetapow);
612
0
        *curr++ = reduce(l0 * r1 + l1 * r0);
613
0
    } while (curr < end);
614
0
}
615
616
/* Above, but add the result to an existing scalar */
617
static ossl_inline
618
void scalar_mult_add(scalar *out, const scalar *lhs,
619
                     const scalar *rhs)
620
0
{
621
0
    uint16_t *curr = out->c, *end = curr + DEGREE;
622
0
    const uint16_t *lc = lhs->c, *rc = rhs->c;
623
0
    const uint16_t *roots = kModRoots;
624
625
0
    do {
626
0
        uint32_t l0 = *lc++, r0 = *rc++;
627
0
        uint32_t l1 = *lc++, r1 = *rc++;
628
0
        uint16_t *c0 = curr++;
629
0
        uint16_t *c1 = curr++;
630
0
        uint32_t zetapow = *roots++;
631
632
0
        *c0 = reduce(*c0 + l0 * r0 + reduce(l1 * r1) * zetapow);
633
0
        *c1 = reduce(*c1 + l0 * r1 + l1 * r0);
634
0
    } while (curr < end);
635
0
}
636
637
/*-
638
 * FIPS 203, Section 4.2.1, Algorithm 5: "ByteEncode_d", for 2<=d<=12.
639
 * Here |bits| is |d|.  For efficiency, we handle the d=1 case separately.
640
 */
641
static void scalar_encode(uint8_t *out, const scalar *s, int bits)
642
0
{
643
0
    const uint16_t *curr = s->c, *end = curr + DEGREE;
644
0
    uint64_t accum = 0, element;
645
0
    int used = 0;
646
647
0
    do {
648
0
        element = *curr++;
649
0
        if (used + bits < 64) {
650
0
            accum |= element << used;
651
0
            used += bits;
652
0
        } else if (used + bits > 64) {
653
0
            out = OPENSSL_store_u64_le(out, accum | (element << used));
654
0
            accum = element >> (64 - used);
655
0
            used = (used + bits) - 64;
656
0
        } else {
657
0
            out = OPENSSL_store_u64_le(out, accum | (element << used));
658
0
            accum = 0;
659
0
            used = 0;
660
0
        }
661
0
    } while (curr < end);
662
0
}
663
664
/*
665
 * scalar_encode_1 is |scalar_encode| specialised for |bits| == 1.
666
 */
667
static void scalar_encode_1(uint8_t out[DEGREE / 8], const scalar *s)
668
0
{
669
0
    int i, j;
670
0
    uint8_t out_byte;
671
672
0
    for (i = 0; i < DEGREE; i += 8) {
673
0
        out_byte = 0;
674
0
        for (j = 0; j < 8; j++)
675
0
            out_byte |= bit0(s->c[i + j]) << j;
676
0
        *out = out_byte;
677
0
        out++;
678
0
    }
679
0
}
680
681
/*-
682
 * FIPS 203, Section 4.2.1, Algorithm 6: "ByteDecode_d", for 2<=d<12.
683
 * Here |bits| is |d|.  For efficiency, we handle the d=1 and d=12 cases
684
 * separately.
685
 *
686
 * scalar_decode parses |DEGREE * bits| bits from |in| into |DEGREE| values in
687
 * |out|.
688
 */
689
static void scalar_decode(scalar *out, const uint8_t *in, int bits)
690
0
{
691
0
    uint16_t *curr = out->c, *end = curr + DEGREE;
692
0
    uint64_t accum = 0;
693
0
    int accum_bits = 0, todo = bits;
694
0
    uint16_t bitmask = (((uint16_t) 1) << bits) - 1, mask = bitmask;
695
0
    uint16_t element = 0;
696
697
0
    do {
698
0
        if (accum_bits == 0) {
699
0
            in = OPENSSL_load_u64_le(&accum, in);
700
0
            accum_bits = 64;
701
0
        }
702
0
        if (todo == bits && accum_bits >= bits) {
703
            /* No partial "element", and all the required bits available */
704
0
            *curr++ = ((uint16_t) accum) & mask;
705
0
            accum >>= bits;
706
0
            accum_bits -= bits;
707
0
        } else if (accum_bits >= todo) {
708
            /* A partial "element", and all the required bits available */
709
0
            *curr++ = element | ((((uint16_t) accum) & mask) << (bits - todo));
710
0
            accum >>= todo;
711
0
            accum_bits -= todo;
712
0
            element = 0;
713
0
            todo = bits;
714
0
            mask = bitmask;
715
0
        } else {
716
            /*
717
             * Only some of the requisite bits accumulated, store |accum_bits|
718
             * of these in |element|.  The accumulated bitcount becomes 0, but
719
             * as soon as we have more bits we'll want to merge accum_bits
720
             * fewer of them into the final |element|.
721
             *
722
             * Note that with a 64-bit accumulator and |bits| always 12 or
723
             * less, if we're here, the previous iteration had all the
724
             * requisite bits, and so there are no kept bits in |element|.
725
             */
726
0
            element = ((uint16_t) accum) & mask;
727
0
            todo -= accum_bits;
728
0
            mask = bitmask >> accum_bits;
729
0
            accum_bits = 0;
730
0
        }
731
0
    } while (curr < end);
732
0
}
733
734
static __owur
735
int scalar_decode_12(scalar *out, const uint8_t in[3 * DEGREE / 2])
736
0
{
737
0
    int i;
738
0
    uint16_t *c = out->c;
739
740
0
    for (i = 0; i < DEGREE / 2; ++i) {
741
0
        uint8_t b1 = *in++;
742
0
        uint8_t b2 = *in++;
743
0
        uint8_t b3 = *in++;
744
0
        int outOfRange1 = (*c++ = b1 | ((b2 & 0x0f) << 8)) >= kPrime;
745
0
        int outOfRange2 = (*c++ = (b2 >> 4) | (b3 << 4)) >= kPrime;
746
747
0
        if (outOfRange1 | outOfRange2)
748
0
            return 0;
749
0
    }
750
0
    return 1;
751
0
}
752
753
/*-
754
 * scalar_decode_decompress_add is a combination of decoding and decompression
755
 * both specialised for |bits| == 1, with the result added (and sum reduced) to
756
 * the output scalar.
757
 *
758
 * NOTE: this function MUST not leak an input-data-depedennt timing signal.
759
 * A timing leak in a related function in the reference Kyber implementation
760
 * made the "clangover" attack (CVE-2024-37880) possible, giving key recovery
761
 * for ML-KEM-512 in minutes, provided the attacker has access to precise
762
 * timing of a CPU performing chosen-ciphertext decap.  Admittedly this is only
763
 * a risk when private keys are reused (perhaps KEMTLS servers).
764
 */
765
static void
766
scalar_decode_decompress_add(scalar *out, const uint8_t in[DEGREE / 8])
767
0
{
768
0
    static const uint16_t half_q_plus_1 = (ML_KEM_PRIME >> 1) + 1;
769
0
    uint16_t *curr = out->c, *end = curr + DEGREE;
770
0
    uint16_t mask;
771
0
    uint8_t b;
772
773
    /*
774
     * Add |half_q_plus_1| if the bit is set, without exposing a side-channel,
775
     * avoiding the "clangover" attack.  See |constish_time_non_zero| for a
776
     * discussion on why the value barrier is by default omitted.
777
     */
778
0
#define decode_decompress_add_bit                               \
779
0
        mask = constish_time_non_zero(bit0(b));                 \
780
0
        *curr = reduce_once(*curr + (mask & half_q_plus_1));    \
781
0
        curr++;                                                 \
782
0
        b >>= 1
783
784
    /* Unrolled to process each byte in one iteration */
785
0
    do {
786
0
        b = *in++;
787
0
        decode_decompress_add_bit;
788
0
        decode_decompress_add_bit;
789
0
        decode_decompress_add_bit;
790
0
        decode_decompress_add_bit;
791
792
0
        decode_decompress_add_bit;
793
0
        decode_decompress_add_bit;
794
0
        decode_decompress_add_bit;
795
0
        decode_decompress_add_bit;
796
0
    } while (curr < end);
797
0
#undef decode_decompress_add_bit
798
0
}
799
800
/*
801
 * FIPS 203, Section 4.2.1, Equation (4.7): Compress_d.
802
 *
803
 * Compresses (lossily) an input |x| mod 3329 into |bits| many bits by grouping
804
 * numbers close to each other together. The formula used is
805
 * round(2^|bits|/kPrime*x) mod 2^|bits|.
806
 * Uses Barrett reduction to achieve constant time. Since we need both the
807
 * remainder (for rounding) and the quotient (as the result), we cannot use
808
 * |reduce| here, but need to do the Barrett reduction directly.
809
 */
810
static __owur uint16_t compress(uint16_t x, int bits)
811
0
{
812
0
    uint32_t shifted = (uint32_t)x << bits;
813
0
    uint64_t product = (uint64_t)shifted * kBarrettMultiplier;
814
0
    uint32_t quotient = (uint32_t)(product >> kBarrettShift);
815
0
    uint32_t remainder = shifted - quotient * kPrime;
816
817
    /*
818
     * Adjust the quotient to round correctly:
819
     *   0 <= remainder <= kHalfPrime round to 0
820
     *   kHalfPrime < remainder <= kPrime + kHalfPrime round to 1
821
     *   kPrime + kHalfPrime < remainder < 2 * kPrime round to 2
822
     */
823
0
    quotient += 1 & constant_time_lt_32(kHalfPrime, remainder);
824
0
    quotient += 1 & constant_time_lt_32(kPrime + kHalfPrime, remainder);
825
0
    return quotient & ((1 << bits) - 1);
826
0
}
827
828
/*
829
 * FIPS 203, Section 4.2.1, Equation (4.8): Decompress_d.
830
831
 * Decompresses |x| by using a close equi-distant representative. The formula
832
 * is round(kPrime/2^|bits|*x). Note that 2^|bits| being the divisor allows us
833
 * to implement this logic using only bit operations.
834
 */
835
static __owur uint16_t decompress(uint16_t x, int bits)
836
0
{
837
0
    uint32_t product = (uint32_t)x * kPrime;
838
0
    uint32_t power = 1 << bits;
839
    /* This is |product| % power, since |power| is a power of 2. */
840
0
    uint32_t remainder = product & (power - 1);
841
    /* This is |product| / power, since |power| is a power of 2. */
842
0
    uint32_t lower = product >> bits;
843
844
    /*
845
     * The rounding logic works since the first half of numbers mod |power|
846
     * have a 0 as first bit, and the second half has a 1 as first bit, since
847
     * |power| is a power of 2. As a 12 bit number, |remainder| is always
848
     * positive, so we will shift in 0s for a right shift.
849
     */
850
0
    return lower + (remainder >> (bits - 1));
851
0
}
852
853
/*-
854
 * FIPS 203, Section 4.2.1, Equation (4.7): "Compress_d".
855
 * In-place lossy rounding of scalars to 2^d bits.
856
 */
857
static void scalar_compress(scalar *s, int bits)
858
0
{
859
0
    int i;
860
861
0
    for (i = 0; i < DEGREE; i++)
862
0
        s->c[i] = compress(s->c[i], bits);
863
0
}
864
865
/*
866
 * FIPS 203, Section 4.2.1, Equation (4.8): "Decompress_d".
867
 * In-place approximate recovery of scalars from 2^d bit compression.
868
 */
869
static void scalar_decompress(scalar *s, int bits)
870
0
{
871
0
    int i;
872
873
0
    for (i = 0; i < DEGREE; i++)
874
0
        s->c[i] = decompress(s->c[i], bits);
875
0
}
876
877
/* Addition updating the LHS vector in-place. */
878
static void vector_add(scalar *lhs, const scalar *rhs, int rank)
879
0
{
880
0
    do {
881
0
        scalar_add(lhs++, rhs++);
882
0
    } while (--rank > 0);
883
0
}
884
885
/*
886
 * Encodes an entire vector into 32*|rank|*|bits| bytes. Note that since 256
887
 * (DEGREE) is divisible by 8, the individual vector entries will always fill a
888
 * whole number of bytes, so we do not need to worry about bit packing here.
889
 */
890
static void vector_encode(uint8_t *out, const scalar *a, int bits, int rank)
891
0
{
892
0
    int stride = bits * DEGREE / 8;
893
894
0
    for (; rank-- > 0; out += stride)
895
0
        scalar_encode(out, a++, bits);
896
0
}
897
898
/*
899
 * Decodes 32*|rank|*|bits| bytes from |in| into |out|. It returns early
900
 * if any parsed value is >= |ML_KEM_PRIME|.  The resulting scalars are
901
 * then decompressed and transformed via the NTT.
902
 *
903
 * Note: Used only in decrypt_cpa(), which returns void and so does not check
904
 * the return value of this function.  Side-channels are fine when the input
905
 * ciphertext to decap() is simply syntactically invalid.
906
 */
907
static void
908
vector_decode_decompress_ntt(scalar *out, const uint8_t *in, int bits, int rank)
909
0
{
910
0
    int stride = bits * DEGREE / 8;
911
912
0
    for (; rank-- > 0; in += stride, ++out) {
913
0
        scalar_decode(out, in, bits);
914
0
        scalar_decompress(out, bits);
915
0
        scalar_ntt(out);
916
0
    }
917
0
}
918
919
/* vector_decode(), specialised to bits == 12. */
920
static __owur
921
int vector_decode_12(scalar *out, const uint8_t in[3 * DEGREE / 2], int rank)
922
0
{
923
0
    int stride = 3 * DEGREE / 2;
924
925
0
    for (; rank-- > 0; in += stride)
926
0
        if (!scalar_decode_12(out++, in))
927
0
            return 0;
928
0
    return 1;
929
0
}
930
931
/* In-place compression of each scalar component */
932
static void vector_compress(scalar *a, int bits, int rank)
933
0
{
934
0
    do {
935
0
        scalar_compress(a++, bits);
936
0
    } while (--rank > 0);
937
0
}
938
939
/* The output scalar must not overlap with the inputs */
940
static void inner_product(scalar *out, const scalar *lhs, const scalar *rhs,
941
                          int rank)
942
0
{
943
0
    scalar_mult(out, lhs, rhs);
944
0
    while (--rank > 0)
945
0
        scalar_mult_add(out, ++lhs, ++rhs);
946
0
}
947
948
/*
949
 * Here, the output vector must not overlap with the inputs, the result is
950
 * directly subjected to inverse NTT.
951
 */
952
static void
953
matrix_mult_intt(scalar *out, const scalar *m, const scalar *a, int rank)
954
0
{
955
0
    const scalar *ar;
956
0
    int i, j;
957
958
0
    for (i = rank; i-- > 0; ++out) {
959
0
        scalar_mult(out, m++, ar = a);
960
0
        for (j = rank - 1; j > 0; --j)
961
0
            scalar_mult_add(out, m++, ++ar);
962
0
        scalar_inverse_ntt(out);
963
0
    }
964
0
}
965
966
/* Here, the output vector must not overlap with the inputs */
967
static void
968
matrix_mult_transpose_add(scalar *out, const scalar *m, const scalar *a, int rank)
969
0
{
970
0
    const scalar *mc = m, *mr, *ar;
971
0
    int i, j;
972
973
0
    for (i = rank; i-- > 0; ++out) {
974
0
        scalar_mult_add(out, mr = mc++, ar = a);
975
0
        for (j = rank; --j > 0; )
976
0
            scalar_mult_add(out, (mr += rank), ++ar);
977
0
    }
978
0
}
979
980
/*-
981
 * Expands the matrix from a seed for key generation and for encaps-CPA.
982
 * NOTE: FIPS 203 matrix "A" is the transpose of this matrix, computed
983
 * by appending the (i,j) indices to the seed in the opposite order!
984
 *
985
 * Where FIPS 203 computes t = A * s + e, we use the transpose of "m".
986
 */
987
static __owur
988
int matrix_expand(EVP_MD_CTX *mdctx, ML_KEM_KEY *key)
989
0
{
990
0
    scalar *out = key->m;
991
0
    uint8_t input[ML_KEM_RANDOM_BYTES + 2];
992
0
    int rank = key->vinfo->rank;
993
0
    int i, j;
994
995
0
    memcpy(input, key->rho, ML_KEM_RANDOM_BYTES);
996
0
    for (i = 0; i < rank; i++) {
997
0
        for (j = 0; j < rank; j++) {
998
0
            input[ML_KEM_RANDOM_BYTES] = i;
999
0
            input[ML_KEM_RANDOM_BYTES + 1] = j;
1000
0
            if (!EVP_DigestInit_ex(mdctx, key->shake128_md, NULL)
1001
0
                || !EVP_DigestUpdate(mdctx, input, sizeof(input))
1002
0
                || !sample_scalar(out++, mdctx))
1003
0
                return 0;
1004
0
        }
1005
0
    }
1006
0
    return 1;
1007
0
}
1008
1009
/*
1010
 * Algorithm 7 from the spec, with eta fixed to two and the PRF call
1011
 * included. Creates binominally distributed elements by sampling 2*|eta| bits,
1012
 * and setting the coefficient to the count of the first bits minus the count of
1013
 * the second bits, resulting in a centered binomial distribution. Since eta is
1014
 * two this gives -2/2 with a probability of 1/16, -1/1 with probability 1/4,
1015
 * and 0 with probability 3/8.
1016
 */
1017
static __owur
1018
int cbd_2(scalar *out, uint8_t in[ML_KEM_RANDOM_BYTES + 1],
1019
          EVP_MD_CTX *mdctx, const ML_KEM_KEY *key)
1020
0
{
1021
0
    uint16_t *curr = out->c, *end = curr + DEGREE;
1022
0
    uint8_t randbuf[4 * DEGREE / 8], *r = randbuf;  /* 64 * eta slots */
1023
0
    uint16_t value, mask;
1024
0
    uint8_t b;
1025
1026
0
    if (!prf(randbuf, sizeof(randbuf), in, mdctx, key))
1027
0
        return 0;
1028
1029
0
    do {
1030
0
        b = *r++;
1031
1032
        /*
1033
         * Add |kPrime| if |value| underflowed.  See |constish_time_non_zero|
1034
         * for a discussion on why the value barrier is by default omitted.
1035
         * While this could have been written reduce_once(value + kPrime), this
1036
         * is one extra addition and small range of |value| tempts some
1037
         * versions of Clang to emit a branch.
1038
         */
1039
0
        value = bit0(b) + bitn(1, b);
1040
0
        value -= bitn(2, b) + bitn(3, b);
1041
0
        mask = constish_time_non_zero(value >> 15);
1042
0
        *curr++ = value + (kPrime & mask);
1043
1044
0
        value = bitn(4, b) + bitn(5, b);
1045
0
        value -= bitn(6, b) + bitn(7, b);
1046
0
        mask = constish_time_non_zero(value >> 15);
1047
0
        *curr++ = value + (kPrime & mask);
1048
0
    } while (curr < end);
1049
0
    return 1;
1050
0
}
1051
1052
/*
1053
 * Algorithm 7 from the spec, with eta fixed to three and the PRF call
1054
 * included. Creates binominally distributed elements by sampling 3*|eta| bits,
1055
 * and setting the coefficient to the count of the first bits minus the count of
1056
 * the second bits, resulting in a centered binomial distribution.
1057
 */
1058
static __owur
1059
int cbd_3(scalar *out, uint8_t in[ML_KEM_RANDOM_BYTES + 1],
1060
          EVP_MD_CTX *mdctx, const ML_KEM_KEY *key)
1061
0
{
1062
0
    uint16_t *curr = out->c, *end = curr + DEGREE;
1063
0
    uint8_t randbuf[6 * DEGREE / 8], *r = randbuf;  /* 64 * eta slots */
1064
0
    uint8_t b1, b2, b3;
1065
0
    uint16_t value, mask;
1066
1067
0
    if (!prf(randbuf, sizeof(randbuf), in, mdctx, key))
1068
0
        return 0;
1069
1070
0
    do {
1071
0
        b1 = *r++;
1072
0
        b2 = *r++;
1073
0
        b3 = *r++;
1074
1075
        /*
1076
         * Add |kPrime| if |value| underflowed.  See |constish_time_non_zero|
1077
         * for a discussion on why the value barrier is by default omitted.
1078
         * While this could have been written reduce_once(value + kPrime), this
1079
         * is one extra addition and small range of |value| tempts some
1080
         * versions of Clang to emit a branch.
1081
         */
1082
0
        value = bit0(b1) + bitn(1, b1) + bitn(2, b1);
1083
0
        value -= bitn(3, b1)  + bitn(4, b1) + bitn(5, b1);
1084
0
        mask = constish_time_non_zero(value >> 15);
1085
0
        *curr++ = value + (kPrime & mask);
1086
1087
0
        value = bitn(6, b1) + bitn(7, b1) + bit0(b2);
1088
0
        value -= bitn(1, b2) + bitn(2, b2) + bitn(3, b2);
1089
0
        mask = constish_time_non_zero(value >> 15);
1090
0
        *curr++ = value + (kPrime & mask);
1091
1092
0
        value = bitn(4, b2) + bitn(5, b2) + bitn(6, b2);
1093
0
        value -= bitn(7, b2) + bit0(b3) + bitn(1, b3);
1094
0
        mask = constish_time_non_zero(value >> 15);
1095
0
        *curr++ = value + (kPrime & mask);
1096
1097
0
        value = bitn(2, b3) + bitn(3, b3) + bitn(4, b3);
1098
0
        value -= bitn(5, b3) + bitn(6, b3) + bitn(7, b3);
1099
0
        mask = constish_time_non_zero(value >> 15);
1100
0
        *curr++ = value + (kPrime & mask);
1101
0
    } while (curr < end);
1102
0
    return 1;
1103
0
}
1104
1105
/*
1106
 * Generates a secret vector by using |cbd| with the given seed to generate
1107
 * scalar elements and incrementing |counter| for each slot of the vector.
1108
 */
1109
static __owur
1110
int gencbd_vector(scalar *out, CBD_FUNC cbd, uint8_t *counter,
1111
                  const uint8_t seed[ML_KEM_RANDOM_BYTES], int rank,
1112
                  EVP_MD_CTX *mdctx, const ML_KEM_KEY *key)
1113
0
{
1114
0
    uint8_t input[ML_KEM_RANDOM_BYTES + 1];
1115
1116
0
    memcpy(input, seed, ML_KEM_RANDOM_BYTES);
1117
0
    do {
1118
0
        input[ML_KEM_RANDOM_BYTES] = (*counter)++;
1119
0
        if (!cbd(out++, input, mdctx, key))
1120
0
            return 0;
1121
0
    } while (--rank > 0);
1122
0
    return 1;
1123
0
}
1124
1125
/*
1126
 * As above plus NTT transform.
1127
 */
1128
static __owur
1129
int gencbd_vector_ntt(scalar *out, CBD_FUNC cbd, uint8_t *counter,
1130
                      const uint8_t seed[ML_KEM_RANDOM_BYTES], int rank,
1131
                      EVP_MD_CTX *mdctx, const ML_KEM_KEY *key)
1132
0
{
1133
0
    uint8_t input[ML_KEM_RANDOM_BYTES + 1];
1134
1135
0
    memcpy(input, seed, ML_KEM_RANDOM_BYTES);
1136
0
    do {
1137
0
        input[ML_KEM_RANDOM_BYTES] = (*counter)++;
1138
0
        if (!cbd(out, input, mdctx, key))
1139
0
            return 0;
1140
0
        scalar_ntt(out++);
1141
0
    } while (--rank > 0);
1142
0
    return 1;
1143
0
}
1144
1145
/* The |ETA1| value for ML-KEM-512 is 3, the rest and all ETA2 values are 2. */
1146
0
#define CBD1(evp_type)  ((evp_type) == EVP_PKEY_ML_KEM_512 ? cbd_3 : cbd_2)
1147
1148
/*
1149
 * FIPS 203, Section 5.2, Algorithm 14: K-PKE.Encrypt.
1150
 *
1151
 * Encrypts a message with given randomness to the ciphertext in |out|. Without
1152
 * applying the Fujisaki-Okamoto transform this would not result in a CCA
1153
 * secure scheme, since lattice schemes are vulnerable to decryption failure
1154
 * oracles.
1155
 *
1156
 * The steps are re-ordered to make more efficient/localised use of storage.
1157
 *
1158
 * Note also that the input public key is assumed to hold a precomputed matrix
1159
 * |A| (our key->m, with the public key holding an expanded (16-bit per scalar
1160
 * coefficient) key->t vector).
1161
 *
1162
 * Caller passes storage in |tmp| for for two temporary vectors.
1163
 */
1164
static __owur
1165
int encrypt_cpa(uint8_t out[ML_KEM_SHARED_SECRET_BYTES],
1166
                const uint8_t message[DEGREE / 8],
1167
                const uint8_t r[ML_KEM_RANDOM_BYTES], scalar *tmp,
1168
                EVP_MD_CTX *mdctx, const ML_KEM_KEY *key)
1169
0
{
1170
0
    const ML_KEM_VINFO *vinfo = key->vinfo;
1171
0
    CBD_FUNC cbd_1 = CBD1(vinfo->evp_type);
1172
0
    int rank = vinfo->rank;
1173
    /* We can use tmp[0..rank-1] as storage for |y|, then |e1|, ... */
1174
0
    scalar *y = &tmp[0], *e1 = y, *e2 = y;
1175
    /* We can use tmp[rank]..tmp[2*rank - 1] for |u| */
1176
0
    scalar *u = &tmp[rank];
1177
0
    scalar v;
1178
0
    uint8_t input[ML_KEM_RANDOM_BYTES + 1];
1179
0
    uint8_t counter = 0;
1180
0
    int du = vinfo->du;
1181
0
    int dv = vinfo->dv;
1182
1183
    /* FIPS 203 "y" vector */
1184
0
    if (!gencbd_vector_ntt(y, cbd_1, &counter, r, rank, mdctx, key))
1185
0
        return 0;
1186
    /* FIPS 203 "v" scalar */
1187
0
    inner_product(&v, key->t, y, rank);
1188
0
    scalar_inverse_ntt(&v);
1189
    /* FIPS 203 "u" vector */
1190
0
    matrix_mult_intt(u, key->m, y, rank);
1191
1192
    /* All done with |y|, now free to reuse tmp[0] for FIPS 203 |e1| */
1193
0
    if (!gencbd_vector(e1, cbd_2, &counter, r, rank, mdctx, key))
1194
0
        return 0;
1195
0
    vector_add(u, e1, rank);
1196
0
    vector_compress(u, du, rank);
1197
0
    vector_encode(out, u, du, rank);
1198
1199
    /* All done with |e1|, now free to reuse tmp[0] for FIPS 203 |e2| */
1200
0
    memcpy(input, r, ML_KEM_RANDOM_BYTES);
1201
0
    input[ML_KEM_RANDOM_BYTES] = counter;
1202
0
    if (!cbd_2(e2, input, mdctx, key))
1203
0
        return 0;
1204
0
    scalar_add(&v, e2);
1205
1206
    /* Combine message with |v| */
1207
0
    scalar_decode_decompress_add(&v, message);
1208
0
    scalar_compress(&v, dv);
1209
0
    scalar_encode(out + vinfo->u_vector_bytes, &v, dv);
1210
0
    return 1;
1211
0
}
1212
1213
/*
1214
 * FIPS 203, Section 5.3, Algorithm 15: K-PKE.Decrypt.
1215
 */
1216
static void
1217
decrypt_cpa(uint8_t out[ML_KEM_SHARED_SECRET_BYTES],
1218
            const uint8_t *ctext, scalar *u, const ML_KEM_KEY *key)
1219
0
{
1220
0
    const ML_KEM_VINFO *vinfo = key->vinfo;
1221
0
    scalar v, mask;
1222
0
    int rank = vinfo->rank;
1223
0
    int du = vinfo->du;
1224
0
    int dv = vinfo->dv;
1225
1226
0
    vector_decode_decompress_ntt(u, ctext, du, rank);
1227
0
    scalar_decode(&v, ctext + vinfo->u_vector_bytes, dv);
1228
0
    scalar_decompress(&v, dv);
1229
0
    inner_product(&mask, key->s, u, rank);
1230
0
    scalar_inverse_ntt(&mask);
1231
0
    scalar_sub(&v, &mask);
1232
0
    scalar_compress(&v, 1);
1233
0
    scalar_encode_1(out, &v);
1234
0
}
1235
1236
/*-
1237
 * FIPS 203, Section 7.1, Algorithm 19: "ML-KEM.KeyGen".
1238
 * FIPS 203, Section 7.2, Algorithm 20: "ML-KEM.Encaps".
1239
 *
1240
 * Fills the |out| buffer with the |ek| output of "ML-KEM.KeyGen", or,
1241
 * equivalently, the |ek| input of "ML-KEM.Encaps", i.e. returns the
1242
 * wire-format of an ML-KEM public key.
1243
 */
1244
static void encode_pubkey(uint8_t *out, const ML_KEM_KEY *key)
1245
0
{
1246
0
    const uint8_t *rho = key->rho;
1247
0
    const ML_KEM_VINFO *vinfo = key->vinfo;
1248
1249
0
    vector_encode(out, key->t, 12, vinfo->rank);
1250
0
    memcpy(out + vinfo->vector_bytes, rho, ML_KEM_RANDOM_BYTES);
1251
0
}
1252
1253
/*-
1254
 * FIPS 203, Section 7.1, Algorithm 19: "ML-KEM.KeyGen".
1255
 *
1256
 * Fills the |out| buffer with the |dk| output of "ML-KEM.KeyGen".
1257
 * This matches the input format of parse_prvkey() below.
1258
 */
1259
static void encode_prvkey(uint8_t *out, const ML_KEM_KEY *key)
1260
0
{
1261
0
    const ML_KEM_VINFO *vinfo = key->vinfo;
1262
1263
0
    vector_encode(out, key->s, 12, vinfo->rank);
1264
0
    out += vinfo->vector_bytes;
1265
0
    encode_pubkey(out, key);
1266
0
    out += vinfo->pubkey_bytes;
1267
0
    memcpy(out, key->pkhash, ML_KEM_PKHASH_BYTES);
1268
0
    out += ML_KEM_PKHASH_BYTES;
1269
0
    memcpy(out, key->z, ML_KEM_RANDOM_BYTES);
1270
0
}
1271
1272
/*-
1273
 * FIPS 203, Section 7.1, Algorithm 19: "ML-KEM.KeyGen".
1274
 * FIPS 203, Section 7.2, Algorithm 20: "ML-KEM.Encaps".
1275
 *
1276
 * This function parses the |in| buffer as the |ek| output of "ML-KEM.KeyGen",
1277
 * or, equivalently, the |ek| input of "ML-KEM.Encaps", i.e. decodes the
1278
 * wire-format of the ML-KEM public key.
1279
 */
1280
static int parse_pubkey(const uint8_t *in, EVP_MD_CTX *mdctx, ML_KEM_KEY *key)
1281
0
{
1282
0
    const ML_KEM_VINFO *vinfo = key->vinfo;
1283
1284
    /* Decode and check |t| */
1285
0
    if (!vector_decode_12(key->t, in, vinfo->rank)) {
1286
0
        ERR_raise_data(ERR_LIB_PROV, PROV_R_INVALID_KEY,
1287
0
                       "%s invalid public 't' vector",
1288
0
                       vinfo->algorithm_name);
1289
0
        return 0;
1290
0
    }
1291
    /* Save the matrix |m| recovery seed |rho| */
1292
0
    memcpy(key->rho, in + vinfo->vector_bytes, ML_KEM_RANDOM_BYTES);
1293
    /*
1294
     * Pre-compute the public key hash, needed for both encap and decap.
1295
     * Also pre-compute the matrix expansion, stored with the public key.
1296
     */
1297
0
    if (!hash_h(key->pkhash, in, vinfo->pubkey_bytes, mdctx, key)
1298
0
        || !matrix_expand(mdctx, key)) {
1299
0
        ERR_raise_data(ERR_LIB_CRYPTO, ERR_R_INTERNAL_ERROR,
1300
0
                       "internal error while parsing %s public key",
1301
0
                       vinfo->algorithm_name);
1302
0
        return 0;
1303
0
    }
1304
0
    return 1;
1305
0
}
1306
1307
/*
1308
 * FIPS 203, Section 7.1, Algorithm 19: "ML-KEM.KeyGen".
1309
 *
1310
 * Parses the |in| buffer as a |dk| output of "ML-KEM.KeyGen".
1311
 * This matches the output format of encode_prvkey() above.
1312
 */
1313
static int parse_prvkey(const uint8_t *in, EVP_MD_CTX *mdctx, ML_KEM_KEY *key)
1314
0
{
1315
0
    const ML_KEM_VINFO *vinfo = key->vinfo;
1316
1317
    /* Decode and check |s|. */
1318
0
    if (!vector_decode_12(key->s, in, vinfo->rank)) {
1319
0
        ERR_raise_data(ERR_LIB_PROV, PROV_R_INVALID_KEY,
1320
0
                       "%s invalid private 's' vector",
1321
0
                       vinfo->algorithm_name);
1322
0
        return 0;
1323
0
    }
1324
0
    in += vinfo->vector_bytes;
1325
1326
0
    if (!parse_pubkey(in, mdctx, key))
1327
0
        return 0;
1328
0
    in += vinfo->pubkey_bytes;
1329
1330
    /* Check public key hash. */
1331
0
    if (memcmp(key->pkhash, in, ML_KEM_PKHASH_BYTES) != 0) {
1332
0
        ERR_raise_data(ERR_LIB_PROV, PROV_R_INVALID_KEY,
1333
0
                       "%s public key hash mismatch",
1334
0
                       vinfo->algorithm_name);
1335
0
        return 0;
1336
0
    }
1337
0
    in += ML_KEM_PKHASH_BYTES;
1338
1339
0
    memcpy(key->z, in, ML_KEM_RANDOM_BYTES);
1340
0
    return 1;
1341
0
}
1342
1343
/*
1344
 * FIPS 203, Section 6.1, Algorithm 16: "ML-KEM.KeyGen_internal".
1345
 *
1346
 * The implementation of Section 5.1, Algorithm 13, "K-PKE.KeyGen(d)" is
1347
 * inlined.
1348
 *
1349
 * The caller MUST pass a pre-allocated digest context that is not shared with
1350
 * any concurrent computation.
1351
 *
1352
 * This function optionally outputs the serialised wire-form |ek| public key
1353
 * into the provided |pubenc| buffer, and generates the content of the |rho|,
1354
 * |pkhash|, |t|, |m|, |s| and |z| components of the private |key| (which must
1355
 * have preallocated space for these).
1356
 *
1357
 * Keys are computed from a 32-byte random |d| plus the 1 byte rank for
1358
 * domain separation.  These are concatenated and hashed to produce a pair of
1359
 * 32-byte seeds public "rho", used to generate the matrix, and private "sigma",
1360
 * used to generate the secret vector |s|.
1361
 *
1362
 * The second random input |z| is copied verbatim into the Fujisaki-Okamoto
1363
 * (FO) transform "implicit-rejection" secret (the |z| component of the private
1364
 * key), which thwarts chosen-ciphertext attacks, provided decap() runs in
1365
 * constant time, with no side channel leaks, on all well-formed (valid length,
1366
 * and correctly encoded) ciphertext inputs.
1367
 */
1368
static __owur
1369
int genkey(const uint8_t seed[ML_KEM_SEED_BYTES],
1370
           EVP_MD_CTX *mdctx, uint8_t *pubenc, ML_KEM_KEY *key)
1371
0
{
1372
0
    uint8_t hashed[2 * ML_KEM_RANDOM_BYTES];
1373
0
    const uint8_t *const sigma = hashed + ML_KEM_RANDOM_BYTES;
1374
0
    uint8_t augmented_seed[ML_KEM_RANDOM_BYTES + 1];
1375
0
    const ML_KEM_VINFO *vinfo = key->vinfo;
1376
0
    CBD_FUNC cbd_1 = CBD1(vinfo->evp_type);
1377
0
    int rank = vinfo->rank;
1378
0
    uint8_t counter = 0;
1379
0
    int ret = 0;
1380
1381
    /*
1382
     * Use the "d" seed salted with the rank to derive the public and private
1383
     * seeds rho and sigma.
1384
     */
1385
0
    memcpy(augmented_seed, seed, ML_KEM_RANDOM_BYTES);
1386
0
    augmented_seed[ML_KEM_RANDOM_BYTES] = (uint8_t) rank;
1387
0
    if (!hash_g(hashed, augmented_seed, sizeof(augmented_seed), mdctx, key))
1388
0
        goto end;
1389
0
    memcpy(key->rho, hashed, ML_KEM_RANDOM_BYTES);
1390
    /* The |rho| matrix seed is public */
1391
0
    CONSTTIME_DECLASSIFY(key->rho, ML_KEM_RANDOM_BYTES);
1392
1393
    /* FIPS 203 |e| vector is initial value of key->t */
1394
0
    if (!matrix_expand(mdctx, key)
1395
0
        || !gencbd_vector_ntt(key->s, cbd_1, &counter, sigma, rank, mdctx, key)
1396
0
        || !gencbd_vector_ntt(key->t, cbd_1, &counter, sigma, rank, mdctx, key))
1397
0
        goto end;
1398
1399
    /* To |e| we now add the product of transpose |m| and |s|, giving |t|. */
1400
0
    matrix_mult_transpose_add(key->t, key->m, key->s, rank);
1401
    /* The |t| vector is public */
1402
0
    CONSTTIME_DECLASSIFY(key->t, vinfo->rank * sizeof(scalar));
1403
1404
0
    if (pubenc == NULL) {
1405
        /* Incremental digest of public key without in-full serialisation. */
1406
0
        if (!hash_h_pubkey(key->pkhash, mdctx, key))
1407
0
            goto end;
1408
0
    } else {
1409
0
        encode_pubkey(pubenc, key);
1410
0
        if (!hash_h(key->pkhash, pubenc, vinfo->pubkey_bytes, mdctx, key))
1411
0
            goto end;
1412
0
    }
1413
1414
    /* Save |z| portion of seed for "implicit rejection" on failure. */
1415
0
    memcpy(key->z, seed + ML_KEM_RANDOM_BYTES, ML_KEM_RANDOM_BYTES);
1416
1417
    /* Optionally save the |d| portion of the seed */
1418
0
    key->d = key->z + ML_KEM_RANDOM_BYTES;
1419
0
    if (key->prov_flags & ML_KEM_KEY_RETAIN_SEED) {
1420
0
        memcpy(key->d, seed, ML_KEM_RANDOM_BYTES);
1421
0
    } else {
1422
0
        OPENSSL_cleanse(key->d, ML_KEM_RANDOM_BYTES);
1423
0
        key->d = NULL;
1424
0
    }
1425
1426
0
    ret = 1;
1427
0
 end:
1428
0
    OPENSSL_cleanse((void *)augmented_seed, ML_KEM_RANDOM_BYTES);
1429
0
    OPENSSL_cleanse((void *)sigma, ML_KEM_RANDOM_BYTES);
1430
0
    if (ret == 0) {
1431
0
        ERR_raise_data(ERR_LIB_CRYPTO, ERR_R_INTERNAL_ERROR,
1432
0
                       "internal error while generating %s private key",
1433
0
                       vinfo->algorithm_name);
1434
0
    }
1435
0
    return ret;
1436
0
}
1437
1438
/*-
1439
 * FIPS 203, Section 6.2, Algorithm 17: "ML-KEM.Encaps_internal".
1440
 * This is the deterministic version with randomness supplied externally.
1441
 *
1442
 * The caller must pass space for two vectors in |tmp|.
1443
 * The |ctext| buffer have space for the ciphertext of the ML-KEM variant
1444
 * of the provided key.
1445
 */
1446
static
1447
int encap(uint8_t *ctext, uint8_t secret[ML_KEM_SHARED_SECRET_BYTES],
1448
          const uint8_t entropy[ML_KEM_RANDOM_BYTES],
1449
          scalar *tmp, EVP_MD_CTX *mdctx, const ML_KEM_KEY *key)
1450
0
{
1451
0
    uint8_t input[ML_KEM_RANDOM_BYTES + ML_KEM_PKHASH_BYTES];
1452
0
    uint8_t Kr[ML_KEM_SHARED_SECRET_BYTES + ML_KEM_RANDOM_BYTES];
1453
0
    uint8_t *r = Kr + ML_KEM_SHARED_SECRET_BYTES;
1454
0
    int ret;
1455
1456
0
    memcpy(input, entropy, ML_KEM_RANDOM_BYTES);
1457
0
    memcpy(input + ML_KEM_RANDOM_BYTES, key->pkhash, ML_KEM_PKHASH_BYTES);
1458
0
    ret = hash_g(Kr, input, sizeof(input), mdctx, key)
1459
0
        && encrypt_cpa(ctext, entropy, r, tmp, mdctx, key);
1460
0
    OPENSSL_cleanse((void *)input, sizeof(input));
1461
1462
0
    if (ret)
1463
0
        memcpy(secret, Kr, ML_KEM_SHARED_SECRET_BYTES);
1464
0
    else
1465
0
        ERR_raise_data(ERR_LIB_CRYPTO, ERR_R_INTERNAL_ERROR,
1466
0
                       "internal error while performing %s encapsulation",
1467
0
                       key->vinfo->algorithm_name);
1468
0
    return ret;
1469
0
}
1470
1471
/*
1472
 * FIPS 203, Section 6.3, Algorithm 18: ML-KEM.Decaps_internal
1473
 *
1474
 * Barring failure of the supporting SHA3/SHAKE primitives, this is fully
1475
 * deterministic, the randomness for the FO transform is extracted during
1476
 * private key generation.
1477
 *
1478
 * The caller must pass space for two vectors in |tmp|.
1479
 * The |ctext| and |tmp_ctext| buffers must each have space for the ciphertext
1480
 * of the key's ML-KEM variant.
1481
 */
1482
static
1483
int decap(uint8_t secret[ML_KEM_SHARED_SECRET_BYTES],
1484
          const uint8_t *ctext, uint8_t *tmp_ctext, scalar *tmp,
1485
          EVP_MD_CTX *mdctx, const ML_KEM_KEY *key)
1486
0
{
1487
0
    uint8_t decrypted[ML_KEM_SHARED_SECRET_BYTES + ML_KEM_PKHASH_BYTES];
1488
0
    uint8_t failure_key[ML_KEM_RANDOM_BYTES];
1489
0
    uint8_t Kr[ML_KEM_SHARED_SECRET_BYTES + ML_KEM_RANDOM_BYTES];
1490
0
    uint8_t *r = Kr + ML_KEM_SHARED_SECRET_BYTES;
1491
0
    const uint8_t *pkhash = key->pkhash;
1492
0
    const ML_KEM_VINFO *vinfo = key->vinfo;
1493
0
    int i;
1494
0
    uint8_t mask;
1495
1496
    /*
1497
     * If our KDF is unavailable, fail early! Otherwise, keep going ignoring
1498
     * any further errors, returning success, and whatever we got for a shared
1499
     * secret.  The decrypt_cpa() function is just arithmetic on secret data,
1500
     * so should not be subject to failure that makes its output predictable.
1501
     *
1502
     * We guard against "should never happen" catastrophic failure of the
1503
     * "pure" function |hash_g| by overwriting the shared secret with the
1504
     * content of the failure key and returning early, if nevertheless hash_g
1505
     * fails.  This is not constant-time, but a failure of |hash_g| already
1506
     * implies loss of side-channel resistance.
1507
     *
1508
     * The same action is taken, if also |encrypt_cpa| should catastrophically
1509
     * fail, due to failure of the |PRF| underlying the CBD functions.
1510
     */
1511
0
    if (!kdf(failure_key, key->z, ctext, vinfo->ctext_bytes, mdctx, key)) {
1512
0
        ERR_raise_data(ERR_LIB_CRYPTO, ERR_R_INTERNAL_ERROR,
1513
0
                       "internal error while performing %s decapsulation",
1514
0
                       vinfo->algorithm_name);
1515
0
        return 0;
1516
0
    }
1517
0
    decrypt_cpa(decrypted, ctext, tmp, key);
1518
0
    memcpy(decrypted + ML_KEM_SHARED_SECRET_BYTES, pkhash, ML_KEM_PKHASH_BYTES);
1519
0
    if (!hash_g(Kr, decrypted, sizeof(decrypted), mdctx, key)
1520
0
        || !encrypt_cpa(tmp_ctext, decrypted, r, tmp, mdctx, key)) {
1521
0
        memcpy(secret, failure_key, ML_KEM_SHARED_SECRET_BYTES);
1522
0
        OPENSSL_cleanse(decrypted, ML_KEM_SHARED_SECRET_BYTES);
1523
0
        return 1;
1524
0
    }
1525
0
    mask = constant_time_eq_int_8(0,
1526
0
        CRYPTO_memcmp(ctext, tmp_ctext, vinfo->ctext_bytes));
1527
0
    for (i = 0; i < ML_KEM_SHARED_SECRET_BYTES; i++)
1528
0
        secret[i] = constant_time_select_8(mask, Kr[i], failure_key[i]);
1529
0
    OPENSSL_cleanse(decrypted, ML_KEM_SHARED_SECRET_BYTES);
1530
0
    OPENSSL_cleanse(Kr, sizeof(Kr));
1531
0
    return 1;
1532
0
}
1533
1534
/*
1535
 * After allocating storage for public or private key data, update the key
1536
 * component pointers to reference that storage.
1537
 */
1538
static __owur
1539
int add_storage(scalar *p, int private, ML_KEM_KEY *key)
1540
0
{
1541
0
    int rank = key->vinfo->rank;
1542
1543
0
    if (p == NULL)
1544
0
        return 0;
1545
1546
    /*
1547
     * We're adding key material, the seed buffer will now hold |rho| and
1548
     * |pkhash|.
1549
     */
1550
0
    memset(key->seedbuf, 0, sizeof(key->seedbuf));
1551
0
    key->rho = key->seedbuf;
1552
0
    key->pkhash = key->seedbuf + ML_KEM_RANDOM_BYTES;
1553
0
    key->d = key->z = NULL;
1554
1555
    /* A public key needs space for |t| and |m| */
1556
0
    key->m = (key->t = p) + rank;
1557
1558
    /*
1559
     * A private key also needs space for |s| and |z|.
1560
     * The |z| buffer always includes additional space for |d|, but a key's |d|
1561
     * pointer is left NULL when parsed from the NIST format, which omits that
1562
     * information.  Only keys generated from a (d, z) seed pair will have a
1563
     * non-NULL |d| pointer.
1564
     */
1565
0
    if (private)
1566
0
        key->z = (uint8_t *)(rank + (key->s = key->m + rank * rank));
1567
0
    return 1;
1568
0
}
1569
1570
/*
1571
 * After freeing the storage associated with a key that failed to be
1572
 * constructed, reset the internal pointers back to NULL.
1573
 */
1574
void
1575
ossl_ml_kem_key_reset(ML_KEM_KEY *key)
1576
0
{
1577
0
    if (key->t == NULL)
1578
0
        return;
1579
    /*-
1580
     * Cleanse any sensitive data:
1581
     * - The private vector |s| is immediately followed by the FO failure
1582
     *   secret |z|, and seed |d|, we can cleanse all three in one call.
1583
     *
1584
     * - Otherwise, when key->d is set, cleanse the stashed seed.
1585
     */
1586
0
    if (ossl_ml_kem_have_prvkey(key))
1587
0
        OPENSSL_cleanse(key->s,
1588
0
                        key->vinfo->rank * sizeof(scalar) + 2 * ML_KEM_RANDOM_BYTES);
1589
0
    OPENSSL_free(key->t);
1590
0
    key->d = key->z = (uint8_t *)(key->s = key->m = key->t = NULL);
1591
0
}
1592
1593
/*
1594
 * ----- API exported to the provider
1595
 *
1596
 * Parameters with an implicit fixed length in the internal static API of each
1597
 * variant have an explicit checked length argument at this layer.
1598
 */
1599
1600
/* Retrieve the parameters of one of the ML-KEM variants */
1601
const ML_KEM_VINFO *ossl_ml_kem_get_vinfo(int evp_type)
1602
0
{
1603
0
    switch (evp_type) {
1604
0
    case EVP_PKEY_ML_KEM_512:
1605
0
        return &vinfo_map[ML_KEM_512_VINFO];
1606
0
    case EVP_PKEY_ML_KEM_768:
1607
0
        return &vinfo_map[ML_KEM_768_VINFO];
1608
0
    case EVP_PKEY_ML_KEM_1024:
1609
0
        return &vinfo_map[ML_KEM_1024_VINFO];
1610
0
    }
1611
0
    return NULL;
1612
0
}
1613
1614
ML_KEM_KEY *ossl_ml_kem_key_new(OSSL_LIB_CTX *libctx, const char *properties,
1615
                                int evp_type)
1616
0
{
1617
0
    const ML_KEM_VINFO *vinfo = ossl_ml_kem_get_vinfo(evp_type);
1618
0
    ML_KEM_KEY *key;
1619
1620
0
    if (vinfo == NULL) {
1621
0
        ERR_raise_data(ERR_LIB_CRYPTO, ERR_R_PASSED_INVALID_ARGUMENT,
1622
0
                       "unsupported ML-KEM key type: %d", evp_type);
1623
0
        return NULL;
1624
0
    }
1625
1626
0
    if ((key = OPENSSL_malloc(sizeof(*key))) == NULL)
1627
0
        return NULL;
1628
1629
0
    key->vinfo = vinfo;
1630
0
    key->libctx = libctx;
1631
0
    key->prov_flags = ML_KEM_KEY_PROV_FLAGS_DEFAULT;
1632
0
    key->shake128_md = EVP_MD_fetch(libctx, "SHAKE128", properties);
1633
0
    key->shake256_md = EVP_MD_fetch(libctx, "SHAKE256", properties);
1634
0
    key->sha3_256_md = EVP_MD_fetch(libctx, "SHA3-256", properties);
1635
0
    key->sha3_512_md = EVP_MD_fetch(libctx, "SHA3-512", properties);
1636
0
    key->d = key->z = key->rho = key->pkhash = key->encoded_dk = NULL;
1637
0
    key->s = key->m = key->t = NULL;
1638
1639
0
    if (key->shake128_md != NULL
1640
0
        && key->shake256_md != NULL
1641
0
        && key->sha3_256_md != NULL
1642
0
        && key->sha3_512_md != NULL)
1643
0
        return key;
1644
1645
0
    ossl_ml_kem_key_free(key);
1646
0
    ERR_raise_data(ERR_LIB_CRYPTO, ERR_R_INTERNAL_ERROR,
1647
0
                   "missing SHA3 digest algorithms while creating %s key",
1648
0
                   vinfo->algorithm_name);
1649
0
    return NULL;
1650
0
}
1651
1652
ML_KEM_KEY *ossl_ml_kem_key_dup(const ML_KEM_KEY *key, int selection)
1653
0
{
1654
0
    int ok = 0;
1655
0
    ML_KEM_KEY *ret;
1656
1657
    /*
1658
     * Partially decoded keys, not yet imported or loaded, should never be
1659
     * duplicated.
1660
     */
1661
0
    if (ossl_ml_kem_decoded_key(key))
1662
0
        return NULL;
1663
1664
0
    if (key == NULL
1665
0
        || (ret = OPENSSL_memdup(key, sizeof(*key))) == NULL)
1666
0
        return NULL;
1667
0
    ret->d = ret->z = ret->rho = ret->pkhash = NULL;
1668
0
    ret->s = ret->m = ret->t = NULL;
1669
1670
    /* Clear selection bits we can't fulfill */
1671
0
    if (!ossl_ml_kem_have_pubkey(key))
1672
0
        selection = 0;
1673
0
    else if (!ossl_ml_kem_have_prvkey(key))
1674
0
        selection &= ~OSSL_KEYMGMT_SELECT_PRIVATE_KEY;
1675
1676
0
    switch (selection & OSSL_KEYMGMT_SELECT_KEYPAIR) {
1677
0
    case 0:
1678
0
        ok = 1;
1679
0
        break;
1680
0
    case OSSL_KEYMGMT_SELECT_PUBLIC_KEY:
1681
0
        ok = add_storage(OPENSSL_memdup(key->t, key->vinfo->puballoc), 0, ret);
1682
0
        ret->rho = ret->seedbuf;
1683
0
        ret->pkhash = ret->rho + ML_KEM_RANDOM_BYTES;
1684
0
        break;
1685
0
    case OSSL_KEYMGMT_SELECT_PRIVATE_KEY:
1686
0
        ok = add_storage(OPENSSL_memdup(key->t, key->vinfo->prvalloc), 1, ret);
1687
        /* Duplicated keys retain |d|, if available */
1688
0
        if (key->d != NULL)
1689
0
            ret->d = ret->z + ML_KEM_RANDOM_BYTES;
1690
0
        break;
1691
0
    }
1692
1693
0
    if (!ok) {
1694
0
        OPENSSL_free(ret);
1695
0
        return NULL;
1696
0
    }
1697
1698
0
    EVP_MD_up_ref(ret->shake128_md);
1699
0
    EVP_MD_up_ref(ret->shake256_md);
1700
0
    EVP_MD_up_ref(ret->sha3_256_md);
1701
0
    EVP_MD_up_ref(ret->sha3_512_md);
1702
1703
0
    return ret;
1704
0
}
1705
1706
void ossl_ml_kem_key_free(ML_KEM_KEY *key)
1707
0
{
1708
0
    if (key == NULL)
1709
0
        return;
1710
1711
0
    EVP_MD_free(key->shake128_md);
1712
0
    EVP_MD_free(key->shake256_md);
1713
0
    EVP_MD_free(key->sha3_256_md);
1714
0
    EVP_MD_free(key->sha3_512_md);
1715
1716
0
    if (ossl_ml_kem_decoded_key(key)) {
1717
0
        OPENSSL_cleanse(key->seedbuf, sizeof(key->seedbuf));
1718
0
        if (ossl_ml_kem_have_dkenc(key)) {
1719
0
            OPENSSL_cleanse(key->encoded_dk, key->vinfo->prvkey_bytes);
1720
0
            OPENSSL_free(key->encoded_dk);
1721
0
        }
1722
0
    }
1723
0
    ossl_ml_kem_key_reset(key);
1724
0
    OPENSSL_free(key);
1725
0
}
1726
1727
/* Serialise the public component of an ML-KEM key */
1728
int ossl_ml_kem_encode_public_key(uint8_t *out, size_t len,
1729
                                  const ML_KEM_KEY *key)
1730
0
{
1731
0
    if (!ossl_ml_kem_have_pubkey(key)
1732
0
        || len != key->vinfo->pubkey_bytes)
1733
0
        return 0;
1734
0
    encode_pubkey(out, key);
1735
0
    return 1;
1736
0
}
1737
1738
/* Serialise an ML-KEM private key */
1739
int ossl_ml_kem_encode_private_key(uint8_t *out, size_t len,
1740
                                   const ML_KEM_KEY *key)
1741
0
{
1742
0
    if (!ossl_ml_kem_have_prvkey(key)
1743
0
        || len != key->vinfo->prvkey_bytes)
1744
0
        return 0;
1745
0
    encode_prvkey(out, key);
1746
0
    return 1;
1747
0
}
1748
1749
int ossl_ml_kem_encode_seed(uint8_t *out, size_t len,
1750
                            const ML_KEM_KEY *key)
1751
0
{
1752
0
    if (key == NULL || key->d == NULL || len != ML_KEM_SEED_BYTES)
1753
0
        return 0;
1754
    /*
1755
     * Both in the seed buffer, and in the allocated storage, the |d| component
1756
     * of the seed is stored last, so we must copy each separately.
1757
     */
1758
0
    memcpy(out, key->d, ML_KEM_RANDOM_BYTES);
1759
0
    out += ML_KEM_RANDOM_BYTES;
1760
0
    memcpy(out, key->z, ML_KEM_RANDOM_BYTES);
1761
0
    return 1;
1762
0
}
1763
1764
/*
1765
 * Stash the seed without (yet) performing a keygen, used during decoding, to
1766
 * avoid an extra keygen if we're only going to export the key again to load
1767
 * into another provider.
1768
 */
1769
ML_KEM_KEY *ossl_ml_kem_set_seed(const uint8_t *seed, size_t seedlen, ML_KEM_KEY *key)
1770
0
{
1771
0
    if (key == NULL
1772
0
        || ossl_ml_kem_have_pubkey(key)
1773
0
        || ossl_ml_kem_have_seed(key)
1774
0
        || seedlen != ML_KEM_SEED_BYTES)
1775
0
        return NULL;
1776
    /*
1777
     * With no public or private key material on hand, we can use the seed
1778
     * buffer for |z| and |d|, in that order.
1779
     */
1780
0
    key->z = key->seedbuf;
1781
0
    key->d = key->z + ML_KEM_RANDOM_BYTES;
1782
0
    memcpy(key->d, seed, ML_KEM_RANDOM_BYTES);
1783
0
    seed += ML_KEM_RANDOM_BYTES;
1784
0
    memcpy(key->z, seed, ML_KEM_RANDOM_BYTES);
1785
0
    return key;
1786
0
}
1787
1788
/* Parse input as a public key */
1789
int ossl_ml_kem_parse_public_key(const uint8_t *in, size_t len, ML_KEM_KEY *key)
1790
0
{
1791
0
    EVP_MD_CTX *mdctx = NULL;
1792
0
    const ML_KEM_VINFO *vinfo;
1793
0
    int ret = 0;
1794
1795
    /* Keys with key material are immutable */
1796
0
    if (key == NULL
1797
0
        || ossl_ml_kem_have_pubkey(key)
1798
0
        || ossl_ml_kem_have_dkenc(key))
1799
0
        return 0;
1800
0
    vinfo = key->vinfo;
1801
1802
0
    if (len != vinfo->pubkey_bytes
1803
0
        || (mdctx = EVP_MD_CTX_new()) == NULL)
1804
0
        return 0;
1805
1806
0
    if (add_storage(OPENSSL_malloc(vinfo->puballoc), 0, key))
1807
0
        ret = parse_pubkey(in, mdctx, key);
1808
1809
0
    if (!ret)
1810
0
        ossl_ml_kem_key_reset(key);
1811
0
    EVP_MD_CTX_free(mdctx);
1812
0
    return ret;
1813
0
}
1814
1815
/* Parse input as a new private key */
1816
int ossl_ml_kem_parse_private_key(const uint8_t *in, size_t len,
1817
                                  ML_KEM_KEY *key)
1818
0
{
1819
0
    EVP_MD_CTX *mdctx = NULL;
1820
0
    const ML_KEM_VINFO *vinfo;
1821
0
    int ret = 0;
1822
1823
    /* Keys with key material are immutable */
1824
0
    if (key == NULL
1825
0
        || ossl_ml_kem_have_pubkey(key)
1826
0
        || ossl_ml_kem_have_dkenc(key))
1827
0
        return 0;
1828
0
    vinfo = key->vinfo;
1829
1830
0
    if (len != vinfo->prvkey_bytes
1831
0
        || (mdctx = EVP_MD_CTX_new()) == NULL)
1832
0
        return 0;
1833
1834
0
    if (add_storage(OPENSSL_malloc(vinfo->prvalloc), 1, key))
1835
0
        ret = parse_prvkey(in, mdctx, key);
1836
1837
0
    if (!ret)
1838
0
        ossl_ml_kem_key_reset(key);
1839
0
    EVP_MD_CTX_free(mdctx);
1840
0
    return ret;
1841
0
}
1842
1843
/*
1844
 * Generate a new keypair, either from the saved seed (when non-null), or from
1845
 * the RNG.
1846
 */
1847
int ossl_ml_kem_genkey(uint8_t *pubenc, size_t publen, ML_KEM_KEY *key)
1848
0
{
1849
0
    uint8_t seed[ML_KEM_SEED_BYTES];
1850
0
    EVP_MD_CTX *mdctx = NULL;
1851
0
    const ML_KEM_VINFO *vinfo;
1852
0
    int ret = 0;
1853
1854
0
    if (key == NULL
1855
0
        || ossl_ml_kem_have_pubkey(key)
1856
0
        || ossl_ml_kem_have_dkenc(key))
1857
0
        return 0;
1858
0
    vinfo = key->vinfo;
1859
1860
0
    if (pubenc != NULL && publen != vinfo->pubkey_bytes)
1861
0
        return 0;
1862
1863
0
    if (ossl_ml_kem_have_seed(key)) {
1864
0
        if (!ossl_ml_kem_encode_seed(seed, sizeof(seed), key))
1865
0
            return 0;
1866
0
        key->d = key->z = NULL;
1867
0
    } else if (RAND_priv_bytes_ex(key->libctx, seed, sizeof(seed),
1868
0
                                  key->vinfo->secbits) <= 0) {
1869
0
        return 0;
1870
0
    }
1871
1872
0
    if ((mdctx = EVP_MD_CTX_new()) == NULL)
1873
0
        return 0;
1874
1875
    /*
1876
     * Data derived from (d, z) defaults secret, and to avoid side-channel
1877
     * leaks should not influence control flow.
1878
     */
1879
0
    CONSTTIME_SECRET(seed, ML_KEM_SEED_BYTES);
1880
1881
0
    if (add_storage(OPENSSL_malloc(vinfo->prvalloc), 1, key))
1882
0
        ret = genkey(seed, mdctx, pubenc, key);
1883
0
    OPENSSL_cleanse(seed, sizeof(seed));
1884
1885
    /* Declassify secret inputs and derived outputs before returning control */
1886
0
    CONSTTIME_DECLASSIFY(seed, ML_KEM_SEED_BYTES);
1887
1888
0
    EVP_MD_CTX_free(mdctx);
1889
0
    if (!ret) {
1890
0
        ossl_ml_kem_key_reset(key);
1891
0
        return 0;
1892
0
    }
1893
1894
    /* The public components are already declassified */
1895
0
    CONSTTIME_DECLASSIFY(key->s, vinfo->rank * sizeof(scalar));
1896
0
    CONSTTIME_DECLASSIFY(key->z, 2 * ML_KEM_RANDOM_BYTES);
1897
0
    return 1;
1898
0
}
1899
1900
/*
1901
 * FIPS 203, Section 6.2, Algorithm 17: ML-KEM.Encaps_internal
1902
 * This is the deterministic version with randomness supplied externally.
1903
 */
1904
int ossl_ml_kem_encap_seed(uint8_t *ctext, size_t clen,
1905
                           uint8_t *shared_secret, size_t slen,
1906
                           const uint8_t *entropy, size_t elen,
1907
                           const ML_KEM_KEY *key)
1908
0
{
1909
0
    const ML_KEM_VINFO *vinfo;
1910
0
    EVP_MD_CTX *mdctx;
1911
0
    int ret = 0;
1912
1913
0
    if (key == NULL || !ossl_ml_kem_have_pubkey(key))
1914
0
        return 0;
1915
0
    vinfo = key->vinfo;
1916
1917
0
    if (ctext == NULL || clen != vinfo->ctext_bytes
1918
0
        || shared_secret == NULL || slen != ML_KEM_SHARED_SECRET_BYTES
1919
0
        || entropy == NULL || elen != ML_KEM_RANDOM_BYTES
1920
0
        || (mdctx = EVP_MD_CTX_new()) == NULL)
1921
0
        return 0;
1922
    /*
1923
     * Data derived from the encap entropy defaults secret, and to avoid
1924
     * side-channel leaks should not influence control flow.
1925
     */
1926
0
    CONSTTIME_SECRET(entropy, elen);
1927
1928
    /*-
1929
     * This avoids the need to handle allocation failures for two (max 2KB
1930
     * each) vectors, that are never retained on return from this function.
1931
     * We stack-allocate these.
1932
     */
1933
0
#   define case_encap_seed(bits)                                            \
1934
0
    case EVP_PKEY_ML_KEM_##bits:                                            \
1935
0
        {                                                                   \
1936
0
            scalar tmp[2 * ML_KEM_##bits##_RANK];                           \
1937
0
                                                                            \
1938
0
            ret = encap(ctext, shared_secret, entropy, tmp, mdctx, key);    \
1939
0
            OPENSSL_cleanse((void *)tmp, sizeof(tmp));                      \
1940
0
            break;                                                          \
1941
0
        }
1942
0
    switch (vinfo->evp_type) {
1943
0
    case_encap_seed(512);
1944
0
    case_encap_seed(768);
1945
0
    case_encap_seed(1024);
1946
0
    }
1947
0
#   undef case_encap_seed
1948
1949
    /* Declassify secret inputs and derived outputs before returning control */
1950
0
    CONSTTIME_DECLASSIFY(entropy, elen);
1951
0
    CONSTTIME_DECLASSIFY(ctext, clen);
1952
0
    CONSTTIME_DECLASSIFY(shared_secret, slen);
1953
1954
0
    EVP_MD_CTX_free(mdctx);
1955
0
    return ret;
1956
0
}
1957
1958
int ossl_ml_kem_encap_rand(uint8_t *ctext, size_t clen,
1959
                           uint8_t *shared_secret, size_t slen,
1960
                           const ML_KEM_KEY *key)
1961
0
{
1962
0
    uint8_t r[ML_KEM_RANDOM_BYTES];
1963
1964
0
    if (key == NULL)
1965
0
        return 0;
1966
1967
0
    if (RAND_bytes_ex(key->libctx, r, ML_KEM_RANDOM_BYTES,
1968
0
                      key->vinfo->secbits) < 1)
1969
0
        return 0;
1970
1971
0
    return ossl_ml_kem_encap_seed(ctext, clen, shared_secret, slen,
1972
0
                                  r, sizeof(r), key);
1973
0
}
1974
1975
int ossl_ml_kem_decap(uint8_t *shared_secret, size_t slen,
1976
                      const uint8_t *ctext, size_t clen,
1977
                      const ML_KEM_KEY *key)
1978
0
{
1979
0
    const ML_KEM_VINFO *vinfo;
1980
0
    EVP_MD_CTX *mdctx;
1981
0
    int ret = 0;
1982
#if defined(OPENSSL_CONSTANT_TIME_VALIDATION)
1983
    int classify_bytes;
1984
#endif
1985
1986
    /* Need a private key here */
1987
0
    if (!ossl_ml_kem_have_prvkey(key))
1988
0
        return 0;
1989
0
    vinfo = key->vinfo;
1990
1991
0
    if (shared_secret == NULL || slen != ML_KEM_SHARED_SECRET_BYTES
1992
0
        || ctext == NULL || clen != vinfo->ctext_bytes
1993
0
        || (mdctx = EVP_MD_CTX_new()) == NULL) {
1994
0
        (void)RAND_bytes_ex(key->libctx, shared_secret,
1995
0
                            ML_KEM_SHARED_SECRET_BYTES, vinfo->secbits);
1996
0
        return 0;
1997
0
    }
1998
#if defined(OPENSSL_CONSTANT_TIME_VALIDATION)
1999
    /*
2000
     * Data derived from |s| and |z| defaults secret, and to avoid side-channel
2001
     * leaks should not influence control flow.
2002
     */
2003
    classify_bytes = 2 * sizeof(scalar) + ML_KEM_RANDOM_BYTES;
2004
    CONSTTIME_SECRET(key->s, classify_bytes);
2005
#endif
2006
2007
    /*-
2008
     * This avoids the need to handle allocation failures for two (max 2KB
2009
     * each) vectors and an encoded ciphertext (max 1568 bytes), that are never
2010
     * retained on return from this function.
2011
     * We stack-allocate these.
2012
     */
2013
0
#   define case_decap(bits)                                             \
2014
0
    case EVP_PKEY_ML_KEM_##bits:                                        \
2015
0
        {                                                               \
2016
0
            uint8_t cbuf[CTEXT_BYTES(bits)];                            \
2017
0
            scalar tmp[2 * ML_KEM_##bits##_RANK];                       \
2018
0
                                                                        \
2019
0
            ret = decap(shared_secret, ctext, cbuf, tmp, mdctx, key);   \
2020
0
            OPENSSL_cleanse((void *)tmp, sizeof(tmp));                  \
2021
0
            break;                                                      \
2022
0
        }
2023
0
    switch (vinfo->evp_type) {
2024
0
    case_decap(512);
2025
0
    case_decap(768);
2026
0
    case_decap(1024);
2027
0
    }
2028
2029
    /* Declassify secret inputs and derived outputs before returning control */
2030
0
    CONSTTIME_DECLASSIFY(key->s, classify_bytes);
2031
0
    CONSTTIME_DECLASSIFY(shared_secret, slen);
2032
0
    EVP_MD_CTX_free(mdctx);
2033
2034
0
    return ret;
2035
0
#   undef case_decap
2036
0
}
2037
2038
int ossl_ml_kem_pubkey_cmp(const ML_KEM_KEY *key1, const ML_KEM_KEY *key2)
2039
0
{
2040
    /*
2041
     * This handles any unexpected differences in the ML-KEM variant rank,
2042
     * giving different key component structures, barring SHA3-256 hash
2043
     * collisions, the keys are the same size.
2044
     */
2045
0
    if (ossl_ml_kem_have_pubkey(key1) && ossl_ml_kem_have_pubkey(key2))
2046
0
        return memcmp(key1->pkhash, key2->pkhash, ML_KEM_PKHASH_BYTES) == 0;
2047
2048
    /*
2049
     * No match if just one of the public keys is not available, otherwise both
2050
     * are unavailable, and for now such keys are considered equal.
2051
     */
2052
0
    return (ossl_ml_kem_have_pubkey(key1) ^ ossl_ml_kem_have_pubkey(key2));
2053
0
}