Coverage Report

Created: 2026-04-28 06:29

next uncovered line (L), next uncovered region (R), next uncovered branch (B)
/src/openssl/crypto/ml_kem/ml_kem.c
Line
Count
Source
1
/*
2
 * Copyright 2024-2026 The OpenSSL Project Authors. All Rights Reserved.
3
 *
4
 * Licensed under the Apache License 2.0 (the "License").  You may not use
5
 * this file except in compliance with the License.  You can obtain a copy
6
 * in the file LICENSE in the source distribution or at
7
 * https://www.openssl.org/source/license.html
8
 */
9
10
#include <openssl/byteorder.h>
11
#include <openssl/rand.h>
12
#include <openssl/proverr.h>
13
#include "crypto/ml_kem.h"
14
#include "internal/common.h"
15
#include "internal/constant_time.h"
16
#include "internal/sha3.h"
17
18
#if ML_KEM_SEED_BYTES != ML_KEM_SHARED_SECRET_BYTES + ML_KEM_RANDOM_BYTES
19
#error "ML-KEM keygen seed length != shared secret + random bytes length"
20
#endif
21
#if ML_KEM_SHARED_SECRET_BYTES != ML_KEM_RANDOM_BYTES
22
#error "Invalid unequal lengths of ML-KEM shared secret and random inputs"
23
#endif
24
25
#if UINT_MAX < UINT32_MAX
26
#error "Unsupported compiler: sizeof(unsigned int) < sizeof(uint32_t)"
27
#endif
28
29
/* Handy function-like bit-extraction macros */
30
0
#define bit0(b) ((b) & 1)
31
0
#define bitn(n, b) (((b) >> n) & 1)
32
33
/*
34
 * 12 bits are sufficient to losslessly represent values in [0, q-1].
35
 * INVERSE_DEGREE is (n/2)^-1 mod q; used in inverse NTT.
36
 */
37
0
#define DEGREE ML_KEM_DEGREE
38
#define INVERSE_DEGREE (ML_KEM_PRIME - 2 * 13)
39
#define LOG2PRIME 12
40
#define BARRETT_SHIFT (2 * LOG2PRIME)
41
42
#ifdef SHA3_BLOCKSIZE
43
#define SHAKE128_BLOCKSIZE SHA3_BLOCKSIZE(128)
44
#endif
45
46
/*
47
 * The scalar rejection-sampling buffer size needs to be a multiple of 12, but
48
 * is otherwise arbitrary, the preferred block size matches the internal buffer
49
 * size of SHAKE128, avoiding internal buffering and copying in SHAKE128. That
50
 * block size of (1600 - 256)/8 bytes, or 168, just happens to divide by 12!
51
 *
52
 * If the blocksize is unknown, or is not divisible by 12, 168 is used as a
53
 * fallback.
54
 */
55
#if defined(SHAKE128_BLOCKSIZE) && (SHAKE128_BLOCKSIZE) % 12 == 0
56
#define SCALAR_SAMPLING_BUFSIZE (SHAKE128_BLOCKSIZE)
57
#else
58
#define SCALAR_SAMPLING_BUFSIZE 168
59
#endif
60
61
/*
62
 * Structure of keys
63
 */
64
typedef struct ossl_ml_kem_scalar_st {
65
    /* On every function entry and exit, 0 <= c[i] < ML_KEM_PRIME. */
66
    uint16_t c[ML_KEM_DEGREE];
67
} scalar;
68
69
/* Key material allocation layout */
70
#define DECLARE_ML_KEM_PUBKEYDATA(name, rank)                  \
71
    struct name##_alloc {                                      \
72
        /* Public vector |t| */                                \
73
        scalar tbuf[(rank)];                                   \
74
        /* Pre-computed matrix |m| (FIPS 203 |A| transpose) */ \
75
        scalar mbuf[(rank) * (rank)];                          \
76
    }
77
78
#define DECLARE_ML_KEM_PRVKEYDATA(name, rank)  \
79
    struct name##_alloc {                      \
80
        scalar sbuf[rank];                     \
81
        uint8_t zbuf[2 * ML_KEM_RANDOM_BYTES]; \
82
    }
