Coverage Report

Created: 2026-03-09 06:55

next uncovered line (L), next uncovered region (R), next uncovered branch (B)
/src/openssl/crypto/ml_kem/ml_kem.c
Line
Count
Source
1
/*
2
 * Copyright 2024-2025 The OpenSSL Project Authors. All Rights Reserved.
3
 *
4
 * Licensed under the Apache License 2.0 (the "License").  You may not use
5
 * this file except in compliance with the License.  You can obtain a copy
6
 * in the file LICENSE in the source distribution or at
7
 * https://www.openssl.org/source/license.html
8
 */
9
10
#include <openssl/byteorder.h>
11
#include <openssl/rand.h>
12
#include <openssl/proverr.h>
13
#include "crypto/ml_kem.h"
14
#include "internal/common.h"
15
#include "internal/constant_time.h"
16
#include "internal/sha3.h"
17
18
#if defined(OPENSSL_CONSTANT_TIME_VALIDATION)
19
#include <valgrind/memcheck.h>
20
#endif
21
22
#if ML_KEM_SEED_BYTES != ML_KEM_SHARED_SECRET_BYTES + ML_KEM_RANDOM_BYTES
23
#error "ML-KEM keygen seed length != shared secret + random bytes length"
24
#endif
25
#if ML_KEM_SHARED_SECRET_BYTES != ML_KEM_RANDOM_BYTES
26
#error "Invalid unequal lengths of ML-KEM shared secret and random inputs"
27
#endif
28
29
#if UINT_MAX < UINT32_MAX
30
#error "Unsupported compiler: sizeof(unsigned int) < sizeof(uint32_t)"
31
#endif
32
33
/* Handy function-like bit-extraction macros */
34
0
#define bit0(b) ((b) & 1)
35
0
#define bitn(n, b) (((b) >> n) & 1)
36
37
/*
38
 * 12 bits are sufficient to losslessly represent values in [0, q-1].
39
 * INVERSE_DEGREE is (n/2)^-1 mod q; used in inverse NTT.
40
 */
41
0
#define DEGREE ML_KEM_DEGREE
42
#define INVERSE_DEGREE (ML_KEM_PRIME - 2 * 13)
43
#define LOG2PRIME 12
44
#define BARRETT_SHIFT (2 * LOG2PRIME)
45
46
#ifdef SHA3_BLOCKSIZE
47
#define SHAKE128_BLOCKSIZE SHA3_BLOCKSIZE(128)
48
#endif
49
50
/*
51
 * Return whether a value that can only be 0 or 1 is non-zero, in constant time
52
 * in practice!  The return value is a mask that is all ones if true, and all
53
 * zeros otherwise (twos-complement arithmetic assumed for unsigned values).
54
 *
55
 * Although this is used in constant-time selects, we omit a value barrier
56
 * here.  Value barriers impede auto-vectorization (likely because it forces
57
 * the value to transit through a general-purpose register). On AArch64, this
58
 * is a difference of 2x.
59
 *
60
 * We usually add value barriers to selects because Clang turns consecutive
61
 * selects with the same condition into a branch instead of CMOV/CSEL. This
62
 * condition does not occur in Kyber, so omitting it seems to be safe so far,
63
 * but see |cbd_2|, |cbd_3|, where reduction needs to be specialised to the
64
 * sign of the input, rather than adding |q| in advance, and using the generic
65
 * |reduce_once|.  (David Benjamin, Chromium)
66
 */
67
#if 0
68
#define constish_time_non_zero(b) (~constant_time_is_zero(b));
69
#else
70
0
#define constish_time_non_zero(b) (0u - (b))
71
#endif
72
73
/*
74
 * The scalar rejection-sampling buffer size needs to be a multiple of 12, but
75
 * is otherwise arbitrary, the preferred block size matches the internal buffer
76
 * size of SHAKE128, avoiding internal buffering and copying in SHAKE128. That
77
 * block size of (1600 - 256)/8 bytes, or 168, just happens to divide by 12!
78
 *
79
 * If the blocksize is unknown, or is not divisible by 12, 168 is used as a
80
 * fallback.
81
 */
82
#if defined(SHAKE128_BLOCKSIZE) && (SHAKE128_BLOCKSIZE) % 12 == 0
83
#define SCALAR_SAMPLING_BUFSIZE (SHAKE128_BLOCKSIZE)
84
#else
85
#define SCALAR_SAMPLING_BUFSIZE 168
86
#endif
87
88
/*
89
 * Structure of keys
90
 */
91
typedef struct ossl_ml_kem_scalar_st {
92
    /* On every function entry and exit, 0 <= c[i] < ML_KEM_PRIME. */
93
    uint16_t c[ML_KEM_DEGREE];
94
} scalar;
95
96
/* Key material allocation layout */
97
#define DECLARE_ML_KEM_PUBKEYDATA(name, rank)                  \
98
    struct name##_alloc {                                      \
99
        /* Public vector |t| */                                \
100
        scalar tbuf[(rank)];                                   \
101
        /* Pre-computed matrix |m| (FIPS 203 |A| transpose) */ \
102
        scalar mbuf[(rank) * (rank)];                          \
103
    }
104
105
#define DECLARE_ML_KEM_PRVKEYDATA(name, rank)  \
106
    struct name##_alloc {                      \
107
        scalar sbuf[rank];                     \
108
        uint8_t zbuf[2 * ML_KEM_RANDOM_BYTES]; \
109
    }
