Coverage Report

Created: 2026-04-22 06:14

next uncovered line (L), next uncovered region (R), next uncovered branch (B)
/src/openssl/crypto/ml_kem/ml_kem.c
Line
Count
Source
1
/*
2
 * Copyright 2024-2026 The OpenSSL Project Authors. All Rights Reserved.
3
 *
4
 * Licensed under the Apache License 2.0 (the "License").  You may not use
5
 * this file except in compliance with the License.  You can obtain a copy
6
 * in the file LICENSE in the source distribution or at
7
 * https://www.openssl.org/source/license.html
8
 */
9
10
#include <openssl/byteorder.h>
11
#include <openssl/rand.h>
12
#include <openssl/proverr.h>
13
#include "crypto/ml_kem.h"
14
#include "internal/common.h"
15
#include "internal/constant_time.h"
16
#include "internal/sha3.h"
17
18
#if defined(OPENSSL_CONSTANT_TIME_VALIDATION)
19
#include <valgrind/memcheck.h>
20
#endif
21
22
#if ML_KEM_SEED_BYTES != ML_KEM_SHARED_SECRET_BYTES + ML_KEM_RANDOM_BYTES
23
#error "ML-KEM keygen seed length != shared secret + random bytes length"
24
#endif
25
#if ML_KEM_SHARED_SECRET_BYTES != ML_KEM_RANDOM_BYTES
26
#error "Invalid unequal lengths of ML-KEM shared secret and random inputs"
27
#endif
28
29
#if UINT_MAX < UINT32_MAX
30
#error "Unsupported compiler: sizeof(unsigned int) < sizeof(uint32_t)"
31
#endif
32
33
/* Handy function-like bit-extraction macros */
34
0
#define bit0(b) ((b) & 1)
35
0
#define bitn(n, b) (((b) >> n) & 1)
36
37
/*
38
 * 12 bits are sufficient to losslessly represent values in [0, q-1].
39
 * INVERSE_DEGREE is (n/2)^-1 mod q; used in inverse NTT.
40
 */
41
0
#define DEGREE ML_KEM_DEGREE
42
#define INVERSE_DEGREE (ML_KEM_PRIME - 2 * 13)
43
#define LOG2PRIME 12
44
#define BARRETT_SHIFT (2 * LOG2PRIME)
45
46
#ifdef SHA3_BLOCKSIZE
47
#define SHAKE128_BLOCKSIZE SHA3_BLOCKSIZE(128)
48
#endif
49
50
/*
51
 * Return whether a value that can only be 0 or 1 is non-zero, in constant time
52
 * in practice!  The return value is a mask that is all ones if true, and all
53
 * zeros otherwise (twos-complement arithmetic assumed for unsigned values).
54
 *
55
 * Although this is used in constant-time selects, we omit a value barrier
56
 * here.  Value barriers impede auto-vectorization (likely because it forces
57
 * the value to transit through a general-purpose register). On AArch64, this
58
 * is a difference of 2x.
59
 *
60
 * We usually add value barriers to selects because Clang turns consecutive
61
 * selects with the same condition into a branch instead of CMOV/CSEL. This
62
 * condition does not occur in Kyber, so omitting it seems to be safe so far,
63
 * but see |cbd_2|, |cbd_3|, where reduction needs to be specialised to the
64
 * sign of the input, rather than adding |q| in advance, and using the generic
65
 * |reduce_once|.  (David Benjamin, Chromium)
66
 */
67
#if 0
68
#define constish_time_non_zero(b) (~constant_time_is_zero(b));
69
#else
70
0
#define constish_time_non_zero(b) (0u - (b))
71
#endif
72
73
/*
74
 * The scalar rejection-sampling buffer size needs to be a multiple of 12, but
75
 * is otherwise arbitrary, the preferred block size matches the internal buffer
76
 * size of SHAKE128, avoiding internal buffering and copying in SHAKE128. That
77
 * block size of (1600 - 256)/8 bytes, or 168, just happens to divide by 12!
78
 *
79
 * If the blocksize is unknown, or is not divisible by 12, 168 is used as a
80
 * fallback.
81
 */
82
#if defined(SHAKE128_BLOCKSIZE) && (SHAKE128_BLOCKSIZE) % 12 == 0
83
#define SCALAR_SAMPLING_BUFSIZE (SHAKE128_BLOCKSIZE)
84
#else
85
#define SCALAR_SAMPLING_BUFSIZE 168
86
#endif
87
88
/*
89
 * Structure of keys
90
 */
91
typedef struct ossl_ml_kem_scalar_st {
92
    /* On every function entry and exit, 0 <= c[i] < ML_KEM_PRIME. */
93
    uint16_t c[ML_KEM_DEGREE];
94
} scalar;
95
96
/* Key material allocation layout */
97
#define DECLARE_ML_KEM_PUBKEYDATA(name, rank)                  \
98
    struct name##_alloc {                                      \
99
        /* Public vector |t| */                                \
100
        scalar tbuf[(rank)];                                   \
101
        /* Pre-computed matrix |m| (FIPS 203 |A| transpose) */ \
102
        scalar mbuf[(rank) * (rank)];                          \
103
    }
104
105
#define DECLARE_ML_KEM_PRVKEYDATA(name, rank)  \
106
    struct name##_alloc {                      \
107
        scalar sbuf[rank];                     \
108
        uint8_t zbuf[2 * ML_KEM_RANDOM_BYTES]; \
109
    }