83
84
/* Declare variant-specific public and private storage */
85
#define DECLARE_ML_KEM_VARIANT_KEYDATA(bits)                        \
86
    DECLARE_ML_KEM_PUBKEYDATA(pubkey_##bits, ML_KEM_##bits##_RANK); \
87
    DECLARE_ML_KEM_PRVKEYDATA(prvkey_##bits, ML_KEM_##bits##_RANK)
88
89
DECLARE_ML_KEM_VARIANT_KEYDATA(512);
90
DECLARE_ML_KEM_VARIANT_KEYDATA(768);
91
DECLARE_ML_KEM_VARIANT_KEYDATA(1024);
92
#undef DECLARE_ML_KEM_VARIANT_KEYDATA
93
#undef DECLARE_ML_KEM_PUBKEYDATA
94
#undef DECLARE_ML_KEM_PRVKEYDATA
95
96
typedef __owur int (*CBD_FUNC)(scalar *out, uint8_t in[ML_KEM_RANDOM_BYTES + 1],
97
    EVP_MD_CTX *mdctx, const ML_KEM_KEY *key);
98
static void scalar_encode(uint8_t *out, const scalar *s, int bits);
99
100
/*
101
 * The wire-form of a losslessly encoded vector uses 12-bits per element.
102
 *
103
 * The wire-form public key consists of the lossless encoding of the public
104
 * vector |t|, followed by the public seed |rho|.
105
 *
106
 * Our serialised private key concatenates serialisations of the private vector
107
 * |s|, the public key, the public key hash, and the failure secret |z|.
108
 */
109
#define VECTOR_BYTES(b) ((3 * DEGREE / 2) * ML_KEM_##b##_RANK)
110
#define PUBKEY_BYTES(b) (VECTOR_BYTES(b) + ML_KEM_RANDOM_BYTES)
111
#define PRVKEY_BYTES(b) (2 * PUBKEY_BYTES(b) + ML_KEM_PKHASH_BYTES)
112
113
/*
114
 * Encapsulation produces a vector "u" and a scalar "v", whose coordinates
115
 * (numbers modulo the ML-KEM prime "q") are lossily encoded using as "du" and
116
 * "dv" bits, respectively.  This encoding is the ciphertext input for
117
 * decapsulation.
118
 */
119
#define U_VECTOR_BYTES(b) ((DEGREE / 8) * ML_KEM_##b##_DU * ML_KEM_##b##_RANK)
120
#define V_SCALAR_BYTES(b) ((DEGREE / 8) * ML_KEM_##b##_DV)
121
#define CTEXT_BYTES(b) (U_VECTOR_BYTES(b) + V_SCALAR_BYTES(b))
122
123
/*
124
 * Indices of slots in the vinfo tables below
125
 */
126
0
#define ML_KEM_512_VINFO 0
127
0
#define ML_KEM_768_VINFO 1
128
0
#define ML_KEM_1024_VINFO 2
129
130
/*
131
 * Per-variant fixed parameters
132
 */
133
static const ML_KEM_VINFO vinfo_map[3] = {
134
    { "ML-KEM-512",
135
        PRVKEY_BYTES(512),
136
        sizeof(struct prvkey_512_alloc),
137
        PUBKEY_BYTES(512),
138
        sizeof(struct pubkey_512_alloc),
139
        CTEXT_BYTES(512),
140
        VECTOR_BYTES(512),
141
        U_VECTOR_BYTES(512),
142
        EVP_PKEY_ML_KEM_512,
143
        ML_KEM_512_BITS,
144
        ML_KEM_512_RANK,
145
        ML_KEM_512_DU,
146
        ML_KEM_512_DV,
147
        ML_KEM_512_SECBITS,
148
        ML_KEM_512_SECURITY_CATEGORY },
149
    { "ML-KEM-768",
150
        PRVKEY_BYTES(768),
151
        sizeof(struct prvkey_768_alloc),
152
        PUBKEY_BYTES(768),
153
        sizeof(struct pubkey_768_alloc),
154
        CTEXT_BYTES(768),
155
        VECTOR_BYTES(768),
156
        U_VECTOR_BYTES(768),
157
        EVP_PKEY_ML_KEM_768,
158
        ML_KEM_768_BITS,
159
        ML_KEM_768_RANK,
160
        ML_KEM_768_DU,
161
        ML_KEM_768_DV,
162
        ML_KEM_768_SECBITS,
163
        ML_KEM_768_SECURITY_CATEGORY },
164
    { "ML-KEM-1024",
165
        PRVKEY_BYTES(1024),
166
        sizeof(struct prvkey_1024_alloc),
167
        PUBKEY_BYTES(1024),
168
        sizeof(struct pubkey_1024_alloc),
169
        CTEXT_BYTES(1024),
170
        VECTOR_BYTES(1024),
171
        U_VECTOR_BYTES(1024),
172
        EVP_PKEY_ML_KEM_1024,
173
        ML_KEM_1024_BITS,
174
        ML_KEM_1024_RANK,
175
        ML_KEM_1024_DU,
176
        ML_KEM_1024_DV,
177
        ML_KEM_1024_SECBITS,
178
        ML_KEM_1024_SECURITY_CATEGORY }
179
};
180
181
/*
182
 * Remainders modulo `kPrime`, for sufficiently small inputs, are computed in
183
 * constant time via Barrett reduction, and a final call to reduce_once(),
184
 * which reduces inputs that are at most 2*kPrime and is also constant-time.
185
 */
186
static const int kPrime = ML_KEM_PRIME;
187
static const unsigned int kBarrettShift = BARRETT_SHIFT;
188
static const size_t kBarrettMultiplier = (1 << BARRETT_SHIFT) / ML_KEM_PRIME;
189
static const uint16_t kHalfPrime = (ML_KEM_PRIME - 1) / 2;
190
static const uint16_t kInverseDegree = INVERSE_DEGREE;
191
192
/*
193
 * Python helper:
194
 *
195
 * p = 3329
196
 * def bitreverse(i):
197
 *     ret = 0
198
 *     for n in range(7):
199
 *         bit = i & 1
200
 *         ret <<= 1
201
 *         ret |= bit
202
 *         i >>= 1
203
 *     return ret
204
 */
205
206
/*-
207
 * First precomputed array from Appendix A of FIPS 203, or else Python:
208
 * kNTTRoots = [pow(17, bitreverse(i), p) for i in range(128)]
209
 */
210
static const uint16_t kNTTRoots[128] = {
211
    0x001, 0x6c1, 0xa14, 0xcd9, 0xa52, 0x276, 0x769, 0x350,
212
    0x426, 0x77f, 0x0c1, 0x31d, 0xae2, 0xcbc, 0x239, 0x6d2,
213
    0x128, 0x98f, 0x53b, 0x5c4, 0xbe6, 0x038, 0x8c0, 0x535,
214
    0x592, 0x82e, 0x217, 0xb42, 0x959, 0xb3f, 0x7b6, 0x335,
215
    0x121, 0x14b, 0xcb5, 0x6dc, 0x4ad, 0x900, 0x8e5, 0x807,
216
    0x28a, 0x7b9, 0x9d1, 0x278, 0xb31, 0x021, 0x528, 0x77b,
217
    0x90f, 0x59b, 0x327, 0x1c4, 0x59e, 0xb34, 0x5fe, 0x962,
218
    0xa57, 0xa39, 0x5c9, 0x288, 0x9aa, 0xc26, 0x4cb, 0x38e,
219
    0x011, 0xac9, 0x247, 0xa59, 0x665, 0x2d3, 0x8f0, 0x44c,
220
    0x581, 0xa66, 0xcd1, 0x0e9, 0x2f4, 0x86c, 0xbc7, 0xbea,
221
    0x6a7, 0x673, 0xae5, 0x6fd, 0x737, 0x3b8, 0x5b5, 0xa7f,
222
    0x3ab, 0x904, 0x985, 0x954, 0x2dd, 0x921, 0x10c, 0x281,
223
    0x630, 0x8fa, 0x7f5, 0xc94, 0x177, 0x9f5, 0x82a, 0x66d,
224
    0x427, 0x13f, 0xad5, 0x2f5, 0x833, 0x231, 0x9a2, 0xa22,
225
    0xaf4, 0x444, 0x193, 0x402, 0x477, 0x866, 0xad7, 0x376,
226
    0x6ba, 0x4bc, 0x752, 0x405, 0x83e, 0xb77, 0x375, 0x86a
227
};
228
229
/*
230
 * InverseNTTRoots = [pow(17, -bitreverse(i), p) for i in range(128)]
231
 * Listed in order of use in the inverse NTT loop (index 0 is skipped):
232
 *
233
 *  0, 64, 65, ..., 127, 32, 33, ..., 63, 16, 17, ..., 31, 8, 9, ...
234
 */
235
static const uint16_t kInverseNTTRoots[128] = {
236
    0x001, 0x497, 0x98c, 0x18a, 0x4c3, 0x8fc, 0x5af, 0x845,
237
    0x647, 0x98b, 0x22a, 0x49b, 0x88a, 0x8ff, 0xb6e, 0x8bd,
238
    0x20d, 0x2df, 0x35f, 0xad0, 0x4ce, 0xa0c, 0x22c, 0xbc2,
239
    0x8da, 0x694, 0x4d7, 0x30c, 0xb8a, 0x06d, 0x50c, 0x407,
240
    0x6d1, 0xa80, 0xbf5, 0x3e0, 0xa24, 0x3ad, 0x37c, 0x3fd,
241
    0x956, 0x282, 0x74c, 0x949, 0x5ca, 0x604, 0x21c, 0x68e,
242
    0x65a, 0x117, 0x13a, 0x495, 0xa0d, 0xc18, 0x030, 0x29b,
243
    0x780, 0x8b5, 0x411, 0xa2e, 0x69c, 0x2a8, 0xaba, 0x238,
244
    0xcf0, 0x973, 0x836, 0x0db, 0x357, 0xa79, 0x738, 0x2c8,
245
    0x2aa, 0x39f, 0x703, 0x1cd, 0x763, 0xb3d, 0x9da, 0x766,
246
    0x3f2, 0x586, 0x7d9, 0xce0, 0x1d0, 0xa89, 0x330, 0x548,
247
    0xa77, 0x4fa, 0x41c, 0x401, 0x854, 0x625, 0x04c, 0xbb6,
248
    0xbe0, 0x9cc, 0x54b, 0x1c2, 0x3a8, 0x1bf, 0xaea, 0x4d3,
249
    0x76f, 0x7cc, 0x441, 0xcc9, 0x11b, 0x73d, 0x7c6, 0x372,
250
    0xbd9, 0x62f, 0xac8, 0x045, 0x21f, 0x9e4, 0xc40, 0x582,
251
    0x8db, 0x9b1, 0x598, 0xa8b, 0x2af, 0x028, 0x2ed, 0x640
252
};
253
254
/*
255
 * Second precomputed array from Appendix A of FIPS 203 (normalised positive),
256
 * or else Python:
257
 * ModRoots = [pow(17, 2*bitreverse(i) + 1, p) for i in range(128)]
258
 */
259
static const uint16_t kModRoots[128] = {
260
    0x011, 0xcf0, 0xac9, 0x238, 0x247, 0xaba, 0xa59, 0x2a8,
261
    0x665, 0x69c, 0x2d3, 0xa2e, 0x8f0, 0x411, 0x44c, 0x8b5,
262
    0x581, 0x780, 0xa66, 0x29b, 0xcd1, 0x030, 0x0e9, 0xc18,
263
    0x2f4, 0xa0d, 0x86c, 0x495, 0xbc7, 0x13a, 0xbea, 0x117,
264
    0x6a7, 0x65a, 0x673, 0x68e, 0xae5, 0x21c, 0x6fd, 0x604,
265
    0x737, 0x5ca, 0x3b8, 0x949, 0x5b5, 0x74c, 0xa7f, 0x282,
266
    0x3ab, 0x956, 0x904, 0x3fd, 0x985, 0x37c, 0x954, 0x3ad,
267
    0x2dd, 0xa24, 0x921, 0x3e0, 0x10c, 0xbf5, 0x281, 0xa80,
268
    0x630, 0x6d1, 0x8fa, 0x407, 0x7f5, 0x50c, 0xc94, 0x06d,
269
    0x177, 0xb8a, 0x9f5, 0x30c, 0x82a, 0x4d7, 0x66d, 0x694,
270
    0x427, 0x8da, 0x13f, 0xbc2, 0xad5, 0x22c, 0x2f5, 0xa0c,
271
    0x833, 0x4ce, 0x231, 0xad0, 0x9a2, 0x35f, 0xa22, 0x2df,
272
    0xaf4, 0x20d, 0x444, 0x8bd, 0x193, 0xb6e, 0x402, 0x8ff,
273
    0x477, 0x88a, 0x866, 0x49b, 0xad7, 0x22a, 0x376, 0x98b,
274
    0x6ba, 0x647, 0x4bc, 0x845, 0x752, 0x5af, 0x405, 0x8fc,
275
    0x83e, 0x4c3, 0xb77, 0x18a, 0x375, 0x98c, 0x86a, 0x497
276
};
277
278
/*
279
 * single_keccak hashes |inlen| bytes from |in| and writes |outlen| bytes of
280
 * output to |out|. If the |md| specifies a fixed-output function, like
281
 * SHA3-256, then |outlen| must be the correct length for that function.
282
 */
283
static __owur int single_keccak(uint8_t *out, size_t outlen, const uint8_t *in, size_t inlen,
284
    EVP_MD_CTX *mdctx)
285
0
{
286
0
    unsigned int sz = (unsigned int)outlen;
287
288
0
    if (!EVP_DigestUpdate(mdctx, in, inlen))
289
0
        return 0;
290
0
    if (EVP_MD_xof(EVP_MD_CTX_get0_md(mdctx)))
291
0
        return EVP_DigestFinalXOF(mdctx, out, outlen);
292
0
    return EVP_DigestFinal_ex(mdctx, out, &sz)
293
0
        && ossl_assert((size_t)sz == outlen);
294
0
}
295
296
/*
297
 * FIPS 203, Section 4.1, equation (4.3): PRF. Takes 32+1 input bytes, and uses
298
 * SHAKE256 to produce the input to SamplePolyCBD_eta: FIPS 203, algorithm 8.
299
 */
300
static __owur int prf(uint8_t *out, size_t len, const uint8_t in[ML_KEM_RANDOM_BYTES + 1],
301
    EVP_MD_CTX *mdctx, const ML_KEM_KEY *key)
302
0
{
303
0
    return EVP_DigestInit_ex(mdctx, key->shake256_md, NULL)
304
0
        && single_keccak(out, len, in, ML_KEM_RANDOM_BYTES + 1, mdctx);
305
0
}
306
307
/*
308
 * FIPS 203, Section 4.1, equation (4.4): H.  SHA3-256 hash of a variable
309
 * length input, producing 32 bytes of output.
310
 */
311
static __owur int hash_h(uint8_t out[ML_KEM_PKHASH_BYTES], const uint8_t *in, size_t len,
312
    EVP_MD_CTX *mdctx, const ML_KEM_KEY *key)
313
0
{
314
0
    return EVP_DigestInit_ex(mdctx, key->sha3_256_md, NULL)
315
0
        && single_keccak(out, ML_KEM_PKHASH_BYTES, in, len, mdctx);
316
0
}
317
318
/* Incremental hash_h of expanded public key */
319
static int
320
hash_h_pubkey(uint8_t pkhash[ML_KEM_PKHASH_BYTES],
321
    EVP_MD_CTX *mdctx, ML_KEM_KEY *key)
322
0
{
323
0
    const ML_KEM_VINFO *vinfo = key->vinfo;
324
0
    const scalar *t = key->t, *end = t + vinfo->rank;
325
0
    unsigned int sz;
326
327
0
    if (!EVP_DigestInit_ex(mdctx, key->sha3_256_md, NULL))
328
0
        return 0;
329
330
0
    do {
331
0
        uint8_t buf[3 * DEGREE / 2];
332
333
0
        scalar_encode(buf, t++, 12);
334
0
        if (!EVP_DigestUpdate(mdctx, buf, sizeof(buf)))
335
0
            return 0;
336
0
    } while (t < end);
337
338
0
    if (!EVP_DigestUpdate(mdctx, key->rho, ML_KEM_RANDOM_BYTES))
339
0
        return 0;
340
0
    return EVP_DigestFinal_ex(mdctx, pkhash, &sz)
341
0
        && ossl_assert(sz == ML_KEM_PKHASH_BYTES);
342
0
}
343
344
/*
345
 * FIPS 203, Section 4.1, equation (4.5): G.  SHA3-512 hash of a variable
346
 * length input, producing 64 bytes of output, in particular the seeds
347
 * (d,z) for key generation.
348
 */
349
static __owur int hash_g(uint8_t out[ML_KEM_SEED_BYTES], const uint8_t *in, size_t len,
350
    EVP_MD_CTX *mdctx, const ML_KEM_KEY *key)
351
0
{
352
0
    return EVP_DigestInit_ex(mdctx, key->sha3_512_md, NULL)
353
0
        && single_keccak(out, ML_KEM_SEED_BYTES, in, len, mdctx);
354
0
}
355
356
/*
357
 * FIPS 203, Section 4.1, equation (4.4): J. SHAKE256 taking a variable length
358
 * input to compute a 32-byte implicit rejection shared secret, of the same
359
 * length as the expected shared secret.  (Computed even on success to avoid
360
 * side-channel leaks).
361
 */
362
static __owur int kdf(uint8_t out[ML_KEM_SHARED_SECRET_BYTES],
363
    const uint8_t z[ML_KEM_RANDOM_BYTES],
364
    const uint8_t *ctext, size_t len,
365
    EVP_MD_CTX *mdctx, const ML_KEM_KEY *key)
366
0
{
367
0
    return EVP_DigestInit_ex(mdctx, key->shake256_md, NULL)
368
0
        && EVP_DigestUpdate(mdctx, z, ML_KEM_RANDOM_BYTES)
369
0
        && EVP_DigestUpdate(mdctx, ctext, len)
370
0
        && EVP_DigestFinalXOF(mdctx, out, ML_KEM_SHARED_SECRET_BYTES);
371
0
}
372
373
/*
374
 * FIPS 203, Section 4.2.2, Algorithm 7: "SampleNTT" (steps 3-17, steps 1, 2
375
 * are performed by the caller). Rejection-samples a Keccak stream to get
376
 * uniformly distributed elements in the range [0,q). This is used for matrix
377
 * expansion and only operates on public inputs.
378
 */
379
static __owur int sample_scalar(scalar *out, EVP_MD_CTX *mdctx)
380
0
{
381
0
    uint16_t *curr = out->c, *endout = curr + DEGREE;
382
0
    uint8_t buf[SCALAR_SAMPLING_BUFSIZE], *in;
383
0
    uint8_t *endin = buf + sizeof(buf);
384
0
    uint16_t d;
385
0
    uint8_t b1, b2, b3;
386
387
0
    do {
388
0
        if (!EVP_DigestSqueeze(mdctx, in = buf, sizeof(buf)))
389
0
            return 0;
390
0
        do {
391
0
            b1 = *in++;
392
0
            b2 = *in++;
393
0
            b3 = *in++;
394
395
0
            if (curr >= endout)
396
0
                break;
397
0
            if ((d = ((b2 & 0x0f) << 8) + b1) < kPrime)
398
0
                *curr++ = d;
399
0
            if (curr >= endout)
400
0
                break;
401
0
            if ((d = (b3 << 4) + (b2 >> 4)) < kPrime)
402
0
                *curr++ = d;
403
0
        } while (in < endin);
404
0
    } while (curr < endout);
405
0
    return 1;
406
0
}
407
408
static CRYPTO_ONCE ml_kem_ntt_once = CRYPTO_ONCE_STATIC_INIT;
409
410
#if defined(_ARCH_PPC64)
411
#include "crypto/ppc_arch.h"
412
#endif
413
414
#if defined(MLKEM_NTT_PPC_ASM) && defined(_ARCH_PPC64)
415
/*
416
 * PPC64LE Platform supports.
417
 */
418
typedef void (*ml_kem_scalar_ntt_fn)(scalar *p);
419
typedef void (*ml_kem_scalar_inverse_ntt_fn)(scalar *p);
420
421
static void scalar_ntt_generic(scalar *p);
422
static void scalar_inverse_ntt_generic(scalar *p);
423
424
static ml_kem_scalar_ntt_fn scalar_ntt = scalar_ntt_generic;
425
static ml_kem_scalar_inverse_ntt_fn scalar_inverse_ntt = scalar_inverse_ntt_generic;
426
427
void mlkem_ntt_ppc(uint16_t *c);
428
void mlkem_inverse_ntt_ppc(uint16_t *c);
429
430
static void scalar_ntt_ppc(scalar *s)
431
{
432
    mlkem_ntt_ppc(s->c);
433
}
434
435
static void scalar_inverse_ntt_ppc(scalar *s)
436
{
437
    mlkem_inverse_ntt_ppc(s->c);
438
}
439
#else
440
#define scalar_ntt_generic scalar_ntt
441
#define scalar_inverse_ntt_generic scalar_inverse_ntt
442
#endif
443
444
/*
445
 * Initialize NTT function pointers to PPC64le implementations if available.
446
 * Scalar implementations are used by default.
447
 */
448
static void ml_kem_ntt_init(void)
449
0
{
450
#if defined(MLKEM_NTT_PPC_ASM) && defined(_ARCH_PPC64)
451
#if defined(__LITTLE_ENDIAN__) || (__BYTE_ORDER__ == __ORDER_LITTLE_ENDIAN__)
452
    if (OPENSSL_ppccap_P & PPC_CRYPTO207) {
453
        scalar_ntt = scalar_ntt_ppc;
454
        scalar_inverse_ntt = scalar_inverse_ntt_ppc;
455
    }
456
#endif
457
#endif
458
0
}
459
460
/*-
461
 * reduce_once reduces 0 <= x < 2*kPrime, mod kPrime.
462
 *
463
 * Subtract |q| if the input is larger, without exposing a side-channel,
464
 * avoiding the "clangover" attack.  See |constish_time_true| for a
465
 * discussion on why the value barrier is by default omitted.
466
 */
467
static __owur uint16_t reduce_once(uint16_t x)
468
0
{
469
0
    const uint16_t subtracted = x - kPrime;
470
0
    uint16_t mask = constish_time_true(subtracted >> 15);
471
472
0
    return (mask & x) | (~mask & subtracted);
473
0
}
474
475
/*
476
 * Constant-time reduce x mod kPrime using Barrett reduction. x must be less
477
 * than kPrime + 2 * kPrime^2.  This is sufficient to reduce a product of
478
 * two already reduced u_int16 values, in fact it is sufficient for each
479
 * to be less than 2^12, because (kPrime * (2 * kPrime + 1)) > 2^24.
480
 */
481
static __owur uint16_t reduce(uint32_t x)
482
0
{
483
0
    uint64_t product = (uint64_t)x * kBarrettMultiplier;
484
0
    uint32_t quotient = (uint32_t)(product >> kBarrettShift);
485
0
    uint32_t remainder = x - quotient * kPrime;
486
487
0
    return reduce_once(remainder);
488
0
}
489
490
/* Multiply a scalar by a constant. */
491
static void scalar_mult_const(scalar *s, uint16_t a)
492
0
{
493
0
    uint16_t *curr = s->c, *end = curr + DEGREE, tmp;
494
495
0
    do {
496
0
        tmp = reduce(*curr * a);
497
0
        *curr++ = tmp;
498
0
    } while (curr < end);
499
0
}
500
501
/*-
502
 * FIPS 203, Section 4.3, Algorithm 9: "NTT".
503
 * In-place number theoretic transform of a given scalar.  Note that ML-KEM's
504
 * kPrime 3329 does not have a 512th root of unity, so this transform leaves
505
 * off the last iteration of the usual FFT code, with the 128 relevant roots of
506
 * unity being stored in NTTRoots.  This means the output should be seen as 128
507
 * elements in GF(3329^2), with the coefficients of the elements being
508
 * consecutive entries in |s->c|.
509
 */
510
static void scalar_ntt_generic(scalar *s)
511
0
{
512
0
    const uint16_t *roots = kNTTRoots;
513
0
    uint16_t *end = s->c + DEGREE;
514
0
    int offset = DEGREE / 2;
515
516
0
    do {
517
0
        uint16_t *curr = s->c, *peer;
518
519
0
        do {
520
0
            uint16_t *pause = curr + offset, even, odd;
521
0
            uint32_t zeta = *++roots;
522
523
0
            peer = pause;
524
0
            do {
525
0
                even = *curr;
526
0
                odd = reduce(*peer * zeta);
527
0
                *peer++ = reduce_once(even - odd + kPrime);
528
0
                *curr++ = reduce_once(odd + even);
529
0
            } while (curr < pause);
530
0
        } while ((curr = peer) < end);
531
0
    } while ((offset >>= 1) >= 2);
532
0
}
533
534
/*-
535
 * FIPS 203, Section 4.3, Algorithm 10: "NTT^(-1)".
536
 * In-place inverse number theoretic transform of a given scalar, with pairs of
537
 * entries of s->v being interpreted as elements of GF(3329^2). Just as with
538
 * the number theoretic transform, this leaves off the first step of the normal
539
 * iFFT to account for the fact that 3329 does not have a 512th root of unity,
540
 * using the precomputed 128 roots of unity stored in InverseNTTRoots.
541
 */
542
static void scalar_inverse_ntt_generic(scalar *s)
543
0
{
544
0
    const uint16_t *roots = kInverseNTTRoots;
545
0
    uint16_t *end = s->c + DEGREE;
546
0
    int offset = 2;
547
548
0
    do {
549
0
        uint16_t *curr = s->c, *peer;
550
551
0
        do {
552
0
            uint16_t *pause = curr + offset, even, odd;
553
0
            uint32_t zeta = *++roots;
554
555
0
            peer = pause;
556
0
            do {
557
0
                even = *curr;
558
0
                odd = *peer;
559
0
                *peer++ = reduce(zeta * (even - odd + kPrime));
560
0
                *curr++ = reduce_once(odd + even);
561
0
            } while (curr < pause);
562
0
        } while ((curr = peer) < end);
563
0
    } while ((offset <<= 1) < DEGREE);
564
0
    scalar_mult_const(s, kInverseDegree);
565
0
}
566
567
/* Addition updating the LHS scalar in-place. */
568
static void scalar_add(scalar *lhs, const scalar *rhs)
569
0
{
570
0
    int i;
571
572
0
    for (i = 0; i < DEGREE; i++)
573
0
        lhs->c[i] = reduce_once(lhs->c[i] + rhs->c[i]);
574
0
}
575
576
/* Subtraction updating the LHS scalar in-place. */
577
static void scalar_sub(scalar *lhs, const scalar *rhs)
578
0
{
579
0
    int i;
580
581
0
    for (i = 0; i < DEGREE; i++)
582
0
        lhs->c[i] = reduce_once(lhs->c[i] - rhs->c[i] + kPrime);
583
0
}
584
585
/*
586
 * Multiplying two scalars in the number theoretically transformed state. Since
587
 * 3329 does not have a 512th root of unity, this means we have to interpret
588
 * the 2*ith and (2*i+1)th entries of the scalar as elements of
589
 * GF(3329)[X]/(X^2 - 17^(2*bitreverse(i)+1)).
590
 *
591
 * The value of 17^(2*bitreverse(i)+1) mod 3329 is stored in the precomputed
592
 * ModRoots table. Note that our Barrett transform only allows us to multiply
593
 * two reduced numbers together, so we need some intermediate reduction steps,
594
 * even if an uint64_t could hold 3 multiplied numbers.
595
 */
596
static void scalar_mult(scalar *out, const scalar *lhs,
597
    const scalar *rhs)
598
0
{
599
0
    uint16_t *curr = out->c, *end = curr + DEGREE;
600
0
    const uint16_t *lc = lhs->c, *rc = rhs->c;
601
0
    const uint16_t *roots = kModRoots;
602
603
0
    do {
604
0
        uint32_t l0 = *lc++, r0 = *rc++;
605
0
        uint32_t l1 = *lc++, r1 = *rc++;
606
0
        uint32_t zetapow = *roots++;
607
608
0
        *curr++ = reduce(l0 * r0 + reduce(l1 * r1) * zetapow);
609
0
        *curr++ = reduce(l0 * r1 + l1 * r0);
610
0
    } while (curr < end);
611
0
}
612
613
/* Above, but add the result to an existing scalar */
614
static ossl_inline void scalar_mult_add(scalar *out, const scalar *lhs,
615
    const scalar *rhs)
616
0
{
617
0
    uint16_t *curr = out->c, *end = curr + DEGREE;
618
0
    const uint16_t *lc = lhs->c, *rc = rhs->c;
619
0
    const uint16_t *roots = kModRoots;
620
621
0
    do {
622
0
        uint32_t l0 = *lc++, r0 = *rc++;
623
0
        uint32_t l1 = *lc++, r1 = *rc++;
624
0
        uint16_t *c0 = curr++;
625
0
        uint16_t *c1 = curr++;
626
0
        uint32_t zetapow = *roots++;
627
628
0
        *c0 = reduce(*c0 + l0 * r0 + reduce(l1 * r1) * zetapow);
629
0
        *c1 = reduce(*c1 + l0 * r1 + l1 * r0);
630
0
    } while (curr < end);
631
0
}
632
633
/*-
634
 * FIPS 203, Section 4.2.1, Algorithm 5: "ByteEncode_d", for 2<=d<=12.
635
 * Here |bits| is |d|.  For efficiency, we handle the d=1 case separately.
636
 */
637
static void scalar_encode(uint8_t *out, const scalar *s, int bits)
638
0
{
639
0
    const uint16_t *curr = s->c, *end = curr + DEGREE;
640
0
    uint64_t accum = 0, element;
641
0
    int used = 0;
642
643
0
    do {
644
0
        element = *curr++;
645
0
        if (used + bits < 64) {
646
0
            accum |= element << used;
647
0
            used += bits;
648
0
        } else if (used + bits > 64) {
649
0
            out = OPENSSL_store_u64_le(out, accum | (element << used));
650
0
            accum = element >> (64 - used);
651
0
            used = (used + bits) - 64;
652
0
        } else {
653
0
            out = OPENSSL_store_u64_le(out, accum | (element << used));
654
0
            accum = 0;
655
0
            used = 0;
656
0
        }
657
0
    } while (curr < end);
658
0
}
659
660
/*
661
 * scalar_encode_1 is |scalar_encode| specialised for |bits| == 1.
662
 */
663
static void scalar_encode_1(uint8_t out[DEGREE / 8], const scalar *s)
664
0
{
665
0
    int i, j;
666
0
    uint8_t out_byte;
667
668
0
    for (i = 0; i < DEGREE; i += 8) {
669
0
        out_byte = 0;
670
0
        for (j = 0; j < 8; j++)
671
0
            out_byte |= bit0(s->c[i + j]) << j;
672
0
        *out = out_byte;
673
0
        out++;
674
0
    }
675
0
}
676
677
/*-
678
 * FIPS 203, Section 4.2.1, Algorithm 6: "ByteDecode_d", for 2<=d<12.
679
 * Here |bits| is |d|.  For efficiency, we handle the d=1 and d=12 cases
680
 * separately.
681
 *
682
 * scalar_decode parses |DEGREE * bits| bits from |in| into |DEGREE| values in
683
 * |out|.
684
 */
685
static void scalar_decode(scalar *out, const uint8_t *in, int bits)
686
0
{
687
0
    uint16_t *curr = out->c, *end = curr + DEGREE;
688
0
    uint64_t accum = 0;
689
0
    int accum_bits = 0, todo = bits;
690
0
    uint16_t bitmask = (((uint16_t)1) << bits) - 1, mask = bitmask;
691
0
    uint16_t element = 0;
692
693
0
    do {
694
0
        if (accum_bits == 0) {
695
0
            in = OPENSSL_load_u64_le(&accum, in);
696
0
            accum_bits = 64;
697
0
        }
698
0
        if (todo == bits && accum_bits >= bits) {
699
            /* No partial "element", and all the required bits available */
700
0
            *curr++ = ((uint16_t)accum) & mask;
701
0
            accum >>= bits;
702
0
            accum_bits -= bits;
703
0
        } else if (accum_bits >= todo) {
704
            /* A partial "element", and all the required bits available */
705
0
            *curr++ = element | ((((uint16_t)accum) & mask) << (bits - todo));
706
0
            accum >>= todo;
707
0
            accum_bits -= todo;
708
0
            element = 0;
709
0
            todo = bits;
710
0
            mask = bitmask;
711
0
        } else {
712
            /*
713
             * Only some of the requisite bits accumulated, store |accum_bits|
714
             * of these in |element|.  The accumulated bitcount becomes 0, but
715
             * as soon as we have more bits we'll want to merge accum_bits
716
             * fewer of them into the final |element|.
717
             *
718
             * Note that with a 64-bit accumulator and |bits| always 12 or
719
             * less, if we're here, the previous iteration had all the
720
             * requisite bits, and so there are no kept bits in |element|.
721
             */
722
0
            element = ((uint16_t)accum) & mask;
723
0
            todo -= accum_bits;
724
0
            mask = bitmask >> accum_bits;
725
0
            accum_bits = 0;
726
0
        }
727
0
    } while (curr < end);
728
0
}
729
730
static __owur int scalar_decode_12(scalar *out, const uint8_t in[3 * DEGREE / 2])
731
0
{
732
0
    int i;
733
0
    uint16_t *c = out->c;
734
735
0
    for (i = 0; i < DEGREE / 2; ++i) {
736
0
        uint8_t b1 = *in++;
737
0
        uint8_t b2 = *in++;
738
0
        uint8_t b3 = *in++;
739
0
        int outOfRange1 = (*c++ = b1 | ((b2 & 0x0f) << 8)) >= kPrime;
740
0
        int outOfRange2 = (*c++ = (b2 >> 4) | (b3 << 4)) >= kPrime;
741
742
0
        if (outOfRange1 | outOfRange2)
743
0
            return 0;
744
0
    }
745
0
    return 1;
746
0
}
747
748
/*-
749
 * scalar_decode_decompress_add is a combination of decoding and decompression
750
 * both specialised for |bits| == 1, with the result added (and sum reduced) to
751
 * the output scalar.
752
 *
753
 * NOTE: this function MUST not leak an input-data-depedennt timing signal.
754
 * A timing leak in a related function in the reference Kyber implementation
755
 * made the "clangover" attack (CVE-2024-37880) possible, giving key recovery
756
 * for ML-KEM-512 in minutes, provided the attacker has access to precise
757
 * timing of a CPU performing chosen-ciphertext decap.  Admittedly this is only
758
 * a risk when private keys are reused (perhaps KEMTLS servers).
759
 */
760
static void
761
scalar_decode_decompress_add(scalar *out, const uint8_t in[DEGREE / 8])
762
0
{
763
0
    static const uint16_t half_q_plus_1 = (ML_KEM_PRIME >> 1) + 1;
764
0
    uint16_t *curr = out->c, *end = curr + DEGREE;
765
0
    uint16_t mask;
766
0
    uint8_t b;
767
768
    /*
769
     * Add |half_q_plus_1| if the bit is set, without exposing a side-channel,
770
     * avoiding the "clangover" attack.  See |constish_time_true| for a
771
     * discussion on why the value barrier is by default omitted.
772
     */
773
0
#define decode_decompress_add_bit                        \
774
0
    mask = constish_time_true(bit0(b));                  \
775
0
    *curr = reduce_once(*curr + (mask & half_q_plus_1)); \
776
0
    curr++;                                              \
777
0
    b >>= 1
778
779
    /* Unrolled to process each byte in one iteration */
780
0
    do {
781
0
        b = *in++;
782
0
        decode_decompress_add_bit;
783
0
        decode_decompress_add_bit;
784
0
        decode_decompress_add_bit;
785
0
        decode_decompress_add_bit;
786
787
0
        decode_decompress_add_bit;
788
0
        decode_decompress_add_bit;
789
0
        decode_decompress_add_bit;
790
0
        decode_decompress_add_bit;
791
0
    } while (curr < end);
792
0
#undef decode_decompress_add_bit
793
0
}
794
795
/*
796
 * FIPS 203, Section 4.2.1, Equation (4.7): Compress_d.
797
 *
798
 * Compresses (lossily) an input |x| mod 3329 into |bits| many bits by grouping
799
 * numbers close to each other together. The formula used is
800
 * round(2^|bits|/kPrime*x) mod 2^|bits|.
801
 * Uses Barrett reduction to achieve constant time. Since we need both the
802
 * remainder (for rounding) and the quotient (as the result), we cannot use
803
 * |reduce| here, but need to do the Barrett reduction directly.
804
 */
805
static __owur uint16_t compress(uint16_t x, int bits)
806
0
{
807
0
    uint32_t shifted = (uint32_t)x << bits;
808
0
    uint64_t product = (uint64_t)shifted * kBarrettMultiplier;
809
0
    uint32_t quotient = (uint32_t)(product >> kBarrettShift);
810
0
    uint32_t remainder = shifted - quotient * kPrime;
811
812
    /*
813
     * Adjust the quotient to round correctly:
814
     *   0 <= remainder <= kHalfPrime round to 0
815
     *   kHalfPrime < remainder <= kPrime + kHalfPrime round to 1
816
     *   kPrime + kHalfPrime < remainder < 2 * kPrime round to 2
817
     */
818
0
    quotient += 1 & constant_time_lt_32(kHalfPrime, remainder);
819
0
    quotient += 1 & constant_time_lt_32(kPrime + kHalfPrime, remainder);
820
0
    return quotient & ((1 << bits) - 1);
821
0
}
822
823
/*
824
 * FIPS 203, Section 4.2.1, Equation (4.8): Decompress_d.
825
826
 * Decompresses |x| by using a close equi-distant representative. The formula
827
 * is round(kPrime/2^|bits|*x). Note that 2^|bits| being the divisor allows us
828
 * to implement this logic using only bit operations.
829
 */
830
static __owur uint16_t decompress(uint16_t x, int bits)
831
0
{
832
0
    uint32_t product = (uint32_t)x * kPrime;
833
0
    uint32_t power = 1 << bits;
834
    /* This is |product| % power, since |power| is a power of 2. */
835
0
    uint32_t remainder = product & (power - 1);
836
    /* This is |product| / power, since |power| is a power of 2. */
837
0
    uint32_t lower = product >> bits;
838
839
    /*
840
     * The rounding logic works since the first half of numbers mod |power|
841
     * have a 0 as first bit, and the second half has a 1 as first bit, since
842
     * |power| is a power of 2. As a 12 bit number, |remainder| is always
843
     * positive, so we will shift in 0s for a right shift.
844
     */
845
0
    return lower + (remainder >> (bits - 1));
846
0
}
847
848
/*-
849
 * FIPS 203, Section 4.2.1, Equation (4.7): "Compress_d".
850
 * In-place lossy rounding of scalars to 2^d bits.
851
 */
852
static void scalar_compress(scalar *s, int bits)
853
0
{
854
0
    int i;
855
856
0
    for (i = 0; i < DEGREE; i++)
857
0
        s->c[i] = compress(s->c[i], bits);
858
0
}
859
860
/*
861
 * FIPS 203, Section 4.2.1, Equation (4.8): "Decompress_d".
862
 * In-place approximate recovery of scalars from 2^d bit compression.
863
 */
864
static void scalar_decompress(scalar *s, int bits)
865
0
{
866
0
    int i;
867
868
0
    for (i = 0; i < DEGREE; i++)
869
0
        s->c[i] = decompress(s->c[i], bits);
870
0
}
871
872
/* Addition updating the LHS vector in-place. */
873
static void vector_add(scalar *lhs, const scalar *rhs, int rank)
874
0
{
875
0
    do {
876
0
        scalar_add(lhs++, rhs++);
877
0
    } while (--rank > 0);
878
0
}
879
880
/*
881
 * Encodes an entire vector into 32*|rank|*|bits| bytes. Note that since 256
882
 * (DEGREE) is divisible by 8, the individual vector entries will always fill a
883
 * whole number of bytes, so we do not need to worry about bit packing here.
884
 */
885
static void vector_encode(uint8_t *out, const scalar *a, int bits, int rank)
886
0
{
887
0
    int stride = bits * DEGREE / 8;
888
889
0
    for (; rank-- > 0; out += stride)
890
0
        scalar_encode(out, a++, bits);
891
0
}
892
893
/*
894
 * Decodes 32*|rank|*|bits| bytes from |in| into |out|. It returns early
895
 * if any parsed value is >= |ML_KEM_PRIME|.  The resulting scalars are
896
 * then decompressed and transformed via the NTT.
897
 *
898
 * Note: Used only in decrypt_cpa(), which returns void and so does not check
899
 * the return value of this function.  Side-channels are fine when the input
900
 * ciphertext to decap() is simply syntactically invalid.
901
 */
902
static void
903
vector_decode_decompress_ntt(scalar *out, const uint8_t *in, int bits, int rank)
904
0
{
905
0
    int stride = bits * DEGREE / 8;
906
907
0
    for (; rank-- > 0; in += stride, ++out) {
908
0
        scalar_decode(out, in, bits);
909
0
        scalar_decompress(out, bits);
910
0
        scalar_ntt(out);
911
0
    }
912
0
}
913
914
/* vector_decode(), specialised to bits == 12. */
915
static __owur int vector_decode_12(scalar *out, const uint8_t in[3 * DEGREE / 2], int rank)
916
0
{
917
0
    int stride = 3 * DEGREE / 2;
918
919
0
    for (; rank-- > 0; in += stride)
920
0
        if (!scalar_decode_12(out++, in))
921
0
            return 0;
922
0
    return 1;
923
0
}
924
925
/* In-place compression of each scalar component */
926
static void vector_compress(scalar *a, int bits, int rank)
927
0
{
928
0
    do {
929
0
        scalar_compress(a++, bits);
930
0
    } while (--rank > 0);
931
0
}
932
933
/* The output scalar must not overlap with the inputs */
934
static void inner_product(scalar *out, const scalar *lhs, const scalar *rhs,
935
    int rank)
936
0
{
937
0
    scalar_mult(out, lhs, rhs);
938
0
    while (--rank > 0)
939
0
        scalar_mult_add(out, ++lhs, ++rhs);
940
0
}
941
942
/*
943
 * Here, the output vector must not overlap with the inputs, the result is
944
 * directly subjected to inverse NTT.
945
 */
946
static void
947
matrix_mult_intt(scalar *out, const scalar *m, const scalar *a, int rank)
948
0
{
949
0
    const scalar *ar;
950
0
    int i, j;
951
952
0
    for (i = rank; i-- > 0; ++out) {
953
0
        scalar_mult(out, m++, ar = a);
954
0
        for (j = rank - 1; j > 0; --j)
955
0
            scalar_mult_add(out, m++, ++ar);
956
0
        scalar_inverse_ntt(out);
957
0
    }
958
0
}
959
960
/* Here, the output vector must not overlap with the inputs */
961
static void
962
matrix_mult_transpose_add(scalar *out, const scalar *m, const scalar *a, int rank)
963
0
{
964
0
    const scalar *mc = m, *mr, *ar;
965
0
    int i, j;
966
967
0
    for (i = rank; i-- > 0; ++out) {
968
0
        scalar_mult_add(out, mr = mc++, ar = a);
969
0
        for (j = rank; --j > 0;)
970
0
            scalar_mult_add(out, (mr += rank), ++ar);
971
0
    }
972
0
}
973
974
/*-
975
 * Expands the matrix from a seed for key generation and for encaps-CPA.
976
 * NOTE: FIPS 203 matrix "A" is the transpose of this matrix, computed
977
 * by appending the (i,j) indices to the seed in the opposite order!
978
 *
979
 * Where FIPS 203 computes t = A * s + e, we use the transpose of "m".
980
 */
981
static __owur int matrix_expand(EVP_MD_CTX *mdctx, ML_KEM_KEY *key)
982
0
{
983
0
    scalar *out = key->m;
984
0
    uint8_t input[ML_KEM_RANDOM_BYTES + 2];
985
0
    int rank = key->vinfo->rank;
986
0
    int i, j;
987
988
0
    memcpy(input, key->rho, ML_KEM_RANDOM_BYTES);
989
0
    for (i = 0; i < rank; i++) {
990
0
        for (j = 0; j < rank; j++) {
991
0
            input[ML_KEM_RANDOM_BYTES] = i;
992
0
            input[ML_KEM_RANDOM_BYTES + 1] = j;
993
0
            if (!EVP_DigestInit_ex(mdctx, key->shake128_md, NULL)
994
0
                || !EVP_DigestUpdate(mdctx, input, sizeof(input))
995
0
                || !sample_scalar(out++, mdctx))
996
0
                return 0;
997
0
        }
998
0
    }
999
0
    return 1;
1000
0
}
1001
1002
/*
1003
 * Algorithm 7 from the spec, with eta fixed to two and the PRF call
1004
 * included. Creates binominally distributed elements by sampling 2*|eta| bits,
1005
 * and setting the coefficient to the count of the first bits minus the count of
1006
 * the second bits, resulting in a centered binomial distribution. Since eta is
1007
 * two this gives -2/2 with a probability of 1/16, -1/1 with probability 1/4,
1008
 * and 0 with probability 3/8.
1009
 */
1010
static __owur int cbd_2(scalar *out, uint8_t in[ML_KEM_RANDOM_BYTES + 1],
1011
    EVP_MD_CTX *mdctx, const ML_KEM_KEY *key)
1012
0
{
1013
0
    uint16_t *curr = out->c, *end = curr + DEGREE;
1014
0
    uint8_t randbuf[4 * DEGREE / 8], *r = randbuf; /* 64 * eta slots */
1015
0
    uint16_t value, mask;
1016
0
    uint8_t b;
1017
1018
0
    if (!prf(randbuf, sizeof(randbuf), in, mdctx, key))
1019
0
        return 0;
1020
1021
0
    do {
1022
0
        b = *r++;
1023
1024
        /*
1025
         * Add |kPrime| if |value| underflowed.  See |constish_time_true| for
1026
         * a discussion on why the value barrier is by default omitted.  While
1027
         * this could have been written reduce_once(value + kPrime), this is
1028
         * one extra addition and small range of |value| tempts some versions
1029
         * of Clang to emit a branch.
1030
         */
1031
0
        value = bit0(b) + bitn(1, b);
1032
0
        value -= bitn(2, b) + bitn(3, b);
1033
0
        mask = constish_time_true(value >> 15);
1034
0
        *curr++ = value + (kPrime & mask);
1035
1036
0
        value = bitn(4, b) + bitn(5, b);
1037
0
        value -= bitn(6, b) + bitn(7, b);
1038
0
        mask = constish_time_true(value >> 15);
1039
0
        *curr++ = value + (kPrime & mask);
1040
0
    } while (curr < end);
1041
0
    return 1;
1042
0
}
1043
1044
/*
1045
 * Algorithm 7 from the spec, with eta fixed to three and the PRF call
1046
 * included. Creates binominally distributed elements by sampling 3*|eta| bits,
1047
 * and setting the coefficient to the count of the first bits minus the count of
1048
 * the second bits, resulting in a centered binomial distribution.
1049
 */
1050
static __owur int cbd_3(scalar *out, uint8_t in[ML_KEM_RANDOM_BYTES + 1],
1051
    EVP_MD_CTX *mdctx, const ML_KEM_KEY *key)
1052
0
{
1053
0
    uint16_t *curr = out->c, *end = curr + DEGREE;
1054
0
    uint8_t randbuf[6 * DEGREE / 8], *r = randbuf; /* 64 * eta slots */
1055
0
    uint8_t b1, b2, b3;
1056
0
    uint16_t value, mask;
1057
1058
0
    if (!prf(randbuf, sizeof(randbuf), in, mdctx, key))
1059
0
        return 0;
1060
1061
0
    do {
1062
0
        b1 = *r++;
1063
0
        b2 = *r++;
1064
0
        b3 = *r++;
1065
1066
        /*
1067
         * Add |kPrime| if |value| underflowed.  See |constish_time_true|
1068
         * for a discussion on why the value barrier is by default omitted.
1069
         * While this could have been written reduce_once(value + kPrime), this
1070
         * is one extra addition and small range of |value| tempts some
1071
         * versions of Clang to emit a branch.
1072
         */
1073
0
        value = bit0(b1) + bitn(1, b1) + bitn(2, b1);
1074
0
        value -= bitn(3, b1) + bitn(4, b1) + bitn(5, b1);
1075
0
        mask = constish_time_true(value >> 15);
1076
0
        *curr++ = value + (kPrime & mask);
1077
1078
0
        value = bitn(6, b1) + bitn(7, b1) + bit0(b2);
1079
0
        value -= bitn(1, b2) + bitn(2, b2) + bitn(3, b2);
1080
0
        mask = constish_time_true(value >> 15);
1081
0
        *curr++ = value + (kPrime & mask);
1082
1083
0
        value = bitn(4, b2) + bitn(5, b2) + bitn(6, b2);
1084
0
        value -= bitn(7, b2) + bit0(b3) + bitn(1, b3);
1085
0
        mask = constish_time_true(value >> 15);
1086
0
        *curr++ = value + (kPrime & mask);
1087
1088
0
        value = bitn(2, b3) + bitn(3, b3) + bitn(4, b3);
1089
0
        value -= bitn(5, b3) + bitn(6, b3) + bitn(7, b3);
1090
0
        mask = constish_time_true(value >> 15);
1091
0
        *curr++ = value + (kPrime & mask);
1092
0
    } while (curr < end);
1093
0
    return 1;
1094
0
}
1095
1096
/*
1097
 * Generates a secret vector by using |cbd| with the given seed to generate
1098
 * scalar elements and incrementing |counter| for each slot of the vector.
1099
 */
1100
static __owur int gencbd_vector(scalar *out, CBD_FUNC cbd, uint8_t *counter,
1101
    const uint8_t seed[ML_KEM_RANDOM_BYTES], int rank,
1102
    EVP_MD_CTX *mdctx, const ML_KEM_KEY *key)
1103
0
{
1104
0
    uint8_t input[ML_KEM_RANDOM_BYTES + 1];
1105
1106
0
    memcpy(input, seed, ML_KEM_RANDOM_BYTES);
1107
0
    do {
1108
0
        input[ML_KEM_RANDOM_BYTES] = (*counter)++;
1109
0
        if (!cbd(out++, input, mdctx, key))
1110
0
            return 0;
1111
0
    } while (--rank > 0);
1112
0
    return 1;
1113
0
}
1114
1115
/*
1116
 * As above plus NTT transform.
1117
 */
1118
static __owur int gencbd_vector_ntt(scalar *out, CBD_FUNC cbd, uint8_t *counter,
1119
    const uint8_t seed[ML_KEM_RANDOM_BYTES], int rank,
1120
    EVP_MD_CTX *mdctx, const ML_KEM_KEY *key)
1121
0
{
1122
0
    uint8_t input[ML_KEM_RANDOM_BYTES + 1];
1123
1124
0
    memcpy(input, seed, ML_KEM_RANDOM_BYTES);
1125
0
    do {
1126
0
        input[ML_KEM_RANDOM_BYTES] = (*counter)++;
1127
0
        if (!cbd(out, input, mdctx, key))
1128
0
            return 0;
1129
0
        scalar_ntt(out++);
1130
0
    } while (--rank > 0);
1131
0
    return 1;
1132
0
}
1133
1134
/* The |ETA1| value for ML-KEM-512 is 3, the rest and all ETA2 values are 2. */
1135
0
#define CBD1(evp_type) ((evp_type) == EVP_PKEY_ML_KEM_512 ? cbd_3 : cbd_2)
1136
1137
/*
1138
 * FIPS 203, Section 5.2, Algorithm 14: K-PKE.Encrypt.
1139
 *
1140
 * Encrypts a message with given randomness to the ciphertext in |out|. Without
1141
 * applying the Fujisaki-Okamoto transform this would not result in a CCA
1142
 * secure scheme, since lattice schemes are vulnerable to decryption failure
1143
 * oracles.
1144
 *
1145
 * The steps are re-ordered to make more efficient/localised use of storage.
1146
 *
1147
 * Note also that the input public key is assumed to hold a precomputed matrix
1148
 * |A| (our key->m, with the public key holding an expanded (16-bit per scalar
1149
 * coefficient) key->t vector).
1150
 *
1151
 * Caller passes storage in |tmp| for for two temporary vectors.
1152
 */
1153
static __owur int encrypt_cpa(uint8_t out[ML_KEM_SHARED_SECRET_BYTES],
1154
    const uint8_t message[DEGREE / 8],
1155
    const uint8_t r[ML_KEM_RANDOM_BYTES], scalar *tmp,
1156
    EVP_MD_CTX *mdctx, const ML_KEM_KEY *key)
1157
0
{
1158
0
    const ML_KEM_VINFO *vinfo = key->vinfo;
1159
0
    CBD_FUNC cbd_1 = CBD1(vinfo->evp_type);
1160
0
    int rank = vinfo->rank;
1161
    /* We can use tmp[0..rank-1] as storage for |y|, then |e1|, ... */
1162
0
    scalar *y = &tmp[0], *e1 = y, *e2 = y;
1163
    /* We can use tmp[rank]..tmp[2*rank - 1] for |u| */
1164
0
    scalar *u = &tmp[rank];
1165
0
    scalar v;
1166
0
    uint8_t input[ML_KEM_RANDOM_BYTES + 1];
1167
0
    uint8_t counter = 0;
1168
0
    int du = vinfo->du;
1169
0
    int dv = vinfo->dv;
1170
1171
    /* FIPS 203 "y" vector */
1172
0
    if (!gencbd_vector_ntt(y, cbd_1, &counter, r, rank, mdctx, key))
1173
0
        return 0;
1174
    /* FIPS 203 "v" scalar */
1175
0
    inner_product(&v, key->t, y, rank);
1176
0
    scalar_inverse_ntt(&v);
1177
    /* FIPS 203 "u" vector */
1178
0
    matrix_mult_intt(u, key->m, y, rank);
1179
1180
    /* All done with |y|, now free to reuse tmp[0] for FIPS 203 |e1| */
1181
0
    if (!gencbd_vector(e1, cbd_2, &counter, r, rank, mdctx, key))
1182
0
        return 0;
1183
0
    vector_add(u, e1, rank);
1184
0
    vector_compress(u, du, rank);
1185
0
    vector_encode(out, u, du, rank);
1186
1187
    /* All done with |e1|, now free to reuse tmp[0] for FIPS 203 |e2| */
1188
0
    memcpy(input, r, ML_KEM_RANDOM_BYTES);
1189
0
    input[ML_KEM_RANDOM_BYTES] = counter;
1190
0
    if (!cbd_2(e2, input, mdctx, key))
1191
0
        return 0;
1192
0
    scalar_add(&v, e2);
1193
1194
    /* Combine message with |v| */
1195
0
    scalar_decode_decompress_add(&v, message);
1196
0
    scalar_compress(&v, dv);
1197
0
    scalar_encode(out + vinfo->u_vector_bytes, &v, dv);
1198
0
    return 1;
1199
0
}
1200
1201
/*
1202
 * FIPS 203, Section 5.3, Algorithm 15: K-PKE.Decrypt.
1203
 */
1204
static void
1205
decrypt_cpa(uint8_t out[ML_KEM_SHARED_SECRET_BYTES],
1206
    const uint8_t *ctext, scalar *u, const ML_KEM_KEY *key)
1207
0
{
1208
0
    const ML_KEM_VINFO *vinfo = key->vinfo;
1209
0
    scalar v, mask;
1210
0
    int rank = vinfo->rank;
1211
0
    int du = vinfo->du;
1212
0
    int dv = vinfo->dv;
1213
1214
0
    vector_decode_decompress_ntt(u, ctext, du, rank);
1215
0
    scalar_decode(&v, ctext + vinfo->u_vector_bytes, dv);
1216
0
    scalar_decompress(&v, dv);
1217
0
    inner_product(&mask, key->s, u, rank);
1218
0
    scalar_inverse_ntt(&mask);
1219
0
    scalar_sub(&v, &mask);
1220
0
    scalar_compress(&v, 1);
1221
0
    scalar_encode_1(out, &v);
1222
0
}
1223
1224
/*-
1225
 * FIPS 203, Section 7.1, Algorithm 19: "ML-KEM.KeyGen".
1226
 * FIPS 203, Section 7.2, Algorithm 20: "ML-KEM.Encaps".
1227
 *
1228
 * Fills the |out| buffer with the |ek| output of "ML-KEM.KeyGen", or,
1229
 * equivalently, the |ek| input of "ML-KEM.Encaps", i.e. returns the
1230
 * wire-format of an ML-KEM public key.
1231
 */
1232
static void encode_pubkey(uint8_t *out, const ML_KEM_KEY *key)
1233
0
{
1234
0
    const uint8_t *rho = key->rho;
1235
0
    const ML_KEM_VINFO *vinfo = key->vinfo;
1236
1237
0
    vector_encode(out, key->t, 12, vinfo->rank);
1238
0
    memcpy(out + vinfo->vector_bytes, rho, ML_KEM_RANDOM_BYTES);
1239
0
}
1240
1241
/*-
1242
 * FIPS 203, Section 7.1, Algorithm 19: "ML-KEM.KeyGen".
1243
 *
1244
 * Fills the |out| buffer with the |dk| output of "ML-KEM.KeyGen".
1245
 * This matches the input format of parse_prvkey() below.
1246
 */
1247
static void encode_prvkey(uint8_t *out, const ML_KEM_KEY *key)
1248
0
{
1249
0
    const ML_KEM_VINFO *vinfo = key->vinfo;
1250
1251
0
    vector_encode(out, key->s, 12, vinfo->rank);
1252
0
    out += vinfo->vector_bytes;
1253
0
    encode_pubkey(out, key);
1254
0
    out += vinfo->pubkey_bytes;
1255
0
    memcpy(out, key->pkhash, ML_KEM_PKHASH_BYTES);
1256
0
    out += ML_KEM_PKHASH_BYTES;
1257
0
    memcpy(out, key->z, ML_KEM_RANDOM_BYTES);
1258
0
}
1259
1260
/*-
1261
 * FIPS 203, Section 7.1, Algorithm 19: "ML-KEM.KeyGen".
1262
 * FIPS 203, Section 7.2, Algorithm 20: "ML-KEM.Encaps".
1263
 *
1264
 * This function parses the |in| buffer as the |ek| output of "ML-KEM.KeyGen",
1265
 * or, equivalently, the |ek| input of "ML-KEM.Encaps", i.e. decodes the
1266
 * wire-format of the ML-KEM public key.
1267
 */
1268
static int parse_pubkey(const uint8_t *in, EVP_MD_CTX *mdctx, ML_KEM_KEY *key)
1269
0
{
1270
0
    const ML_KEM_VINFO *vinfo = key->vinfo;
1271
1272
    /* Decode and check |t| */
1273
0
    if (!vector_decode_12(key->t, in, vinfo->rank)) {
1274
0
        ERR_raise_data(ERR_LIB_PROV, PROV_R_INVALID_KEY,
1275
0
            "%s invalid public 't' vector",
1276
0
            vinfo->algorithm_name);
1277
0
        return 0;
1278
0
    }
1279
    /* Save the matrix |m| recovery seed |rho| */
1280
0
    memcpy(key->rho, in + vinfo->vector_bytes, ML_KEM_RANDOM_BYTES);
1281
    /*
1282
     * Pre-compute the public key hash, needed for both encap and decap.
1283
     * Also pre-compute the matrix expansion, stored with the public key.
1284
     */
1285
0
    if (!hash_h(key->pkhash, in, vinfo->pubkey_bytes, mdctx, key)
1286
0
        || !matrix_expand(mdctx, key)) {
1287
0
        ERR_raise_data(ERR_LIB_CRYPTO, ERR_R_INTERNAL_ERROR,
1288
0
            "internal error while parsing %s public key",
1289
0
            vinfo->algorithm_name);
1290
0
        return 0;
1291
0
    }
1292
0
    return 1;
1293
0
}
1294
1295
/*
1296
 * FIPS 203, Section 7.1, Algorithm 19: "ML-KEM.KeyGen".
1297
 *
1298
 * Parses the |in| buffer as a |dk| output of "ML-KEM.KeyGen".
1299
 * This matches the output format of encode_prvkey() above.
1300
 */
1301
static int parse_prvkey(const uint8_t *in, EVP_MD_CTX *mdctx, ML_KEM_KEY *key)
1302
0
{
1303
0
    const ML_KEM_VINFO *vinfo = key->vinfo;
1304
1305
    /* Decode and check |s|. */
1306
0
    if (!vector_decode_12(key->s, in, vinfo->rank)) {
1307
0
        ERR_raise_data(ERR_LIB_PROV, PROV_R_INVALID_KEY,
1308
0
            "%s invalid private 's' vector",
1309
0
            vinfo->algorithm_name);
1310
0
        return 0;
1311
0
    }
1312
0
    in += vinfo->vector_bytes;
1313
1314
0
    if (!parse_pubkey(in, mdctx, key))
1315
0
        return 0;
1316
0
    in += vinfo->pubkey_bytes;
1317
1318
    /* Check public key hash. */
1319
0
    if (memcmp(key->pkhash, in, ML_KEM_PKHASH_BYTES) != 0) {
1320
0
        ERR_raise_data(ERR_LIB_PROV, PROV_R_INVALID_KEY,
1321
0
            "%s public key hash mismatch",
1322
0
            vinfo->algorithm_name);
1323
0
        return 0;
1324
0
    }
1325
0
    in += ML_KEM_PKHASH_BYTES;
1326
1327
0
    memcpy(key->z, in, ML_KEM_RANDOM_BYTES);
1328
0
    return 1;
1329
0
}
1330
1331
/*
1332
 * FIPS 203, Section 6.1, Algorithm 16: "ML-KEM.KeyGen_internal".
1333
 *
1334
 * The implementation of Section 5.1, Algorithm 13, "K-PKE.KeyGen(d)" is
1335
 * inlined.
1336
 *
1337
 * The caller MUST pass a pre-allocated digest context that is not shared with
1338
 * any concurrent computation.
1339
 *
1340
 * This function optionally outputs the serialised wire-form |ek| public key
1341
 * into the provided |pubenc| buffer, and generates the content of the |rho|,
1342
 * |pkhash|, |t|, |m|, |s| and |z| components of the private |key| (which must
1343
 * have preallocated space for these).
1344
 *
1345
 * Keys are computed from a 32-byte random |d| plus the 1 byte rank for
1346
 * domain separation.  These are concatenated and hashed to produce a pair of
1347
 * 32-byte seeds public "rho", used to generate the matrix, and private "sigma",
1348
 * used to generate the secret vector |s|.
1349
 *
1350
 * The second random input |z| is copied verbatim into the Fujisaki-Okamoto
1351
 * (FO) transform "implicit-rejection" secret (the |z| component of the private
1352
 * key), which thwarts chosen-ciphertext attacks, provided decap() runs in
1353
 * constant time, with no side channel leaks, on all well-formed (valid length,
1354
 * and correctly encoded) ciphertext inputs.
1355
 */
1356
static __owur int genkey(const uint8_t seed[ML_KEM_SEED_BYTES],
1357
    EVP_MD_CTX *mdctx, uint8_t *pubenc, ML_KEM_KEY *key)
1358
0
{
1359
0
    uint8_t hashed[2 * ML_KEM_RANDOM_BYTES];
1360
0
    const uint8_t *const sigma = hashed + ML_KEM_RANDOM_BYTES;
1361
0
    uint8_t augmented_seed[ML_KEM_RANDOM_BYTES + 1];
1362
0
    const ML_KEM_VINFO *vinfo = key->vinfo;
1363
0
    CBD_FUNC cbd_1 = CBD1(vinfo->evp_type);
1364
0
    int rank = vinfo->rank;
1365
0
    uint8_t counter = 0;
1366
0
    int ret = 0;
1367
1368
    /*
1369
     * Use the "d" seed salted with the rank to derive the public and private
1370
     * seeds rho and sigma.
1371
     */
1372
0
    memcpy(augmented_seed, seed, ML_KEM_RANDOM_BYTES);
1373
0
    augmented_seed[ML_KEM_RANDOM_BYTES] = (uint8_t)rank;
1374
0
    if (!hash_g(hashed, augmented_seed, sizeof(augmented_seed), mdctx, key))
1375
0
        goto end;
1376
0
    memcpy(key->rho, hashed, ML_KEM_RANDOM_BYTES);
1377
    /* The |rho| matrix seed is public */
1378
0
    CONSTTIME_DECLASSIFY(key->rho, ML_KEM_RANDOM_BYTES);
1379
1380
    /* FIPS 203 |e| vector is initial value of key->t */
1381
0
    if (!matrix_expand(mdctx, key)
1382
0
        || !gencbd_vector_ntt(key->s, cbd_1, &counter, sigma, rank, mdctx, key)
1383
0
        || !gencbd_vector_ntt(key->t, cbd_1, &counter, sigma, rank, mdctx, key))
1384
0
        goto end;
1385
1386
    /* To |e| we now add the product of transpose |m| and |s|, giving |t|. */
1387
0
    matrix_mult_transpose_add(key->t, key->m, key->s, rank);
1388
    /* The |t| vector is public */
1389
0
    CONSTTIME_DECLASSIFY(key->t, vinfo->rank * sizeof(scalar));
1390
1391
0
    if (pubenc == NULL) {
1392
        /* Incremental digest of public key without in-full serialisation. */
1393
0
        if (!hash_h_pubkey(key->pkhash, mdctx, key))
1394
0
            goto end;
1395
0
    } else {
1396
0
        encode_pubkey(pubenc, key);
1397
0
        if (!hash_h(key->pkhash, pubenc, vinfo->pubkey_bytes, mdctx, key))
1398
0
            goto end;
1399
0
    }
1400
1401
    /* Save |z| portion of seed for "implicit rejection" on failure. */
1402
0
    memcpy(key->z, seed + ML_KEM_RANDOM_BYTES, ML_KEM_RANDOM_BYTES);
1403
1404
    /* Save the |d| portion of the seed */
1405
0
    key->d = key->z + ML_KEM_RANDOM_BYTES;
1406
0
    memcpy(key->d, seed, ML_KEM_RANDOM_BYTES);
1407
1408
0
    ret = 1;
1409
0
end:
1410
0
    OPENSSL_cleanse((void *)augmented_seed, ML_KEM_RANDOM_BYTES);
1411
0
    OPENSSL_cleanse((void *)sigma, ML_KEM_RANDOM_BYTES);
1412
0
    if (ret == 0) {
1413
0
        ERR_raise_data(ERR_LIB_CRYPTO, ERR_R_INTERNAL_ERROR,
1414
0
            "internal error while generating %s private key",
1415
0
            vinfo->algorithm_name);
1416
0
    }
1417
0
    return ret;
1418
0
}
1419
1420
/*-
1421
 * FIPS 203, Section 6.2, Algorithm 17: "ML-KEM.Encaps_internal".
1422
 * This is the deterministic version with randomness supplied externally.
1423
 *
1424
 * The caller must pass space for two vectors in |tmp|.
1425
 * The |ctext| buffer have space for the ciphertext of the ML-KEM variant
1426
 * of the provided key.
1427
 */
1428
static int encap(uint8_t *ctext, uint8_t secret[ML_KEM_SHARED_SECRET_BYTES],
1429
    const uint8_t entropy[ML_KEM_RANDOM_BYTES],
1430
    scalar *tmp, EVP_MD_CTX *mdctx, const ML_KEM_KEY *key)
1431
0
{
1432
0
    uint8_t input[ML_KEM_RANDOM_BYTES + ML_KEM_PKHASH_BYTES];
1433
0
    uint8_t Kr[ML_KEM_SHARED_SECRET_BYTES + ML_KEM_RANDOM_BYTES];
1434
0
    uint8_t *r = Kr + ML_KEM_SHARED_SECRET_BYTES;
1435
0
    int ret;
1436
1437
0
    memcpy(input, entropy, ML_KEM_RANDOM_BYTES);
1438
0
    memcpy(input + ML_KEM_RANDOM_BYTES, key->pkhash, ML_KEM_PKHASH_BYTES);
1439
0
    ret = hash_g(Kr, input, sizeof(input), mdctx, key)
1440
0
        && encrypt_cpa(ctext, entropy, r, tmp, mdctx, key);
1441
0
    OPENSSL_cleanse((void *)input, sizeof(input));
1442
1443
0
    if (ret)
1444
0
        memcpy(secret, Kr, ML_KEM_SHARED_SECRET_BYTES);
1445
0
    else
1446
0
        ERR_raise_data(ERR_LIB_CRYPTO, ERR_R_INTERNAL_ERROR,
1447
0
            "internal error while performing %s encapsulation",
1448
0
            key->vinfo->algorithm_name);
1449
0
    return ret;
1450
0
}
1451
1452
/*
1453
 * Hash the input message |m'| and public key digest |h|
1454
 * to obtain |K| and |r|.
1455
 */
1456
static int hash_kr(uint8_t *out, uint8_t *in,
1457
    EVP_MD_CTX *mdctx, const ML_KEM_KEY *key)
1458
0
{
1459
0
    unsigned int sz, wanted;
1460
1461
0
    wanted = ML_KEM_SHARED_SECRET_BYTES + ML_KEM_RANDOM_BYTES;
1462
0
    return (EVP_DigestInit_ex(mdctx, key->sha3_512_md, NULL)
1463
0
        && EVP_DigestUpdate(mdctx, in, ML_KEM_RANDOM_BYTES)
1464
0
        && EVP_DigestUpdate(mdctx, key->pkhash, ML_KEM_PKHASH_BYTES)
1465
0
        && EVP_DigestFinal_ex(mdctx, out, &sz)
1466
0
        && ossl_assert(sz == wanted));
1467
0
}
1468
1469
/*-
1470
 * Decap needs space for: Kbar | K | r | m'
1471
 * We slice up a single buffer to hold them all.
1472
 * We don't need to cleanse the public pkhash value.
1473
 */
1474
0
#define DECAP_BUFFER_SZ (2 * ML_KEM_SHARED_SECRET_BYTES + 2 * ML_KEM_RANDOM_BYTES)
1475
1476
/*
1477
 * FIPS 203, Section 6.3, Algorithm 18: ML-KEM.Decaps_internal
1478
 *
1479
 * Barring failure of the supporting SHA3/SHAKE primitives, this is fully
1480
 * deterministic, the randomness for the FO transform is extracted during
1481
 * private key generation.
1482
 *
1483
 * The caller must pass space for two vectors in |tmp|.
1484
 * The |ctext| and |tmp_ctext| buffers must each have space for the ciphertext
1485
 * of the key's ML-KEM variant.
1486
 */
1487
static int decap(uint8_t secret[ML_KEM_SHARED_SECRET_BYTES],
1488
    const uint8_t *ctext, uint8_t *tmp_ctext, scalar *tmp,
1489
    EVP_MD_CTX *mdctx, const ML_KEM_KEY *key)
1490
0
{
1491
0
    uint8_t buf[DECAP_BUFFER_SZ];
1492
0
    uint8_t *failure_key = buf; /* Kbar */
1493
0
    uint8_t *Kr = failure_key + ML_KEM_SHARED_SECRET_BYTES;
1494
0
    uint8_t *r = Kr + ML_KEM_SHARED_SECRET_BYTES;
1495
0
    uint8_t *m = r + ML_KEM_RANDOM_BYTES; /* m' */
1496
0
    const ML_KEM_VINFO *vinfo = key->vinfo;
1497
0
    int i;
1498
0
    uint8_t mask;
1499
1500
    /*
1501
     * If our KDF is unavailable, fail early! Otherwise, keep going ignoring
1502
     * any further errors, returning success, and whatever we got for a shared
1503
     * secret.  The decrypt_cpa() function is just arithmetic on secret data,
1504
     * so should not be subject to failure that makes its output predictable.
1505
     *
1506
     * We guard against "should never happen" catastrophic failure of the
1507
     * "pure" function |hash_g| by overwriting the shared secret with the
1508
     * content of the failure key and returning early, if nevertheless hash_g
1509
     * fails.  This is not constant-time, but a failure of |hash_g| already
1510
     * implies loss of side-channel resistance.
1511
     *
1512
     * The same action is taken, if also |encrypt_cpa| should catastrophically
1513
     * fail, due to failure of the |PRF| underlying the CBD functions.
1514
     */
1515
0
    if (!kdf(failure_key, key->z, ctext, vinfo->ctext_bytes, mdctx, key)) {
1516
0
        ERR_raise_data(ERR_LIB_CRYPTO, ERR_R_INTERNAL_ERROR,
1517
0
            "internal error while performing %s decapsulation",
1518
0
            vinfo->algorithm_name);
1519
0
        return 0;
1520
0
    }
1521
0
    decrypt_cpa(m, ctext, tmp, key);
1522
0
    if (!hash_kr(Kr, m, mdctx, key)
1523
0
        || !encrypt_cpa(tmp_ctext, m, r, tmp, mdctx, key)) {
1524
0
        memcpy(secret, failure_key, ML_KEM_SHARED_SECRET_BYTES);
1525
0
        goto end;
1526
0
    }
1527
0
    mask = constant_time_eq_int_8(0,
1528
0
        CRYPTO_memcmp(ctext, tmp_ctext, vinfo->ctext_bytes));
1529
0
    for (i = 0; i < ML_KEM_SHARED_SECRET_BYTES; i++)
1530
0
        secret[i] = constant_time_select_8(mask, Kr[i], failure_key[i]);
1531
0
end:
1532
0
    OPENSSL_cleanse(buf, DECAP_BUFFER_SZ);
1533
0
    return 1;
1534
0
}
1535
1536
/*
1537
 * After allocating storage for public or private key data, update the key
1538
 * component pointers to reference that storage.
1539
 *
1540
 * The caller should only store private data in `priv` *after* a successful
1541
 * (non-zero) return from this function.
1542
 */
1543
static __owur int add_storage(scalar *pub, scalar *priv, int private, ML_KEM_KEY *key)
1544
0
{
1545
0
    int rank = key->vinfo->rank;
1546
1547
0
    if (pub == NULL || (private && priv == NULL)) {
1548
        /*
1549
         * One of these could be allocated correctly. It is legal to call free with a NULL
1550
         * pointer, so always attempt to free both allocations here
1551
         */
1552
0
        OPENSSL_free(pub);
1553
0
        OPENSSL_secure_free(priv);
1554
0
        return 0;
1555
0
    }
1556
1557
    /*
1558
     * We're adding key material, set up rho and pkhash to point to the rho_pkhash buffer
1559
     */
1560
0
    memset(key->rho_pkhash, 0, sizeof(key->rho_pkhash));
1561
0
    key->rho = key->rho_pkhash;
1562
0
    key->pkhash = key->rho_pkhash + ML_KEM_RANDOM_BYTES;
1563
0
    key->d = key->z = NULL;
1564
1565
    /* A public key needs space for |t| and |m| */
1566
0
    key->m = (key->t = pub) + rank;
1567
1568
    /*
1569
     * A private key also needs space for |s| and |z|.
1570
     * The |z| buffer always includes additional space for |d|, but a key's |d|
1571
     * pointer is left NULL when parsed from the NIST format, which omits that
1572
     * information.  Only keys generated from a (d, z) seed pair will have a
1573
     * non-NULL |d| pointer.
1574
     */
1575
0
    if (private)
1576
0
        key->z = (uint8_t *)(rank + (key->s = priv));
1577
0
    return 1;
1578
0
}
1579
1580
/*
1581
 * After freeing the storage associated with a key that failed to be
1582
 * constructed, reset the internal pointers back to NULL.
1583
 */
1584
void ossl_ml_kem_key_reset(ML_KEM_KEY *key)
1585
0
{
1586
    /*
1587
     * seedbuf can be allocated and contain |z| and |d| if the key is
1588
     * being created from a private key encoding.  Similarly a pending
1589
     * serialised (encoded) private key may be queued up to load.
1590
     * Clear and free that data now.
1591
     */
1592
0
    if (key->seedbuf != NULL)
1593
0
        OPENSSL_secure_clear_free(key->seedbuf, ML_KEM_SEED_BYTES);
1594
0
    if (ossl_ml_kem_have_dkenc(key))
1595
0
        OPENSSL_secure_clear_free(key->encoded_dk, key->vinfo->prvkey_bytes);
1596
1597
    /*-
1598
     * Cleanse any sensitive data:
1599
     * - The private vector |s| is immediately followed by the FO failure
1600
     *   secret |z|, and seed |d|, we can cleanse all three in one call.
1601
     */
1602
0
    if (key->t != NULL) {
1603
0
        if (ossl_ml_kem_have_prvkey(key))
1604
0
            OPENSSL_secure_clear_free(key->s, key->vinfo->prvalloc);
1605
0
        OPENSSL_free(key->t);
1606
0
    }
1607
0
    key->d = key->z = key->seedbuf = key->encoded_dk = (uint8_t *)(key->s = key->m = key->t = NULL);
1608
0
}
1609
1610
/*
1611
 * ----- API exported to the provider
1612
 *
1613
 * Parameters with an implicit fixed length in the internal static API of each
1614
 * variant have an explicit checked length argument at this layer.
1615
 */
1616
1617
/* Retrieve the parameters of one of the ML-KEM variants */
1618
const ML_KEM_VINFO *ossl_ml_kem_get_vinfo(int evp_type)
1619
0
{
1620
0
    (void)CRYPTO_THREAD_run_once(&ml_kem_ntt_once, ml_kem_ntt_init);
1621
1622
0
    switch (evp_type) {
1623
0
    case EVP_PKEY_ML_KEM_512:
1624
0
        return &vinfo_map[ML_KEM_512_VINFO];
1625
0
    case EVP_PKEY_ML_KEM_768:
1626
0
        return &vinfo_map[ML_KEM_768_VINFO];
1627
0
    case EVP_PKEY_ML_KEM_1024:
1628
0
        return &vinfo_map[ML_KEM_1024_VINFO];
1629
0
    }
1630
0
    return NULL;
1631
0
}
1632
1633
/*
1634
 * @brief Fetch digest algorithms based on a propq.
1635
 * For the import case ossl_ml_kem_key_new() gets passed a NULL propq,
1636
 * so the propq is optionally deferred to the import using OSSL_PARAM.
1637
 */
1638
int ossl_ml_kem_key_fetch_digest(ML_KEM_KEY *key, const char *propq)
1639
0
{
1640
0
    if (key->shake128_md != NULL) {
1641
0
        EVP_MD_free(key->shake128_md);
1642
0
        EVP_MD_free(key->shake256_md);
1643
0
        EVP_MD_free(key->sha3_256_md);
1644
0
        EVP_MD_free(key->sha3_512_md);
1645
0
    }
1646
0
    key->shake128_md = EVP_MD_fetch(key->libctx, "SHAKE128", propq);
1647
0
    key->shake256_md = EVP_MD_fetch(key->libctx, "SHAKE256", propq);
1648
0
    key->sha3_256_md = EVP_MD_fetch(key->libctx, "SHA3-256", propq);
1649
0
    key->sha3_512_md = EVP_MD_fetch(key->libctx, "SHA3-512", propq);
1650
0
    return (key->shake128_md != NULL && key->shake256_md != NULL
1651
0
        && key->sha3_256_md != NULL && key->sha3_512_md != NULL);
1652
0
}
1653
1654
ML_KEM_KEY *ossl_ml_kem_key_new(OSSL_LIB_CTX *libctx, const char *properties,
1655
    int evp_type)
1656
0
{
1657
0
    const ML_KEM_VINFO *vinfo = ossl_ml_kem_get_vinfo(evp_type);
1658
0
    ML_KEM_KEY *key;
1659
1660
0
    if (vinfo == NULL) {
1661
0
        ERR_raise_data(ERR_LIB_CRYPTO, ERR_R_PASSED_INVALID_ARGUMENT,
1662
0
            "unsupported ML-KEM key type: %d", evp_type);
1663
0
        return NULL;
1664
0
    }
1665
1666
0
    if ((key = OPENSSL_malloc(sizeof(*key))) == NULL)
1667
0
        return NULL;
1668
1669
0
    key->vinfo = vinfo;
1670
0
    key->libctx = libctx;
1671
0
    key->prov_flags = ML_KEM_KEY_PROV_FLAGS_DEFAULT;
1672
0
    key->d = key->z = key->rho = key->pkhash = key->encoded_dk = key->seedbuf = NULL;
1673
0
    key->s = key->m = key->t = NULL;
1674
0
    key->shake128_md = key->shake256_md = key->sha3_256_md = key->sha3_512_md = NULL;
1675
0
    if (ossl_ml_kem_key_fetch_digest(key, properties))
1676
0
        return key;
1677
1678
0
    ossl_ml_kem_key_free(key);
1679
0
    ERR_raise_data(ERR_LIB_CRYPTO, ERR_R_INTERNAL_ERROR,
1680
0
        "missing SHA3 digest algorithms while creating %s key",
1681
0
        vinfo->algorithm_name);
1682
0
    return NULL;
1683
0
}
1684
1685
ML_KEM_KEY *ossl_ml_kem_key_dup(const ML_KEM_KEY *key, int selection)
1686
0
{
1687
0
    int ok = 0;
1688
0
    ML_KEM_KEY *ret;
1689
0
    void *tmp_pub;
1690
0
    void *tmp_priv;
1691
1692
0
    if (key == NULL)
1693
0
        return NULL;
1694
    /*
1695
     * Partially decoded keys, not yet imported or loaded, should never be
1696
     * duplicated.
1697
     */
1698
0
    if (ossl_ml_kem_decoded_key(key))
1699
0
        return NULL;
1700
1701
0
    else if ((ret = OPENSSL_memdup(key, sizeof(*key))) == NULL)
1702
0
        return NULL;
1703
1704
0
    ret->d = ret->z = ret->rho = ret->pkhash = NULL;
1705
0
    ret->s = ret->m = ret->t = NULL;
1706
1707
    /* Clear selection bits we can't fulfill */
1708
0
    if (!ossl_ml_kem_have_pubkey(key))
1709
0
        selection = 0;
1710
0
    else if (!ossl_ml_kem_have_prvkey(key))
1711
0
        selection &= ~OSSL_KEYMGMT_SELECT_PRIVATE_KEY;
1712
1713
0
    switch (selection & OSSL_KEYMGMT_SELECT_KEYPAIR) {
1714
0
    case 0:
1715
0
        ok = 1;
1716
0
        break;
1717
0
    case OSSL_KEYMGMT_SELECT_PUBLIC_KEY:
1718
0
        ok = add_storage(OPENSSL_memdup(key->t, key->vinfo->puballoc), NULL, 0, ret);
1719
0
        break;
1720
0
    case OSSL_KEYMGMT_SELECT_PRIVATE_KEY:
1721
0
        tmp_pub = OPENSSL_memdup(key->t, key->vinfo->puballoc);
1722
0
        if (tmp_pub == NULL)
1723
0
            break;
1724
0
        tmp_priv = OPENSSL_secure_malloc(key->vinfo->prvalloc);
1725
0
        if (tmp_priv == NULL) {
1726
0
            OPENSSL_free(tmp_pub);
1727
0
            break;
1728
0
        }
1729
0
        if ((ok = add_storage(tmp_pub, tmp_priv, 1, ret)) != 0)
1730
0
            memcpy(tmp_priv, key->s, key->vinfo->prvalloc);
1731
        /* Duplicated keys retain |d|, if available */
1732
0
        if (key->d != NULL)
1733
0
            ret->d = ret->z + ML_KEM_RANDOM_BYTES;
1734
0
        break;
1735
0
    }
1736
1737
0
    if (!ok) {
1738
0
        OPENSSL_free(ret);
1739
0
        return NULL;
1740
0
    }
1741
1742
0
    EVP_MD_up_ref(ret->shake128_md);
1743
0
    EVP_MD_up_ref(ret->shake256_md);
1744
0
    EVP_MD_up_ref(ret->sha3_256_md);
1745
0
    EVP_MD_up_ref(ret->sha3_512_md);
1746
1747
0
    return ret;
1748
0
}
1749
1750
void ossl_ml_kem_key_free(ML_KEM_KEY *key)
1751
0
{
1752
0
    if (key == NULL)
1753
0
        return;
1754
1755
0
    EVP_MD_free(key->shake128_md);
1756
0
    EVP_MD_free(key->shake256_md);
1757
0
    EVP_MD_free(key->sha3_256_md);
1758
0
    EVP_MD_free(key->sha3_512_md);
1759
1760
0
    ossl_ml_kem_key_reset(key);
1761
0
    OPENSSL_free(key);
1762
0
}
1763
1764
/* Serialise the public component of an ML-KEM key */
1765
int ossl_ml_kem_encode_public_key(uint8_t *out, size_t len,
1766
    const ML_KEM_KEY *key)
1767
0
{
1768
0
    if (!ossl_ml_kem_have_pubkey(key)
1769
0
        || len != key->vinfo->pubkey_bytes)
1770
0
        return 0;
1771
0
    encode_pubkey(out, key);
1772
0
    return 1;
1773
0
}
1774
1775
/* Serialise an ML-KEM private key */
1776
int ossl_ml_kem_encode_private_key(uint8_t *out, size_t len,
1777
    const ML_KEM_KEY *key)
1778
0
{
1779
0
    if (!ossl_ml_kem_have_prvkey(key)
1780
0
        || len != key->vinfo->prvkey_bytes)
1781
0
        return 0;
1782
0
    encode_prvkey(out, key);
1783
0
    return 1;
1784
0
}
1785
1786
int ossl_ml_kem_encode_seed(uint8_t *out, size_t len,
1787
    const ML_KEM_KEY *key)
1788
0
{
1789
0
    if (key == NULL || key->d == NULL || len != ML_KEM_SEED_BYTES)
1790
0
        return 0;
1791
    /*
1792
     * Both in the seed buffer, and in the allocated storage, the |d| component
1793
     * of the seed is stored last, so we must copy each separately.
1794
     */
1795
0
    memcpy(out, key->d, ML_KEM_RANDOM_BYTES);
1796
0
    out += ML_KEM_RANDOM_BYTES;
1797
0
    memcpy(out, key->z, ML_KEM_RANDOM_BYTES);
1798
0
    return 1;
1799
0
}
1800
1801
/*
1802
 * Stash the seed without (yet) performing a keygen, used during decoding, to
1803
 * avoid an extra keygen if we're only going to export the key again to load
1804
 * into another provider.
1805
 */
1806
ML_KEM_KEY *ossl_ml_kem_set_seed(const uint8_t *seed, size_t seedlen, ML_KEM_KEY *key)
1807
0
{
1808
0
    if (key == NULL
1809
0
        || ossl_ml_kem_have_pubkey(key)
1810
0
        || ossl_ml_kem_have_seed(key)
1811
0
        || seedlen != ML_KEM_SEED_BYTES)
1812
0
        return NULL;
1813
1814
0
    if (key->seedbuf == NULL) {
1815
0
        key->seedbuf = OPENSSL_secure_malloc(seedlen);
1816
0
        if (key->seedbuf == NULL)
1817
0
            return NULL;
1818
0
    }
1819
1820
0
    key->z = key->seedbuf;
1821
0
    key->d = key->z + ML_KEM_RANDOM_BYTES;
1822
0
    memcpy(key->d, seed, ML_KEM_RANDOM_BYTES);
1823
0
    seed += ML_KEM_RANDOM_BYTES;
1824
0
    memcpy(key->z, seed, ML_KEM_RANDOM_BYTES);
1825
0
    return key;
1826
0
}
1827
1828
/* Parse input as a public key */
1829
int ossl_ml_kem_parse_public_key(const uint8_t *in, size_t len, ML_KEM_KEY *key)
1830
0
{
1831
0
    EVP_MD_CTX *mdctx = NULL;
1832
0
    const ML_KEM_VINFO *vinfo;
1833
0
    int ret = 0;
1834
1835
    /* Keys with key material are immutable */
1836
0
    if (key == NULL
1837
0
        || ossl_ml_kem_have_pubkey(key)
1838
0
        || ossl_ml_kem_have_dkenc(key))
1839
0
        return 0;
1840
0
    vinfo = key->vinfo;
1841
1842
0
    if (len != vinfo->pubkey_bytes
1843
0
        || (mdctx = EVP_MD_CTX_new()) == NULL)
1844
0
        return 0;
1845
1846
0
    if (add_storage(OPENSSL_malloc(vinfo->puballoc), NULL, 0, key))
1847
0
        ret = parse_pubkey(in, mdctx, key);
1848
1849
0
    if (!ret)
1850
0
        ossl_ml_kem_key_reset(key);
1851
0
    EVP_MD_CTX_free(mdctx);
1852
0
    return ret;
1853
0
}
1854
1855
/* Parse input as a new private key */
1856
int ossl_ml_kem_parse_private_key(const uint8_t *in, size_t len,
1857
    ML_KEM_KEY *key)
1858
0
{
1859
0
    EVP_MD_CTX *mdctx = NULL;
1860
0
    const ML_KEM_VINFO *vinfo;
1861
0
    int ret = 0;
1862
1863
    /* Keys with key material are immutable */
1864
0
    if (key == NULL
1865
0
        || ossl_ml_kem_have_pubkey(key)
1866
0
        || ossl_ml_kem_have_dkenc(key))
1867
0
        return 0;
1868
0
    vinfo = key->vinfo;
1869
1870
0
    if (len != vinfo->prvkey_bytes
1871
0
        || (mdctx = EVP_MD_CTX_new()) == NULL)
1872
0
        return 0;
1873
1874
0
    if (add_storage(OPENSSL_malloc(vinfo->puballoc),
1875
0
            OPENSSL_secure_malloc(vinfo->prvalloc), 1, key))
1876
0
        ret = parse_prvkey(in, mdctx, key);
1877
1878
0
    if (!ret)
1879
0
        ossl_ml_kem_key_reset(key);
1880
0
    EVP_MD_CTX_free(mdctx);
1881
0
    return ret;
1882
0
}
1883
1884
/*
1885
 * Generate a new keypair, either from the saved seed (when non-null), or from
1886
 * the RNG.
1887
 */
1888
int ossl_ml_kem_genkey(uint8_t *pubenc, size_t publen, ML_KEM_KEY *key)
1889
0
{
1890
0
    uint8_t seed[ML_KEM_SEED_BYTES];
1891
0
    EVP_MD_CTX *mdctx = NULL;
1892
0
    const ML_KEM_VINFO *vinfo;
1893
0
    int ret = 0;
1894
1895
0
    if (key == NULL
1896
0
        || ossl_ml_kem_have_pubkey(key)
1897
0
        || ossl_ml_kem_have_dkenc(key))
1898
0
        return 0;
1899
0
    vinfo = key->vinfo;
1900
1901
0
    if (pubenc != NULL && publen != vinfo->pubkey_bytes)
1902
0
        return 0;
1903
1904
0
    if (key->seedbuf != NULL) {
1905
0
        if (!ossl_ml_kem_encode_seed(seed, sizeof(seed), key))
1906
0
            return 0;
1907
0
        ossl_ml_kem_key_reset(key);
1908
0
    } else if (RAND_priv_bytes_ex(key->libctx, seed, sizeof(seed),
1909
0
                   key->vinfo->secbits)
1910
0
        <= 0) {
1911
0
        return 0;
1912
0
    }
1913
1914
0
    if ((mdctx = EVP_MD_CTX_new()) == NULL)
1915
0
        return 0;
1916
1917
    /*
1918
     * Data derived from (d, z) defaults secret, and to avoid side-channel
1919
     * leaks should not influence control flow.
1920
     */
1921
0
    CONSTTIME_SECRET(seed, ML_KEM_SEED_BYTES);
1922
1923
0
    if (add_storage(OPENSSL_malloc(vinfo->puballoc),
1924
0
            OPENSSL_secure_malloc(vinfo->prvalloc), 1, key))
1925
0
        ret = genkey(seed, mdctx, pubenc, key);
1926
0
    OPENSSL_cleanse(seed, sizeof(seed));
1927
1928
    /* Declassify secret inputs and derived outputs before returning control */
1929
0
    CONSTTIME_DECLASSIFY(seed, ML_KEM_SEED_BYTES);
1930
1931
0
    EVP_MD_CTX_free(mdctx);
1932
0
    if (!ret) {
1933
0
        ossl_ml_kem_key_reset(key);
1934
0
        return 0;
1935
0
    }
1936
1937
    /* The public components are already declassified */
1938
0
    CONSTTIME_DECLASSIFY(key->s, vinfo->rank * sizeof(scalar));
1939
0
    CONSTTIME_DECLASSIFY(key->z, 2 * ML_KEM_RANDOM_BYTES);
1940
0
    return 1;
1941
0
}
1942
1943
/*
1944
 * FIPS 203, Section 6.2, Algorithm 17: ML-KEM.Encaps_internal
1945
 * This is the deterministic version with randomness supplied externally.
1946
 */
1947
int ossl_ml_kem_encap_seed(uint8_t *ctext, size_t clen,
1948
    uint8_t *shared_secret, size_t slen,
1949
    const uint8_t *entropy, size_t elen,
1950
    const ML_KEM_KEY *key)
1951
0
{
1952
0
    const ML_KEM_VINFO *vinfo;
1953
0
    EVP_MD_CTX *mdctx;
1954
0
    int ret = 0;
1955
1956
0
    if (key == NULL || !ossl_ml_kem_have_pubkey(key))
1957
0
        return 0;
1958
0
    vinfo = key->vinfo;
1959
1960
0
    if (ctext == NULL || clen != vinfo->ctext_bytes
1961
0
        || shared_secret == NULL || slen != ML_KEM_SHARED_SECRET_BYTES
1962
0
        || entropy == NULL || elen != ML_KEM_RANDOM_BYTES
1963
0
        || (mdctx = EVP_MD_CTX_new()) == NULL)
1964
0
        return 0;
1965
    /*
1966
     * Data derived from the encap entropy defaults secret, and to avoid
1967
     * side-channel leaks should not influence control flow.
1968
     */
1969
0
    CONSTTIME_SECRET(entropy, elen);
1970
1971
    /*-
1972
     * This avoids the need to handle allocation failures for two (max 2KB
1973
     * each) vectors, that are never retained on return from this function.
1974
     * We stack-allocate these.
1975
     */
1976
0
#define case_encap_seed(bits)                                        \
1977
0
    {                                                                \
1978
0
        scalar tmp[2 * ML_KEM_##bits##_RANK];                        \
1979
0
                                                                     \
1980
0
        ret = encap(ctext, shared_secret, entropy, tmp, mdctx, key); \
1981
0
        OPENSSL_cleanse((void *)tmp, sizeof(tmp));                   \
1982
0
    }
1983
0
    switch (vinfo->evp_type) {
1984
0
    case EVP_PKEY_ML_KEM_512:
1985
0
        case_encap_seed(512);
1986
0
        break;
1987
0
    case EVP_PKEY_ML_KEM_768:
1988
0
        case_encap_seed(768);
1989
0
        break;
1990
0
    case EVP_PKEY_ML_KEM_1024:
1991
0
        case_encap_seed(1024);
1992
0
        break;
1993
0
    }
1994
0
#undef case_encap_seed
1995
1996
    /* Declassify secret inputs and derived outputs before returning control */
1997
0
    CONSTTIME_DECLASSIFY(entropy, elen);
1998
0
    CONSTTIME_DECLASSIFY(ctext, clen);
1999
0
    CONSTTIME_DECLASSIFY(shared_secret, slen);
2000
2001
0
    EVP_MD_CTX_free(mdctx);
2002
0
    return ret;
2003
0
}
2004
2005
int ossl_ml_kem_encap_rand(uint8_t *ctext, size_t clen,
2006
    uint8_t *shared_secret, size_t slen,
2007
    const ML_KEM_KEY *key)
2008
0
{
2009
0
    uint8_t r[ML_KEM_RANDOM_BYTES];
2010
2011
0
    if (key == NULL)
2012
0
        return 0;
2013
2014
0
    if (RAND_bytes_ex(key->libctx, r, ML_KEM_RANDOM_BYTES,
2015
0
            key->vinfo->secbits)
2016
0
        < 1)
2017
0
        return 0;
2018
2019
0
    return ossl_ml_kem_encap_seed(ctext, clen, shared_secret, slen,
2020
0
        r, sizeof(r), key);
2021
0
}
2022
2023
int ossl_ml_kem_decap(uint8_t *shared_secret, size_t slen,
2024
    const uint8_t *ctext, size_t clen,
2025
    const ML_KEM_KEY *key)
2026
0
{
2027
0
    const ML_KEM_VINFO *vinfo;
2028
0
    EVP_MD_CTX *mdctx;
2029
0
    int ret = 0;
2030
#if defined(OPENSSL_CONSTANT_TIME_VALIDATION)
2031
    int classify_bytes;
2032
#endif
2033
2034
    /* Need a private key here */
2035
0
    if (!ossl_ml_kem_have_prvkey(key))
2036
0
        return 0;
2037
0
    vinfo = key->vinfo;
2038
2039
0
    if (shared_secret == NULL || slen != ML_KEM_SHARED_SECRET_BYTES
2040
0
        || ctext == NULL || clen != vinfo->ctext_bytes
2041
0
        || (mdctx = EVP_MD_CTX_new()) == NULL) {
2042
0
        (void)RAND_bytes_ex(key->libctx, shared_secret,
2043
0
            ML_KEM_SHARED_SECRET_BYTES, vinfo->secbits);
2044
0
        return 0;
2045
0
    }
2046
    /*
2047
     * Data derived from |s| and |z| defaults secret, and to avoid side-channel
2048
     * leaks should not influence control flow.
2049
     */
2050
#if defined(OPENSSL_CONSTANT_TIME_VALIDATION)
2051
    classify_bytes = vinfo->rank * sizeof(scalar) + ML_KEM_RANDOM_BYTES;
2052
#endif
2053
0
    CONSTTIME_SECRET(key->s, classify_bytes);
2054
2055
    /*-
2056
     * This avoids the need to handle allocation failures for two (max 2KB
2057
     * each) vectors and an encoded ciphertext (max 1568 bytes), that are never
2058
     * retained on return from this function.
2059
     * We stack-allocate these.
2060
     */
2061
0
#define case_decap(bits)                                          \
2062
0
    {                                                             \
2063
0
        uint8_t cbuf[CTEXT_BYTES(bits)];                          \
2064
0
        scalar tmp[2 * ML_KEM_##bits##_RANK];                     \
2065
0
                                                                  \
2066
0
        ret = decap(shared_secret, ctext, cbuf, tmp, mdctx, key); \
2067
0
        OPENSSL_cleanse((void *)tmp, sizeof(tmp));                \
2068
0
    }
2069
0
    switch (vinfo->evp_type) {
2070
0
    case EVP_PKEY_ML_KEM_512:
2071
0
        case_decap(512);
2072
0
        break;
2073
0
    case EVP_PKEY_ML_KEM_768:
2074
0
        case_decap(768);
2075
0
        break;
2076
0
    case EVP_PKEY_ML_KEM_1024:
2077
0
        case_decap(1024);
2078
0
        break;
2079
0
    }
2080
0
#undef case_decap
2081
2082
    /* Declassify secret inputs and derived outputs before returning control */
2083
0
    CONSTTIME_DECLASSIFY(key->s, classify_bytes);
2084
0
    CONSTTIME_DECLASSIFY(shared_secret, slen);
2085
0
    EVP_MD_CTX_free(mdctx);
2086
2087
0
    return ret;
2088
0
}
2089
2090
int ossl_ml_kem_pubkey_cmp(const ML_KEM_KEY *key1, const ML_KEM_KEY *key2)
2091
0
{
2092
    /*
2093
     * This handles any unexpected differences in the ML-KEM variant rank,
2094
     * giving different key component structures, barring SHA3-256 hash
2095
     * collisions, the keys are the same size.
2096
     */
2097
0
    if (ossl_ml_kem_have_pubkey(key1) && ossl_ml_kem_have_pubkey(key2))
2098
0
        return memcmp(key1->pkhash, key2->pkhash, ML_KEM_PKHASH_BYTES) == 0;
2099
2100
    /*
2101
     * No match if just one of the public keys is not available, otherwise both
2102
     * are unavailable, and for now such keys are considered equal.
2103
     */
2104
0
    return (!(ossl_ml_kem_have_pubkey(key1) ^ ossl_ml_kem_have_pubkey(key2)));
2105
0
}