110
111
/* Declare variant-specific public and private storage */
112
#define DECLARE_ML_KEM_VARIANT_KEYDATA(bits)                        \
113
    DECLARE_ML_KEM_PUBKEYDATA(pubkey_##bits, ML_KEM_##bits##_RANK); \
114
    DECLARE_ML_KEM_PRVKEYDATA(prvkey_##bits, ML_KEM_##bits##_RANK)
115
116
DECLARE_ML_KEM_VARIANT_KEYDATA(512);
117
DECLARE_ML_KEM_VARIANT_KEYDATA(768);
118
DECLARE_ML_KEM_VARIANT_KEYDATA(1024);
119
#undef DECLARE_ML_KEM_VARIANT_KEYDATA
120
#undef DECLARE_ML_KEM_PUBKEYDATA
121
#undef DECLARE_ML_KEM_PRVKEYDATA
122
123
typedef __owur int (*CBD_FUNC)(scalar *out, uint8_t in[ML_KEM_RANDOM_BYTES + 1],
124
    EVP_MD_CTX *mdctx, const ML_KEM_KEY *key);
125
static void scalar_encode(uint8_t *out, const scalar *s, int bits);
126
127
/*
128
 * The wire-form of a losslessly encoded vector uses 12-bits per element.
129
 *
130
 * The wire-form public key consists of the lossless encoding of the public
131
 * vector |t|, followed by the public seed |rho|.
132
 *
133
 * Our serialised private key concatenates serialisations of the private vector
134
 * |s|, the public key, the public key hash, and the failure secret |z|.
135
 */
136
#define VECTOR_BYTES(b) ((3 * DEGREE / 2) * ML_KEM_##b##_RANK)
137
#define PUBKEY_BYTES(b) (VECTOR_BYTES(b) + ML_KEM_RANDOM_BYTES)
138
#define PRVKEY_BYTES(b) (2 * PUBKEY_BYTES(b) + ML_KEM_PKHASH_BYTES)
139
140
/*
141
 * Encapsulation produces a vector "u" and a scalar "v", whose coordinates
142
 * (numbers modulo the ML-KEM prime "q") are lossily encoded using as "du" and
143
 * "dv" bits, respectively.  This encoding is the ciphertext input for
144
 * decapsulation.
145
 */
146
#define U_VECTOR_BYTES(b) ((DEGREE / 8) * ML_KEM_##b##_DU * ML_KEM_##b##_RANK)
147
#define V_SCALAR_BYTES(b) ((DEGREE / 8) * ML_KEM_##b##_DV)
148
#define CTEXT_BYTES(b) (U_VECTOR_BYTES(b) + V_SCALAR_BYTES(b))
149
150
#if defined(OPENSSL_CONSTANT_TIME_VALIDATION)
151
152
/*
153
 * CONSTTIME_SECRET takes a pointer and a number of bytes and marks that region
154
 * of memory as secret. Secret data is tracked as it flows to registers and
155
 * other parts of a memory. If secret data is used as a condition for a branch,
156
 * or as a memory index, it will trigger warnings in valgrind.
157
 */
158
#define CONSTTIME_SECRET(ptr, len) VALGRIND_MAKE_MEM_UNDEFINED(ptr, len)
159
160
/*
161
 * CONSTTIME_DECLASSIFY takes a pointer and a number of bytes and marks that
162
 * region of memory as public. Public data is not subject to constant-time
163
 * rules.
164
 */
165
#define CONSTTIME_DECLASSIFY(ptr, len) VALGRIND_MAKE_MEM_DEFINED(ptr, len)
166
167
#else
168
169
#define CONSTTIME_SECRET(ptr, len)
170
#define CONSTTIME_DECLASSIFY(ptr, len)
171
172
#endif
173
174
/*
175
 * Indices of slots in the vinfo tables below
176
 */
177
0
#define ML_KEM_512_VINFO 0
178
0
#define ML_KEM_768_VINFO 1
179
0
#define ML_KEM_1024_VINFO 2
180
181
/*
182
 * Per-variant fixed parameters
183
 */
184
static const ML_KEM_VINFO vinfo_map[3] = {
185
    { "ML-KEM-512",
186
        PRVKEY_BYTES(512),
187
        sizeof(struct prvkey_512_alloc),
188
        PUBKEY_BYTES(512),
189
        sizeof(struct pubkey_512_alloc),
190
        CTEXT_BYTES(512),
191
        VECTOR_BYTES(512),
192
        U_VECTOR_BYTES(512),
193
        EVP_PKEY_ML_KEM_512,
194
        ML_KEM_512_BITS,
195
        ML_KEM_512_RANK,
196
        ML_KEM_512_DU,
197
        ML_KEM_512_DV,
198
        ML_KEM_512_SECBITS,
199
        ML_KEM_512_SECURITY_CATEGORY },
200
    { "ML-KEM-768",
201
        PRVKEY_BYTES(768),
202
        sizeof(struct prvkey_768_alloc),
203
        PUBKEY_BYTES(768),
204
        sizeof(struct pubkey_768_alloc),
205
        CTEXT_BYTES(768),
206
        VECTOR_BYTES(768),
207
        U_VECTOR_BYTES(768),
208
        EVP_PKEY_ML_KEM_768,
209
        ML_KEM_768_BITS,
210
        ML_KEM_768_RANK,
211
        ML_KEM_768_DU,
212
        ML_KEM_768_DV,
213
        ML_KEM_768_SECBITS,
214
        ML_KEM_768_SECURITY_CATEGORY },
215
    { "ML-KEM-1024",
216
        PRVKEY_BYTES(1024),
217
        sizeof(struct prvkey_1024_alloc),
218
        PUBKEY_BYTES(1024),
219
        sizeof(struct pubkey_1024_alloc),
220
        CTEXT_BYTES(1024),
221
        VECTOR_BYTES(1024),
222
        U_VECTOR_BYTES(1024),
223
        EVP_PKEY_ML_KEM_1024,
224
        ML_KEM_1024_BITS,
225
        ML_KEM_1024_RANK,
226
        ML_KEM_1024_DU,
227
        ML_KEM_1024_DV,
228
        ML_KEM_1024_SECBITS,
229
        ML_KEM_1024_SECURITY_CATEGORY }
230
};
231
232
/*
233
 * Remainders modulo `kPrime`, for sufficiently small inputs, are computed in
234
 * constant time via Barrett reduction, and a final call to reduce_once(),
235
 * which reduces inputs that are at most 2*kPrime and is also constant-time.
236
 */
237
static const int kPrime = ML_KEM_PRIME;
238
static const unsigned int kBarrettShift = BARRETT_SHIFT;
239
static const size_t kBarrettMultiplier = (1 << BARRETT_SHIFT) / ML_KEM_PRIME;
240
static const uint16_t kHalfPrime = (ML_KEM_PRIME - 1) / 2;
241
static const uint16_t kInverseDegree = INVERSE_DEGREE;
242
243
/*
244
 * Python helper:
245
 *
246
 * p = 3329
247
 * def bitreverse(i):
248
 *     ret = 0
249
 *     for n in range(7):
250
 *         bit = i & 1
251
 *         ret <<= 1
252
 *         ret |= bit
253
 *         i >>= 1
254
 *     return ret
255
 */
256
257
/*-
258
 * First precomputed array from Appendix A of FIPS 203, or else Python:
259
 * kNTTRoots = [pow(17, bitreverse(i), p) for i in range(128)]
260
 */
261
static const uint16_t kNTTRoots[128] = {
262
    0x001, 0x6c1, 0xa14, 0xcd9, 0xa52, 0x276, 0x769, 0x350,
263
    0x426, 0x77f, 0x0c1, 0x31d, 0xae2, 0xcbc, 0x239, 0x6d2,
264
    0x128, 0x98f, 0x53b, 0x5c4, 0xbe6, 0x038, 0x8c0, 0x535,
265
    0x592, 0x82e, 0x217, 0xb42, 0x959, 0xb3f, 0x7b6, 0x335,
266
    0x121, 0x14b, 0xcb5, 0x6dc, 0x4ad, 0x900, 0x8e5, 0x807,
267
    0x28a, 0x7b9, 0x9d1, 0x278, 0xb31, 0x021, 0x528, 0x77b,
268
    0x90f, 0x59b, 0x327, 0x1c4, 0x59e, 0xb34, 0x5fe, 0x962,
269
    0xa57, 0xa39, 0x5c9, 0x288, 0x9aa, 0xc26, 0x4cb, 0x38e,
270
    0x011, 0xac9, 0x247, 0xa59, 0x665, 0x2d3, 0x8f0, 0x44c,
271
    0x581, 0xa66, 0xcd1, 0x0e9, 0x2f4, 0x86c, 0xbc7, 0xbea,
272
    0x6a7, 0x673, 0xae5, 0x6fd, 0x737, 0x3b8, 0x5b5, 0xa7f,
273
    0x3ab, 0x904, 0x985, 0x954, 0x2dd, 0x921, 0x10c, 0x281,
274
    0x630, 0x8fa, 0x7f5, 0xc94, 0x177, 0x9f5, 0x82a, 0x66d,
275
    0x427, 0x13f, 0xad5, 0x2f5, 0x833, 0x231, 0x9a2, 0xa22,
276
    0xaf4, 0x444, 0x193, 0x402, 0x477, 0x866, 0xad7, 0x376,
277
    0x6ba, 0x4bc, 0x752, 0x405, 0x83e, 0xb77, 0x375, 0x86a
278
};
279
280
/*
281
 * InverseNTTRoots = [pow(17, -bitreverse(i), p) for i in range(128)]
282
 * Listed in order of use in the inverse NTT loop (index 0 is skipped):
283
 *
284
 *  0, 64, 65, ..., 127, 32, 33, ..., 63, 16, 17, ..., 31, 8, 9, ...
285
 */
286
static const uint16_t kInverseNTTRoots[128] = {
287
    0x001, 0x497, 0x98c, 0x18a, 0x4c3, 0x8fc, 0x5af, 0x845,
288
    0x647, 0x98b, 0x22a, 0x49b, 0x88a, 0x8ff, 0xb6e, 0x8bd,
289
    0x20d, 0x2df, 0x35f, 0xad0, 0x4ce, 0xa0c, 0x22c, 0xbc2,
290
    0x8da, 0x694, 0x4d7, 0x30c, 0xb8a, 0x06d, 0x50c, 0x407,
291
    0x6d1, 0xa80, 0xbf5, 0x3e0, 0xa24, 0x3ad, 0x37c, 0x3fd,
292
    0x956, 0x282, 0x74c, 0x949, 0x5ca, 0x604, 0x21c, 0x68e,
293
    0x65a, 0x117, 0x13a, 0x495, 0xa0d, 0xc18, 0x030, 0x29b,
294
    0x780, 0x8b5, 0x411, 0xa2e, 0x69c, 0x2a8, 0xaba, 0x238,
295
    0xcf0, 0x973, 0x836, 0x0db, 0x357, 0xa79, 0x738, 0x2c8,
296
    0x2aa, 0x39f, 0x703, 0x1cd, 0x763, 0xb3d, 0x9da, 0x766,
297
    0x3f2, 0x586, 0x7d9, 0xce0, 0x1d0, 0xa89, 0x330, 0x548,
298
    0xa77, 0x4fa, 0x41c, 0x401, 0x854, 0x625, 0x04c, 0xbb6,
299
    0xbe0, 0x9cc, 0x54b, 0x1c2, 0x3a8, 0x1bf, 0xaea, 0x4d3,
300
    0x76f, 0x7cc, 0x441, 0xcc9, 0x11b, 0x73d, 0x7c6, 0x372,
301
    0xbd9, 0x62f, 0xac8, 0x045, 0x21f, 0x9e4, 0xc40, 0x582,
302
    0x8db, 0x9b1, 0x598, 0xa8b, 0x2af, 0x028, 0x2ed, 0x640
303
};
304
305
/*
306
 * Second precomputed array from Appendix A of FIPS 203 (normalised positive),
307
 * or else Python:
308
 * ModRoots = [pow(17, 2*bitreverse(i) + 1, p) for i in range(128)]
309
 */
310
static const uint16_t kModRoots[128] = {
311
    0x011, 0xcf0, 0xac9, 0x238, 0x247, 0xaba, 0xa59, 0x2a8,
312
    0x665, 0x69c, 0x2d3, 0xa2e, 0x8f0, 0x411, 0x44c, 0x8b5,
313
    0x581, 0x780, 0xa66, 0x29b, 0xcd1, 0x030, 0x0e9, 0xc18,
314
    0x2f4, 0xa0d, 0x86c, 0x495, 0xbc7, 0x13a, 0xbea, 0x117,
315
    0x6a7, 0x65a, 0x673, 0x68e, 0xae5, 0x21c, 0x6fd, 0x604,
316
    0x737, 0x5ca, 0x3b8, 0x949, 0x5b5, 0x74c, 0xa7f, 0x282,
317
    0x3ab, 0x956, 0x904, 0x3fd, 0x985, 0x37c, 0x954, 0x3ad,
318
    0x2dd, 0xa24, 0x921, 0x3e0, 0x10c, 0xbf5, 0x281, 0xa80,
319
    0x630, 0x6d1, 0x8fa, 0x407, 0x7f5, 0x50c, 0xc94, 0x06d,
320
    0x177, 0xb8a, 0x9f5, 0x30c, 0x82a, 0x4d7, 0x66d, 0x694,
321
    0x427, 0x8da, 0x13f, 0xbc2, 0xad5, 0x22c, 0x2f5, 0xa0c,
322
    0x833, 0x4ce, 0x231, 0xad0, 0x9a2, 0x35f, 0xa22, 0x2df,
323
    0xaf4, 0x20d, 0x444, 0x8bd, 0x193, 0xb6e, 0x402, 0x8ff,
324
    0x477, 0x88a, 0x866, 0x49b, 0xad7, 0x22a, 0x376, 0x98b,
325
    0x6ba, 0x647, 0x4bc, 0x845, 0x752, 0x5af, 0x405, 0x8fc,
326
    0x83e, 0x4c3, 0xb77, 0x18a, 0x375, 0x98c, 0x86a, 0x497
327
};
328
329
/*
330
 * single_keccak hashes |inlen| bytes from |in| and writes |outlen| bytes of
331
 * output to |out|. If the |md| specifies a fixed-output function, like
332
 * SHA3-256, then |outlen| must be the correct length for that function.
333
 */
334
static __owur int single_keccak(uint8_t *out, size_t outlen, const uint8_t *in, size_t inlen,
335
    EVP_MD_CTX *mdctx)
336
0
{
337
0
    unsigned int sz = (unsigned int)outlen;
338
339
0
    if (!EVP_DigestUpdate(mdctx, in, inlen))
340
0
        return 0;
341
0
    if (EVP_MD_xof(EVP_MD_CTX_get0_md(mdctx)))
342
0
        return EVP_DigestFinalXOF(mdctx, out, outlen);
343
0
    return EVP_DigestFinal_ex(mdctx, out, &sz)
344
0
        && ossl_assert((size_t)sz == outlen);
345
0
}
346
347
/*
348
 * FIPS 203, Section 4.1, equation (4.3): PRF. Takes 32+1 input bytes, and uses
349
 * SHAKE256 to produce the input to SamplePolyCBD_eta: FIPS 203, algorithm 8.
350
 */
351
static __owur int prf(uint8_t *out, size_t len, const uint8_t in[ML_KEM_RANDOM_BYTES + 1],
352
    EVP_MD_CTX *mdctx, const ML_KEM_KEY *key)
353
0
{
354
0
    return EVP_DigestInit_ex(mdctx, key->shake256_md, NULL)
355
0
        && single_keccak(out, len, in, ML_KEM_RANDOM_BYTES + 1, mdctx);
356
0
}
357
358
/*
359
 * FIPS 203, Section 4.1, equation (4.4): H.  SHA3-256 hash of a variable
360
 * length input, producing 32 bytes of output.
361
 */
362
static __owur int hash_h(uint8_t out[ML_KEM_PKHASH_BYTES], const uint8_t *in, size_t len,
363
    EVP_MD_CTX *mdctx, const ML_KEM_KEY *key)
364
0
{
365
0
    return EVP_DigestInit_ex(mdctx, key->sha3_256_md, NULL)
366
0
        && single_keccak(out, ML_KEM_PKHASH_BYTES, in, len, mdctx);
367
0
}
368
369
/* Incremental hash_h of expanded public key */
370
static int
371
hash_h_pubkey(uint8_t pkhash[ML_KEM_PKHASH_BYTES],
372
    EVP_MD_CTX *mdctx, ML_KEM_KEY *key)
373
0
{
374
0
    const ML_KEM_VINFO *vinfo = key->vinfo;
375
0
    const scalar *t = key->t, *end = t + vinfo->rank;
376
0
    unsigned int sz;
377
378
0
    if (!EVP_DigestInit_ex(mdctx, key->sha3_256_md, NULL))
379
0
        return 0;
380
381
0
    do {
382
0
        uint8_t buf[3 * DEGREE / 2];
383
384
0
        scalar_encode(buf, t++, 12);
385
0
        if (!EVP_DigestUpdate(mdctx, buf, sizeof(buf)))
386
0
            return 0;
387
0
    } while (t < end);
388
389
0
    if (!EVP_DigestUpdate(mdctx, key->rho, ML_KEM_RANDOM_BYTES))
390
0
        return 0;
391
0
    return EVP_DigestFinal_ex(mdctx, pkhash, &sz)
392
0
        && ossl_assert(sz == ML_KEM_PKHASH_BYTES);
393
0
}
394
395
/*
396
 * FIPS 203, Section 4.1, equation (4.5): G.  SHA3-512 hash of a variable
397
 * length input, producing 64 bytes of output, in particular the seeds
398
 * (d,z) for key generation.
399
 */
400
static __owur int hash_g(uint8_t out[ML_KEM_SEED_BYTES], const uint8_t *in, size_t len,
401
    EVP_MD_CTX *mdctx, const ML_KEM_KEY *key)
402
0
{
403
0
    return EVP_DigestInit_ex(mdctx, key->sha3_512_md, NULL)
404
0
        && single_keccak(out, ML_KEM_SEED_BYTES, in, len, mdctx);
405
0
}
406
407
/*
408
 * FIPS 203, Section 4.1, equation (4.4): J. SHAKE256 taking a variable length
409
 * input to compute a 32-byte implicit rejection shared secret, of the same
410
 * length as the expected shared secret.  (Computed even on success to avoid
411
 * side-channel leaks).
412
 */
413
static __owur int kdf(uint8_t out[ML_KEM_SHARED_SECRET_BYTES],
414
    const uint8_t z[ML_KEM_RANDOM_BYTES],
415
    const uint8_t *ctext, size_t len,
416
    EVP_MD_CTX *mdctx, const ML_KEM_KEY *key)
417
0
{
418
0
    return EVP_DigestInit_ex(mdctx, key->shake256_md, NULL)
419
0
        && EVP_DigestUpdate(mdctx, z, ML_KEM_RANDOM_BYTES)
420
0
        && EVP_DigestUpdate(mdctx, ctext, len)
421
0
        && EVP_DigestFinalXOF(mdctx, out, ML_KEM_SHARED_SECRET_BYTES);
422
0
}
423
424
/*
425
 * FIPS 203, Section 4.2.2, Algorithm 7: "SampleNTT" (steps 3-17, steps 1, 2
426
 * are performed by the caller). Rejection-samples a Keccak stream to get
427
 * uniformly distributed elements in the range [0,q). This is used for matrix
428
 * expansion and only operates on public inputs.
429
 */
430
static __owur int sample_scalar(scalar *out, EVP_MD_CTX *mdctx)
431
0
{
432
0
    uint16_t *curr = out->c, *endout = curr + DEGREE;
433
0
    uint8_t buf[SCALAR_SAMPLING_BUFSIZE], *in;
434
0
    uint8_t *endin = buf + sizeof(buf);
435
0
    uint16_t d;
436
0
    uint8_t b1, b2, b3;
437
438
0
    do {
439
0
        if (!EVP_DigestSqueeze(mdctx, in = buf, sizeof(buf)))
440
0
            return 0;
441
0
        do {
442
0
            b1 = *in++;
443
0
            b2 = *in++;
444
0
            b3 = *in++;
445
446
0
            if (curr >= endout)
447
0
                break;
448
0
            if ((d = ((b2 & 0x0f) << 8) + b1) < kPrime)
449
0
                *curr++ = d;
450
0
            if (curr >= endout)
451
0
                break;
452
0
            if ((d = (b3 << 4) + (b2 >> 4)) < kPrime)
453
0
                *curr++ = d;
454
0
        } while (in < endin);
455
0
    } while (curr < endout);
456
0
    return 1;
457
0
}
458
459
/*-
460
 * reduce_once reduces 0 <= x < 2*kPrime, mod kPrime.
461
 *
462
 * Subtract |q| if the input is larger, without exposing a side-channel,
463
 * avoiding the "clangover" attack.  See |constish_time_non_zero| for a
464
 * discussion on why the value barrier is by default omitted.
465
 */
466
static __owur uint16_t reduce_once(uint16_t x)
467
0
{
468
0
    const uint16_t subtracted = x - kPrime;
469
0
    uint16_t mask = constish_time_non_zero(subtracted >> 15);
470
471
0
    return (mask & x) | (~mask & subtracted);
472
0
}
473
474
/*
475
 * Constant-time reduce x mod kPrime using Barrett reduction. x must be less
476
 * than kPrime + 2 * kPrime^2.  This is sufficient to reduce a product of
477
 * two already reduced u_int16 values, in fact it is sufficient for each
478
 * to be less than 2^12, because (kPrime * (2 * kPrime + 1)) > 2^24.
479
 */
480
static __owur uint16_t reduce(uint32_t x)
481
0
{
482
0
    uint64_t product = (uint64_t)x * kBarrettMultiplier;
483
0
    uint32_t quotient = (uint32_t)(product >> kBarrettShift);
484
0
    uint32_t remainder = x - quotient * kPrime;
485
486
0
    return reduce_once(remainder);
487
0
}
488
489
/* Multiply a scalar by a constant. */
490
static void scalar_mult_const(scalar *s, uint16_t a)
491
0
{
492
0
    uint16_t *curr = s->c, *end = curr + DEGREE, tmp;
493
494
0
    do {
495
0
        tmp = reduce(*curr * a);
496
0
        *curr++ = tmp;
497
0
    } while (curr < end);
498
0
}
499
500
/*-
501
 * FIPS 203, Section 4.3, Algorithm 9: "NTT".
502
 * In-place number theoretic transform of a given scalar.  Note that ML-KEM's
503
 * kPrime 3329 does not have a 512th root of unity, so this transform leaves
504
 * off the last iteration of the usual FFT code, with the 128 relevant roots of
505
 * unity being stored in NTTRoots.  This means the output should be seen as 128
506
 * elements in GF(3329^2), with the coefficients of the elements being
507
 * consecutive entries in |s->c|.
508
 */
509
static void scalar_ntt(scalar *s)
510
0
{
511
0
    const uint16_t *roots = kNTTRoots;
512
0
    uint16_t *end = s->c + DEGREE;
513
0
    int offset = DEGREE / 2;
514
515
0
    do {
516
0
        uint16_t *curr = s->c, *peer;
517
518
0
        do {
519
0
            uint16_t *pause = curr + offset, even, odd;
520
0
            uint32_t zeta = *++roots;
521
522
0
            peer = pause;
523
0
            do {
524
0
                even = *curr;
525
0
                odd = reduce(*peer * zeta);
526
0
                *peer++ = reduce_once(even - odd + kPrime);
527
0
                *curr++ = reduce_once(odd + even);
528
0
            } while (curr < pause);
529
0
        } while ((curr = peer) < end);
530
0
    } while ((offset >>= 1) >= 2);
531
0
}
532
533
/*-
534
 * FIPS 203, Section 4.3, Algorithm 10: "NTT^(-1)".
535
 * In-place inverse number theoretic transform of a given scalar, with pairs of
536
 * entries of s->v being interpreted as elements of GF(3329^2). Just as with
537
 * the number theoretic transform, this leaves off the first step of the normal
538
 * iFFT to account for the fact that 3329 does not have a 512th root of unity,
539
 * using the precomputed 128 roots of unity stored in InverseNTTRoots.
540
 */
541
static void scalar_inverse_ntt(scalar *s)
542
0
{
543
0
    const uint16_t *roots = kInverseNTTRoots;
544
0
    uint16_t *end = s->c + DEGREE;
545
0
    int offset = 2;
546
547
0
    do {
548
0
        uint16_t *curr = s->c, *peer;
549
550
0
        do {
551
0
            uint16_t *pause = curr + offset, even, odd;
552
0
            uint32_t zeta = *++roots;
553
554
0
            peer = pause;
555
0
            do {
556
0
                even = *curr;
557
0
                odd = *peer;
558
0
                *peer++ = reduce(zeta * (even - odd + kPrime));
559
0
                *curr++ = reduce_once(odd + even);
560
0
            } while (curr < pause);
561
0
        } while ((curr = peer) < end);
562
0
    } while ((offset <<= 1) < DEGREE);
563
0
    scalar_mult_const(s, kInverseDegree);
564
0
}
565
566
/* Addition updating the LHS scalar in-place. */
567
static void scalar_add(scalar *lhs, const scalar *rhs)
568
0
{
569
0
    int i;
570
571
0
    for (i = 0; i < DEGREE; i++)
572
0
        lhs->c[i] = reduce_once(lhs->c[i] + rhs->c[i]);
573
0
}
574
575
/* Subtraction updating the LHS scalar in-place. */
576
static void scalar_sub(scalar *lhs, const scalar *rhs)
577
0
{
578
0
    int i;
579
580
0
    for (i = 0; i < DEGREE; i++)
581
0
        lhs->c[i] = reduce_once(lhs->c[i] - rhs->c[i] + kPrime);
582
0
}
583
584
/*
585
 * Multiplying two scalars in the number theoretically transformed state. Since
586
 * 3329 does not have a 512th root of unity, this means we have to interpret
587
 * the 2*ith and (2*i+1)th entries of the scalar as elements of
588
 * GF(3329)[X]/(X^2 - 17^(2*bitreverse(i)+1)).
589
 *
590
 * The value of 17^(2*bitreverse(i)+1) mod 3329 is stored in the precomputed
591
 * ModRoots table. Note that our Barrett transform only allows us to multiply
592
 * two reduced numbers together, so we need some intermediate reduction steps,
593
 * even if an uint64_t could hold 3 multiplied numbers.
594
 */
595
static void scalar_mult(scalar *out, const scalar *lhs,
596
    const scalar *rhs)
597
0
{
598
0
    uint16_t *curr = out->c, *end = curr + DEGREE;
599
0
    const uint16_t *lc = lhs->c, *rc = rhs->c;
600
0
    const uint16_t *roots = kModRoots;
601
602
0
    do {
603
0
        uint32_t l0 = *lc++, r0 = *rc++;
604
0
        uint32_t l1 = *lc++, r1 = *rc++;
605
0
        uint32_t zetapow = *roots++;
606
607
0
        *curr++ = reduce(l0 * r0 + reduce(l1 * r1) * zetapow);
608
0
        *curr++ = reduce(l0 * r1 + l1 * r0);
609
0
    } while (curr < end);
610
0
}
611
612
/* Above, but add the result to an existing scalar */
613
static ossl_inline void scalar_mult_add(scalar *out, const scalar *lhs,
614
    const scalar *rhs)
615
0
{
616
0
    uint16_t *curr = out->c, *end = curr + DEGREE;
617
0
    const uint16_t *lc = lhs->c, *rc = rhs->c;
618
0
    const uint16_t *roots = kModRoots;
619
620
0
    do {
621
0
        uint32_t l0 = *lc++, r0 = *rc++;
622
0
        uint32_t l1 = *lc++, r1 = *rc++;
623
0
        uint16_t *c0 = curr++;
624
0
        uint16_t *c1 = curr++;
625
0
        uint32_t zetapow = *roots++;
626
627
0
        *c0 = reduce(*c0 + l0 * r0 + reduce(l1 * r1) * zetapow);
628
0
        *c1 = reduce(*c1 + l0 * r1 + l1 * r0);
629
0
    } while (curr < end);
630
0
}
631
632
/*-
633
 * FIPS 203, Section 4.2.1, Algorithm 5: "ByteEncode_d", for 2<=d<=12.
634
 * Here |bits| is |d|.  For efficiency, we handle the d=1 case separately.
635
 */
636
static void scalar_encode(uint8_t *out, const scalar *s, int bits)
637
0
{
638
0
    const uint16_t *curr = s->c, *end = curr + DEGREE;
639
0
    uint64_t accum = 0, element;
640
0
    int used = 0;
641
642
0
    do {
643
0
        element = *curr++;
644
0
        if (used + bits < 64) {
645
0
            accum |= element << used;
646
0
            used += bits;
647
0
        } else if (used + bits > 64) {
648
0
            out = OPENSSL_store_u64_le(out, accum | (element << used));
649
0
            accum = element >> (64 - used);
650
0
            used = (used + bits) - 64;
651
0
        } else {
652
0
            out = OPENSSL_store_u64_le(out, accum | (element << used));
653
0
            accum = 0;
654
0
            used = 0;
655
0
        }
656
0
    } while (curr < end);
657
0
}
658
659
/*
660
 * scalar_encode_1 is |scalar_encode| specialised for |bits| == 1.
661
 */
662
static void scalar_encode_1(uint8_t out[DEGREE / 8], const scalar *s)
663
0
{
664
0
    int i, j;
665
0
    uint8_t out_byte;
666
667
0
    for (i = 0; i < DEGREE; i += 8) {
668
0
        out_byte = 0;
669
0
        for (j = 0; j < 8; j++)
670
0
            out_byte |= bit0(s->c[i + j]) << j;
671
0
        *out = out_byte;
672
0
        out++;
673
0
    }
674
0
}
675
676
/*-
677
 * FIPS 203, Section 4.2.1, Algorithm 6: "ByteDecode_d", for 2<=d<12.
678
 * Here |bits| is |d|.  For efficiency, we handle the d=1 and d=12 cases
679
 * separately.
680
 *
681
 * scalar_decode parses |DEGREE * bits| bits from |in| into |DEGREE| values in
682
 * |out|.
683
 */
684
static void scalar_decode(scalar *out, const uint8_t *in, int bits)
685
0
{
686
0
    uint16_t *curr = out->c, *end = curr + DEGREE;
687
0
    uint64_t accum = 0;
688
0
    int accum_bits = 0, todo = bits;
689
0
    uint16_t bitmask = (((uint16_t)1) << bits) - 1, mask = bitmask;
690
0
    uint16_t element = 0;
691
692
0
    do {
693
0
        if (accum_bits == 0) {
694
0
            in = OPENSSL_load_u64_le(&accum, in);
695
0
            accum_bits = 64;
696
0
        }
697
0
        if (todo == bits && accum_bits >= bits) {
698
            /* No partial "element", and all the required bits available */
699
0
            *curr++ = ((uint16_t)accum) & mask;
700
0
            accum >>= bits;
701
0
            accum_bits -= bits;
702
0
        } else if (accum_bits >= todo) {
703
            /* A partial "element", and all the required bits available */
704
0
            *curr++ = element | ((((uint16_t)accum) & mask) << (bits - todo));
705
0
            accum >>= todo;
706
0
            accum_bits -= todo;
707
0
            element = 0;
708
0
            todo = bits;
709
0
            mask = bitmask;
710
0
        } else {
711
            /*
712
             * Only some of the requisite bits accumulated, store |accum_bits|
713
             * of these in |element|.  The accumulated bitcount becomes 0, but
714
             * as soon as we have more bits we'll want to merge accum_bits
715
             * fewer of them into the final |element|.
716
             *
717
             * Note that with a 64-bit accumulator and |bits| always 12 or
718
             * less, if we're here, the previous iteration had all the
719
             * requisite bits, and so there are no kept bits in |element|.
720
             */
721
0
            element = ((uint16_t)accum) & mask;
722
0
            todo -= accum_bits;
723
0
            mask = bitmask >> accum_bits;
724
0
            accum_bits = 0;
725
0
        }
726
0
    } while (curr < end);
727
0
}
728
729
static __owur int scalar_decode_12(scalar *out, const uint8_t in[3 * DEGREE / 2])
730
0
{
731
0
    int i;
732
0
    uint16_t *c = out->c;
733
734
0
    for (i = 0; i < DEGREE / 2; ++i) {
735
0
        uint8_t b1 = *in++;
736
0
        uint8_t b2 = *in++;
737
0
        uint8_t b3 = *in++;
738
0
        int outOfRange1 = (*c++ = b1 | ((b2 & 0x0f) << 8)) >= kPrime;
739
0
        int outOfRange2 = (*c++ = (b2 >> 4) | (b3 << 4)) >= kPrime;
740
741
0
        if (outOfRange1 | outOfRange2)
742
0
            return 0;
743
0
    }
744
0
    return 1;
745
0
}
746
747
/*-
748
 * scalar_decode_decompress_add is a combination of decoding and decompression
749
 * both specialised for |bits| == 1, with the result added (and sum reduced) to
750
 * the output scalar.
751
 *
752
 * NOTE: this function MUST not leak an input-data-depedennt timing signal.
753
 * A timing leak in a related function in the reference Kyber implementation
754
 * made the "clangover" attack (CVE-2024-37880) possible, giving key recovery
755
 * for ML-KEM-512 in minutes, provided the attacker has access to precise
756
 * timing of a CPU performing chosen-ciphertext decap.  Admittedly this is only
757
 * a risk when private keys are reused (perhaps KEMTLS servers).
758
 */
759
static void
760
scalar_decode_decompress_add(scalar *out, const uint8_t in[DEGREE / 8])
761
0
{
762
0
    static const uint16_t half_q_plus_1 = (ML_KEM_PRIME >> 1) + 1;
763
0
    uint16_t *curr = out->c, *end = curr + DEGREE;
764
0
    uint16_t mask;
765
0
    uint8_t b;
766
767
    /*
768
     * Add |half_q_plus_1| if the bit is set, without exposing a side-channel,
769
     * avoiding the "clangover" attack.  See |constish_time_non_zero| for a
770
     * discussion on why the value barrier is by default omitted.
771
     */
772
0
#define decode_decompress_add_bit                        \
773
0
    mask = constish_time_non_zero(bit0(b));              \
774
0
    *curr = reduce_once(*curr + (mask & half_q_plus_1)); \
775
0
    curr++;                                              \
776
0
    b >>= 1
777
778
    /* Unrolled to process each byte in one iteration */
779
0
    do {
780
0
        b = *in++;
781
0
        decode_decompress_add_bit;
782
0
        decode_decompress_add_bit;
783
0
        decode_decompress_add_bit;
784
0
        decode_decompress_add_bit;
785
786
0
        decode_decompress_add_bit;
787
0
        decode_decompress_add_bit;
788
0
        decode_decompress_add_bit;
789
0
        decode_decompress_add_bit;
790
0
    } while (curr < end);
791
0
#undef decode_decompress_add_bit
792
0
}
793
794
/*
795
 * FIPS 203, Section 4.2.1, Equation (4.7): Compress_d.
796
 *
797
 * Compresses (lossily) an input |x| mod 3329 into |bits| many bits by grouping
798
 * numbers close to each other together. The formula used is
799
 * round(2^|bits|/kPrime*x) mod 2^|bits|.
800
 * Uses Barrett reduction to achieve constant time. Since we need both the
801
 * remainder (for rounding) and the quotient (as the result), we cannot use
802
 * |reduce| here, but need to do the Barrett reduction directly.
803
 */
804
static __owur uint16_t compress(uint16_t x, int bits)
805
0
{
806
0
    uint32_t shifted = (uint32_t)x << bits;
807
0
    uint64_t product = (uint64_t)shifted * kBarrettMultiplier;
808
0
    uint32_t quotient = (uint32_t)(product >> kBarrettShift);
809
0
    uint32_t remainder = shifted - quotient * kPrime;
810
811
    /*
812
     * Adjust the quotient to round correctly:
813
     *   0 <= remainder <= kHalfPrime round to 0
814
     *   kHalfPrime < remainder <= kPrime + kHalfPrime round to 1
815
     *   kPrime + kHalfPrime < remainder < 2 * kPrime round to 2
816
     */
817
0
    quotient += 1 & constant_time_lt_32(kHalfPrime, remainder);
818
0
    quotient += 1 & constant_time_lt_32(kPrime + kHalfPrime, remainder);
819
0
    return quotient & ((1 << bits) - 1);
820
0
}
821
822
/*
823
 * FIPS 203, Section 4.2.1, Equation (4.8): Decompress_d.
824
825
 * Decompresses |x| by using a close equi-distant representative. The formula
826
 * is round(kPrime/2^|bits|*x). Note that 2^|bits| being the divisor allows us
827
 * to implement this logic using only bit operations.
828
 */
829
static __owur uint16_t decompress(uint16_t x, int bits)
830
0
{
831
0
    uint32_t product = (uint32_t)x * kPrime;
832
0
    uint32_t power = 1 << bits;
833
    /* This is |product| % power, since |power| is a power of 2. */
834
0
    uint32_t remainder = product & (power - 1);
835
    /* This is |product| / power, since |power| is a power of 2. */
836
0
    uint32_t lower = product >> bits;
837
838
    /*
839
     * The rounding logic works since the first half of numbers mod |power|
840
     * have a 0 as first bit, and the second half has a 1 as first bit, since
841
     * |power| is a power of 2. As a 12 bit number, |remainder| is always
842
     * positive, so we will shift in 0s for a right shift.
843
     */
844
0
    return lower + (remainder >> (bits - 1));
845
0
}
846
847
/*-
848
 * FIPS 203, Section 4.2.1, Equation (4.7): "Compress_d".
849
 * In-place lossy rounding of scalars to 2^d bits.
850
 */
851
static void scalar_compress(scalar *s, int bits)
852
0
{
853
0
    int i;
854
855
0
    for (i = 0; i < DEGREE; i++)
856
0
        s->c[i] = compress(s->c[i], bits);
857
0
}
858
859
/*
860
 * FIPS 203, Section 4.2.1, Equation (4.8): "Decompress_d".
861
 * In-place approximate recovery of scalars from 2^d bit compression.
862
 */
863
static void scalar_decompress(scalar *s, int bits)
864
0
{
865
0
    int i;
866
867
0
    for (i = 0; i < DEGREE; i++)
868
0
        s->c[i] = decompress(s->c[i], bits);
869
0
}
870
871
/* Addition updating the LHS vector in-place. */
872
static void vector_add(scalar *lhs, const scalar *rhs, int rank)
873
0
{
874
0
    do {
875
0
        scalar_add(lhs++, rhs++);
876
0
    } while (--rank > 0);
877
0
}
878
879
/*
880
 * Encodes an entire vector into 32*|rank|*|bits| bytes. Note that since 256
881
 * (DEGREE) is divisible by 8, the individual vector entries will always fill a
882
 * whole number of bytes, so we do not need to worry about bit packing here.
883
 */
884
static void vector_encode(uint8_t *out, const scalar *a, int bits, int rank)
885
0
{
886
0
    int stride = bits * DEGREE / 8;
887
888
0
    for (; rank-- > 0; out += stride)
889
0
        scalar_encode(out, a++, bits);
890
0
}
891
892
/*
893
 * Decodes 32*|rank|*|bits| bytes from |in| into |out|. It returns early
894
 * if any parsed value is >= |ML_KEM_PRIME|.  The resulting scalars are
895
 * then decompressed and transformed via the NTT.
896
 *
897
 * Note: Used only in decrypt_cpa(), which returns void and so does not check
898
 * the return value of this function.  Side-channels are fine when the input
899
 * ciphertext to decap() is simply syntactically invalid.
900
 */
901
static void
902
vector_decode_decompress_ntt(scalar *out, const uint8_t *in, int bits, int rank)
903
0
{
904
0
    int stride = bits * DEGREE / 8;
905
906
0
    for (; rank-- > 0; in += stride, ++out) {
907
0
        scalar_decode(out, in, bits);
908
0
        scalar_decompress(out, bits);
909
0
        scalar_ntt(out);
910
0
    }
911
0
}
912
913
/* vector_decode(), specialised to bits == 12. */
914
static __owur int vector_decode_12(scalar *out, const uint8_t in[3 * DEGREE / 2], int rank)
915
0
{
916
0
    int stride = 3 * DEGREE / 2;
917
918
0
    for (; rank-- > 0; in += stride)
919
0
        if (!scalar_decode_12(out++, in))
920
0
            return 0;
921
0
    return 1;
922
0
}
923
924
/* In-place compression of each scalar component */
925
static void vector_compress(scalar *a, int bits, int rank)
926
0
{
927
0
    do {
928
0
        scalar_compress(a++, bits);
929
0
    } while (--rank > 0);
930
0
}
931
932
/* The output scalar must not overlap with the inputs */
933
static void inner_product(scalar *out, const scalar *lhs, const scalar *rhs,
934
    int rank)
935
0
{
936
0
    scalar_mult(out, lhs, rhs);
937
0
    while (--rank > 0)
938
0
        scalar_mult_add(out, ++lhs, ++rhs);
939
0
}
940
941
/*
942
 * Here, the output vector must not overlap with the inputs, the result is
943
 * directly subjected to inverse NTT.
944
 */
945
static void
946
matrix_mult_intt(scalar *out, const scalar *m, const scalar *a, int rank)
947
0
{
948
0
    const scalar *ar;
949
0
    int i, j;
950
951
0
    for (i = rank; i-- > 0; ++out) {
952
0
        scalar_mult(out, m++, ar = a);
953
0
        for (j = rank - 1; j > 0; --j)
954
0
            scalar_mult_add(out, m++, ++ar);
955
0
        scalar_inverse_ntt(out);
956
0
    }
957
0
}
958
959
/* Here, the output vector must not overlap with the inputs */
960
static void
961
matrix_mult_transpose_add(scalar *out, const scalar *m, const scalar *a, int rank)
962
0
{
963
0
    const scalar *mc = m, *mr, *ar;
964
0
    int i, j;
965
966
0
    for (i = rank; i-- > 0; ++out) {
967
0
        scalar_mult_add(out, mr = mc++, ar = a);
968
0
        for (j = rank; --j > 0;)
969
0
            scalar_mult_add(out, (mr += rank), ++ar);
970
0
    }
971
0
}
972
973
/*-
974
 * Expands the matrix from a seed for key generation and for encaps-CPA.
975
 * NOTE: FIPS 203 matrix "A" is the transpose of this matrix, computed
976
 * by appending the (i,j) indices to the seed in the opposite order!
977
 *
978
 * Where FIPS 203 computes t = A * s + e, we use the transpose of "m".
979
 */
980
static __owur int matrix_expand(EVP_MD_CTX *mdctx, ML_KEM_KEY *key)
981
0
{
982
0
    scalar *out = key->m;
983
0
    uint8_t input[ML_KEM_RANDOM_BYTES + 2];
984
0
    int rank = key->vinfo->rank;
985
0
    int i, j;
986
987
0
    memcpy(input, key->rho, ML_KEM_RANDOM_BYTES);
988
0
    for (i = 0; i < rank; i++) {
989
0
        for (j = 0; j < rank; j++) {
990
0
            input[ML_KEM_RANDOM_BYTES] = i;
991
0
            input[ML_KEM_RANDOM_BYTES + 1] = j;
992
0
            if (!EVP_DigestInit_ex(mdctx, key->shake128_md, NULL)
993
0
                || !EVP_DigestUpdate(mdctx, input, sizeof(input))
994
0
                || !sample_scalar(out++, mdctx))
995
0
                return 0;
996
0
        }
997
0
    }
998
0
    return 1;
999
0
}
1000
1001
/*
1002
 * Algorithm 7 from the spec, with eta fixed to two and the PRF call
1003
 * included. Creates binominally distributed elements by sampling 2*|eta| bits,
1004
 * and setting the coefficient to the count of the first bits minus the count of
1005
 * the second bits, resulting in a centered binomial distribution. Since eta is
1006
 * two this gives -2/2 with a probability of 1/16, -1/1 with probability 1/4,
1007
 * and 0 with probability 3/8.
1008
 */
1009
static __owur int cbd_2(scalar *out, uint8_t in[ML_KEM_RANDOM_BYTES + 1],
1010
    EVP_MD_CTX *mdctx, const ML_KEM_KEY *key)
1011
0
{
1012
0
    uint16_t *curr = out->c, *end = curr + DEGREE;
1013
0
    uint8_t randbuf[4 * DEGREE / 8], *r = randbuf; /* 64 * eta slots */
1014
0
    uint16_t value, mask;
1015
0
    uint8_t b;
1016
1017
0
    if (!prf(randbuf, sizeof(randbuf), in, mdctx, key))
1018
0
        return 0;
1019
1020
0
    do {
1021
0
        b = *r++;
1022
1023
        /*
1024
         * Add |kPrime| if |value| underflowed.  See |constish_time_non_zero|
1025
         * for a discussion on why the value barrier is by default omitted.
1026
         * While this could have been written reduce_once(value + kPrime), this
1027
         * is one extra addition and small range of |value| tempts some
1028
         * versions of Clang to emit a branch.
1029
         */
1030
0
        value = bit0(b) + bitn(1, b);
1031
0
        value -= bitn(2, b) + bitn(3, b);
1032
0
        mask = constish_time_non_zero(value >> 15);
1033
0
        *curr++ = value + (kPrime & mask);
1034
1035
0
        value = bitn(4, b) + bitn(5, b);
1036
0
        value -= bitn(6, b) + bitn(7, b);
1037
0
        mask = constish_time_non_zero(value >> 15);
1038
0
        *curr++ = value + (kPrime & mask);
1039
0
    } while (curr < end);
1040
0
    return 1;
1041
0
}
1042
1043
/*
1044
 * Algorithm 7 from the spec, with eta fixed to three and the PRF call
1045
 * included. Creates binominally distributed elements by sampling 3*|eta| bits,
1046
 * and setting the coefficient to the count of the first bits minus the count of
1047
 * the second bits, resulting in a centered binomial distribution.
1048
 */
1049
static __owur int cbd_3(scalar *out, uint8_t in[ML_KEM_RANDOM_BYTES + 1],
1050
    EVP_MD_CTX *mdctx, const ML_KEM_KEY *key)
1051
0
{
1052
0
    uint16_t *curr = out->c, *end = curr + DEGREE;
1053
0
    uint8_t randbuf[6 * DEGREE / 8], *r = randbuf; /* 64 * eta slots */
1054
0
    uint8_t b1, b2, b3;
1055
0
    uint16_t value, mask;
1056
1057
0
    if (!prf(randbuf, sizeof(randbuf), in, mdctx, key))
1058
0
        return 0;
1059
1060
0
    do {
1061
0
        b1 = *r++;
1062
0
        b2 = *r++;
1063
0
        b3 = *r++;
1064
1065
        /*
1066
         * Add |kPrime| if |value| underflowed.  See |constish_time_non_zero|
1067
         * for a discussion on why the value barrier is by default omitted.
1068
         * While this could have been written reduce_once(value + kPrime), this
1069
         * is one extra addition and small range of |value| tempts some
1070
         * versions of Clang to emit a branch.
1071
         */
1072
0
        value = bit0(b1) + bitn(1, b1) + bitn(2, b1);
1073
0
        value -= bitn(3, b1) + bitn(4, b1) + bitn(5, b1);
1074
0
        mask = constish_time_non_zero(value >> 15);
1075
0
        *curr++ = value + (kPrime & mask);
1076
1077
0
        value = bitn(6, b1) + bitn(7, b1) + bit0(b2);
1078
0
        value -= bitn(1, b2) + bitn(2, b2) + bitn(3, b2);
1079
0
        mask = constish_time_non_zero(value >> 15);
1080
0
        *curr++ = value + (kPrime & mask);
1081
1082
0
        value = bitn(4, b2) + bitn(5, b2) + bitn(6, b2);
1083
0
        value -= bitn(7, b2) + bit0(b3) + bitn(1, b3);
1084
0
        mask = constish_time_non_zero(value >> 15);
1085
0
        *curr++ = value + (kPrime & mask);
1086
1087
0
        value = bitn(2, b3) + bitn(3, b3) + bitn(4, b3);
1088
0
        value -= bitn(5, b3) + bitn(6, b3) + bitn(7, b3);
1089
0
        mask = constish_time_non_zero(value >> 15);
1090
0
        *curr++ = value + (kPrime & mask);
1091
0
    } while (curr < end);
1092
0
    return 1;
1093
0
}
1094
1095
/*
1096
 * Generates a secret vector by using |cbd| with the given seed to generate
1097
 * scalar elements and incrementing |counter| for each slot of the vector.
1098
 */
1099
static __owur int gencbd_vector(scalar *out, CBD_FUNC cbd, uint8_t *counter,
1100
    const uint8_t seed[ML_KEM_RANDOM_BYTES], int rank,
1101
    EVP_MD_CTX *mdctx, const ML_KEM_KEY *key)
1102
0
{
1103
0
    uint8_t input[ML_KEM_RANDOM_BYTES + 1];
1104
1105
0
    memcpy(input, seed, ML_KEM_RANDOM_BYTES);
1106
0
    do {
1107
0
        input[ML_KEM_RANDOM_BYTES] = (*counter)++;
1108
0
        if (!cbd(out++, input, mdctx, key))
1109
0
            return 0;
1110
0
    } while (--rank > 0);
1111
0
    return 1;
1112
0
}
1113
1114
/*
1115
 * As above plus NTT transform.
1116
 */
1117
static __owur int gencbd_vector_ntt(scalar *out, CBD_FUNC cbd, uint8_t *counter,
1118
    const uint8_t seed[ML_KEM_RANDOM_BYTES], int rank,
1119
    EVP_MD_CTX *mdctx, const ML_KEM_KEY *key)
1120
0
{
1121
0
    uint8_t input[ML_KEM_RANDOM_BYTES + 1];
1122
1123
0
    memcpy(input, seed, ML_KEM_RANDOM_BYTES);
1124
0
    do {
1125
0
        input[ML_KEM_RANDOM_BYTES] = (*counter)++;
1126
0
        if (!cbd(out, input, mdctx, key))
1127
0
            return 0;
1128
0
        scalar_ntt(out++);
1129
0
    } while (--rank > 0);
1130
0
    return 1;
1131
0
}
1132
1133
/* The |ETA1| value for ML-KEM-512 is 3, the rest and all ETA2 values are 2. */
1134
0
#define CBD1(evp_type) ((evp_type) == EVP_PKEY_ML_KEM_512 ? cbd_3 : cbd_2)
1135
1136
/*
1137
 * FIPS 203, Section 5.2, Algorithm 14: K-PKE.Encrypt.
1138
 *
1139
 * Encrypts a message with given randomness to the ciphertext in |out|. Without
1140
 * applying the Fujisaki-Okamoto transform this would not result in a CCA
1141
 * secure scheme, since lattice schemes are vulnerable to decryption failure
1142
 * oracles.
1143
 *
1144
 * The steps are re-ordered to make more efficient/localised use of storage.
1145
 *
1146
 * Note also that the input public key is assumed to hold a precomputed matrix
1147
 * |A| (our key->m, with the public key holding an expanded (16-bit per scalar
1148
 * coefficient) key->t vector).
1149
 *
1150
 * Caller passes storage in |tmp| for for two temporary vectors.
1151
 */
1152
static __owur int encrypt_cpa(uint8_t out[ML_KEM_SHARED_SECRET_BYTES],
1153
    const uint8_t message[DEGREE / 8],
1154
    const uint8_t r[ML_KEM_RANDOM_BYTES], scalar *tmp,
1155
    EVP_MD_CTX *mdctx, const ML_KEM_KEY *key)
1156
0
{
1157
0
    const ML_KEM_VINFO *vinfo = key->vinfo;
1158
0
    CBD_FUNC cbd_1 = CBD1(vinfo->evp_type);
1159
0
    int rank = vinfo->rank;
1160
    /* We can use tmp[0..rank-1] as storage for |y|, then |e1|, ... */
1161
0
    scalar *y = &tmp[0], *e1 = y, *e2 = y;
1162
    /* We can use tmp[rank]..tmp[2*rank - 1] for |u| */
1163
0
    scalar *u = &tmp[rank];
1164
0
    scalar v;
1165
0
    uint8_t input[ML_KEM_RANDOM_BYTES + 1];
1166
0
    uint8_t counter = 0;
1167
0
    int du = vinfo->du;
1168
0
    int dv = vinfo->dv;
1169
1170
    /* FIPS 203 "y" vector */
1171
0
    if (!gencbd_vector_ntt(y, cbd_1, &counter, r, rank, mdctx, key))
1172
0
        return 0;
1173
    /* FIPS 203 "v" scalar */
1174
0
    inner_product(&v, key->t, y, rank);
1175
0
    scalar_inverse_ntt(&v);
1176
    /* FIPS 203 "u" vector */
1177
0
    matrix_mult_intt(u, key->m, y, rank);
1178
1179
    /* All done with |y|, now free to reuse tmp[0] for FIPS 203 |e1| */
1180
0
    if (!gencbd_vector(e1, cbd_2, &counter, r, rank, mdctx, key))
1181
0
        return 0;
1182
0
    vector_add(u, e1, rank);
1183
0
    vector_compress(u, du, rank);
1184
0
    vector_encode(out, u, du, rank);
1185
1186
    /* All done with |e1|, now free to reuse tmp[0] for FIPS 203 |e2| */
1187
0
    memcpy(input, r, ML_KEM_RANDOM_BYTES);
1188
0
    input[ML_KEM_RANDOM_BYTES] = counter;
1189
0
    if (!cbd_2(e2, input, mdctx, key))
1190
0
        return 0;
1191
0
    scalar_add(&v, e2);
1192
1193
    /* Combine message with |v| */
1194
0
    scalar_decode_decompress_add(&v, message);
1195
0
    scalar_compress(&v, dv);
1196
0
    scalar_encode(out + vinfo->u_vector_bytes, &v, dv);
1197
0
    return 1;
1198
0
}
1199
1200
/*
1201
 * FIPS 203, Section 5.3, Algorithm 15: K-PKE.Decrypt.
1202
 */
1203
static void
1204
decrypt_cpa(uint8_t out[ML_KEM_SHARED_SECRET_BYTES],
1205
    const uint8_t *ctext, scalar *u, const ML_KEM_KEY *key)
1206
0
{
1207
0
    const ML_KEM_VINFO *vinfo = key->vinfo;
1208
0
    scalar v, mask;
1209
0
    int rank = vinfo->rank;
1210
0
    int du = vinfo->du;
1211
0
    int dv = vinfo->dv;
1212
1213
0
    vector_decode_decompress_ntt(u, ctext, du, rank);
1214
0
    scalar_decode(&v, ctext + vinfo->u_vector_bytes, dv);
1215
0
    scalar_decompress(&v, dv);
1216
0
    inner_product(&mask, key->s, u, rank);
1217
0
    scalar_inverse_ntt(&mask);
1218
0
    scalar_sub(&v, &mask);
1219
0
    scalar_compress(&v, 1);
1220
0
    scalar_encode_1(out, &v);
1221
0
}
1222
1223
/*-
1224
 * FIPS 203, Section 7.1, Algorithm 19: "ML-KEM.KeyGen".
1225
 * FIPS 203, Section 7.2, Algorithm 20: "ML-KEM.Encaps".
1226
 *
1227
 * Fills the |out| buffer with the |ek| output of "ML-KEM.KeyGen", or,
1228
 * equivalently, the |ek| input of "ML-KEM.Encaps", i.e. returns the
1229
 * wire-format of an ML-KEM public key.
1230
 */
1231
static void encode_pubkey(uint8_t *out, const ML_KEM_KEY *key)
1232
0
{
1233
0
    const uint8_t *rho = key->rho;
1234
0
    const ML_KEM_VINFO *vinfo = key->vinfo;
1235
1236
0
    vector_encode(out, key->t, 12, vinfo->rank);
1237
0
    memcpy(out + vinfo->vector_bytes, rho, ML_KEM_RANDOM_BYTES);
1238
0
}
1239
1240
/*-
1241
 * FIPS 203, Section 7.1, Algorithm 19: "ML-KEM.KeyGen".
1242
 *
1243
 * Fills the |out| buffer with the |dk| output of "ML-KEM.KeyGen".
1244
 * This matches the input format of parse_prvkey() below.
1245
 */
1246
static void encode_prvkey(uint8_t *out, const ML_KEM_KEY *key)
1247
0
{
1248
0
    const ML_KEM_VINFO *vinfo = key->vinfo;
1249
1250
0
    vector_encode(out, key->s, 12, vinfo->rank);
1251
0
    out += vinfo->vector_bytes;
1252
0
    encode_pubkey(out, key);
1253
0
    out += vinfo->pubkey_bytes;
1254
0
    memcpy(out, key->pkhash, ML_KEM_PKHASH_BYTES);
1255
0
    out += ML_KEM_PKHASH_BYTES;
1256
0
    memcpy(out, key->z, ML_KEM_RANDOM_BYTES);
1257
0
}
1258
1259
/*-
1260
 * FIPS 203, Section 7.1, Algorithm 19: "ML-KEM.KeyGen".
1261
 * FIPS 203, Section 7.2, Algorithm 20: "ML-KEM.Encaps".
1262
 *
1263
 * This function parses the |in| buffer as the |ek| output of "ML-KEM.KeyGen",
1264
 * or, equivalently, the |ek| input of "ML-KEM.Encaps", i.e. decodes the
1265
 * wire-format of the ML-KEM public key.
1266
 */
1267
static int parse_pubkey(const uint8_t *in, EVP_MD_CTX *mdctx, ML_KEM_KEY *key)
1268
0
{
1269
0
    const ML_KEM_VINFO *vinfo = key->vinfo;
1270
1271
    /* Decode and check |t| */
1272
0
    if (!vector_decode_12(key->t, in, vinfo->rank)) {
1273
0
        ERR_raise_data(ERR_LIB_PROV, PROV_R_INVALID_KEY,
1274
0
            "%s invalid public 't' vector",
1275
0
            vinfo->algorithm_name);
1276
0
        return 0;
1277
0
    }
1278
    /* Save the matrix |m| recovery seed |rho| */
1279
0
    memcpy(key->rho, in + vinfo->vector_bytes, ML_KEM_RANDOM_BYTES);
1280
    /*
1281
     * Pre-compute the public key hash, needed for both encap and decap.
1282
     * Also pre-compute the matrix expansion, stored with the public key.
1283
     */
1284
0
    if (!hash_h(key->pkhash, in, vinfo->pubkey_bytes, mdctx, key)
1285
0
        || !matrix_expand(mdctx, key)) {
1286
0
        ERR_raise_data(ERR_LIB_CRYPTO, ERR_R_INTERNAL_ERROR,
1287
0
            "internal error while parsing %s public key",
1288
0
            vinfo->algorithm_name);
1289
0
        return 0;
1290
0
    }
1291
0
    return 1;
1292
0
}
1293
1294
/*
1295
 * FIPS 203, Section 7.1, Algorithm 19: "ML-KEM.KeyGen".
1296
 *
1297
 * Parses the |in| buffer as a |dk| output of "ML-KEM.KeyGen".
1298
 * This matches the output format of encode_prvkey() above.
1299
 */
1300
static int parse_prvkey(const uint8_t *in, EVP_MD_CTX *mdctx, ML_KEM_KEY *key)
1301
0
{
1302
0
    const ML_KEM_VINFO *vinfo = key->vinfo;
1303
1304
    /* Decode and check |s|. */
1305
0
    if (!vector_decode_12(key->s, in, vinfo->rank)) {
1306
0
        ERR_raise_data(ERR_LIB_PROV, PROV_R_INVALID_KEY,
1307
0
            "%s invalid private 's' vector",
1308
0
            vinfo->algorithm_name);
1309
0
        return 0;
1310
0
    }
1311
0
    in += vinfo->vector_bytes;
1312
1313
0
    if (!parse_pubkey(in, mdctx, key))
1314
0
        return 0;
1315
0
    in += vinfo->pubkey_bytes;
1316
1317
    /* Check public key hash. */
1318
0
    if (memcmp(key->pkhash, in, ML_KEM_PKHASH_BYTES) != 0) {
1319
0
        ERR_raise_data(ERR_LIB_PROV, PROV_R_INVALID_KEY,
1320
0
            "%s public key hash mismatch",
1321
0
            vinfo->algorithm_name);
1322
0
        return 0;
1323
0
    }
1324
0
    in += ML_KEM_PKHASH_BYTES;
1325
1326
0
    memcpy(key->z, in, ML_KEM_RANDOM_BYTES);
1327
0
    return 1;
1328
0
}
1329
1330
/*
1331
 * FIPS 203, Section 6.1, Algorithm 16: "ML-KEM.KeyGen_internal".
1332
 *
1333
 * The implementation of Section 5.1, Algorithm 13, "K-PKE.KeyGen(d)" is
1334
 * inlined.
1335
 *
1336
 * The caller MUST pass a pre-allocated digest context that is not shared with
1337
 * any concurrent computation.
1338
 *
1339
 * This function optionally outputs the serialised wire-form |ek| public key
1340
 * into the provided |pubenc| buffer, and generates the content of the |rho|,
1341
 * |pkhash|, |t|, |m|, |s| and |z| components of the private |key| (which must
1342
 * have preallocated space for these).
1343
 *
1344
 * Keys are computed from a 32-byte random |d| plus the 1 byte rank for
1345
 * domain separation.  These are concatenated and hashed to produce a pair of
1346
 * 32-byte seeds public "rho", used to generate the matrix, and private "sigma",
1347
 * used to generate the secret vector |s|.
1348
 *
1349
 * The second random input |z| is copied verbatim into the Fujisaki-Okamoto
1350
 * (FO) transform "implicit-rejection" secret (the |z| component of the private
1351
 * key), which thwarts chosen-ciphertext attacks, provided decap() runs in
1352
 * constant time, with no side channel leaks, on all well-formed (valid length,
1353
 * and correctly encoded) ciphertext inputs.
1354
 */
1355
static __owur int genkey(const uint8_t seed[ML_KEM_SEED_BYTES],
1356
    EVP_MD_CTX *mdctx, uint8_t *pubenc, ML_KEM_KEY *key)
1357
0
{
1358
0
    uint8_t hashed[2 * ML_KEM_RANDOM_BYTES];
1359
0
    const uint8_t *const sigma = hashed + ML_KEM_RANDOM_BYTES;
1360
0
    uint8_t augmented_seed[ML_KEM_RANDOM_BYTES + 1];
1361
0
    const ML_KEM_VINFO *vinfo = key->vinfo;
1362
0
    CBD_FUNC cbd_1 = CBD1(vinfo->evp_type);
1363
0
    int rank = vinfo->rank;
1364
0
    uint8_t counter = 0;
1365
0
    int ret = 0;
1366
1367
    /*
1368
     * Use the "d" seed salted with the rank to derive the public and private
1369
     * seeds rho and sigma.
1370
     */
1371
0
    memcpy(augmented_seed, seed, ML_KEM_RANDOM_BYTES);
1372
0
    augmented_seed[ML_KEM_RANDOM_BYTES] = (uint8_t)rank;
1373
0
    if (!hash_g(hashed, augmented_seed, sizeof(augmented_seed), mdctx, key))
1374
0
        goto end;
1375
0
    memcpy(key->rho, hashed, ML_KEM_RANDOM_BYTES);
1376
    /* The |rho| matrix seed is public */
1377
0
    CONSTTIME_DECLASSIFY(key->rho, ML_KEM_RANDOM_BYTES);
1378
1379
    /* FIPS 203 |e| vector is initial value of key->t */
1380
0
    if (!matrix_expand(mdctx, key)
1381
0
        || !gencbd_vector_ntt(key->s, cbd_1, &counter, sigma, rank, mdctx, key)
1382
0
        || !gencbd_vector_ntt(key->t, cbd_1, &counter, sigma, rank, mdctx, key))
1383
0
        goto end;
1384
1385
    /* To |e| we now add the product of transpose |m| and |s|, giving |t|. */
1386
0
    matrix_mult_transpose_add(key->t, key->m, key->s, rank);
1387
    /* The |t| vector is public */
1388
0
    CONSTTIME_DECLASSIFY(key->t, vinfo->rank * sizeof(scalar));
1389
1390
0
    if (pubenc == NULL) {
1391
        /* Incremental digest of public key without in-full serialisation. */
1392
0
        if (!hash_h_pubkey(key->pkhash, mdctx, key))
1393
0
            goto end;
1394
0
    } else {
1395
0
        encode_pubkey(pubenc, key);
1396
0
        if (!hash_h(key->pkhash, pubenc, vinfo->pubkey_bytes, mdctx, key))
1397
0
            goto end;
1398
0
    }
1399
1400
    /* Save |z| portion of seed for "implicit rejection" on failure. */
1401
0
    memcpy(key->z, seed + ML_KEM_RANDOM_BYTES, ML_KEM_RANDOM_BYTES);
1402
1403
    /* Save the |d| portion of the seed */
1404
0
    key->d = key->z + ML_KEM_RANDOM_BYTES;
1405
0
    memcpy(key->d, seed, ML_KEM_RANDOM_BYTES);
1406
1407
0
    ret = 1;
1408
0
end:
1409
0
    OPENSSL_cleanse((void *)augmented_seed, ML_KEM_RANDOM_BYTES);
1410
0
    OPENSSL_cleanse((void *)sigma, ML_KEM_RANDOM_BYTES);
1411
0
    if (ret == 0) {
1412
0
        ERR_raise_data(ERR_LIB_CRYPTO, ERR_R_INTERNAL_ERROR,
1413
0
            "internal error while generating %s private key",
1414
0
            vinfo->algorithm_name);
1415
0
    }
1416
0
    return ret;
1417
0
}
1418
1419
/*-
1420
 * FIPS 203, Section 6.2, Algorithm 17: "ML-KEM.Encaps_internal".
1421
 * This is the deterministic version with randomness supplied externally.
1422
 *
1423
 * The caller must pass space for two vectors in |tmp|.
1424
 * The |ctext| buffer have space for the ciphertext of the ML-KEM variant
1425
 * of the provided key.
1426
 */
1427
static int encap(uint8_t *ctext, uint8_t secret[ML_KEM_SHARED_SECRET_BYTES],
1428
    const uint8_t entropy[ML_KEM_RANDOM_BYTES],
1429
    scalar *tmp, EVP_MD_CTX *mdctx, const ML_KEM_KEY *key)
1430
0
{
1431
0
    uint8_t input[ML_KEM_RANDOM_BYTES + ML_KEM_PKHASH_BYTES];
1432
0
    uint8_t Kr[ML_KEM_SHARED_SECRET_BYTES + ML_KEM_RANDOM_BYTES];
1433
0
    uint8_t *r = Kr + ML_KEM_SHARED_SECRET_BYTES;
1434
0
    int ret;
1435
1436
0
    memcpy(input, entropy, ML_KEM_RANDOM_BYTES);
1437
0
    memcpy(input + ML_KEM_RANDOM_BYTES, key->pkhash, ML_KEM_PKHASH_BYTES);
1438
0
    ret = hash_g(Kr, input, sizeof(input), mdctx, key)
1439
0
        && encrypt_cpa(ctext, entropy, r, tmp, mdctx, key);
1440
0
    OPENSSL_cleanse((void *)input, sizeof(input));
1441
1442
0
    if (ret)
1443
0
        memcpy(secret, Kr, ML_KEM_SHARED_SECRET_BYTES);
1444
0
    else
1445
0
        ERR_raise_data(ERR_LIB_CRYPTO, ERR_R_INTERNAL_ERROR,
1446
0
            "internal error while performing %s encapsulation",
1447
0
            key->vinfo->algorithm_name);
1448
0
    return ret;
1449
0
}
1450
1451
/*
1452
 * FIPS 203, Section 6.3, Algorithm 18: ML-KEM.Decaps_internal
1453
 *
1454
 * Barring failure of the supporting SHA3/SHAKE primitives, this is fully
1455
 * deterministic, the randomness for the FO transform is extracted during
1456
 * private key generation.
1457
 *
1458
 * The caller must pass space for two vectors in |tmp|.
1459
 * The |ctext| and |tmp_ctext| buffers must each have space for the ciphertext
1460
 * of the key's ML-KEM variant.
1461
 */
1462
static int decap(uint8_t secret[ML_KEM_SHARED_SECRET_BYTES],
1463
    const uint8_t *ctext, uint8_t *tmp_ctext, scalar *tmp,
1464
    EVP_MD_CTX *mdctx, const ML_KEM_KEY *key)
1465
0
{
1466
0
    uint8_t decrypted[ML_KEM_SHARED_SECRET_BYTES + ML_KEM_PKHASH_BYTES];
1467
0
    uint8_t failure_key[ML_KEM_RANDOM_BYTES];
1468
0
    uint8_t Kr[ML_KEM_SHARED_SECRET_BYTES + ML_KEM_RANDOM_BYTES];
1469
0
    uint8_t *r = Kr + ML_KEM_SHARED_SECRET_BYTES;
1470
0
    const uint8_t *pkhash = key->pkhash;
1471
0
    const ML_KEM_VINFO *vinfo = key->vinfo;
1472
0
    int i;
1473
0
    uint8_t mask;
1474
1475
    /*
1476
     * If our KDF is unavailable, fail early! Otherwise, keep going ignoring
1477
     * any further errors, returning success, and whatever we got for a shared
1478
     * secret.  The decrypt_cpa() function is just arithmetic on secret data,
1479
     * so should not be subject to failure that makes its output predictable.
1480
     *
1481
     * We guard against "should never happen" catastrophic failure of the
1482
     * "pure" function |hash_g| by overwriting the shared secret with the
1483
     * content of the failure key and returning early, if nevertheless hash_g
1484
     * fails.  This is not constant-time, but a failure of |hash_g| already
1485
     * implies loss of side-channel resistance.
1486
     *
1487
     * The same action is taken, if also |encrypt_cpa| should catastrophically
1488
     * fail, due to failure of the |PRF| underlying the CBD functions.
1489
     */
1490
0
    if (!kdf(failure_key, key->z, ctext, vinfo->ctext_bytes, mdctx, key)) {
1491
0
        ERR_raise_data(ERR_LIB_CRYPTO, ERR_R_INTERNAL_ERROR,
1492
0
            "internal error while performing %s decapsulation",
1493
0
            vinfo->algorithm_name);
1494
0
        return 0;
1495
0
    }
1496
0
    decrypt_cpa(decrypted, ctext, tmp, key);
1497
0
    memcpy(decrypted + ML_KEM_SHARED_SECRET_BYTES, pkhash, ML_KEM_PKHASH_BYTES);
1498
0
    if (!hash_g(Kr, decrypted, sizeof(decrypted), mdctx, key)
1499
0
        || !encrypt_cpa(tmp_ctext, decrypted, r, tmp, mdctx, key)) {
1500
0
        memcpy(secret, failure_key, ML_KEM_SHARED_SECRET_BYTES);
1501
0
        OPENSSL_cleanse(decrypted, ML_KEM_SHARED_SECRET_BYTES);
1502
0
        return 1;
1503
0
    }
1504
0
    mask = constant_time_eq_int_8(0,
1505
0
        CRYPTO_memcmp(ctext, tmp_ctext, vinfo->ctext_bytes));
1506
0
    for (i = 0; i < ML_KEM_SHARED_SECRET_BYTES; i++)
1507
0
        secret[i] = constant_time_select_8(mask, Kr[i], failure_key[i]);
1508
0
    OPENSSL_cleanse(decrypted, ML_KEM_SHARED_SECRET_BYTES);
1509
0
    OPENSSL_cleanse(Kr, sizeof(Kr));
1510
0
    return 1;
1511
0
}
1512
1513
/*
1514
 * After allocating storage for public or private key data, update the key
1515
 * component pointers to reference that storage.
1516
 *
1517
 * The caller should only store private data in `priv` *after* a successful
1518
 * (non-zero) return from this function.
1519
 */
1520
static __owur int add_storage(scalar *pub, scalar *priv, int private, ML_KEM_KEY *key)
1521
0
{
1522
0
    int rank = key->vinfo->rank;
1523
1524
0
    if (pub == NULL || (private && priv == NULL)) {
1525
        /*
1526
         * One of these could be allocated correctly. It is legal to call free with a NULL
1527
         * pointer, so always attempt to free both allocations here
1528
         */
1529
0
        OPENSSL_free(pub);
1530
0
        OPENSSL_secure_free(priv);
1531
0
        return 0;
1532
0
    }
1533
1534
    /*
1535
     * We're adding key material, set up rho and pkhash to point to the rho_pkhash buffer
1536
     */
1537
0
    memset(key->rho_pkhash, 0, sizeof(key->rho_pkhash));
1538
0
    key->rho = key->rho_pkhash;
1539
0
    key->pkhash = key->rho_pkhash + ML_KEM_RANDOM_BYTES;
1540
0
    key->d = key->z = NULL;
1541
1542
    /* A public key needs space for |t| and |m| */
1543
0
    key->m = (key->t = pub) + rank;
1544
1545
    /*
1546
     * A private key also needs space for |s| and |z|.
1547
     * The |z| buffer always includes additional space for |d|, but a key's |d|
1548
     * pointer is left NULL when parsed from the NIST format, which omits that
1549
     * information.  Only keys generated from a (d, z) seed pair will have a
1550
     * non-NULL |d| pointer.
1551
     */
1552
0
    if (private)
1553
0
        key->z = (uint8_t *)(rank + (key->s = priv));
1554
0
    return 1;
1555
0
}
1556
1557
/*
1558
 * After freeing the storage associated with a key that failed to be
1559
 * constructed, reset the internal pointers back to NULL.
1560
 */
1561
void ossl_ml_kem_key_reset(ML_KEM_KEY *key)
1562
0
{
1563
    /*
1564
     * seedbuf can be allocated and contain |z| and |d| if the key is
1565
     * being created from a private key encoding.  Similarly a pending
1566
     * serialised (encoded) private key may be queued up to load.
1567
     * Clear and free that data now.
1568
     */
1569
0
    if (key->seedbuf != NULL)
1570
0
        OPENSSL_secure_clear_free(key->seedbuf, ML_KEM_SEED_BYTES);
1571
0
    if (ossl_ml_kem_have_dkenc(key))
1572
0
        OPENSSL_secure_clear_free(key->encoded_dk, key->vinfo->prvkey_bytes);
1573
1574
    /*-
1575
     * Cleanse any sensitive data:
1576
     * - The private vector |s| is immediately followed by the FO failure
1577
     *   secret |z|, and seed |d|, we can cleanse all three in one call.
1578
     */
1579
0
    if (key->t != NULL) {
1580
0
        if (ossl_ml_kem_have_prvkey(key))
1581
0
            OPENSSL_secure_clear_free(key->s, key->vinfo->prvalloc);
1582
0
        OPENSSL_free(key->t);
1583
0
    }
1584
0
    key->d = key->z = key->seedbuf = key->encoded_dk = (uint8_t *)(key->s = key->m = key->t = NULL);
1585
0
}
1586
1587
/*
1588
 * ----- API exported to the provider
1589
 *
1590
 * Parameters with an implicit fixed length in the internal static API of each
1591
 * variant have an explicit checked length argument at this layer.
1592
 */
1593
1594
/* Retrieve the parameters of one of the ML-KEM variants */
1595
const ML_KEM_VINFO *ossl_ml_kem_get_vinfo(int evp_type)
1596
0
{
1597
0
    switch (evp_type) {
1598
0
    case EVP_PKEY_ML_KEM_512:
1599
0
        return &vinfo_map[ML_KEM_512_VINFO];
1600
0
    case EVP_PKEY_ML_KEM_768:
1601
0
        return &vinfo_map[ML_KEM_768_VINFO];
1602
0
    case EVP_PKEY_ML_KEM_1024:
1603
0
        return &vinfo_map[ML_KEM_1024_VINFO];
1604
0
    }
1605
0
    return NULL;
1606
0
}
1607
1608
ML_KEM_KEY *ossl_ml_kem_key_new(OSSL_LIB_CTX *libctx, const char *properties,
1609
    int evp_type)
1610
0
{
1611
0
    const ML_KEM_VINFO *vinfo = ossl_ml_kem_get_vinfo(evp_type);
1612
0
    ML_KEM_KEY *key;
1613
1614
0
    if (vinfo == NULL) {
1615
0
        ERR_raise_data(ERR_LIB_CRYPTO, ERR_R_PASSED_INVALID_ARGUMENT,
1616
0
            "unsupported ML-KEM key type: %d", evp_type);
1617
0
        return NULL;
1618
0
    }
1619
1620
0
    if ((key = OPENSSL_malloc(sizeof(*key))) == NULL)
1621
0
        return NULL;
1622
1623
0
    key->vinfo = vinfo;
1624
0
    key->libctx = libctx;
1625
0
    key->prov_flags = ML_KEM_KEY_PROV_FLAGS_DEFAULT;
1626
0
    key->shake128_md = EVP_MD_fetch(libctx, "SHAKE128", properties);
1627
0
    key->shake256_md = EVP_MD_fetch(libctx, "SHAKE256", properties);
1628
0
    key->sha3_256_md = EVP_MD_fetch(libctx, "SHA3-256", properties);
1629
0
    key->sha3_512_md = EVP_MD_fetch(libctx, "SHA3-512", properties);
1630
0
    key->d = key->z = key->rho = key->pkhash = key->encoded_dk = key->seedbuf = NULL;
1631
0
    key->s = key->m = key->t = NULL;
1632
1633
0
    if (key->shake128_md != NULL
1634
0
        && key->shake256_md != NULL
1635
0
        && key->sha3_256_md != NULL
1636
0
        && key->sha3_512_md != NULL)
1637
0
        return key;
1638
1639
0
    ossl_ml_kem_key_free(key);
1640
0
    ERR_raise_data(ERR_LIB_CRYPTO, ERR_R_INTERNAL_ERROR,
1641
0
        "missing SHA3 digest algorithms while creating %s key",
1642
0
        vinfo->algorithm_name);
1643
0
    return NULL;
1644
0
}
1645
1646
ML_KEM_KEY *ossl_ml_kem_key_dup(const ML_KEM_KEY *key, int selection)
1647
0
{
1648
0
    int ok = 0;
1649
0
    ML_KEM_KEY *ret;
1650
0
    void *tmp_pub;
1651
0
    void *tmp_priv;
1652
1653
0
    if (key == NULL)
1654
0
        return NULL;
1655
    /*
1656
     * Partially decoded keys, not yet imported or loaded, should never be
1657
     * duplicated.
1658
     */
1659
0
    if (ossl_ml_kem_decoded_key(key))
1660
0
        return NULL;
1661
1662
0
    else if ((ret = OPENSSL_memdup(key, sizeof(*key))) == NULL)
1663
0
        return NULL;
1664
1665
0
    ret->d = ret->z = ret->rho = ret->pkhash = NULL;
1666
0
    ret->s = ret->m = ret->t = NULL;
1667
1668
    /* Clear selection bits we can't fulfill */
1669
0
    if (!ossl_ml_kem_have_pubkey(key))
1670
0
        selection = 0;
1671
0
    else if (!ossl_ml_kem_have_prvkey(key))
1672
0
        selection &= ~OSSL_KEYMGMT_SELECT_PRIVATE_KEY;
1673
1674
0
    switch (selection & OSSL_KEYMGMT_SELECT_KEYPAIR) {
1675
0
    case 0:
1676
0
        ok = 1;
1677
0
        break;
1678
0
    case OSSL_KEYMGMT_SELECT_PUBLIC_KEY:
1679
0
        ok = add_storage(OPENSSL_memdup(key->t, key->vinfo->puballoc), NULL, 0, ret);
1680
0
        break;
1681
0
    case OSSL_KEYMGMT_SELECT_PRIVATE_KEY:
1682
0
        tmp_pub = OPENSSL_memdup(key->t, key->vinfo->puballoc);
1683
0
        if (tmp_pub == NULL)
1684
0
            break;
1685
0
        tmp_priv = OPENSSL_secure_malloc(key->vinfo->prvalloc);
1686
0
        if (tmp_priv == NULL) {
1687
0
            OPENSSL_free(tmp_pub);
1688
0
            break;
1689
0
        }
1690
0
        if ((ok = add_storage(tmp_pub, tmp_priv, 1, ret)) != 0)
1691
0
            memcpy(tmp_priv, key->s, key->vinfo->prvalloc);
1692
        /* Duplicated keys retain |d|, if available */
1693
0
        if (key->d != NULL)
1694
0
            ret->d = ret->z + ML_KEM_RANDOM_BYTES;
1695
0
        break;
1696
0
    }
1697
1698
0
    if (!ok) {
1699
0
        OPENSSL_free(ret);
1700
0
        return NULL;
1701
0
    }
1702
1703
0
    EVP_MD_up_ref(ret->shake128_md);
1704
0
    EVP_MD_up_ref(ret->shake256_md);
1705
0
    EVP_MD_up_ref(ret->sha3_256_md);
1706
0
    EVP_MD_up_ref(ret->sha3_512_md);
1707
1708
0
    return ret;
1709
0
}
1710
1711
void ossl_ml_kem_key_free(ML_KEM_KEY *key)
1712
0
{
1713
0
    if (key == NULL)
1714
0
        return;
1715
1716
0
    EVP_MD_free(key->shake128_md);
1717
0
    EVP_MD_free(key->shake256_md);
1718
0
    EVP_MD_free(key->sha3_256_md);
1719
0
    EVP_MD_free(key->sha3_512_md);
1720
1721
0
    ossl_ml_kem_key_reset(key);
1722
0
    OPENSSL_free(key);
1723
0
}
1724
1725
/* Serialise the public component of an ML-KEM key */
1726
int ossl_ml_kem_encode_public_key(uint8_t *out, size_t len,
1727
    const ML_KEM_KEY *key)
1728
0
{
1729
0
    if (!ossl_ml_kem_have_pubkey(key)
1730
0
        || len != key->vinfo->pubkey_bytes)
1731
0
        return 0;
1732
0
    encode_pubkey(out, key);
1733
0
    return 1;
1734
0
}
1735
1736
/* Serialise an ML-KEM private key */
1737
int ossl_ml_kem_encode_private_key(uint8_t *out, size_t len,
1738
    const ML_KEM_KEY *key)
1739
0
{
1740
0
    if (!ossl_ml_kem_have_prvkey(key)
1741
0
        || len != key->vinfo->prvkey_bytes)
1742
0
        return 0;
1743
0
    encode_prvkey(out, key);
1744
0
    return 1;
1745
0
}
1746
1747
int ossl_ml_kem_encode_seed(uint8_t *out, size_t len,
1748
    const ML_KEM_KEY *key)
1749
0
{
1750
0
    if (key == NULL || key->d == NULL || len != ML_KEM_SEED_BYTES)
1751
0
        return 0;
1752
    /*
1753
     * Both in the seed buffer, and in the allocated storage, the |d| component
1754
     * of the seed is stored last, so we must copy each separately.
1755
     */
1756
0
    memcpy(out, key->d, ML_KEM_RANDOM_BYTES);
1757
0
    out += ML_KEM_RANDOM_BYTES;
1758
0
    memcpy(out, key->z, ML_KEM_RANDOM_BYTES);
1759
0
    return 1;
1760
0
}
1761
1762
/*
1763
 * Stash the seed without (yet) performing a keygen, used during decoding, to
1764
 * avoid an extra keygen if we're only going to export the key again to load
1765
 * into another provider.
1766
 */
1767
ML_KEM_KEY *ossl_ml_kem_set_seed(const uint8_t *seed, size_t seedlen, ML_KEM_KEY *key)
1768
0
{
1769
0
    if (key == NULL
1770
0
        || ossl_ml_kem_have_pubkey(key)
1771
0
        || ossl_ml_kem_have_seed(key)
1772
0
        || seedlen != ML_KEM_SEED_BYTES)
1773
0
        return NULL;
1774
1775
0
    if (key->seedbuf == NULL) {
1776
0
        key->seedbuf = OPENSSL_secure_malloc(seedlen);
1777
0
        if (key->seedbuf == NULL)
1778
0
            return NULL;
1779
0
    }
1780
1781
0
    key->z = key->seedbuf;
1782
0
    key->d = key->z + ML_KEM_RANDOM_BYTES;
1783
0
    memcpy(key->d, seed, ML_KEM_RANDOM_BYTES);
1784
0
    seed += ML_KEM_RANDOM_BYTES;
1785
0
    memcpy(key->z, seed, ML_KEM_RANDOM_BYTES);
1786
0
    return key;
1787
0
}
1788
1789
/* Parse input as a public key */
1790
int ossl_ml_kem_parse_public_key(const uint8_t *in, size_t len, ML_KEM_KEY *key)
1791
0
{
1792
0
    EVP_MD_CTX *mdctx = NULL;
1793
0
    const ML_KEM_VINFO *vinfo;
1794
0
    int ret = 0;
1795
1796
    /* Keys with key material are immutable */
1797
0
    if (key == NULL
1798
0
        || ossl_ml_kem_have_pubkey(key)
1799
0
        || ossl_ml_kem_have_dkenc(key))
1800
0
        return 0;
1801
0
    vinfo = key->vinfo;
1802
1803
0
    if (len != vinfo->pubkey_bytes
1804
0
        || (mdctx = EVP_MD_CTX_new()) == NULL)
1805
0
        return 0;
1806
1807
0
    if (add_storage(OPENSSL_malloc(vinfo->puballoc), NULL, 0, key))
1808
0
        ret = parse_pubkey(in, mdctx, key);
1809
1810
0
    if (!ret)
1811
0
        ossl_ml_kem_key_reset(key);
1812
0
    EVP_MD_CTX_free(mdctx);
1813
0
    return ret;
1814
0
}
1815
1816
/* Parse input as a new private key */
1817
int ossl_ml_kem_parse_private_key(const uint8_t *in, size_t len,
1818
    ML_KEM_KEY *key)
1819
0
{
1820
0
    EVP_MD_CTX *mdctx = NULL;
1821
0
    const ML_KEM_VINFO *vinfo;
1822
0
    int ret = 0;
1823
1824
    /* Keys with key material are immutable */
1825
0
    if (key == NULL
1826
0
        || ossl_ml_kem_have_pubkey(key)
1827
0
        || ossl_ml_kem_have_dkenc(key))
1828
0
        return 0;
1829
0
    vinfo = key->vinfo;
1830
1831
0
    if (len != vinfo->prvkey_bytes
1832
0
        || (mdctx = EVP_MD_CTX_new()) == NULL)
1833
0
        return 0;
1834
1835
0
    if (add_storage(OPENSSL_malloc(vinfo->puballoc),
1836
0
            OPENSSL_secure_malloc(vinfo->prvalloc), 1, key))
1837
0
        ret = parse_prvkey(in, mdctx, key);
1838
1839
0
    if (!ret)
1840
0
        ossl_ml_kem_key_reset(key);
1841
0
    EVP_MD_CTX_free(mdctx);
1842
0
    return ret;
1843
0
}
1844
1845
/*
1846
 * Generate a new keypair, either from the saved seed (when non-null), or from
1847
 * the RNG.
1848
 */
1849
int ossl_ml_kem_genkey(uint8_t *pubenc, size_t publen, ML_KEM_KEY *key)
1850
0
{
1851
0
    uint8_t seed[ML_KEM_SEED_BYTES];
1852
0
    EVP_MD_CTX *mdctx = NULL;
1853
0
    const ML_KEM_VINFO *vinfo;
1854
0
    int ret = 0;
1855
1856
0
    if (key == NULL
1857
0
        || ossl_ml_kem_have_pubkey(key)
1858
0
        || ossl_ml_kem_have_dkenc(key))
1859
0
        return 0;
1860
0
    vinfo = key->vinfo;
1861
1862
0
    if (pubenc != NULL && publen != vinfo->pubkey_bytes)
1863
0
        return 0;
1864
1865
0
    if (key->seedbuf != NULL) {
1866
0
        if (!ossl_ml_kem_encode_seed(seed, sizeof(seed), key))
1867
0
            return 0;
1868
0
        ossl_ml_kem_key_reset(key);
1869
0
    } else if (RAND_priv_bytes_ex(key->libctx, seed, sizeof(seed),
1870
0
                   key->vinfo->secbits)
1871
0
        <= 0) {
1872
0
        return 0;
1873
0
    }
1874
1875
0
    if ((mdctx = EVP_MD_CTX_new()) == NULL)
1876
0
        return 0;
1877
1878
    /*
1879
     * Data derived from (d, z) defaults secret, and to avoid side-channel
1880
     * leaks should not influence control flow.
1881
     */
1882
0
    CONSTTIME_SECRET(seed, ML_KEM_SEED_BYTES);
1883
1884
0
    if (add_storage(OPENSSL_malloc(vinfo->puballoc),
1885
0
            OPENSSL_secure_malloc(vinfo->prvalloc), 1, key))
1886
0
        ret = genkey(seed, mdctx, pubenc, key);
1887
0
    OPENSSL_cleanse(seed, sizeof(seed));
1888
1889
    /* Declassify secret inputs and derived outputs before returning control */
1890
0
    CONSTTIME_DECLASSIFY(seed, ML_KEM_SEED_BYTES);
1891
1892
0
    EVP_MD_CTX_free(mdctx);
1893
0
    if (!ret) {
1894
0
        ossl_ml_kem_key_reset(key);
1895
0
        return 0;
1896
0
    }
1897
1898
    /* The public components are already declassified */
1899
0
    CONSTTIME_DECLASSIFY(key->s, vinfo->rank * sizeof(scalar));
1900
0
    CONSTTIME_DECLASSIFY(key->z, 2 * ML_KEM_RANDOM_BYTES);
1901
0
    return 1;
1902
0
}
1903
1904
/*
1905
 * FIPS 203, Section 6.2, Algorithm 17: ML-KEM.Encaps_internal
1906
 * This is the deterministic version with randomness supplied externally.
1907
 */
1908
int ossl_ml_kem_encap_seed(uint8_t *ctext, size_t clen,
1909
    uint8_t *shared_secret, size_t slen,
1910
    const uint8_t *entropy, size_t elen,
1911
    const ML_KEM_KEY *key)
1912
0
{
1913
0
    const ML_KEM_VINFO *vinfo;
1914
0
    EVP_MD_CTX *mdctx;
1915
0
    int ret = 0;
1916
1917
0
    if (key == NULL || !ossl_ml_kem_have_pubkey(key))
1918
0
        return 0;
1919
0
    vinfo = key->vinfo;
1920
1921
0
    if (ctext == NULL || clen != vinfo->ctext_bytes
1922
0
        || shared_secret == NULL || slen != ML_KEM_SHARED_SECRET_BYTES
1923
0
        || entropy == NULL || elen != ML_KEM_RANDOM_BYTES
1924
0
        || (mdctx = EVP_MD_CTX_new()) == NULL)
1925
0
        return 0;
1926
    /*
1927
     * Data derived from the encap entropy defaults secret, and to avoid
1928
     * side-channel leaks should not influence control flow.
1929
     */
1930
0
    CONSTTIME_SECRET(entropy, elen);
1931
1932
    /*-
1933
     * This avoids the need to handle allocation failures for two (max 2KB
1934
     * each) vectors, that are never retained on return from this function.
1935
     * We stack-allocate these.
1936
     */
1937
0
#define case_encap_seed(bits)                                        \
1938
0
    {                                                                \
1939
0
        scalar tmp[2 * ML_KEM_##bits##_RANK];                        \
1940
0
                                                                     \
1941
0
        ret = encap(ctext, shared_secret, entropy, tmp, mdctx, key); \
1942
0
        OPENSSL_cleanse((void *)tmp, sizeof(tmp));                   \
1943
0
    }
1944
0
    switch (vinfo->evp_type) {
1945
0
    case EVP_PKEY_ML_KEM_512:
1946
0
        case_encap_seed(512);
1947
0
        break;
1948
0
    case EVP_PKEY_ML_KEM_768:
1949
0
        case_encap_seed(768);
1950
0
        break;
1951
0
    case EVP_PKEY_ML_KEM_1024:
1952
0
        case_encap_seed(1024);
1953
0
        break;
1954
0
    }
1955
0
#undef case_encap_seed
1956
1957
    /* Declassify secret inputs and derived outputs before returning control */
1958
0
    CONSTTIME_DECLASSIFY(entropy, elen);
1959
0
    CONSTTIME_DECLASSIFY(ctext, clen);
1960
0
    CONSTTIME_DECLASSIFY(shared_secret, slen);
1961
1962
0
    EVP_MD_CTX_free(mdctx);
1963
0
    return ret;
1964
0
}
1965
1966
int ossl_ml_kem_encap_rand(uint8_t *ctext, size_t clen,
1967
    uint8_t *shared_secret, size_t slen,
1968
    const ML_KEM_KEY *key)
1969
0
{
1970
0
    uint8_t r[ML_KEM_RANDOM_BYTES];
1971
1972
0
    if (key == NULL)
1973
0
        return 0;
1974
1975
0
    if (RAND_bytes_ex(key->libctx, r, ML_KEM_RANDOM_BYTES,
1976
0
            key->vinfo->secbits)
1977
0
        < 1)
1978
0
        return 0;
1979
1980
0
    return ossl_ml_kem_encap_seed(ctext, clen, shared_secret, slen,
1981
0
        r, sizeof(r), key);
1982
0
}
1983
1984
int ossl_ml_kem_decap(uint8_t *shared_secret, size_t slen,
1985
    const uint8_t *ctext, size_t clen,
1986
    const ML_KEM_KEY *key)
1987
0
{
1988
0
    const ML_KEM_VINFO *vinfo;
1989
0
    EVP_MD_CTX *mdctx;
1990
0
    int ret = 0;
1991
#if defined(OPENSSL_CONSTANT_TIME_VALIDATION)
1992
    int classify_bytes = 2 * sizeof(scalar) + ML_KEM_RANDOM_BYTES;
1993
#endif
1994
1995
    /* Need a private key here */
1996
0
    if (!ossl_ml_kem_have_prvkey(key))
1997
0
        return 0;
1998
0
    vinfo = key->vinfo;
1999
2000
0
    if (shared_secret == NULL || slen != ML_KEM_SHARED_SECRET_BYTES
2001
0
        || ctext == NULL || clen != vinfo->ctext_bytes
2002
0
        || (mdctx = EVP_MD_CTX_new()) == NULL) {
2003
0
        (void)RAND_bytes_ex(key->libctx, shared_secret,
2004
0
            ML_KEM_SHARED_SECRET_BYTES, vinfo->secbits);
2005
0
        return 0;
2006
0
    }
2007
    /*
2008
     * Data derived from |s| and |z| defaults secret, and to avoid side-channel
2009
     * leaks should not influence control flow.
2010
     */
2011
0
    CONSTTIME_SECRET(key->s, classify_bytes);
2012
2013
    /*-
2014
     * This avoids the need to handle allocation failures for two (max 2KB
2015
     * each) vectors and an encoded ciphertext (max 1568 bytes), that are never
2016
     * retained on return from this function.
2017
     * We stack-allocate these.
2018
     */
2019
0
#define case_decap(bits)                                          \
2020
0
    {                                                             \
2021
0
        uint8_t cbuf[CTEXT_BYTES(bits)];                          \
2022
0
        scalar tmp[2 * ML_KEM_##bits##_RANK];                     \
2023
0
                                                                  \
2024
0
        ret = decap(shared_secret, ctext, cbuf, tmp, mdctx, key); \
2025
0
        OPENSSL_cleanse((void *)tmp, sizeof(tmp));                \
2026
0
    }
2027
0
    switch (vinfo->evp_type) {
2028
0
    case EVP_PKEY_ML_KEM_512:
2029
0
        case_decap(512);
2030
0
        break;
2031
0
    case EVP_PKEY_ML_KEM_768:
2032
0
        case_decap(768);
2033
0
        break;
2034
0
    case EVP_PKEY_ML_KEM_1024:
2035
0
        case_decap(1024);
2036
0
        break;
2037
0
    }
2038
0
#undef case_decap
2039
2040
    /* Declassify secret inputs and derived outputs before returning control */
2041
0
    CONSTTIME_DECLASSIFY(key->s, classify_bytes);
2042
0
    CONSTTIME_DECLASSIFY(shared_secret, slen);
2043
0
    EVP_MD_CTX_free(mdctx);
2044
2045
0
    return ret;
2046
0
}
2047
2048
int ossl_ml_kem_pubkey_cmp(const ML_KEM_KEY *key1, const ML_KEM_KEY *key2)
2049
0
{
2050
    /*
2051
     * This handles any unexpected differences in the ML-KEM variant rank,
2052
     * giving different key component structures, barring SHA3-256 hash
2053
     * collisions, the keys are the same size.
2054
     */
2055
0
    if (ossl_ml_kem_have_pubkey(key1) && ossl_ml_kem_have_pubkey(key2))
2056
0
        return memcmp(key1->pkhash, key2->pkhash, ML_KEM_PKHASH_BYTES) == 0;
2057
2058
    /*
2059
     * No match if just one of the public keys is not available, otherwise both
2060
     * are unavailable, and for now such keys are considered equal.
2061
     */
2062
0
    return (!(ossl_ml_kem_have_pubkey(key1) ^ ossl_ml_kem_have_pubkey(key2)));
2063
0
}