110
111
/* Declare variant-specific public and private storage */
112
#define DECLARE_ML_KEM_VARIANT_KEYDATA(bits)                        \
113
    DECLARE_ML_KEM_PUBKEYDATA(pubkey_##bits, ML_KEM_##bits##_RANK); \
114
    DECLARE_ML_KEM_PRVKEYDATA(prvkey_##bits, ML_KEM_##bits##_RANK)
115
116
DECLARE_ML_KEM_VARIANT_KEYDATA(512);
117
DECLARE_ML_KEM_VARIANT_KEYDATA(768);
118
DECLARE_ML_KEM_VARIANT_KEYDATA(1024);
119
#undef DECLARE_ML_KEM_VARIANT_KEYDATA
120
#undef DECLARE_ML_KEM_PUBKEYDATA
121
#undef DECLARE_ML_KEM_PRVKEYDATA
122
123
typedef __owur int (*CBD_FUNC)(scalar *out, uint8_t in[ML_KEM_RANDOM_BYTES + 1],
124
    EVP_MD_CTX *mdctx, const ML_KEM_KEY *key);
125
static void scalar_encode(uint8_t *out, const scalar *s, int bits);
126
127
/*
128
 * The wire-form of a losslessly encoded vector uses 12-bits per element.
129
 *
130
 * The wire-form public key consists of the lossless encoding of the public
131
 * vector |t|, followed by the public seed |rho|.
132
 *
133
 * Our serialised private key concatenates serialisations of the private vector
134
 * |s|, the public key, the public key hash, and the failure secret |z|.
135
 */
136
#define VECTOR_BYTES(b) ((3 * DEGREE / 2) * ML_KEM_##b##_RANK)
137
#define PUBKEY_BYTES(b) (VECTOR_BYTES(b) + ML_KEM_RANDOM_BYTES)
138
#define PRVKEY_BYTES(b) (2 * PUBKEY_BYTES(b) + ML_KEM_PKHASH_BYTES)
139
140
/*
141
 * Encapsulation produces a vector "u" and a scalar "v", whose coordinates
142
 * (numbers modulo the ML-KEM prime "q") are lossily encoded using as "du" and
143
 * "dv" bits, respectively.  This encoding is the ciphertext input for
144
 * decapsulation.
145
 */
146
#define U_VECTOR_BYTES(b) ((DEGREE / 8) * ML_KEM_##b##_DU * ML_KEM_##b##_RANK)
147
#define V_SCALAR_BYTES(b) ((DEGREE / 8) * ML_KEM_##b##_DV)
148
#define CTEXT_BYTES(b) (U_VECTOR_BYTES(b) + V_SCALAR_BYTES(b))
149
150
#if defined(OPENSSL_CONSTANT_TIME_VALIDATION)
151
152
/*
153
 * CONSTTIME_SECRET takes a pointer and a number of bytes and marks that region
154
 * of memory as secret. Secret data is tracked as it flows to registers and
155
 * other parts of a memory. If secret data is used as a condition for a branch,
156
 * or as a memory index, it will trigger warnings in valgrind.
157
 */
158
#define CONSTTIME_SECRET(ptr, len) VALGRIND_MAKE_MEM_UNDEFINED(ptr, len)
159
160
/*
161
 * CONSTTIME_DECLASSIFY takes a pointer and a number of bytes and marks that
162
 * region of memory as public. Public data is not subject to constant-time
163
 * rules.
164
 */
165
#define CONSTTIME_DECLASSIFY(ptr, len) VALGRIND_MAKE_MEM_DEFINED(ptr, len)
166
167
#else
168
169
#define CONSTTIME_SECRET(ptr, len)
170
#define CONSTTIME_DECLASSIFY(ptr, len)
171
172
#endif
173
174
/*
175
 * Indices of slots in the vinfo tables below
176
 */
177
0
#define ML_KEM_512_VINFO 0
178
0
#define ML_KEM_768_VINFO 1
179
0
#define ML_KEM_1024_VINFO 2
180
181
/*
182
 * Per-variant fixed parameters
183
 */
184
static const ML_KEM_VINFO vinfo_map[3] = {
185
    { "ML-KEM-512",
186
        PRVKEY_BYTES(512),
187
        sizeof(struct prvkey_512_alloc),
188
        PUBKEY_BYTES(512),
189
        sizeof(struct pubkey_512_alloc),
190
        CTEXT_BYTES(512),
191
        VECTOR_BYTES(512),
192
        U_VECTOR_BYTES(512),
193
        EVP_PKEY_ML_KEM_512,
194
        ML_KEM_512_BITS,
195
        ML_KEM_512_RANK,
196
        ML_KEM_512_DU,
197
        ML_KEM_512_DV,
198
        ML_KEM_512_SECBITS,
199
        ML_KEM_512_SECURITY_CATEGORY },
200
    { "ML-KEM-768",
201
        PRVKEY_BYTES(768),
202
        sizeof(struct prvkey_768_alloc),
203
        PUBKEY_BYTES(768),
204
        sizeof(struct pubkey_768_alloc),
205
        CTEXT_BYTES(768),
206
        VECTOR_BYTES(768),
207
        U_VECTOR_BYTES(768),
208
        EVP_PKEY_ML_KEM_768,
209
        ML_KEM_768_BITS,
210
        ML_KEM_768_RANK,
211
        ML_KEM_768_DU,
212
        ML_KEM_768_DV,
213
        ML_KEM_768_SECBITS,
214
        ML_KEM_768_SECURITY_CATEGORY },
215
    { "ML-KEM-1024",
216
        PRVKEY_BYTES(1024),
217
        sizeof(struct prvkey_1024_alloc),
218
        PUBKEY_BYTES(1024),
219
        sizeof(struct pubkey_1024_alloc),
220
        CTEXT_BYTES(1024),
221
        VECTOR_BYTES(1024),
222
        U_VECTOR_BYTES(1024),
223
        EVP_PKEY_ML_KEM_1024,
224
        ML_KEM_1024_BITS,
225
        ML_KEM_1024_RANK,
226
        ML_KEM_1024_DU,
227
        ML_KEM_1024_DV,
228
        ML_KEM_1024_SECBITS,
229
        ML_KEM_1024_SECURITY_CATEGORY }
230
};
231
232
/*
233
 * Remainders modulo `kPrime`, for sufficiently small inputs, are computed in
234
 * constant time via Barrett reduction, and a final call to reduce_once(),
235
 * which reduces inputs that are at most 2*kPrime and is also constant-time.
236
 */
237
static const int kPrime = ML_KEM_PRIME;
238
static const unsigned int kBarrettShift = BARRETT_SHIFT;
239
static const size_t kBarrettMultiplier = (1 << BARRETT_SHIFT) / ML_KEM_PRIME;
240
static const uint16_t kHalfPrime = (ML_KEM_PRIME - 1) / 2;
241
static const uint16_t kInverseDegree = INVERSE_DEGREE;
242
243
/*
244
 * Python helper:
245
 *
246
 * p = 3329
247
 * def bitreverse(i):
248
 *     ret = 0
249
 *     for n in range(7):
250
 *         bit = i & 1
251
 *         ret <<= 1
252
 *         ret |= bit
253
 *         i >>= 1
254
 *     return ret
255
 */
256
257
/*-
258
 * First precomputed array from Appendix A of FIPS 203, or else Python:
259
 * kNTTRoots = [pow(17, bitreverse(i), p) for i in range(128)]
260
 */
261
static const uint16_t kNTTRoots[128] = {
262
    0x001, 0x6c1, 0xa14, 0xcd9, 0xa52, 0x276, 0x769, 0x350,
263
    0x426, 0x77f, 0x0c1, 0x31d, 0xae2, 0xcbc, 0x239, 0x6d2,
264
    0x128, 0x98f, 0x53b, 0x5c4, 0xbe6, 0x038, 0x8c0, 0x535,
265
    0x592, 0x82e, 0x217, 0xb42, 0x959, 0xb3f, 0x7b6, 0x335,
266
    0x121, 0x14b, 0xcb5, 0x6dc, 0x4ad, 0x900, 0x8e5, 0x807,
267
    0x28a, 0x7b9, 0x9d1, 0x278, 0xb31, 0x021, 0x528, 0x77b,
268
    0x90f, 0x59b, 0x327, 0x1c4, 0x59e, 0xb34, 0x5fe, 0x962,
269
    0xa57, 0xa39, 0x5c9, 0x288, 0x9aa, 0xc26, 0x4cb, 0x38e,
270
    0x011, 0xac9, 0x247, 0xa59, 0x665, 0x2d3, 0x8f0, 0x44c,
271
    0x581, 0xa66, 0xcd1, 0x0e9, 0x2f4, 0x86c, 0xbc7, 0xbea,
272
    0x6a7, 0x673, 0xae5, 0x6fd, 0x737, 0x3b8, 0x5b5, 0xa7f,
273
    0x3ab, 0x904, 0x985, 0x954, 0x2dd, 0x921, 0x10c, 0x281,
274
    0x630, 0x8fa, 0x7f5, 0xc94, 0x177, 0x9f5, 0x82a, 0x66d,
275
    0x427, 0x13f, 0xad5, 0x2f5, 0x833, 0x231, 0x9a2, 0xa22,
276
    0xaf4, 0x444, 0x193, 0x402, 0x477, 0x866, 0xad7, 0x376,
277
    0x6ba, 0x4bc, 0x752, 0x405, 0x83e, 0xb77, 0x375, 0x86a
278
};
279
280
/*
281
 * InverseNTTRoots = [pow(17, -bitreverse(i), p) for i in range(128)]
282
 * Listed in order of use in the inverse NTT loop (index 0 is skipped):
283
 *
284
 *  0, 64, 65, ..., 127, 32, 33, ..., 63, 16, 17, ..., 31, 8, 9, ...
285
 */
286
static const uint16_t kInverseNTTRoots[128] = {
287
    0x001, 0x497, 0x98c, 0x18a, 0x4c3, 0x8fc, 0x5af, 0x845,
288
    0x647, 0x98b, 0x22a, 0x49b, 0x88a, 0x8ff, 0xb6e, 0x8bd,
289
    0x20d, 0x2df, 0x35f, 0xad0, 0x4ce, 0xa0c, 0x22c, 0xbc2,
290
    0x8da, 0x694, 0x4d7, 0x30c, 0xb8a, 0x06d, 0x50c, 0x407,
291
    0x6d1, 0xa80, 0xbf5, 0x3e0, 0xa24, 0x3ad, 0x37c, 0x3fd,
292
    0x956, 0x282, 0x74c, 0x949, 0x5ca, 0x604, 0x21c, 0x68e,
293
    0x65a, 0x117, 0x13a, 0x495, 0xa0d, 0xc18, 0x030, 0x29b,
294
    0x780, 0x8b5, 0x411, 0xa2e, 0x69c, 0x2a8, 0xaba, 0x238,
295
    0xcf0, 0x973, 0x836, 0x0db, 0x357, 0xa79, 0x738, 0x2c8,
296
    0x2aa, 0x39f, 0x703, 0x1cd, 0x763, 0xb3d, 0x9da, 0x766,
297
    0x3f2, 0x586, 0x7d9, 0xce0, 0x1d0, 0xa89, 0x330, 0x548,
298
    0xa77, 0x4fa, 0x41c, 0x401, 0x854, 0x625, 0x04c, 0xbb6,
299
    0xbe0, 0x9cc, 0x54b, 0x1c2, 0x3a8, 0x1bf, 0xaea, 0x4d3,
300
    0x76f, 0x7cc, 0x441, 0xcc9, 0x11b, 0x73d, 0x7c6, 0x372,
301
    0xbd9, 0x62f, 0xac8, 0x045, 0x21f, 0x9e4, 0xc40, 0x582,
302
    0x8db, 0x9b1, 0x598, 0xa8b, 0x2af, 0x028, 0x2ed, 0x640
303
};
304
305
/*
306
 * Second precomputed array from Appendix A of FIPS 203 (normalised positive),
307
 * or else Python:
308
 * ModRoots = [pow(17, 2*bitreverse(i) + 1, p) for i in range(128)]
309
 */
310
static const uint16_t kModRoots[128] = {
311
    0x011, 0xcf0, 0xac9, 0x238, 0x247, 0xaba, 0xa59, 0x2a8,
312
    0x665, 0x69c, 0x2d3, 0xa2e, 0x8f0, 0x411, 0x44c, 0x8b5,
313
    0x581, 0x780, 0xa66, 0x29b, 0xcd1, 0x030, 0x0e9, 0xc18,
314
    0x2f4, 0xa0d, 0x86c, 0x495, 0xbc7, 0x13a, 0xbea, 0x117,
315
    0x6a7, 0x65a, 0x673, 0x68e, 0xae5, 0x21c, 0x6fd, 0x604,
316
    0x737, 0x5ca, 0x3b8, 0x949, 0x5b5, 0x74c, 0xa7f, 0x282,
317
    0x3ab, 0x956, 0x904, 0x3fd, 0x985, 0x37c, 0x954, 0x3ad,
318
    0x2dd, 0xa24, 0x921, 0x3e0, 0x10c, 0xbf5, 0x281, 0xa80,
319
    0x630, 0x6d1, 0x8fa, 0x407, 0x7f5, 0x50c, 0xc94, 0x06d,
320
    0x177, 0xb8a, 0x9f5, 0x30c, 0x82a, 0x4d7, 0x66d, 0x694,
321
    0x427, 0x8da, 0x13f, 0xbc2, 0xad5, 0x22c, 0x2f5, 0xa0c,
322
    0x833, 0x4ce, 0x231, 0xad0, 0x9a2, 0x35f, 0xa22, 0x2df,
323
    0xaf4, 0x20d, 0x444, 0x8bd, 0x193, 0xb6e, 0x402, 0x8ff,
324
    0x477, 0x88a, 0x866, 0x49b, 0xad7, 0x22a, 0x376, 0x98b,
325
    0x6ba, 0x647, 0x4bc, 0x845, 0x752, 0x5af, 0x405, 0x8fc,
326
    0x83e, 0x4c3, 0xb77, 0x18a, 0x375, 0x98c, 0x86a, 0x497
327
};
328
329
/*
330
 * single_keccak hashes |inlen| bytes from |in| and writes |outlen| bytes of
331
 * output to |out|. If the |md| specifies a fixed-output function, like
332
 * SHA3-256, then |outlen| must be the correct length for that function.
333
 */
334
static __owur int single_keccak(uint8_t *out, size_t outlen, const uint8_t *in, size_t inlen,
335
    EVP_MD_CTX *mdctx)
336
0
{
337
0
    unsigned int sz = (unsigned int)outlen;
338
339
0
    if (!EVP_DigestUpdate(mdctx, in, inlen))
340
0
        return 0;
341
0
    if (EVP_MD_xof(EVP_MD_CTX_get0_md(mdctx)))
342
0
        return EVP_DigestFinalXOF(mdctx, out, outlen);
343
0
    return EVP_DigestFinal_ex(mdctx, out, &sz)
344
0
        && ossl_assert((size_t)sz == outlen);
345
0
}
346
347
/*
348
 * FIPS 203, Section 4.1, equation (4.3): PRF. Takes 32+1 input bytes, and uses
349
 * SHAKE256 to produce the input to SamplePolyCBD_eta: FIPS 203, algorithm 8.
350
 */
351
static __owur int prf(uint8_t *out, size_t len, const uint8_t in[ML_KEM_RANDOM_BYTES + 1],
352
    EVP_MD_CTX *mdctx, const ML_KEM_KEY *key)
353
0
{
354
0
    return EVP_DigestInit_ex(mdctx, key->shake256_md, NULL)
355
0
        && single_keccak(out, len, in, ML_KEM_RANDOM_BYTES + 1, mdctx);
356
0
}
357
358
/*
359
 * FIPS 203, Section 4.1, equation (4.4): H.  SHA3-256 hash of a variable
360
 * length input, producing 32 bytes of output.
361
 */
362
static __owur int hash_h(uint8_t out[ML_KEM_PKHASH_BYTES], const uint8_t *in, size_t len,
363
    EVP_MD_CTX *mdctx, const ML_KEM_KEY *key)
364
0
{
365
0
    return EVP_DigestInit_ex(mdctx, key->sha3_256_md, NULL)
366
0
        && single_keccak(out, ML_KEM_PKHASH_BYTES, in, len, mdctx);
367
0
}
368
369
/* Incremental hash_h of expanded public key */
370
static int
371
hash_h_pubkey(uint8_t pkhash[ML_KEM_PKHASH_BYTES],
372
    EVP_MD_CTX *mdctx, ML_KEM_KEY *key)
373
0
{
374
0
    const ML_KEM_VINFO *vinfo = key->vinfo;
375
0
    const scalar *t = key->t, *end = t + vinfo->rank;
376
0
    unsigned int sz;
377
378
0
    if (!EVP_DigestInit_ex(mdctx, key->sha3_256_md, NULL))
379
0
        return 0;
380
381
0
    do {
382
0
        uint8_t buf[3 * DEGREE / 2];
383
384
0
        scalar_encode(buf, t++, 12);
385
0
        if (!EVP_DigestUpdate(mdctx, buf, sizeof(buf)))
386
0
            return 0;
387
0
    } while (t < end);
388
389
0
    if (!EVP_DigestUpdate(mdctx, key->rho, ML_KEM_RANDOM_BYTES))
390
0
        return 0;
391
0
    return EVP_DigestFinal_ex(mdctx, pkhash, &sz)
392
0
        && ossl_assert(sz == ML_KEM_PKHASH_BYTES);
393
0
}
394
395
/*
396
 * FIPS 203, Section 4.1, equation (4.5): G.  SHA3-512 hash of a variable
397
 * length input, producing 64 bytes of output, in particular the seeds
398
 * (d,z) for key generation.
399
 */
400
static __owur int hash_g(uint8_t out[ML_KEM_SEED_BYTES], const uint8_t *in, size_t len,
401
    EVP_MD_CTX *mdctx, const ML_KEM_KEY *key)
402
0
{
403
0
    return EVP_DigestInit_ex(mdctx, key->sha3_512_md, NULL)
404
0
        && single_keccak(out, ML_KEM_SEED_BYTES, in, len, mdctx);
405
0
}
406
407
/*
408
 * FIPS 203, Section 4.1, equation (4.4): J. SHAKE256 taking a variable length
409
 * input to compute a 32-byte implicit rejection shared secret, of the same
410
 * length as the expected shared secret.  (Computed even on success to avoid
411
 * side-channel leaks).
412
 */
413
static __owur int kdf(uint8_t out[ML_KEM_SHARED_SECRET_BYTES],
414
    const uint8_t z[ML_KEM_RANDOM_BYTES],
415
    const uint8_t *ctext, size_t len,
416
    EVP_MD_CTX *mdctx, const ML_KEM_KEY *key)
417
0
{
418
0
    return EVP_DigestInit_ex(mdctx, key->shake256_md, NULL)
419
0
        && EVP_DigestUpdate(mdctx, z, ML_KEM_RANDOM_BYTES)
420
0
        && EVP_DigestUpdate(mdctx, ctext, len)
421
0
        && EVP_DigestFinalXOF(mdctx, out, ML_KEM_SHARED_SECRET_BYTES);
422
0
}
423
424
/*
425
 * FIPS 203, Section 4.2.2, Algorithm 7: "SampleNTT" (steps 3-17, steps 1, 2
426
 * are performed by the caller). Rejection-samples a Keccak stream to get
427
 * uniformly distributed elements in the range [0,q). This is used for matrix
428
 * expansion and only operates on public inputs.
429
 */
430
static __owur int sample_scalar(scalar *out, EVP_MD_CTX *mdctx)
431
0
{
432
0
    uint16_t *curr = out->c, *endout = curr + DEGREE;
433
0
    uint8_t buf[SCALAR_SAMPLING_BUFSIZE], *in;
434
0
    uint8_t *endin = buf + sizeof(buf);
435
0
    uint16_t d;
436
0
    uint8_t b1, b2, b3;
437
438
0
    do {
439
0
        if (!EVP_DigestSqueeze(mdctx, in = buf, sizeof(buf)))
440
0
            return 0;
441
0
        do {
442
0
            b1 = *in++;
443
0
            b2 = *in++;
444
0
            b3 = *in++;
445
446
0
            if (curr >= endout)
447
0
                break;
448
0
            if ((d = ((b2 & 0x0f) << 8) + b1) < kPrime)
449
0
                *curr++ = d;
450
0
            if (curr >= endout)
451
0
                break;
452
0
            if ((d = (b3 << 4) + (b2 >> 4)) < kPrime)
453
0
                *curr++ = d;
454
0
        } while (in < endin);
455
0
    } while (curr < endout);
456
0
    return 1;
457
0
}
458
459
static CRYPTO_ONCE ml_kem_ntt_once = CRYPTO_ONCE_STATIC_INIT;
460
461
#if defined(_ARCH_PPC64)
462
#include "crypto/ppc_arch.h"
463
#endif
464
465
#if defined(MLKEM_NTT_PPC_ASM) && defined(_ARCH_PPC64)
466
/*
467
 * PPC64LE Platform supports.
468
 */
469
typedef void (*ml_kem_scalar_ntt_fn)(scalar *p);
470
typedef void (*ml_kem_scalar_inverse_ntt_fn)(scalar *p);
471
472
static void scalar_ntt_generic(scalar *p);
473
static void scalar_inverse_ntt_generic(scalar *p);
474
475
static ml_kem_scalar_ntt_fn scalar_ntt = scalar_ntt_generic;
476
static ml_kem_scalar_inverse_ntt_fn scalar_inverse_ntt = scalar_inverse_ntt_generic;
477
478
void mlkem_ntt_ppc(uint16_t *c);
479
void mlkem_inverse_ntt_ppc(uint16_t *c);
480
481
static void scalar_ntt_ppc(scalar *s)
482
{
483
    mlkem_ntt_ppc(s->c);
484
}
485
486
static void scalar_inverse_ntt_ppc(scalar *s)
487
{
488
    mlkem_inverse_ntt_ppc(s->c);
489
}
490
#else
491
#define scalar_ntt_generic scalar_ntt
492
#define scalar_inverse_ntt_generic scalar_inverse_ntt
493
#endif
494
495
/*
496
 * Initialize NTT function pointers to PPC64le implementations if available.
497
 * Scalar implementations are used by default.
498
 */
499
static void ml_kem_ntt_init(void)
500
0
{
501
#if defined(MLKEM_NTT_PPC_ASM) && defined(_ARCH_PPC64)
502
#if defined(__LITTLE_ENDIAN__) || (__BYTE_ORDER__ == __ORDER_LITTLE_ENDIAN__)
503
    if (OPENSSL_ppccap_P & PPC_CRYPTO207) {
504
        scalar_ntt = scalar_ntt_ppc;
505
        scalar_inverse_ntt = scalar_inverse_ntt_ppc;
506
    }
507
#endif
508
#endif
509
0
}
510
511
/*-
512
 * reduce_once reduces 0 <= x < 2*kPrime, mod kPrime.
513
 *
514
 * Subtract |q| if the input is larger, without exposing a side-channel,
515
 * avoiding the "clangover" attack.  See |constish_time_non_zero| for a
516
 * discussion on why the value barrier is by default omitted.
517
 */
518
static __owur uint16_t reduce_once(uint16_t x)
519
0
{
520
0
    const uint16_t subtracted = x - kPrime;
521
0
    uint16_t mask = constish_time_non_zero(subtracted >> 15);
522
523
0
    return (mask & x) | (~mask & subtracted);
524
0
}
525
526
/*
527
 * Constant-time reduce x mod kPrime using Barrett reduction. x must be less
528
 * than kPrime + 2 * kPrime^2.  This is sufficient to reduce a product of
529
 * two already reduced u_int16 values, in fact it is sufficient for each
530
 * to be less than 2^12, because (kPrime * (2 * kPrime + 1)) > 2^24.
531
 */
532
static __owur uint16_t reduce(uint32_t x)
533
0
{
534
0
    uint64_t product = (uint64_t)x * kBarrettMultiplier;
535
0
    uint32_t quotient = (uint32_t)(product >> kBarrettShift);
536
0
    uint32_t remainder = x - quotient * kPrime;
537
538
0
    return reduce_once(remainder);
539
0
}
540
541
/* Multiply a scalar by a constant. */
542
static void scalar_mult_const(scalar *s, uint16_t a)
543
0
{
544
0
    uint16_t *curr = s->c, *end = curr + DEGREE, tmp;
545
546
0
    do {
547
0
        tmp = reduce(*curr * a);
548
0
        *curr++ = tmp;
549
0
    } while (curr < end);
550
0
}
551
552
/*-
553
 * FIPS 203, Section 4.3, Algorithm 9: "NTT".
554
 * In-place number theoretic transform of a given scalar.  Note that ML-KEM's
555
 * kPrime 3329 does not have a 512th root of unity, so this transform leaves
556
 * off the last iteration of the usual FFT code, with the 128 relevant roots of
557
 * unity being stored in NTTRoots.  This means the output should be seen as 128
558
 * elements in GF(3329^2), with the coefficients of the elements being
559
 * consecutive entries in |s->c|.
560
 */
561
static void scalar_ntt_generic(scalar *s)
562
0
{
563
0
    const uint16_t *roots = kNTTRoots;
564
0
    uint16_t *end = s->c + DEGREE;
565
0
    int offset = DEGREE / 2;
566
567
0
    do {
568
0
        uint16_t *curr = s->c, *peer;
569
570
0
        do {
571
0
            uint16_t *pause = curr + offset, even, odd;
572
0
            uint32_t zeta = *++roots;
573
574
0
            peer = pause;
575
0
            do {
576
0
                even = *curr;
577
0
                odd = reduce(*peer * zeta);
578
0
                *peer++ = reduce_once(even - odd + kPrime);
579
0
                *curr++ = reduce_once(odd + even);
580
0
            } while (curr < pause);
581
0
        } while ((curr = peer) < end);
582
0
    } while ((offset >>= 1) >= 2);
583
0
}
584
585
/*-
586
 * FIPS 203, Section 4.3, Algorithm 10: "NTT^(-1)".
587
 * In-place inverse number theoretic transform of a given scalar, with pairs of
588
 * entries of s->v being interpreted as elements of GF(3329^2). Just as with
589
 * the number theoretic transform, this leaves off the first step of the normal
590
 * iFFT to account for the fact that 3329 does not have a 512th root of unity,
591
 * using the precomputed 128 roots of unity stored in InverseNTTRoots.
592
 */
593
static void scalar_inverse_ntt_generic(scalar *s)
594
0
{
595
0
    const uint16_t *roots = kInverseNTTRoots;
596
0
    uint16_t *end = s->c + DEGREE;
597
0
    int offset = 2;
598
599
0
    do {
600
0
        uint16_t *curr = s->c, *peer;
601
602
0
        do {
603
0
            uint16_t *pause = curr + offset, even, odd;
604
0
            uint32_t zeta = *++roots;
605
606
0
            peer = pause;
607
0
            do {
608
0
                even = *curr;
609
0
                odd = *peer;
610
0
                *peer++ = reduce(zeta * (even - odd + kPrime));
611
0
                *curr++ = reduce_once(odd + even);
612
0
            } while (curr < pause);
613
0
        } while ((curr = peer) < end);
614
0
    } while ((offset <<= 1) < DEGREE);
615
0
    scalar_mult_const(s, kInverseDegree);
616
0
}
617
618
/* Addition updating the LHS scalar in-place. */
619
static void scalar_add(scalar *lhs, const scalar *rhs)
620
0
{
621
0
    int i;
622
623
0
    for (i = 0; i < DEGREE; i++)
624
0
        lhs->c[i] = reduce_once(lhs->c[i] + rhs->c[i]);
625
0
}
626
627
/* Subtraction updating the LHS scalar in-place. */
628
static void scalar_sub(scalar *lhs, const scalar *rhs)
629
0
{
630
0
    int i;
631
632
0
    for (i = 0; i < DEGREE; i++)
633
0
        lhs->c[i] = reduce_once(lhs->c[i] - rhs->c[i] + kPrime);
634
0
}
635
636
/*
637
 * Multiplying two scalars in the number theoretically transformed state. Since
638
 * 3329 does not have a 512th root of unity, this means we have to interpret
639
 * the 2*ith and (2*i+1)th entries of the scalar as elements of
640
 * GF(3329)[X]/(X^2 - 17^(2*bitreverse(i)+1)).
641
 *
642
 * The value of 17^(2*bitreverse(i)+1) mod 3329 is stored in the precomputed
643
 * ModRoots table. Note that our Barrett transform only allows us to multiply
644
 * two reduced numbers together, so we need some intermediate reduction steps,
645
 * even if an uint64_t could hold 3 multiplied numbers.
646
 */
647
static void scalar_mult(scalar *out, const scalar *lhs,
648
    const scalar *rhs)
649
0
{
650
0
    uint16_t *curr = out->c, *end = curr + DEGREE;
651
0
    const uint16_t *lc = lhs->c, *rc = rhs->c;
652
0
    const uint16_t *roots = kModRoots;
653
654
0
    do {
655
0
        uint32_t l0 = *lc++, r0 = *rc++;
656
0
        uint32_t l1 = *lc++, r1 = *rc++;
657
0
        uint32_t zetapow = *roots++;
658
659
0
        *curr++ = reduce(l0 * r0 + reduce(l1 * r1) * zetapow);
660
0
        *curr++ = reduce(l0 * r1 + l1 * r0);
661
0
    } while (curr < end);
662
0
}
663
664
/* Above, but add the result to an existing scalar */
665
static ossl_inline void scalar_mult_add(scalar *out, const scalar *lhs,
666
    const scalar *rhs)
667
0
{
668
0
    uint16_t *curr = out->c, *end = curr + DEGREE;
669
0
    const uint16_t *lc = lhs->c, *rc = rhs->c;
670
0
    const uint16_t *roots = kModRoots;
671
672
0
    do {
673
0
        uint32_t l0 = *lc++, r0 = *rc++;
674
0
        uint32_t l1 = *lc++, r1 = *rc++;
675
0
        uint16_t *c0 = curr++;
676
0
        uint16_t *c1 = curr++;
677
0
        uint32_t zetapow = *roots++;
678
679
0
        *c0 = reduce(*c0 + l0 * r0 + reduce(l1 * r1) * zetapow);
680
0
        *c1 = reduce(*c1 + l0 * r1 + l1 * r0);
681
0
    } while (curr < end);
682
0
}
683
684
/*-
685
 * FIPS 203, Section 4.2.1, Algorithm 5: "ByteEncode_d", for 2<=d<=12.
686
 * Here |bits| is |d|.  For efficiency, we handle the d=1 case separately.
687
 */
688
static void scalar_encode(uint8_t *out, const scalar *s, int bits)
689
0
{
690
0
    const uint16_t *curr = s->c, *end = curr + DEGREE;
691
0
    uint64_t accum = 0, element;
692
0
    int used = 0;
693
694
0
    do {
695
0
        element = *curr++;
696
0
        if (used + bits < 64) {
697
0
            accum |= element << used;
698
0
            used += bits;
699
0
        } else if (used + bits > 64) {
700
0
            out = OPENSSL_store_u64_le(out, accum | (element << used));
701
0
            accum = element >> (64 - used);
702
0
            used = (used + bits) - 64;
703
0
        } else {
704
0
            out = OPENSSL_store_u64_le(out, accum | (element << used));
705
0
            accum = 0;
706
0
            used = 0;
707
0
        }
708
0
    } while (curr < end);
709
0
}
710
711
/*
712
 * scalar_encode_1 is |scalar_encode| specialised for |bits| == 1.
713
 */
714
static void scalar_encode_1(uint8_t out[DEGREE / 8], const scalar *s)
715
0
{
716
0
    int i, j;
717
0
    uint8_t out_byte;
718
719
0
    for (i = 0; i < DEGREE; i += 8) {
720
0
        out_byte = 0;
721
0
        for (j = 0; j < 8; j++)
722
0
            out_byte |= bit0(s->c[i + j]) << j;
723
0
        *out = out_byte;
724
0
        out++;
725
0
    }
726
0
}
727
728
/*-
729
 * FIPS 203, Section 4.2.1, Algorithm 6: "ByteDecode_d", for 2<=d<12.
730
 * Here |bits| is |d|.  For efficiency, we handle the d=1 and d=12 cases
731
 * separately.
732
 *
733
 * scalar_decode parses |DEGREE * bits| bits from |in| into |DEGREE| values in
734
 * |out|.
735
 */
736
static void scalar_decode(scalar *out, const uint8_t *in, int bits)
737
0
{
738
0
    uint16_t *curr = out->c, *end = curr + DEGREE;
739
0
    uint64_t accum = 0;
740
0
    int accum_bits = 0, todo = bits;
741
0
    uint16_t bitmask = (((uint16_t)1) << bits) - 1, mask = bitmask;
742
0
    uint16_t element = 0;
743
744
0
    do {
745
0
        if (accum_bits == 0) {
746
0
            in = OPENSSL_load_u64_le(&accum, in);
747
0
            accum_bits = 64;
748
0
        }
749
0
        if (todo == bits && accum_bits >= bits) {
750
            /* No partial "element", and all the required bits available */
751
0
            *curr++ = ((uint16_t)accum) & mask;
752
0
            accum >>= bits;
753
0
            accum_bits -= bits;
754
0
        } else if (accum_bits >= todo) {
755
            /* A partial "element", and all the required bits available */
756
0
            *curr++ = element | ((((uint16_t)accum) & mask) << (bits - todo));
757
0
            accum >>= todo;
758
0
            accum_bits -= todo;
759
0
            element = 0;
760
0
            todo = bits;
761
0
            mask = bitmask;
762
0
        } else {
763
            /*
764
             * Only some of the requisite bits accumulated, store |accum_bits|
765
             * of these in |element|.  The accumulated bitcount becomes 0, but
766
             * as soon as we have more bits we'll want to merge accum_bits
767
             * fewer of them into the final |element|.
768
             *
769
             * Note that with a 64-bit accumulator and |bits| always 12 or
770
             * less, if we're here, the previous iteration had all the
771
             * requisite bits, and so there are no kept bits in |element|.
772
             */
773
0
            element = ((uint16_t)accum) & mask;
774
0
            todo -= accum_bits;
775
0
            mask = bitmask >> accum_bits;
776
0
            accum_bits = 0;
777
0
        }
778
0
    } while (curr < end);
779
0
}
780
781
static __owur int scalar_decode_12(scalar *out, const uint8_t in[3 * DEGREE / 2])
782
0
{
783
0
    int i;
784
0
    uint16_t *c = out->c;
785
786
0
    for (i = 0; i < DEGREE / 2; ++i) {
787
0
        uint8_t b1 = *in++;
788
0
        uint8_t b2 = *in++;
789
0
        uint8_t b3 = *in++;
790
0
        int outOfRange1 = (*c++ = b1 | ((b2 & 0x0f) << 8)) >= kPrime;
791
0
        int outOfRange2 = (*c++ = (b2 >> 4) | (b3 << 4)) >= kPrime;
792
793
0
        if (outOfRange1 | outOfRange2)
794
0
            return 0;
795
0
    }
796
0
    return 1;
797
0
}
798
799
/*-
800
 * scalar_decode_decompress_add is a combination of decoding and decompression
801
 * both specialised for |bits| == 1, with the result added (and sum reduced) to
802
 * the output scalar.
803
 *
804
 * NOTE: this function MUST not leak an input-data-depedennt timing signal.
805
 * A timing leak in a related function in the reference Kyber implementation
806
 * made the "clangover" attack (CVE-2024-37880) possible, giving key recovery
807
 * for ML-KEM-512 in minutes, provided the attacker has access to precise
808
 * timing of a CPU performing chosen-ciphertext decap.  Admittedly this is only
809
 * a risk when private keys are reused (perhaps KEMTLS servers).
810
 */
811
static void
812
scalar_decode_decompress_add(scalar *out, const uint8_t in[DEGREE / 8])
813
0
{
814
0
    static const uint16_t half_q_plus_1 = (ML_KEM_PRIME >> 1) + 1;
815
0
    uint16_t *curr = out->c, *end = curr + DEGREE;
816
0
    uint16_t mask;
817
0
    uint8_t b;
818
819
    /*
820
     * Add |half_q_plus_1| if the bit is set, without exposing a side-channel,
821
     * avoiding the "clangover" attack.  See |constish_time_non_zero| for a
822
     * discussion on why the value barrier is by default omitted.
823
     */
824
0
#define decode_decompress_add_bit                        \
825
0
    mask = constish_time_non_zero(bit0(b));              \
826
0
    *curr = reduce_once(*curr + (mask & half_q_plus_1)); \
827
0
    curr++;                                              \
828
0
    b >>= 1
829
830
    /* Unrolled to process each byte in one iteration */
831
0
    do {
832
0
        b = *in++;
833
0
        decode_decompress_add_bit;
834
0
        decode_decompress_add_bit;
835
0
        decode_decompress_add_bit;
836
0
        decode_decompress_add_bit;
837
838
0
        decode_decompress_add_bit;
839
0
        decode_decompress_add_bit;
840
0
        decode_decompress_add_bit;
841
0
        decode_decompress_add_bit;
842
0
    } while (curr < end);
843
0
#undef decode_decompress_add_bit
844
0
}
845
846
/*
847
 * FIPS 203, Section 4.2.1, Equation (4.7): Compress_d.
848
 *
849
 * Compresses (lossily) an input |x| mod 3329 into |bits| many bits by grouping
850
 * numbers close to each other together. The formula used is
851
 * round(2^|bits|/kPrime*x) mod 2^|bits|.
852
 * Uses Barrett reduction to achieve constant time. Since we need both the
853
 * remainder (for rounding) and the quotient (as the result), we cannot use
854
 * |reduce| here, but need to do the Barrett reduction directly.
855
 */
856
static __owur uint16_t compress(uint16_t x, int bits)
857
0
{
858
0
    uint32_t shifted = (uint32_t)x << bits;
859
0
    uint64_t product = (uint64_t)shifted * kBarrettMultiplier;
860
0
    uint32_t quotient = (uint32_t)(product >> kBarrettShift);
861
0
    uint32_t remainder = shifted - quotient * kPrime;
862
863
    /*
864
     * Adjust the quotient to round correctly:
865
     *   0 <= remainder <= kHalfPrime round to 0
866
     *   kHalfPrime < remainder <= kPrime + kHalfPrime round to 1
867
     *   kPrime + kHalfPrime < remainder < 2 * kPrime round to 2
868
     */
869
0
    quotient += 1 & constant_time_lt_32(kHalfPrime, remainder);
870
0
    quotient += 1 & constant_time_lt_32(kPrime + kHalfPrime, remainder);
871
0
    return quotient & ((1 << bits) - 1);
872
0
}
873
874
/*
875
 * FIPS 203, Section 4.2.1, Equation (4.8): Decompress_d.
876
877
 * Decompresses |x| by using a close equi-distant representative. The formula
878
 * is round(kPrime/2^|bits|*x). Note that 2^|bits| being the divisor allows us
879
 * to implement this logic using only bit operations.
880
 */
881
static __owur uint16_t decompress(uint16_t x, int bits)
882
0
{
883
0
    uint32_t product = (uint32_t)x * kPrime;
884
0
    uint32_t power = 1 << bits;
885
    /* This is |product| % power, since |power| is a power of 2. */
886
0
    uint32_t remainder = product & (power - 1);
887
    /* This is |product| / power, since |power| is a power of 2. */
888
0
    uint32_t lower = product >> bits;
889
890
    /*
891
     * The rounding logic works since the first half of numbers mod |power|
892
     * have a 0 as first bit, and the second half has a 1 as first bit, since
893
     * |power| is a power of 2. As a 12 bit number, |remainder| is always
894
     * positive, so we will shift in 0s for a right shift.
895
     */
896
0
    return lower + (remainder >> (bits - 1));
897
0
}
898
899
/*-
900
 * FIPS 203, Section 4.2.1, Equation (4.7): "Compress_d".
901
 * In-place lossy rounding of scalars to 2^d bits.
902
 */
903
static void scalar_compress(scalar *s, int bits)
904
0
{
905
0
    int i;
906
907
0
    for (i = 0; i < DEGREE; i++)
908
0
        s->c[i] = compress(s->c[i], bits);
909
0
}
910
911
/*
912
 * FIPS 203, Section 4.2.1, Equation (4.8): "Decompress_d".
913
 * In-place approximate recovery of scalars from 2^d bit compression.
914
 */
915
static void scalar_decompress(scalar *s, int bits)
916
0
{
917
0
    int i;
918
919
0
    for (i = 0; i < DEGREE; i++)
920
0
        s->c[i] = decompress(s->c[i], bits);
921
0
}
922
923
/* Addition updating the LHS vector in-place. */
924
static void vector_add(scalar *lhs, const scalar *rhs, int rank)
925
0
{
926
0
    do {
927
0
        scalar_add(lhs++, rhs++);
928
0
    } while (--rank > 0);
929
0
}
930
931
/*
932
 * Encodes an entire vector into 32*|rank|*|bits| bytes. Note that since 256
933
 * (DEGREE) is divisible by 8, the individual vector entries will always fill a
934
 * whole number of bytes, so we do not need to worry about bit packing here.
935
 */
936
static void vector_encode(uint8_t *out, const scalar *a, int bits, int rank)
937
0
{
938
0
    int stride = bits * DEGREE / 8;
939
940
0
    for (; rank-- > 0; out += stride)
941
0
        scalar_encode(out, a++, bits);
942
0
}
943
944
/*
945
 * Decodes 32*|rank|*|bits| bytes from |in| into |out|. It returns early
946
 * if any parsed value is >= |ML_KEM_PRIME|.  The resulting scalars are
947
 * then decompressed and transformed via the NTT.
948
 *
949
 * Note: Used only in decrypt_cpa(), which returns void and so does not check
950
 * the return value of this function.  Side-channels are fine when the input
951
 * ciphertext to decap() is simply syntactically invalid.
952
 */
953
static void
954
vector_decode_decompress_ntt(scalar *out, const uint8_t *in, int bits, int rank)
955
0
{
956
0
    int stride = bits * DEGREE / 8;
957
958
0
    for (; rank-- > 0; in += stride, ++out) {
959
0
        scalar_decode(out, in, bits);
960
0
        scalar_decompress(out, bits);
961
0
        scalar_ntt(out);
962
0
    }
963
0
}
964
965
/* vector_decode(), specialised to bits == 12. */
966
static __owur int vector_decode_12(scalar *out, const uint8_t in[3 * DEGREE / 2], int rank)
967
0
{
968
0
    int stride = 3 * DEGREE / 2;
969
970
0
    for (; rank-- > 0; in += stride)
971
0
        if (!scalar_decode_12(out++, in))
972
0
            return 0;
973
0
    return 1;
974
0
}
975
976
/* In-place compression of each scalar component */
977
static void vector_compress(scalar *a, int bits, int rank)
978
0
{
979
0
    do {
980
0
        scalar_compress(a++, bits);
981
0
    } while (--rank > 0);
982
0
}
983
984
/* The output scalar must not overlap with the inputs */
985
static void inner_product(scalar *out, const scalar *lhs, const scalar *rhs,
986
    int rank)
987
0
{
988
0
    scalar_mult(out, lhs, rhs);
989
0
    while (--rank > 0)
990
0
        scalar_mult_add(out, ++lhs, ++rhs);
991
0
}
992
993
/*
994
 * Here, the output vector must not overlap with the inputs, the result is
995
 * directly subjected to inverse NTT.
996
 */
997
static void
998
matrix_mult_intt(scalar *out, const scalar *m, const scalar *a, int rank)
999
0
{
1000
0
    const scalar *ar;
1001
0
    int i, j;
1002
1003
0
    for (i = rank; i-- > 0; ++out) {
1004
0
        scalar_mult(out, m++, ar = a);
1005
0
        for (j = rank - 1; j > 0; --j)
1006
0
            scalar_mult_add(out, m++, ++ar);
1007
0
        scalar_inverse_ntt(out);
1008
0
    }
1009
0
}
1010
1011
/* Here, the output vector must not overlap with the inputs */
1012
static void
1013
matrix_mult_transpose_add(scalar *out, const scalar *m, const scalar *a, int rank)
1014
0
{
1015
0
    const scalar *mc = m, *mr, *ar;
1016
0
    int i, j;
1017
1018
0
    for (i = rank; i-- > 0; ++out) {
1019
0
        scalar_mult_add(out, mr = mc++, ar = a);
1020
0
        for (j = rank; --j > 0;)
1021
0
            scalar_mult_add(out, (mr += rank), ++ar);
1022
0
    }
1023
0
}
1024
1025
/*-
1026
 * Expands the matrix from a seed for key generation and for encaps-CPA.
1027
 * NOTE: FIPS 203 matrix "A" is the transpose of this matrix, computed
1028
 * by appending the (i,j) indices to the seed in the opposite order!
1029
 *
1030
 * Where FIPS 203 computes t = A * s + e, we use the transpose of "m".
1031
 */
1032
static __owur int matrix_expand(EVP_MD_CTX *mdctx, ML_KEM_KEY *key)
1033
0
{
1034
0
    scalar *out = key->m;
1035
0
    uint8_t input[ML_KEM_RANDOM_BYTES + 2];
1036
0
    int rank = key->vinfo->rank;
1037
0
    int i, j;
1038
1039
0
    memcpy(input, key->rho, ML_KEM_RANDOM_BYTES);
1040
0
    for (i = 0; i < rank; i++) {
1041
0
        for (j = 0; j < rank; j++) {
1042
0
            input[ML_KEM_RANDOM_BYTES] = i;
1043
0
            input[ML_KEM_RANDOM_BYTES + 1] = j;
1044
0
            if (!EVP_DigestInit_ex(mdctx, key->shake128_md, NULL)
1045
0
                || !EVP_DigestUpdate(mdctx, input, sizeof(input))
1046
0
                || !sample_scalar(out++, mdctx))
1047
0
                return 0;
1048
0
        }
1049
0
    }
1050
0
    return 1;
1051
0
}
1052
1053
/*
1054
 * Algorithm 7 from the spec, with eta fixed to two and the PRF call
1055
 * included. Creates binominally distributed elements by sampling 2*|eta| bits,
1056
 * and setting the coefficient to the count of the first bits minus the count of
1057
 * the second bits, resulting in a centered binomial distribution. Since eta is
1058
 * two this gives -2/2 with a probability of 1/16, -1/1 with probability 1/4,
1059
 * and 0 with probability 3/8.
1060
 */
1061
static __owur int cbd_2(scalar *out, uint8_t in[ML_KEM_RANDOM_BYTES + 1],
1062
    EVP_MD_CTX *mdctx, const ML_KEM_KEY *key)
1063
0
{
1064
0
    uint16_t *curr = out->c, *end = curr + DEGREE;
1065
0
    uint8_t randbuf[4 * DEGREE / 8], *r = randbuf; /* 64 * eta slots */
1066
0
    uint16_t value, mask;
1067
0
    uint8_t b;
1068
1069
0
    if (!prf(randbuf, sizeof(randbuf), in, mdctx, key))
1070
0
        return 0;
1071
1072
0
    do {
1073
0
        b = *r++;
1074
1075
        /*
1076
         * Add |kPrime| if |value| underflowed.  See |constish_time_non_zero|
1077
         * for a discussion on why the value barrier is by default omitted.
1078
         * While this could have been written reduce_once(value + kPrime), this
1079
         * is one extra addition and small range of |value| tempts some
1080
         * versions of Clang to emit a branch.
1081
         */
1082
0
        value = bit0(b) + bitn(1, b);
1083
0
        value -= bitn(2, b) + bitn(3, b);
1084
0
        mask = constish_time_non_zero(value >> 15);
1085
0
        *curr++ = value + (kPrime & mask);
1086
1087
0
        value = bitn(4, b) + bitn(5, b);
1088
0
        value -= bitn(6, b) + bitn(7, b);
1089
0
        mask = constish_time_non_zero(value >> 15);
1090
0
        *curr++ = value + (kPrime & mask);
1091
0
    } while (curr < end);
1092
0
    return 1;
1093
0
}
1094
1095
/*
1096
 * Algorithm 7 from the spec, with eta fixed to three and the PRF call
1097
 * included. Creates binominally distributed elements by sampling 3*|eta| bits,
1098
 * and setting the coefficient to the count of the first bits minus the count of
1099
 * the second bits, resulting in a centered binomial distribution.
1100
 */
1101
static __owur int cbd_3(scalar *out, uint8_t in[ML_KEM_RANDOM_BYTES + 1],
1102
    EVP_MD_CTX *mdctx, const ML_KEM_KEY *key)
1103
0
{
1104
0
    uint16_t *curr = out->c, *end = curr + DEGREE;
1105
0
    uint8_t randbuf[6 * DEGREE / 8], *r = randbuf; /* 64 * eta slots */
1106
0
    uint8_t b1, b2, b3;
1107
0
    uint16_t value, mask;
1108
1109
0
    if (!prf(randbuf, sizeof(randbuf), in, mdctx, key))
1110
0
        return 0;
1111
1112
0
    do {
1113
0
        b1 = *r++;
1114
0
        b2 = *r++;
1115
0
        b3 = *r++;
1116
1117
        /*
1118
         * Add |kPrime| if |value| underflowed.  See |constish_time_non_zero|
1119
         * for a discussion on why the value barrier is by default omitted.
1120
         * While this could have been written reduce_once(value + kPrime), this
1121
         * is one extra addition and small range of |value| tempts some
1122
         * versions of Clang to emit a branch.
1123
         */
1124
0
        value = bit0(b1) + bitn(1, b1) + bitn(2, b1);
1125
0
        value -= bitn(3, b1) + bitn(4, b1) + bitn(5, b1);
1126
0
        mask = constish_time_non_zero(value >> 15);
1127
0
        *curr++ = value + (kPrime & mask);
1128
1129
0
        value = bitn(6, b1) + bitn(7, b1) + bit0(b2);
1130
0
        value -= bitn(1, b2) + bitn(2, b2) + bitn(3, b2);
1131
0
        mask = constish_time_non_zero(value >> 15);
1132
0
        *curr++ = value + (kPrime & mask);
1133
1134
0
        value = bitn(4, b2) + bitn(5, b2) + bitn(6, b2);
1135
0
        value -= bitn(7, b2) + bit0(b3) + bitn(1, b3);
1136
0
        mask = constish_time_non_zero(value >> 15);
1137
0
        *curr++ = value + (kPrime & mask);
1138
1139
0
        value = bitn(2, b3) + bitn(3, b3) + bitn(4, b3);
1140
0
        value -= bitn(5, b3) + bitn(6, b3) + bitn(7, b3);
1141
0
        mask = constish_time_non_zero(value >> 15);
1142
0
        *curr++ = value + (kPrime & mask);
1143
0
    } while (curr < end);
1144
0
    return 1;
1145
0
}
1146
1147
/*
1148
 * Generates a secret vector by using |cbd| with the given seed to generate
1149
 * scalar elements and incrementing |counter| for each slot of the vector.
1150
 */
1151
static __owur int gencbd_vector(scalar *out, CBD_FUNC cbd, uint8_t *counter,
1152
    const uint8_t seed[ML_KEM_RANDOM_BYTES], int rank,
1153
    EVP_MD_CTX *mdctx, const ML_KEM_KEY *key)
1154
0
{
1155
0
    uint8_t input[ML_KEM_RANDOM_BYTES + 1];
1156
1157
0
    memcpy(input, seed, ML_KEM_RANDOM_BYTES);
1158
0
    do {
1159
0
        input[ML_KEM_RANDOM_BYTES] = (*counter)++;
1160
0
        if (!cbd(out++, input, mdctx, key))
1161
0
            return 0;
1162
0
    } while (--rank > 0);
1163
0
    return 1;
1164
0
}
1165
1166
/*
1167
 * As above plus NTT transform.
1168
 */
1169
static __owur int gencbd_vector_ntt(scalar *out, CBD_FUNC cbd, uint8_t *counter,
1170
    const uint8_t seed[ML_KEM_RANDOM_BYTES], int rank,
1171
    EVP_MD_CTX *mdctx, const ML_KEM_KEY *key)
1172
0
{
1173
0
    uint8_t input[ML_KEM_RANDOM_BYTES + 1];
1174
1175
0
    memcpy(input, seed, ML_KEM_RANDOM_BYTES);
1176
0
    do {
1177
0
        input[ML_KEM_RANDOM_BYTES] = (*counter)++;
1178
0
        if (!cbd(out, input, mdctx, key))
1179
0
            return 0;
1180
0
        scalar_ntt(out++);
1181
0
    } while (--rank > 0);
1182
0
    return 1;
1183
0
}
1184
1185
/* The |ETA1| value for ML-KEM-512 is 3, the rest and all ETA2 values are 2. */
1186
0
#define CBD1(evp_type) ((evp_type) == EVP_PKEY_ML_KEM_512 ? cbd_3 : cbd_2)
1187
1188
/*
1189
 * FIPS 203, Section 5.2, Algorithm 14: K-PKE.Encrypt.
1190
 *
1191
 * Encrypts a message with given randomness to the ciphertext in |out|. Without
1192
 * applying the Fujisaki-Okamoto transform this would not result in a CCA
1193
 * secure scheme, since lattice schemes are vulnerable to decryption failure
1194
 * oracles.
1195
 *
1196
 * The steps are re-ordered to make more efficient/localised use of storage.
1197
 *
1198
 * Note also that the input public key is assumed to hold a precomputed matrix
1199
 * |A| (our key->m, with the public key holding an expanded (16-bit per scalar
1200
 * coefficient) key->t vector).
1201
 *
1202
 * Caller passes storage in |tmp| for for two temporary vectors.
1203
 */
1204
static __owur int encrypt_cpa(uint8_t out[ML_KEM_SHARED_SECRET_BYTES],
1205
    const uint8_t message[DEGREE / 8],
1206
    const uint8_t r[ML_KEM_RANDOM_BYTES], scalar *tmp,
1207
    EVP_MD_CTX *mdctx, const ML_KEM_KEY *key)
1208
0
{
1209
0
    const ML_KEM_VINFO *vinfo = key->vinfo;
1210
0
    CBD_FUNC cbd_1 = CBD1(vinfo->evp_type);
1211
0
    int rank = vinfo->rank;
1212
    /* We can use tmp[0..rank-1] as storage for |y|, then |e1|, ... */
1213
0
    scalar *y = &tmp[0], *e1 = y, *e2 = y;
1214
    /* We can use tmp[rank]..tmp[2*rank - 1] for |u| */
1215
0
    scalar *u = &tmp[rank];
1216
0
    scalar v;
1217
0
    uint8_t input[ML_KEM_RANDOM_BYTES + 1];
1218
0
    uint8_t counter = 0;
1219
0
    int du = vinfo->du;
1220
0
    int dv = vinfo->dv;
1221
1222
    /* FIPS 203 "y" vector */
1223
0
    if (!gencbd_vector_ntt(y, cbd_1, &counter, r, rank, mdctx, key))
1224
0
        return 0;
1225
    /* FIPS 203 "v" scalar */
1226
0
    inner_product(&v, key->t, y, rank);
1227
0
    scalar_inverse_ntt(&v);
1228
    /* FIPS 203 "u" vector */
1229
0
    matrix_mult_intt(u, key->m, y, rank);
1230
1231
    /* All done with |y|, now free to reuse tmp[0] for FIPS 203 |e1| */
1232
0
    if (!gencbd_vector(e1, cbd_2, &counter, r, rank, mdctx, key))
1233
0
        return 0;
1234
0
    vector_add(u, e1, rank);
1235
0
    vector_compress(u, du, rank);
1236
0
    vector_encode(out, u, du, rank);
1237
1238
    /* All done with |e1|, now free to reuse tmp[0] for FIPS 203 |e2| */
1239
0
    memcpy(input, r, ML_KEM_RANDOM_BYTES);
1240
0
    input[ML_KEM_RANDOM_BYTES] = counter;
1241
0
    if (!cbd_2(e2, input, mdctx, key))
1242
0
        return 0;
1243
0
    scalar_add(&v, e2);
1244
1245
    /* Combine message with |v| */
1246
0
    scalar_decode_decompress_add(&v, message);
1247
0
    scalar_compress(&v, dv);
1248
0
    scalar_encode(out + vinfo->u_vector_bytes, &v, dv);
1249
0
    return 1;
1250
0
}
1251
1252
/*
1253
 * FIPS 203, Section 5.3, Algorithm 15: K-PKE.Decrypt.
1254
 */
1255
static void
1256
decrypt_cpa(uint8_t out[ML_KEM_SHARED_SECRET_BYTES],
1257
    const uint8_t *ctext, scalar *u, const ML_KEM_KEY *key)
1258
0
{
1259
0
    const ML_KEM_VINFO *vinfo = key->vinfo;
1260
0
    scalar v, mask;
1261
0
    int rank = vinfo->rank;
1262
0
    int du = vinfo->du;
1263
0
    int dv = vinfo->dv;
1264
1265
0
    vector_decode_decompress_ntt(u, ctext, du, rank);
1266
0
    scalar_decode(&v, ctext + vinfo->u_vector_bytes, dv);
1267
0
    scalar_decompress(&v, dv);
1268
0
    inner_product(&mask, key->s, u, rank);
1269
0
    scalar_inverse_ntt(&mask);
1270
0
    scalar_sub(&v, &mask);
1271
0
    scalar_compress(&v, 1);
1272
0
    scalar_encode_1(out, &v);
1273
0
}
1274
1275
/*-
1276
 * FIPS 203, Section 7.1, Algorithm 19: "ML-KEM.KeyGen".
1277
 * FIPS 203, Section 7.2, Algorithm 20: "ML-KEM.Encaps".
1278
 *
1279
 * Fills the |out| buffer with the |ek| output of "ML-KEM.KeyGen", or,
1280
 * equivalently, the |ek| input of "ML-KEM.Encaps", i.e. returns the
1281
 * wire-format of an ML-KEM public key.
1282
 */
1283
static void encode_pubkey(uint8_t *out, const ML_KEM_KEY *key)
1284
0
{
1285
0
    const uint8_t *rho = key->rho;
1286
0
    const ML_KEM_VINFO *vinfo = key->vinfo;
1287
1288
0
    vector_encode(out, key->t, 12, vinfo->rank);
1289
0
    memcpy(out + vinfo->vector_bytes, rho, ML_KEM_RANDOM_BYTES);
1290
0
}
1291
1292
/*-
1293
 * FIPS 203, Section 7.1, Algorithm 19: "ML-KEM.KeyGen".
1294
 *
1295
 * Fills the |out| buffer with the |dk| output of "ML-KEM.KeyGen".
1296
 * This matches the input format of parse_prvkey() below.
1297
 */
1298
static void encode_prvkey(uint8_t *out, const ML_KEM_KEY *key)
1299
0
{
1300
0
    const ML_KEM_VINFO *vinfo = key->vinfo;
1301
1302
0
    vector_encode(out, key->s, 12, vinfo->rank);
1303
0
    out += vinfo->vector_bytes;
1304
0
    encode_pubkey(out, key);
1305
0
    out += vinfo->pubkey_bytes;
1306
0
    memcpy(out, key->pkhash, ML_KEM_PKHASH_BYTES);
1307
0
    out += ML_KEM_PKHASH_BYTES;
1308
0
    memcpy(out, key->z, ML_KEM_RANDOM_BYTES);
1309
0
}
1310
1311
/*-
1312
 * FIPS 203, Section 7.1, Algorithm 19: "ML-KEM.KeyGen".
1313
 * FIPS 203, Section 7.2, Algorithm 20: "ML-KEM.Encaps".
1314
 *
1315
 * This function parses the |in| buffer as the |ek| output of "ML-KEM.KeyGen",
1316
 * or, equivalently, the |ek| input of "ML-KEM.Encaps", i.e. decodes the
1317
 * wire-format of the ML-KEM public key.
1318
 */
1319
static int parse_pubkey(const uint8_t *in, EVP_MD_CTX *mdctx, ML_KEM_KEY *key)
1320
0
{
1321
0
    const ML_KEM_VINFO *vinfo = key->vinfo;
1322
1323
    /* Decode and check |t| */
1324
0
    if (!vector_decode_12(key->t, in, vinfo->rank)) {
1325
0
        ERR_raise_data(ERR_LIB_PROV, PROV_R_INVALID_KEY,
1326
0
            "%s invalid public 't' vector",
1327
0
            vinfo->algorithm_name);
1328
0
        return 0;
1329
0
    }
1330
    /* Save the matrix |m| recovery seed |rho| */
1331
0
    memcpy(key->rho, in + vinfo->vector_bytes, ML_KEM_RANDOM_BYTES);
1332
    /*
1333
     * Pre-compute the public key hash, needed for both encap and decap.
1334
     * Also pre-compute the matrix expansion, stored with the public key.
1335
     */
1336
0
    if (!hash_h(key->pkhash, in, vinfo->pubkey_bytes, mdctx, key)
1337
0
        || !matrix_expand(mdctx, key)) {
1338
0
        ERR_raise_data(ERR_LIB_CRYPTO, ERR_R_INTERNAL_ERROR,
1339
0
            "internal error while parsing %s public key",
1340
0
            vinfo->algorithm_name);
1341
0
        return 0;
1342
0
    }
1343
0
    return 1;
1344
0
}
1345
1346
/*
1347
 * FIPS 203, Section 7.1, Algorithm 19: "ML-KEM.KeyGen".
1348
 *
1349
 * Parses the |in| buffer as a |dk| output of "ML-KEM.KeyGen".
1350
 * This matches the output format of encode_prvkey() above.
1351
 */
1352
static int parse_prvkey(const uint8_t *in, EVP_MD_CTX *mdctx, ML_KEM_KEY *key)
1353
0
{
1354
0
    const ML_KEM_VINFO *vinfo = key->vinfo;
1355
1356
    /* Decode and check |s|. */
1357
0
    if (!vector_decode_12(key->s, in, vinfo->rank)) {
1358
0
        ERR_raise_data(ERR_LIB_PROV, PROV_R_INVALID_KEY,
1359
0
            "%s invalid private 's' vector",
1360
0
            vinfo->algorithm_name);
1361
0
        return 0;
1362
0
    }
1363
0
    in += vinfo->vector_bytes;
1364
1365
0
    if (!parse_pubkey(in, mdctx, key))
1366
0
        return 0;
1367
0
    in += vinfo->pubkey_bytes;
1368
1369
    /* Check public key hash. */
1370
0
    if (memcmp(key->pkhash, in, ML_KEM_PKHASH_BYTES) != 0) {
1371
0
        ERR_raise_data(ERR_LIB_PROV, PROV_R_INVALID_KEY,
1372
0
            "%s public key hash mismatch",
1373
0
            vinfo->algorithm_name);
1374
0
        return 0;
1375
0
    }
1376
0
    in += ML_KEM_PKHASH_BYTES;
1377
1378
0
    memcpy(key->z, in, ML_KEM_RANDOM_BYTES);
1379
0
    return 1;
1380
0
}
1381
1382
/*
1383
 * FIPS 203, Section 6.1, Algorithm 16: "ML-KEM.KeyGen_internal".
1384
 *
1385
 * The implementation of Section 5.1, Algorithm 13, "K-PKE.KeyGen(d)" is
1386
 * inlined.
1387
 *
1388
 * The caller MUST pass a pre-allocated digest context that is not shared with
1389
 * any concurrent computation.
1390
 *
1391
 * This function optionally outputs the serialised wire-form |ek| public key
1392
 * into the provided |pubenc| buffer, and generates the content of the |rho|,
1393
 * |pkhash|, |t|, |m|, |s| and |z| components of the private |key| (which must
1394
 * have preallocated space for these).
1395
 *
1396
 * Keys are computed from a 32-byte random |d| plus the 1 byte rank for
1397
 * domain separation.  These are concatenated and hashed to produce a pair of
1398
 * 32-byte seeds public "rho", used to generate the matrix, and private "sigma",
1399
 * used to generate the secret vector |s|.
1400
 *
1401
 * The second random input |z| is copied verbatim into the Fujisaki-Okamoto
1402
 * (FO) transform "implicit-rejection" secret (the |z| component of the private
1403
 * key), which thwarts chosen-ciphertext attacks, provided decap() runs in
1404
 * constant time, with no side channel leaks, on all well-formed (valid length,
1405
 * and correctly encoded) ciphertext inputs.
1406
 */
1407
static __owur int genkey(const uint8_t seed[ML_KEM_SEED_BYTES],
1408
    EVP_MD_CTX *mdctx, uint8_t *pubenc, ML_KEM_KEY *key)
1409
0
{
1410
0
    uint8_t hashed[2 * ML_KEM_RANDOM_BYTES];
1411
0
    const uint8_t *const sigma = hashed + ML_KEM_RANDOM_BYTES;
1412
0
    uint8_t augmented_seed[ML_KEM_RANDOM_BYTES + 1];
1413
0
    const ML_KEM_VINFO *vinfo = key->vinfo;
1414
0
    CBD_FUNC cbd_1 = CBD1(vinfo->evp_type);
1415
0
    int rank = vinfo->rank;
1416
0
    uint8_t counter = 0;
1417
0
    int ret = 0;
1418
1419
    /*
1420
     * Use the "d" seed salted with the rank to derive the public and private
1421
     * seeds rho and sigma.
1422
     */
1423
0
    memcpy(augmented_seed, seed, ML_KEM_RANDOM_BYTES);
1424
0
    augmented_seed[ML_KEM_RANDOM_BYTES] = (uint8_t)rank;
1425
0
    if (!hash_g(hashed, augmented_seed, sizeof(augmented_seed), mdctx, key))
1426
0
        goto end;
1427
0
    memcpy(key->rho, hashed, ML_KEM_RANDOM_BYTES);
1428
    /* The |rho| matrix seed is public */
1429
0
    CONSTTIME_DECLASSIFY(key->rho, ML_KEM_RANDOM_BYTES);
1430
1431
    /* FIPS 203 |e| vector is initial value of key->t */
1432
0
    if (!matrix_expand(mdctx, key)
1433
0
        || !gencbd_vector_ntt(key->s, cbd_1, &counter, sigma, rank, mdctx, key)
1434
0
        || !gencbd_vector_ntt(key->t, cbd_1, &counter, sigma, rank, mdctx, key))
1435
0
        goto end;
1436
1437
    /* To |e| we now add the product of transpose |m| and |s|, giving |t|. */
1438
0
    matrix_mult_transpose_add(key->t, key->m, key->s, rank);
1439
    /* The |t| vector is public */
1440
0
    CONSTTIME_DECLASSIFY(key->t, vinfo->rank * sizeof(scalar));
1441
1442
0
    if (pubenc == NULL) {
1443
        /* Incremental digest of public key without in-full serialisation. */
1444
0
        if (!hash_h_pubkey(key->pkhash, mdctx, key))
1445
0
            goto end;
1446
0
    } else {
1447
0
        encode_pubkey(pubenc, key);
1448
0
        if (!hash_h(key->pkhash, pubenc, vinfo->pubkey_bytes, mdctx, key))
1449
0
            goto end;
1450
0
    }
1451
1452
    /* Save |z| portion of seed for "implicit rejection" on failure. */
1453
0
    memcpy(key->z, seed + ML_KEM_RANDOM_BYTES, ML_KEM_RANDOM_BYTES);
1454
1455
    /* Save the |d| portion of the seed */
1456
0
    key->d = key->z + ML_KEM_RANDOM_BYTES;
1457
0
    memcpy(key->d, seed, ML_KEM_RANDOM_BYTES);
1458
1459
0
    ret = 1;
1460
0
end:
1461
0
    OPENSSL_cleanse((void *)augmented_seed, ML_KEM_RANDOM_BYTES);
1462
0
    OPENSSL_cleanse((void *)sigma, ML_KEM_RANDOM_BYTES);
1463
0
    if (ret == 0) {
1464
0
        ERR_raise_data(ERR_LIB_CRYPTO, ERR_R_INTERNAL_ERROR,
1465
0
            "internal error while generating %s private key",
1466
0
            vinfo->algorithm_name);
1467
0
    }
1468
0
    return ret;
1469
0
}
1470
1471
/*-
1472
 * FIPS 203, Section 6.2, Algorithm 17: "ML-KEM.Encaps_internal".
1473
 * This is the deterministic version with randomness supplied externally.
1474
 *
1475
 * The caller must pass space for two vectors in |tmp|.
1476
 * The |ctext| buffer have space for the ciphertext of the ML-KEM variant
1477
 * of the provided key.
1478
 */
1479
static int encap(uint8_t *ctext, uint8_t secret[ML_KEM_SHARED_SECRET_BYTES],
1480
    const uint8_t entropy[ML_KEM_RANDOM_BYTES],
1481
    scalar *tmp, EVP_MD_CTX *mdctx, const ML_KEM_KEY *key)
1482
0
{
1483
0
    uint8_t input[ML_KEM_RANDOM_BYTES + ML_KEM_PKHASH_BYTES];
1484
0
    uint8_t Kr[ML_KEM_SHARED_SECRET_BYTES + ML_KEM_RANDOM_BYTES];
1485
0
    uint8_t *r = Kr + ML_KEM_SHARED_SECRET_BYTES;
1486
0
    int ret;
1487
1488
0
    memcpy(input, entropy, ML_KEM_RANDOM_BYTES);
1489
0
    memcpy(input + ML_KEM_RANDOM_BYTES, key->pkhash, ML_KEM_PKHASH_BYTES);
1490
0
    ret = hash_g(Kr, input, sizeof(input), mdctx, key)
1491
0
        && encrypt_cpa(ctext, entropy, r, tmp, mdctx, key);
1492
0
    OPENSSL_cleanse((void *)input, sizeof(input));
1493
1494
0
    if (ret)
1495
0
        memcpy(secret, Kr, ML_KEM_SHARED_SECRET_BYTES);
1496
0
    else
1497
0
        ERR_raise_data(ERR_LIB_CRYPTO, ERR_R_INTERNAL_ERROR,
1498
0
            "internal error while performing %s encapsulation",
1499
0
            key->vinfo->algorithm_name);
1500
0
    return ret;
1501
0
}
1502
1503
/*
1504
 * Hash the input message |m'| and public key digest |h|
1505
 * to obtain |K| and |r|.
1506
 */
1507
static int hash_kr(uint8_t *out, uint8_t *in,
1508
    EVP_MD_CTX *mdctx, const ML_KEM_KEY *key)
1509
0
{
1510
0
    unsigned int sz, wanted;
1511
1512
0
    wanted = ML_KEM_SHARED_SECRET_BYTES + ML_KEM_RANDOM_BYTES;
1513
0
    return (EVP_DigestInit_ex(mdctx, key->sha3_512_md, NULL)
1514
0
        && EVP_DigestUpdate(mdctx, in, ML_KEM_RANDOM_BYTES)
1515
0
        && EVP_DigestUpdate(mdctx, key->pkhash, ML_KEM_PKHASH_BYTES)
1516
0
        && EVP_DigestFinal_ex(mdctx, out, &sz)
1517
0
        && ossl_assert(sz == wanted));
1518
0
}
1519
1520
/*-
1521
 * Decap needs space for: Kbar | K | r | m'
1522
 * We slice up a single buffer to hold them all.
1523
 * We don't need to cleanse the public pkhash value.
1524
 */
1525
0
#define DECAP_BUFFER_SZ (2 * ML_KEM_SHARED_SECRET_BYTES + 2 * ML_KEM_RANDOM_BYTES)
1526
1527
/*
1528
 * FIPS 203, Section 6.3, Algorithm 18: ML-KEM.Decaps_internal
1529
 *
1530
 * Barring failure of the supporting SHA3/SHAKE primitives, this is fully
1531
 * deterministic, the randomness for the FO transform is extracted during
1532
 * private key generation.
1533
 *
1534
 * The caller must pass space for two vectors in |tmp|.
1535
 * The |ctext| and |tmp_ctext| buffers must each have space for the ciphertext
1536
 * of the key's ML-KEM variant.
1537
 */
1538
static int decap(uint8_t secret[ML_KEM_SHARED_SECRET_BYTES],
1539
    const uint8_t *ctext, uint8_t *tmp_ctext, scalar *tmp,
1540
    EVP_MD_CTX *mdctx, const ML_KEM_KEY *key)
1541
0
{
1542
0
    uint8_t buf[DECAP_BUFFER_SZ];
1543
0
    uint8_t *failure_key = buf; /* Kbar */
1544
0
    uint8_t *Kr = failure_key + ML_KEM_SHARED_SECRET_BYTES;
1545
0
    uint8_t *r = Kr + ML_KEM_SHARED_SECRET_BYTES;
1546
0
    uint8_t *m = r + ML_KEM_RANDOM_BYTES; /* m' */
1547
0
    const ML_KEM_VINFO *vinfo = key->vinfo;
1548
0
    int i;
1549
0
    uint8_t mask;
1550
1551
    /*
1552
     * If our KDF is unavailable, fail early! Otherwise, keep going ignoring
1553
     * any further errors, returning success, and whatever we got for a shared
1554
     * secret.  The decrypt_cpa() function is just arithmetic on secret data,
1555
     * so should not be subject to failure that makes its output predictable.
1556
     *
1557
     * We guard against "should never happen" catastrophic failure of the
1558
     * "pure" function |hash_g| by overwriting the shared secret with the
1559
     * content of the failure key and returning early, if nevertheless hash_g
1560
     * fails.  This is not constant-time, but a failure of |hash_g| already
1561
     * implies loss of side-channel resistance.
1562
     *
1563
     * The same action is taken, if also |encrypt_cpa| should catastrophically
1564
     * fail, due to failure of the |PRF| underlying the CBD functions.
1565
     */
1566
0
    if (!kdf(failure_key, key->z, ctext, vinfo->ctext_bytes, mdctx, key)) {
1567
0
        ERR_raise_data(ERR_LIB_CRYPTO, ERR_R_INTERNAL_ERROR,
1568
0
            "internal error while performing %s decapsulation",
1569
0
            vinfo->algorithm_name);
1570
0
        return 0;
1571
0
    }
1572
0
    decrypt_cpa(m, ctext, tmp, key);
1573
0
    if (!hash_kr(Kr, m, mdctx, key)
1574
0
        || !encrypt_cpa(tmp_ctext, m, r, tmp, mdctx, key)) {
1575
0
        memcpy(secret, failure_key, ML_KEM_SHARED_SECRET_BYTES);
1576
0
        goto end;
1577
0
    }
1578
0
    mask = constant_time_eq_int_8(0,
1579
0
        CRYPTO_memcmp(ctext, tmp_ctext, vinfo->ctext_bytes));
1580
0
    for (i = 0; i < ML_KEM_SHARED_SECRET_BYTES; i++)
1581
0
        secret[i] = constant_time_select_8(mask, Kr[i], failure_key[i]);
1582
0
end:
1583
0
    OPENSSL_cleanse(buf, DECAP_BUFFER_SZ);
1584
0
    return 1;
1585
0
}
1586
1587
/*
1588
 * After allocating storage for public or private key data, update the key
1589
 * component pointers to reference that storage.
1590
 *
1591
 * The caller should only store private data in `priv` *after* a successful
1592
 * (non-zero) return from this function.
1593
 */
1594
static __owur int add_storage(scalar *pub, scalar *priv, int private, ML_KEM_KEY *key)
1595
0
{
1596
0
    int rank = key->vinfo->rank;
1597
1598
0
    if (pub == NULL || (private && priv == NULL)) {
1599
        /*
1600
         * One of these could be allocated correctly. It is legal to call free with a NULL
1601
         * pointer, so always attempt to free both allocations here
1602
         */
1603
0
        OPENSSL_free(pub);
1604
0
        OPENSSL_secure_free(priv);
1605
0
        return 0;
1606
0
    }
1607
1608
    /*
1609
     * We're adding key material, set up rho and pkhash to point to the rho_pkhash buffer
1610
     */
1611
0
    memset(key->rho_pkhash, 0, sizeof(key->rho_pkhash));
1612
0
    key->rho = key->rho_pkhash;
1613
0
    key->pkhash = key->rho_pkhash + ML_KEM_RANDOM_BYTES;
1614
0
    key->d = key->z = NULL;
1615
1616
    /* A public key needs space for |t| and |m| */
1617
0
    key->m = (key->t = pub) + rank;
1618
1619
    /*
1620
     * A private key also needs space for |s| and |z|.
1621
     * The |z| buffer always includes additional space for |d|, but a key's |d|
1622
     * pointer is left NULL when parsed from the NIST format, which omits that
1623
     * information.  Only keys generated from a (d, z) seed pair will have a
1624
     * non-NULL |d| pointer.
1625
     */
1626
0
    if (private)
1627
0
        key->z = (uint8_t *)(rank + (key->s = priv));
1628
0
    return 1;
1629
0
}
1630
1631
/*
1632
 * After freeing the storage associated with a key that failed to be
1633
 * constructed, reset the internal pointers back to NULL.
1634
 */
1635
void ossl_ml_kem_key_reset(ML_KEM_KEY *key)
1636
0
{
1637
    /*
1638
     * seedbuf can be allocated and contain |z| and |d| if the key is
1639
     * being created from a private key encoding.  Similarly a pending
1640
     * serialised (encoded) private key may be queued up to load.
1641
     * Clear and free that data now.
1642
     */
1643
0
    if (key->seedbuf != NULL)
1644
0
        OPENSSL_secure_clear_free(key->seedbuf, ML_KEM_SEED_BYTES);
1645
0
    if (ossl_ml_kem_have_dkenc(key))
1646
0
        OPENSSL_secure_clear_free(key->encoded_dk, key->vinfo->prvkey_bytes);
1647
1648
    /*-
1649
     * Cleanse any sensitive data:
1650
     * - The private vector |s| is immediately followed by the FO failure
1651
     *   secret |z|, and seed |d|, we can cleanse all three in one call.
1652
     */
1653
0
    if (key->t != NULL) {
1654
0
        if (ossl_ml_kem_have_prvkey(key))
1655
0
            OPENSSL_secure_clear_free(key->s, key->vinfo->prvalloc);
1656
0
        OPENSSL_free(key->t);
1657
0
    }
1658
0
    key->d = key->z = key->seedbuf = key->encoded_dk = (uint8_t *)(key->s = key->m = key->t = NULL);
1659
0
}
1660
1661
/*
1662
 * ----- API exported to the provider
1663
 *
1664
 * Parameters with an implicit fixed length in the internal static API of each
1665
 * variant have an explicit checked length argument at this layer.
1666
 */
1667
1668
/* Retrieve the parameters of one of the ML-KEM variants */
1669
const ML_KEM_VINFO *ossl_ml_kem_get_vinfo(int evp_type)
1670
0
{
1671
0
    (void)CRYPTO_THREAD_run_once(&ml_kem_ntt_once, ml_kem_ntt_init);
1672
1673
0
    switch (evp_type) {
1674
0
    case EVP_PKEY_ML_KEM_512:
1675
0
        return &vinfo_map[ML_KEM_512_VINFO];
1676
0
    case EVP_PKEY_ML_KEM_768:
1677
0
        return &vinfo_map[ML_KEM_768_VINFO];
1678
0
    case EVP_PKEY_ML_KEM_1024:
1679
0
        return &vinfo_map[ML_KEM_1024_VINFO];
1680
0
    }
1681
0
    return NULL;
1682
0
}
1683
1684
/*
1685
 * @brief Fetch digest algorithms based on a propq.
1686
 * For the import case ossl_ml_kem_key_new() gets passed a NULL propq,
1687
 * so the propq is optionally deferred to the import using OSSL_PARAM.
1688
 */
1689
int ossl_ml_kem_key_fetch_digest(ML_KEM_KEY *key, const char *propq)
1690
0
{
1691
0
    if (key->shake128_md != NULL) {
1692
0
        EVP_MD_free(key->shake128_md);
1693
0
        EVP_MD_free(key->shake256_md);
1694
0
        EVP_MD_free(key->sha3_256_md);
1695
0
        EVP_MD_free(key->sha3_512_md);
1696
0
    }
1697
0
    key->shake128_md = EVP_MD_fetch(key->libctx, "SHAKE128", propq);
1698
0
    key->shake256_md = EVP_MD_fetch(key->libctx, "SHAKE256", propq);
1699
0
    key->sha3_256_md = EVP_MD_fetch(key->libctx, "SHA3-256", propq);
1700
0
    key->sha3_512_md = EVP_MD_fetch(key->libctx, "SHA3-512", propq);
1701
0
    return (key->shake128_md != NULL && key->shake256_md != NULL
1702
0
        && key->sha3_256_md != NULL && key->sha3_512_md != NULL);
1703
0
}
1704
1705
ML_KEM_KEY *ossl_ml_kem_key_new(OSSL_LIB_CTX *libctx, const char *properties,
1706
    int evp_type)
1707
0
{
1708
0
    const ML_KEM_VINFO *vinfo = ossl_ml_kem_get_vinfo(evp_type);
1709
0
    ML_KEM_KEY *key;
1710
1711
0
    if (vinfo == NULL) {
1712
0
        ERR_raise_data(ERR_LIB_CRYPTO, ERR_R_PASSED_INVALID_ARGUMENT,
1713
0
            "unsupported ML-KEM key type: %d", evp_type);
1714
0
        return NULL;
1715
0
    }
1716
1717
0
    if ((key = OPENSSL_malloc(sizeof(*key))) == NULL)
1718
0
        return NULL;
1719
1720
0
    key->vinfo = vinfo;
1721
0
    key->libctx = libctx;
1722
0
    key->prov_flags = ML_KEM_KEY_PROV_FLAGS_DEFAULT;
1723
0
    key->d = key->z = key->rho = key->pkhash = key->encoded_dk = key->seedbuf = NULL;
1724
0
    key->s = key->m = key->t = NULL;
1725
0
    key->shake128_md = key->shake256_md = key->sha3_256_md = key->sha3_512_md = NULL;
1726
0
    if (ossl_ml_kem_key_fetch_digest(key, properties))
1727
0
        return key;
1728
1729
0
    ossl_ml_kem_key_free(key);
1730
0
    ERR_raise_data(ERR_LIB_CRYPTO, ERR_R_INTERNAL_ERROR,
1731
0
        "missing SHA3 digest algorithms while creating %s key",
1732
0
        vinfo->algorithm_name);
1733
0
    return NULL;
1734
0
}
1735
1736
ML_KEM_KEY *ossl_ml_kem_key_dup(const ML_KEM_KEY *key, int selection)
1737
0
{
1738
0
    int ok = 0;
1739
0
    ML_KEM_KEY *ret;
1740
0
    void *tmp_pub;
1741
0
    void *tmp_priv;
1742
1743
0
    if (key == NULL)
1744
0
        return NULL;
1745
    /*
1746
     * Partially decoded keys, not yet imported or loaded, should never be
1747
     * duplicated.
1748
     */
1749
0
    if (ossl_ml_kem_decoded_key(key))
1750
0
        return NULL;
1751
1752
0
    else if ((ret = OPENSSL_memdup(key, sizeof(*key))) == NULL)
1753
0
        return NULL;
1754
1755
0
    ret->d = ret->z = ret->rho = ret->pkhash = NULL;
1756
0
    ret->s = ret->m = ret->t = NULL;
1757
1758
    /* Clear selection bits we can't fulfill */
1759
0
    if (!ossl_ml_kem_have_pubkey(key))
1760
0
        selection = 0;
1761
0
    else if (!ossl_ml_kem_have_prvkey(key))
1762
0
        selection &= ~OSSL_KEYMGMT_SELECT_PRIVATE_KEY;
1763
1764
0
    switch (selection & OSSL_KEYMGMT_SELECT_KEYPAIR) {
1765
0
    case 0:
1766
0
        ok = 1;
1767
0
        break;
1768
0
    case OSSL_KEYMGMT_SELECT_PUBLIC_KEY:
1769
0
        ok = add_storage(OPENSSL_memdup(key->t, key->vinfo->puballoc), NULL, 0, ret);
1770
0
        break;
1771
0
    case OSSL_KEYMGMT_SELECT_PRIVATE_KEY:
1772
0
        tmp_pub = OPENSSL_memdup(key->t, key->vinfo->puballoc);
1773
0
        if (tmp_pub == NULL)
1774
0
            break;
1775
0
        tmp_priv = OPENSSL_secure_malloc(key->vinfo->prvalloc);
1776
0
        if (tmp_priv == NULL) {
1777
0
            OPENSSL_free(tmp_pub);
1778
0
            break;
1779
0
        }
1780
0
        if ((ok = add_storage(tmp_pub, tmp_priv, 1, ret)) != 0)
1781
0
            memcpy(tmp_priv, key->s, key->vinfo->prvalloc);
1782
        /* Duplicated keys retain |d|, if available */
1783
0
        if (key->d != NULL)
1784
0
            ret->d = ret->z + ML_KEM_RANDOM_BYTES;
1785
0
        break;
1786
0
    }
1787
1788
0
    if (!ok) {
1789
0
        OPENSSL_free(ret);
1790
0
        return NULL;
1791
0
    }
1792
1793
0
    EVP_MD_up_ref(ret->shake128_md);
1794
0
    EVP_MD_up_ref(ret->shake256_md);
1795
0
    EVP_MD_up_ref(ret->sha3_256_md);
1796
0
    EVP_MD_up_ref(ret->sha3_512_md);
1797
1798
0
    return ret;
1799
0
}
1800
1801
void ossl_ml_kem_key_free(ML_KEM_KEY *key)
1802
0
{
1803
0
    if (key == NULL)
1804
0
        return;
1805
1806
0
    EVP_MD_free(key->shake128_md);
1807
0
    EVP_MD_free(key->shake256_md);
1808
0
    EVP_MD_free(key->sha3_256_md);
1809
0
    EVP_MD_free(key->sha3_512_md);
1810
1811
0
    ossl_ml_kem_key_reset(key);
1812
0
    OPENSSL_free(key);
1813
0
}
1814
1815
/* Serialise the public component of an ML-KEM key */
1816
int ossl_ml_kem_encode_public_key(uint8_t *out, size_t len,
1817
    const ML_KEM_KEY *key)
1818
0
{
1819
0
    if (!ossl_ml_kem_have_pubkey(key)
1820
0
        || len != key->vinfo->pubkey_bytes)
1821
0
        return 0;
1822
0
    encode_pubkey(out, key);
1823
0
    return 1;
1824
0
}
1825
1826
/* Serialise an ML-KEM private key */
1827
int ossl_ml_kem_encode_private_key(uint8_t *out, size_t len,
1828
    const ML_KEM_KEY *key)
1829
0
{
1830
0
    if (!ossl_ml_kem_have_prvkey(key)
1831
0
        || len != key->vinfo->prvkey_bytes)
1832
0
        return 0;
1833
0
    encode_prvkey(out, key);
1834
0
    return 1;
1835
0
}
1836
1837
int ossl_ml_kem_encode_seed(uint8_t *out, size_t len,
1838
    const ML_KEM_KEY *key)
1839
0
{
1840
0
    if (key == NULL || key->d == NULL || len != ML_KEM_SEED_BYTES)
1841
0
        return 0;
1842
    /*
1843
     * Both in the seed buffer, and in the allocated storage, the |d| component
1844
     * of the seed is stored last, so we must copy each separately.
1845
     */
1846
0
    memcpy(out, key->d, ML_KEM_RANDOM_BYTES);
1847
0
    out += ML_KEM_RANDOM_BYTES;
1848
0
    memcpy(out, key->z, ML_KEM_RANDOM_BYTES);
1849
0
    return 1;
1850
0
}
1851
1852
/*
1853
 * Stash the seed without (yet) performing a keygen, used during decoding, to
1854
 * avoid an extra keygen if we're only going to export the key again to load
1855
 * into another provider.
1856
 */
1857
ML_KEM_KEY *ossl_ml_kem_set_seed(const uint8_t *seed, size_t seedlen, ML_KEM_KEY *key)
1858
0
{
1859
0
    if (key == NULL
1860
0
        || ossl_ml_kem_have_pubkey(key)
1861
0
        || ossl_ml_kem_have_seed(key)
1862
0
        || seedlen != ML_KEM_SEED_BYTES)
1863
0
        return NULL;
1864
1865
0
    if (key->seedbuf == NULL) {
1866
0
        key->seedbuf = OPENSSL_secure_malloc(seedlen);
1867
0
        if (key->seedbuf == NULL)
1868
0
            return NULL;
1869
0
    }
1870
1871
0
    key->z = key->seedbuf;
1872
0
    key->d = key->z + ML_KEM_RANDOM_BYTES;
1873
0
    memcpy(key->d, seed, ML_KEM_RANDOM_BYTES);
1874
0
    seed += ML_KEM_RANDOM_BYTES;
1875
0
    memcpy(key->z, seed, ML_KEM_RANDOM_BYTES);
1876
0
    return key;
1877
0
}
1878
1879
/* Parse input as a public key */
1880
int ossl_ml_kem_parse_public_key(const uint8_t *in, size_t len, ML_KEM_KEY *key)
1881
0
{
1882
0
    EVP_MD_CTX *mdctx = NULL;
1883
0
    const ML_KEM_VINFO *vinfo;
1884
0
    int ret = 0;
1885
1886
    /* Keys with key material are immutable */
1887
0
    if (key == NULL
1888
0
        || ossl_ml_kem_have_pubkey(key)
1889
0
        || ossl_ml_kem_have_dkenc(key))
1890
0
        return 0;
1891
0
    vinfo = key->vinfo;
1892
1893
0
    if (len != vinfo->pubkey_bytes
1894
0
        || (mdctx = EVP_MD_CTX_new()) == NULL)
1895
0
        return 0;
1896
1897
0
    if (add_storage(OPENSSL_malloc(vinfo->puballoc), NULL, 0, key))
1898
0
        ret = parse_pubkey(in, mdctx, key);
1899
1900
0
    if (!ret)
1901
0
        ossl_ml_kem_key_reset(key);
1902
0
    EVP_MD_CTX_free(mdctx);
1903
0
    return ret;
1904
0
}
1905
1906
/* Parse input as a new private key */
1907
int ossl_ml_kem_parse_private_key(const uint8_t *in, size_t len,
1908
    ML_KEM_KEY *key)
1909
0
{
1910
0
    EVP_MD_CTX *mdctx = NULL;
1911
0
    const ML_KEM_VINFO *vinfo;
1912
0
    int ret = 0;
1913
1914
    /* Keys with key material are immutable */
1915
0
    if (key == NULL
1916
0
        || ossl_ml_kem_have_pubkey(key)
1917
0
        || ossl_ml_kem_have_dkenc(key))
1918
0
        return 0;
1919
0
    vinfo = key->vinfo;
1920
1921
0
    if (len != vinfo->prvkey_bytes
1922
0
        || (mdctx = EVP_MD_CTX_new()) == NULL)
1923
0
        return 0;
1924
1925
0
    if (add_storage(OPENSSL_malloc(vinfo->puballoc),
1926
0
            OPENSSL_secure_malloc(vinfo->prvalloc), 1, key))
1927
0
        ret = parse_prvkey(in, mdctx, key);
1928
1929
0
    if (!ret)
1930
0
        ossl_ml_kem_key_reset(key);
1931
0
    EVP_MD_CTX_free(mdctx);
1932
0
    return ret;
1933
0
}
1934
1935
/*
1936
 * Generate a new keypair, either from the saved seed (when non-null), or from
1937
 * the RNG.
1938
 */
1939
int ossl_ml_kem_genkey(uint8_t *pubenc, size_t publen, ML_KEM_KEY *key)
1940
0
{
1941
0
    uint8_t seed[ML_KEM_SEED_BYTES];
1942
0
    EVP_MD_CTX *mdctx = NULL;
1943
0
    const ML_KEM_VINFO *vinfo;
1944
0
    int ret = 0;
1945
1946
0
    if (key == NULL
1947
0
        || ossl_ml_kem_have_pubkey(key)
1948
0
        || ossl_ml_kem_have_dkenc(key))
1949
0
        return 0;
1950
0
    vinfo = key->vinfo;
1951
1952
0
    if (pubenc != NULL && publen != vinfo->pubkey_bytes)
1953
0
        return 0;
1954
1955
0
    if (key->seedbuf != NULL) {
1956
0
        if (!ossl_ml_kem_encode_seed(seed, sizeof(seed), key))
1957
0
            return 0;
1958
0
        ossl_ml_kem_key_reset(key);
1959
0
    } else if (RAND_priv_bytes_ex(key->libctx, seed, sizeof(seed),
1960
0
                   key->vinfo->secbits)
1961
0
        <= 0) {
1962
0
        return 0;
1963
0
    }
1964
1965
0
    if ((mdctx = EVP_MD_CTX_new()) == NULL)
1966
0
        return 0;
1967
1968
    /*
1969
     * Data derived from (d, z) defaults secret, and to avoid side-channel
1970
     * leaks should not influence control flow.
1971
     */
1972
0
    CONSTTIME_SECRET(seed, ML_KEM_SEED_BYTES);
1973
1974
0
    if (add_storage(OPENSSL_malloc(vinfo->puballoc),
1975
0
            OPENSSL_secure_malloc(vinfo->prvalloc), 1, key))
1976
0
        ret = genkey(seed, mdctx, pubenc, key);
1977
0
    OPENSSL_cleanse(seed, sizeof(seed));
1978
1979
    /* Declassify secret inputs and derived outputs before returning control */
1980
0
    CONSTTIME_DECLASSIFY(seed, ML_KEM_SEED_BYTES);
1981
1982
0
    EVP_MD_CTX_free(mdctx);
1983
0
    if (!ret) {
1984
0
        ossl_ml_kem_key_reset(key);
1985
0
        return 0;
1986
0
    }
1987
1988
    /* The public components are already declassified */
1989
0
    CONSTTIME_DECLASSIFY(key->s, vinfo->rank * sizeof(scalar));
1990
0
    CONSTTIME_DECLASSIFY(key->z, 2 * ML_KEM_RANDOM_BYTES);
1991
0
    return 1;
1992
0
}
1993
1994
/*
1995
 * FIPS 203, Section 6.2, Algorithm 17: ML-KEM.Encaps_internal
1996
 * This is the deterministic version with randomness supplied externally.
1997
 */
1998
int ossl_ml_kem_encap_seed(uint8_t *ctext, size_t clen,
1999
    uint8_t *shared_secret, size_t slen,
2000
    const uint8_t *entropy, size_t elen,
2001
    const ML_KEM_KEY *key)
2002
0
{
2003
0
    const ML_KEM_VINFO *vinfo;
2004
0
    EVP_MD_CTX *mdctx;
2005
0
    int ret = 0;
2006
2007
0
    if (key == NULL || !ossl_ml_kem_have_pubkey(key))
2008
0
        return 0;
2009
0
    vinfo = key->vinfo;
2010
2011
0
    if (ctext == NULL || clen != vinfo->ctext_bytes
2012
0
        || shared_secret == NULL || slen != ML_KEM_SHARED_SECRET_BYTES
2013
0
        || entropy == NULL || elen != ML_KEM_RANDOM_BYTES
2014
0
        || (mdctx = EVP_MD_CTX_new()) == NULL)
2015
0
        return 0;
2016
    /*
2017
     * Data derived from the encap entropy defaults secret, and to avoid
2018
     * side-channel leaks should not influence control flow.
2019
     */
2020
0
    CONSTTIME_SECRET(entropy, elen);
2021
2022
    /*-
2023
     * This avoids the need to handle allocation failures for two (max 2KB
2024
     * each) vectors, that are never retained on return from this function.
2025
     * We stack-allocate these.
2026
     */
2027
0
#define case_encap_seed(bits)                                        \
2028
0
    {                                                                \
2029
0
        scalar tmp[2 * ML_KEM_##bits##_RANK];                        \
2030
0
                                                                     \
2031
0
        ret = encap(ctext, shared_secret, entropy, tmp, mdctx, key); \
2032
0
        OPENSSL_cleanse((void *)tmp, sizeof(tmp));                   \
2033
0
    }
2034
0
    switch (vinfo->evp_type) {
2035
0
    case EVP_PKEY_ML_KEM_512:
2036
0
        case_encap_seed(512);
2037
0
        break;
2038
0
    case EVP_PKEY_ML_KEM_768:
2039
0
        case_encap_seed(768);
2040
0
        break;
2041
0
    case EVP_PKEY_ML_KEM_1024:
2042
0
        case_encap_seed(1024);
2043
0
        break;
2044
0
    }
2045
0
#undef case_encap_seed
2046
2047
    /* Declassify secret inputs and derived outputs before returning control */
2048
0
    CONSTTIME_DECLASSIFY(entropy, elen);
2049
0
    CONSTTIME_DECLASSIFY(ctext, clen);
2050
0
    CONSTTIME_DECLASSIFY(shared_secret, slen);
2051
2052
0
    EVP_MD_CTX_free(mdctx);
2053
0
    return ret;
2054
0
}
2055
2056
int ossl_ml_kem_encap_rand(uint8_t *ctext, size_t clen,
2057
    uint8_t *shared_secret, size_t slen,
2058
    const ML_KEM_KEY *key)
2059
0
{
2060
0
    uint8_t r[ML_KEM_RANDOM_BYTES];
2061
2062
0
    if (key == NULL)
2063
0
        return 0;
2064
2065
0
    if (RAND_bytes_ex(key->libctx, r, ML_KEM_RANDOM_BYTES,
2066
0
            key->vinfo->secbits)
2067
0
        < 1)
2068
0
        return 0;
2069
2070
0
    return ossl_ml_kem_encap_seed(ctext, clen, shared_secret, slen,
2071
0
        r, sizeof(r), key);
2072
0
}
2073
2074
int ossl_ml_kem_decap(uint8_t *shared_secret, size_t slen,
2075
    const uint8_t *ctext, size_t clen,
2076
    const ML_KEM_KEY *key)
2077
0
{
2078
0
    const ML_KEM_VINFO *vinfo;
2079
0
    EVP_MD_CTX *mdctx;
2080
0
    int ret = 0;
2081
#if defined(OPENSSL_CONSTANT_TIME_VALIDATION)
2082
    int classify_bytes = 2 * sizeof(scalar) + ML_KEM_RANDOM_BYTES;
2083
#endif
2084
2085
    /* Need a private key here */
2086
0
    if (!ossl_ml_kem_have_prvkey(key))
2087
0
        return 0;
2088
0
    vinfo = key->vinfo;
2089
2090
0
    if (shared_secret == NULL || slen != ML_KEM_SHARED_SECRET_BYTES
2091
0
        || ctext == NULL || clen != vinfo->ctext_bytes
2092
0
        || (mdctx = EVP_MD_CTX_new()) == NULL) {
2093
0
        (void)RAND_bytes_ex(key->libctx, shared_secret,
2094
0
            ML_KEM_SHARED_SECRET_BYTES, vinfo->secbits);
2095
0
        return 0;
2096
0
    }
2097
    /*
2098
     * Data derived from |s| and |z| defaults secret, and to avoid side-channel
2099
     * leaks should not influence control flow.
2100
     */
2101
0
    CONSTTIME_SECRET(key->s, classify_bytes);
2102
2103
    /*-
2104
     * This avoids the need to handle allocation failures for two (max 2KB
2105
     * each) vectors and an encoded ciphertext (max 1568 bytes), that are never
2106
     * retained on return from this function.
2107
     * We stack-allocate these.
2108
     */
2109
0
#define case_decap(bits)                                          \
2110
0
    {                                                             \
2111
0
        uint8_t cbuf[CTEXT_BYTES(bits)];                          \
2112
0
        scalar tmp[2 * ML_KEM_##bits##_RANK];                     \
2113
0
                                                                  \
2114
0
        ret = decap(shared_secret, ctext, cbuf, tmp, mdctx, key); \
2115
0
        OPENSSL_cleanse((void *)tmp, sizeof(tmp));                \
2116
0
    }
2117
0
    switch (vinfo->evp_type) {
2118
0
    case EVP_PKEY_ML_KEM_512:
2119
0
        case_decap(512);
2120
0
        break;
2121
0
    case EVP_PKEY_ML_KEM_768:
2122
0
        case_decap(768);
2123
0
        break;
2124
0
    case EVP_PKEY_ML_KEM_1024:
2125
0
        case_decap(1024);
2126
0
        break;
2127
0
    }
2128
0
#undef case_decap
2129
2130
    /* Declassify secret inputs and derived outputs before returning control */
2131
0
    CONSTTIME_DECLASSIFY(key->s, classify_bytes);
2132
0
    CONSTTIME_DECLASSIFY(shared_secret, slen);
2133
0
    EVP_MD_CTX_free(mdctx);
2134
2135
0
    return ret;
2136
0
}
2137
2138
int ossl_ml_kem_pubkey_cmp(const ML_KEM_KEY *key1, const ML_KEM_KEY *key2)
2139
0
{
2140
    /*
2141
     * This handles any unexpected differences in the ML-KEM variant rank,
2142
     * giving different key component structures, barring SHA3-256 hash
2143
     * collisions, the keys are the same size.
2144
     */
2145
0
    if (ossl_ml_kem_have_pubkey(key1) && ossl_ml_kem_have_pubkey(key2))
2146
0
        return memcmp(key1->pkhash, key2->pkhash, ML_KEM_PKHASH_BYTES) == 0;
2147
2148
    /*
2149
     * No match if just one of the public keys is not available, otherwise both
2150
     * are unavailable, and for now such keys are considered equal.
2151
     */
2152
0
    return (!(ossl_ml_kem_have_pubkey(key1) ^ ossl_ml_kem_have_pubkey(key2)));
2153
0
}