Coverage Report

Created: 2026-04-01 06:39

next uncovered line (L), next uncovered region (R), next uncovered branch (B)
/src/openssl35/crypto/ml_kem/ml_kem.c
Line
Count
Source
1
/*
2
 * Copyright 2024-2025 The OpenSSL Project Authors. All Rights Reserved.
3
 *
4
 * Licensed under the Apache License 2.0 (the "License").  You may not use
5
 * this file except in compliance with the License.  You can obtain a copy
6
 * in the file LICENSE in the source distribution or at
7
 * https://www.openssl.org/source/license.html
8
 */
9
10
#include <openssl/byteorder.h>
11
#include <openssl/rand.h>
12
#include <openssl/proverr.h>
13
#include "crypto/ml_kem.h"
14
#include "internal/common.h"
15
#include "internal/constant_time.h"
16
#include "internal/sha3.h"
17
18
#if defined(OPENSSL_CONSTANT_TIME_VALIDATION)
19
#include <valgrind/memcheck.h>
20
#endif
21
22
#if ML_KEM_SEED_BYTES != ML_KEM_SHARED_SECRET_BYTES + ML_KEM_RANDOM_BYTES
23
#error "ML-KEM keygen seed length != shared secret + random bytes length"
24
#endif
25
#if ML_KEM_SHARED_SECRET_BYTES != ML_KEM_RANDOM_BYTES
26
#error "Invalid unequal lengths of ML-KEM shared secret and random inputs"
27
#endif
28
29
#if UINT_MAX < UINT32_MAX
30
#error "Unsupported compiler: sizeof(unsigned int) < sizeof(uint32_t)"
31
#endif
32
33
/* Handy function-like bit-extraction macros */
34
39.8M
#define bit0(b) ((b) & 1)
35
278M
#define bitn(n, b) (((b) >> n) & 1)
36
37
/*
38
 * 12 bits are sufficient to losslessly represent values in [0, q-1].
39
 * INVERSE_DEGREE is (n/2)^-1 mod q; used in inverse NTT.
40
 */
41
3.20M
#define DEGREE ML_KEM_DEGREE
42
#define INVERSE_DEGREE (ML_KEM_PRIME - 2 * 13)
43
#define LOG2PRIME 12
44
#define BARRETT_SHIFT (2 * LOG2PRIME)
45
46
#ifdef SHA3_BLOCKSIZE
47
#define SHAKE128_BLOCKSIZE SHA3_BLOCKSIZE(128)
48
#endif
49
50
/*
51
 * Return whether a value that can only be 0 or 1 is non-zero, in constant time
52
 * in practice!  The return value is a mask that is all ones if true, and all
53
 * zeros otherwise (twos-complement arithmetic assumed for unsigned values).
54
 *
55
 * Although this is used in constant-time selects, we omit a value barrier
56
 * here.  Value barriers impede auto-vectorization (likely because it forces
57
 * the value to transit through a general-purpose register). On AArch64, this
58
 * is a difference of 2x.
59
 *
60
 * We usually add value barriers to selects because Clang turns consecutive
61
 * selects with the same condition into a branch instead of CMOV/CSEL. This
62
 * condition does not occur in Kyber, so omitting it seems to be safe so far,
63
 * but see |cbd_2|, |cbd_3|, where reduction needs to be specialised to the
64
 * sign of the input, rather than adding |q| in advance, and using the generic
65
 * |reduce_once|.  (David Benjamin, Chromium)
66
 */
67
#if 0
68
#define constish_time_non_zero(b) (~constant_time_is_zero(b));
69
#else
70
1.09G
#define constish_time_non_zero(b) (0u - (b))
71
#endif
72
73
/*
74
 * The scalar rejection-sampling buffer size needs to be a multiple of 12, but
75
 * is otherwise arbitrary, the preferred block size matches the internal buffer
76
 * size of SHAKE128, avoiding internal buffering and copying in SHAKE128. That
77
 * block size of (1600 - 256)/8 bytes, or 168, just happens to divide by 12!
78
 *
79
 * If the blocksize is unknown, or is not divisible by 12, 168 is used as a
80
 * fallback.
81
 */
82
#if defined(SHAKE128_BLOCKSIZE) && (SHAKE128_BLOCKSIZE) % 12 == 0
83
#define SCALAR_SAMPLING_BUFSIZE (SHAKE128_BLOCKSIZE)
84
#else
85
#define SCALAR_SAMPLING_BUFSIZE 168
86
#endif
87
88
/*
89
 * Structure of keys
90
 */
91
typedef struct ossl_ml_kem_scalar_st {
92
    /* On every function entry and exit, 0 <= c[i] < ML_KEM_PRIME. */
93
    uint16_t c[ML_KEM_DEGREE];
94
} scalar;
95
96
/* Key material allocation layout */
97
#define DECLARE_ML_KEM_KEYDATA(name, rank, private_sz)               \
98
    struct name##_alloc {                                            \
99
        /* Public vector |t| */                                      \
100
        scalar tbuf[(rank)];                                         \
101
        /* Pre-computed matrix |m| (FIPS 203 |A| transpose) */       \
102
        scalar mbuf[(rank) * (rank)] /* optional private key data */ \
103
            private_sz                                               \
104
    }
105
106
/* Declare variant-specific public and private storage */
107
#define DECLARE_ML_KEM_VARIANT_KEYDATA(bits)                        \
108
    DECLARE_ML_KEM_KEYDATA(pubkey_##bits, ML_KEM_##bits##_RANK, ;); \
109
    DECLARE_ML_KEM_KEYDATA(prvkey_##bits, ML_KEM_##bits##_RANK, ; scalar sbuf[ML_KEM_##bits##_RANK]; uint8_t zbuf[2 * ML_KEM_RANDOM_BYTES];)
110
DECLARE_ML_KEM_VARIANT_KEYDATA(512);
111
DECLARE_ML_KEM_VARIANT_KEYDATA(768);
112
DECLARE_ML_KEM_VARIANT_KEYDATA(1024);
113
#undef DECLARE_ML_KEM_VARIANT_KEYDATA
114
#undef DECLARE_ML_KEM_KEYDATA
115
116
typedef __owur int (*CBD_FUNC)(scalar *out, uint8_t in[ML_KEM_RANDOM_BYTES + 1],
117
    EVP_MD_CTX *mdctx, const ML_KEM_KEY *key);
118
static void scalar_encode(uint8_t *out, const scalar *s, int bits);
119
120
/*
121
 * The wire-form of a losslessly encoded vector uses 12-bits per element.
122
 *
123
 * The wire-form public key consists of the lossless encoding of the public
124
 * vector |t|, followed by the public seed |rho|.
125
 *
126
 * Our serialised private key concatenates serialisations of the private vector
127
 * |s|, the public key, the public key hash, and the failure secret |z|.
128
 */
129
#define VECTOR_BYTES(b) ((3 * DEGREE / 2) * ML_KEM_##b##_RANK)
130
#define PUBKEY_BYTES(b) (VECTOR_BYTES(b) + ML_KEM_RANDOM_BYTES)
131
#define PRVKEY_BYTES(b) (2 * PUBKEY_BYTES(b) + ML_KEM_PKHASH_BYTES)
132
133
/*
134
 * Encapsulation produces a vector "u" and a scalar "v", whose coordinates
135
 * (numbers modulo the ML-KEM prime "q") are lossily encoded using as "du" and
136
 * "dv" bits, respectively.  This encoding is the ciphertext input for
137
 * decapsulation.
138
 */
139
#define U_VECTOR_BYTES(b) ((DEGREE / 8) * ML_KEM_##b##_DU * ML_KEM_##b##_RANK)
140
#define V_SCALAR_BYTES(b) ((DEGREE / 8) * ML_KEM_##b##_DV)
141
#define CTEXT_BYTES(b) (U_VECTOR_BYTES(b) + V_SCALAR_BYTES(b))
142
143
#if defined(OPENSSL_CONSTANT_TIME_VALIDATION)
144
145
/*
146
 * CONSTTIME_SECRET takes a pointer and a number of bytes and marks that region
147
 * of memory as secret. Secret data is tracked as it flows to registers and
148
 * other parts of a memory. If secret data is used as a condition for a branch,
149
 * or as a memory index, it will trigger warnings in valgrind.
150
 */
151
#define CONSTTIME_SECRET(ptr, len) VALGRIND_MAKE_MEM_UNDEFINED(ptr, len)
152
153
/*
154
 * CONSTTIME_DECLASSIFY takes a pointer and a number of bytes and marks that
155
 * region of memory as public. Public data is not subject to constant-time
156
 * rules.
157
 */
158
#define CONSTTIME_DECLASSIFY(ptr, len) VALGRIND_MAKE_MEM_DEFINED(ptr, len)
159
160
#else
161
162
#define CONSTTIME_SECRET(ptr, len)
163
#define CONSTTIME_DECLASSIFY(ptr, len)
164
165
#endif
166
167
/*
168
 * Indices of slots in the vinfo tables below
169
 */
170
60.5k
#define ML_KEM_512_VINFO 0
171
162k
#define ML_KEM_768_VINFO 1
172
57.0k
#define ML_KEM_1024_VINFO 2
173
174
/*
175
 * Per-variant fixed parameters
176
 */
177
static const ML_KEM_VINFO vinfo_map[3] = {
178
    { "ML-KEM-512",
179
        PRVKEY_BYTES(512),
180
        sizeof(struct prvkey_512_alloc),
181
        PUBKEY_BYTES(512),
182
        sizeof(struct pubkey_512_alloc),
183
        CTEXT_BYTES(512),
184
        VECTOR_BYTES(512),
185
        U_VECTOR_BYTES(512),
186
        EVP_PKEY_ML_KEM_512,
187
        ML_KEM_512_BITS,
188
        ML_KEM_512_RANK,
189
        ML_KEM_512_DU,
190
        ML_KEM_512_DV,
191
        ML_KEM_512_SECBITS },
192
    { "ML-KEM-768",
193
        PRVKEY_BYTES(768),
194
        sizeof(struct prvkey_768_alloc),
195
        PUBKEY_BYTES(768),
196
        sizeof(struct pubkey_768_alloc),
197
        CTEXT_BYTES(768),
198
        VECTOR_BYTES(768),
199
        U_VECTOR_BYTES(768),
200
        EVP_PKEY_ML_KEM_768,
201
        ML_KEM_768_BITS,
202
        ML_KEM_768_RANK,
203
        ML_KEM_768_DU,
204
        ML_KEM_768_DV,
205
        ML_KEM_768_SECBITS },
206
    { "ML-KEM-1024",
207
        PRVKEY_BYTES(1024),
208
        sizeof(struct prvkey_1024_alloc),
209
        PUBKEY_BYTES(1024),
210
        sizeof(struct pubkey_1024_alloc),
211
        CTEXT_BYTES(1024),
212
        VECTOR_BYTES(1024),
213
        U_VECTOR_BYTES(1024),
214
        EVP_PKEY_ML_KEM_1024,
215
        ML_KEM_1024_BITS,
216
        ML_KEM_1024_RANK,
217
        ML_KEM_1024_DU,
218
        ML_KEM_1024_DV,
219
        ML_KEM_1024_SECBITS }
220
};
221
222
/*
223
 * Remainders modulo `kPrime`, for sufficiently small inputs, are computed in
224
 * constant time via Barrett reduction, and a final call to reduce_once(),
225
 * which reduces inputs that are at most 2*kPrime and is also constant-time.
226
 */
227
static const int kPrime = ML_KEM_PRIME;
228
static const unsigned int kBarrettShift = BARRETT_SHIFT;
229
static const size_t kBarrettMultiplier = (1 << BARRETT_SHIFT) / ML_KEM_PRIME;
230
static const uint16_t kHalfPrime = (ML_KEM_PRIME - 1) / 2;
231
static const uint16_t kInverseDegree = INVERSE_DEGREE;
232
233
/*
234
 * Python helper:
235
 *
236
 * p = 3329
237
 * def bitreverse(i):
238
 *     ret = 0
239
 *     for n in range(7):
240
 *         bit = i & 1
241
 *         ret <<= 1
242
 *         ret |= bit
243
 *         i >>= 1
244
 *     return ret
245
 */
246
247
/*-
248
 * First precomputed array from Appendix A of FIPS 203, or else Python:
249
 * kNTTRoots = [pow(17, bitreverse(i), p) for i in range(128)]
250
 */
251
static const uint16_t kNTTRoots[128] = {
252
    1,
253
    1729,
254
    2580,
255
    3289,
256
    2642,
257
    630,
258
    1897,
259
    848,
260
    1062,
261
    1919,
262
    193,
263
    797,
264
    2786,
265
    3260,
266
    569,
267
    1746,
268
    296,
269
    2447,
270
    1339,
271
    1476,
272
    3046,
273
    56,
274
    2240,
275
    1333,
276
    1426,
277
    2094,
278
    535,
279
    2882,
280
    2393,
281
    2879,
282
    1974,
283
    821,
284
    289,
285
    331,
286
    3253,
287
    1756,
288
    1197,
289
    2304,
290
    2277,
291
    2055,
292
    650,
293
    1977,
294
    2513,
295
    632,
296
    2865,
297
    33,
298
    1320,
299
    1915,
300
    2319,
301
    1435,
302
    807,
303
    452,
304
    1438,
305
    2868,
306
    1534,
307
    2402,
308
    2647,
309
    2617,
310
    1481,
311
    648,
312
    2474,
313
    3110,
314
    1227,
315
    910,
316
    17,
317
    2761,
318
    583,
319
    2649,
320
    1637,
321
    723,
322
    2288,
323
    1100,
324
    1409,
325
    2662,
326
    3281,
327
    233,
328
    756,
329
    2156,
330
    3015,
331
    3050,
332
    1703,
333
    1651,
334
    2789,
335
    1789,
336
    1847,
337
    952,
338
    1461,
339
    2687,
340
    939,
341
    2308,
342
    2437,
343
    2388,
344
    733,
345
    2337,
346
    268,
347
    641,
348
    1584,
349
    2298,
350
    2037,
351
    3220,
352
    375,
353
    2549,
354
    2090,
355
    1645,
356
    1063,
357
    319,
358
    2773,
359
    757,
360
    2099,
361
    561,
362
    2466,
363
    2594,
364
    2804,
365
    1092,
366
    403,
367
    1026,
368
    1143,
369
    2150,
370
    2775,
371
    886,
372
    1722,
373
    1212,
374
    1874,
375
    1029,
376
    2110,
377
    2935,
378
    885,
379
    2154,
380
};
381
382
/*
383
 * InverseNTTRoots = [pow(17, -bitreverse(i), p) for i in range(128)]
384
 * Listed in order of use in the inverse NTT loop (index 0 is skipped):
385
 *
386
 *  0, 64, 65, ..., 127, 32, 33, ..., 63, 16, 17, ..., 31, 8, 9, ...
387
 */
388
static const uint16_t kInverseNTTRoots[128] = {
389
    1,
390
    1175,
391
    2444,
392
    394,
393
    1219,
394
    2300,
395
    1455,
396
    2117,
397
    1607,
398
    2443,
399
    554,
400
    1179,
401
    2186,
402
    2303,
403
    2926,
404
    2237,
405
    525,
406
    735,
407
    863,
408
    2768,
409
    1230,
410
    2572,
411
    556,
412
    3010,
413
    2266,
414
    1684,
415
    1239,
416
    780,
417
    2954,
418
    109,
419
    1292,
420
    1031,
421
    1745,
422
    2688,
423
    3061,
424
    992,
425
    2596,
426
    941,
427
    892,
428
    1021,
429
    2390,
430
    642,
431
    1868,
432
    2377,
433
    1482,
434
    1540,
435
    540,
436
    1678,
437
    1626,
438
    279,
439
    314,
440
    1173,
441
    2573,
442
    3096,
443
    48,
444
    667,
445
    1920,
446
    2229,
447
    1041,
448
    2606,
449
    1692,
450
    680,
451
    2746,
452
    568,
453
    3312,
454
    2419,
455
    2102,
456
    219,
457
    855,
458
    2681,
459
    1848,
460
    712,
461
    682,
462
    927,
463
    1795,
464
    461,
465
    1891,
466
    2877,
467
    2522,
468
    1894,
469
    1010,
470
    1414,
471
    2009,
472
    3296,
473
    464,
474
    2697,
475
    816,
476
    1352,
477
    2679,
478
    1274,
479
    1052,
480
    1025,
481
    2132,
482
    1573,
483
    76,
484
    2998,
485
    3040,
486
    2508,
487
    1355,
488
    450,
489
    936,
490
    447,
491
    2794,
492
    1235,
493
    1903,
494
    1996,
495
    1089,
496
    3273,
497
    283,
498
    1853,
499
    1990,
500
    882,
501
    3033,
502
    1583,
503
    2760,
504
    69,
505
    543,
506
    2532,
507
    3136,
508
    1410,
509
    2267,
510
    2481,
511
    1432,
512
    2699,
513
    687,
514
    40,
515
    749,
516
    1600,
517
};
518
519
/*
520
 * Second precomputed array from Appendix A of FIPS 203 (normalised positive),
521
 * or else Python:
522
 * ModRoots = [pow(17, 2*bitreverse(i) + 1, p) for i in range(128)]
523
 */
524
static const uint16_t kModRoots[128] = {
525
    17,
526
    3312,
527
    2761,
528
    568,
529
    583,
530
    2746,
531
    2649,
532
    680,
533
    1637,
534
    1692,
535
    723,
536
    2606,
537
    2288,
538
    1041,
539
    1100,
540
    2229,
541
    1409,
542
    1920,
543
    2662,
544
    667,
545
    3281,
546
    48,
547
    233,
548
    3096,
549
    756,
550
    2573,
551
    2156,
552
    1173,
553
    3015,
554
    314,
555
    3050,
556
    279,
557
    1703,
558
    1626,
559
    1651,
560
    1678,
561
    2789,
562
    540,
563
    1789,
564
    1540,
565
    1847,
566
    1482,
567
    952,
568
    2377,
569
    1461,
570
    1868,
571
    2687,
572
    642,
573
    939,
574
    2390,
575
    2308,
576
    1021,
577
    2437,
578
    892,
579
    2388,
580
    941,
581
    733,
582
    2596,
583
    2337,
584
    992,
585
    268,
586
    3061,
587
    641,
588
    2688,
589
    1584,
590
    1745,
591
    2298,
592
    1031,
593
    2037,
594
    1292,
595
    3220,
596
    109,
597
    375,
598
    2954,
599
    2549,
600
    780,
601
    2090,
602
    1239,
603
    1645,
604
    1684,
605
    1063,
606
    2266,
607
    319,
608
    3010,
609
    2773,
610
    556,
611
    757,
612
    2572,
613
    2099,
614
    1230,
615
    561,
616
    2768,
617
    2466,
618
    863,
619
    2594,
620
    735,
621
    2804,
622
    525,
623
    1092,
624
    2237,
625
    403,
626
    2926,
627
    1026,
628
    2303,
629
    1143,
630
    2186,
631
    2150,
632
    1179,
633
    2775,
634
    554,
635
    886,
636
    2443,
637
    1722,
638
    1607,
639
    1212,
640
    2117,
641
    1874,
642
    1455,
643
    1029,
644
    2300,
645
    2110,
646
    1219,
647
    2935,
648
    394,
649
    885,
650
    2444,
651
    2154,
652
    1175,
653
};
654
655
/*
656
 * single_keccak hashes |inlen| bytes from |in| and writes |outlen| bytes of
657
 * output to |out|. If the |md| specifies a fixed-output function, like
658
 * SHA3-256, then |outlen| must be the correct length for that function.
659
 */
660
static __owur int single_keccak(uint8_t *out, size_t outlen, const uint8_t *in, size_t inlen,
661
    EVP_MD_CTX *mdctx)
662
361k
{
663
361k
    unsigned int sz = (unsigned int)outlen;
664
665
361k
    if (!EVP_DigestUpdate(mdctx, in, inlen))
666
0
        return 0;
667
361k
    if (EVP_MD_xof(EVP_MD_CTX_get0_md(mdctx)))
668
310k
        return EVP_DigestFinalXOF(mdctx, out, outlen);
669
51.9k
    return EVP_DigestFinal_ex(mdctx, out, &sz)
670
51.9k
        && ossl_assert((size_t)sz == outlen);
671
361k
}
672
673
/*
674
 * FIPS 203, Section 4.1, equation (4.3): PRF. Takes 32+1 input bytes, and uses
675
 * SHAKE256 to produce the input to SamplePolyCBD_eta: FIPS 203, algorithm 8.
676
 */
677
static __owur int prf(uint8_t *out, size_t len, const uint8_t in[ML_KEM_RANDOM_BYTES + 1],
678
    EVP_MD_CTX *mdctx, const ML_KEM_KEY *key)
679
310k
{
680
310k
    return EVP_DigestInit_ex(mdctx, key->shake256_md, NULL)
681
310k
        && single_keccak(out, len, in, ML_KEM_RANDOM_BYTES + 1, mdctx);
682
310k
}
683
684
/*
685
 * FIPS 203, Section 4.1, equation (4.4): H.  SHA3-256 hash of a variable
686
 * length input, producing 32 bytes of output.
687
 */
688
static __owur int hash_h(uint8_t out[ML_KEM_PKHASH_BYTES], const uint8_t *in, size_t len,
689
    EVP_MD_CTX *mdctx, const ML_KEM_KEY *key)
690
325
{
691
325
    return EVP_DigestInit_ex(mdctx, key->sha3_256_md, NULL)
692
325
        && single_keccak(out, ML_KEM_PKHASH_BYTES, in, len, mdctx);
693
325
}
694
695
/* Incremental hash_h of expanded public key */
696
static int
697
hash_h_pubkey(uint8_t pkhash[ML_KEM_PKHASH_BYTES],
698
    EVP_MD_CTX *mdctx, ML_KEM_KEY *key)
699
51.3k
{
700
51.3k
    const ML_KEM_VINFO *vinfo = key->vinfo;
701
51.3k
    const scalar *t = key->t, *end = t + vinfo->rank;
702
51.3k
    unsigned int sz;
703
704
51.3k
    if (!EVP_DigestInit_ex(mdctx, key->sha3_256_md, NULL))
705
0
        return 0;
706
707
153k
    do {
708
153k
        uint8_t buf[3 * DEGREE / 2];
709
710
153k
        scalar_encode(buf, t++, 12);
711
153k
        if (!EVP_DigestUpdate(mdctx, buf, sizeof(buf)))
712
0
            return 0;
713
153k
    } while (t < end);
714
715
51.3k
    if (!EVP_DigestUpdate(mdctx, key->rho, ML_KEM_RANDOM_BYTES))
716
0
        return 0;
717
51.3k
    return EVP_DigestFinal_ex(mdctx, pkhash, &sz)
718
51.3k
        && ossl_assert(sz == ML_KEM_PKHASH_BYTES);
719
51.3k
}
720
721
/*
722
 * FIPS 203, Section 4.1, equation (4.5): G.  SHA3-512 hash of a variable
723
 * length input, producing 64 bytes of output, in particular the seeds
724
 * (d,z) for key generation.
725
 */
726
static __owur int hash_g(uint8_t out[ML_KEM_SEED_BYTES], const uint8_t *in, size_t len,
727
    EVP_MD_CTX *mdctx, const ML_KEM_KEY *key)
728
51.6k
{
729
51.6k
    return EVP_DigestInit_ex(mdctx, key->sha3_512_md, NULL)
730
51.6k
        && single_keccak(out, ML_KEM_SEED_BYTES, in, len, mdctx);
731
51.6k
}
732
733
/*
734
 * FIPS 203, Section 4.1, equation (4.4): J. SHAKE256 taking a variable length
735
 * input to compute a 32-byte implicit rejection shared secret, of the same
736
 * length as the expected shared secret.  (Computed even on success to avoid
737
 * side-channel leaks).
738
 */
739
static __owur int kdf(uint8_t out[ML_KEM_SHARED_SECRET_BYTES],
740
    const uint8_t z[ML_KEM_RANDOM_BYTES],
741
    const uint8_t *ctext, size_t len,
742
    EVP_MD_CTX *mdctx, const ML_KEM_KEY *key)
743
141
{
744
141
    return EVP_DigestInit_ex(mdctx, key->shake256_md, NULL)
745
141
        && EVP_DigestUpdate(mdctx, z, ML_KEM_RANDOM_BYTES)
746
141
        && EVP_DigestUpdate(mdctx, ctext, len)
747
141
        && EVP_DigestFinalXOF(mdctx, out, ML_KEM_SHARED_SECRET_BYTES);
748
141
}
749
750
/*
751
 * FIPS 203, Section 4.2.2, Algorithm 7: "SampleNTT" (steps 3-17, steps 1, 2
752
 * are performed by the caller). Rejection-samples a Keccak stream to get
753
 * uniformly distributed elements in the range [0,q). This is used for matrix
754
 * expansion and only operates on public inputs.
755
 */
756
static __owur int sample_scalar(scalar *out, EVP_MD_CTX *mdctx)
757
464k
{
758
464k
    uint16_t *curr = out->c, *endout = curr + DEGREE;
759
464k
    uint8_t buf[SCALAR_SAMPLING_BUFSIZE], *in;
760
464k
    uint8_t *endin = buf + sizeof(buf);
761
464k
    uint16_t d;
762
464k
    uint8_t b1, b2, b3;
763
764
1.39M
    do {
765
1.39M
        if (!EVP_DigestSqueeze(mdctx, in = buf, sizeof(buf)))
766
0
            return 0;
767
72.7M
        do {
768
72.7M
            b1 = *in++;
769
72.7M
            b2 = *in++;
770
72.7M
            b3 = *in++;
771
772
72.7M
            if (curr >= endout)
773
155k
                break;
774
72.6M
            if ((d = ((b2 & 0x0f) << 8) + b1) < kPrime)
775
59.9M
                *curr++ = d;
776
72.6M
            if (curr >= endout)
777
308k
                break;
778
72.3M
            if ((d = (b3 << 4) + (b2 >> 4)) < kPrime)
779
58.9M
                *curr++ = d;
780
72.3M
        } while (in < endin);
781
1.39M
    } while (curr < endout);
782
464k
    return 1;
783
464k
}
784
785
/*-
786
 * reduce_once reduces 0 <= x < 2*kPrime, mod kPrime.
787
 *
788
 * Subtract |q| if the input is larger, without exposing a side-channel,
789
 * avoiding the "clangover" attack.  See |constish_time_non_zero| for a
790
 * discussion on why the value barrier is by default omitted.
791
 */
792
static __owur uint16_t reduce_once(uint16_t x)
793
1.01G
{
794
1.01G
    const uint16_t subtracted = x - kPrime;
795
1.01G
    uint16_t mask = constish_time_non_zero(subtracted >> 15);
796
797
1.01G
    return (mask & x) | (~mask & subtracted);
798
1.01G
}
799
800
/*
801
 * Constant-time reduce x mod kPrime using Barrett reduction. x must be less
802
 * than kPrime + 2 * kPrime^2.  This is sufficient to reduce a product of
803
 * two already reduced u_int16 values, in fact it is sufficient for each
804
 * to be less than 2^12, because (kPrime * (2 * kPrime + 1)) > 2^24.
805
 */
806
static __owur uint16_t reduce(uint32_t x)
807
457M
{
808
457M
    uint64_t product = (uint64_t)x * kBarrettMultiplier;
809
457M
    uint32_t quotient = (uint32_t)(product >> kBarrettShift);
810
457M
    uint32_t remainder = x - quotient * kPrime;
811
812
457M
    return reduce_once(remainder);
813
457M
}
814
815
/* Multiply a scalar by a constant. */
816
static void scalar_mult_const(scalar *s, uint16_t a)
817
1.35k
{
818
1.35k
    uint16_t *curr = s->c, *end = curr + DEGREE, tmp;
819
820
346k
    do {
821
346k
        tmp = reduce(*curr * a);
822
346k
        *curr++ = tmp;
823
346k
    } while (curr < end);
824
1.35k
}
825
826
/*-
827
 * FIPS 203, Section 4.3, Algorithm 9: "NTT".
828
 * In-place number theoretic transform of a given scalar.  Note that ML-KEM's
829
 * kPrime 3329 does not have a 512th root of unity, so this transform leaves
830
 * off the last iteration of the usual FFT code, with the 128 relevant roots of
831
 * unity being stored in NTTRoots.  This means the output should be seen as 128
832
 * elements in GF(3329^2), with the coefficients of the elements being
833
 * consecutive entries in |s->c|.
834
 */
835
static void scalar_ntt(scalar *s)
836
309k
{
837
309k
    const uint16_t *roots = kNTTRoots;
838
309k
    uint16_t *end = s->c + DEGREE;
839
309k
    int offset = DEGREE / 2;
840
841
2.16M
    do {
842
2.16M
        uint16_t *curr = s->c, *peer;
843
844
39.2M
        do {
845
39.2M
            uint16_t *pause = curr + offset, even, odd;
846
39.2M
            uint32_t zeta = *++roots;
847
848
39.2M
            peer = pause;
849
277M
            do {
850
277M
                even = *curr;
851
277M
                odd = reduce(*peer * zeta);
852
277M
                *peer++ = reduce_once(even - odd + kPrime);
853
277M
                *curr++ = reduce_once(odd + even);
854
277M
            } while (curr < pause);
855
39.2M
        } while ((curr = peer) < end);
856
2.16M
    } while ((offset >>= 1) >= 2);
857
309k
}
858
859
/*-
860
 * FIPS 203, Section 4.3, Algorithm 10: "NTT^(-1)".
861
 * In-place inverse number theoretic transform of a given scalar, with pairs of
862
 * entries of s->v being interpreted as elements of GF(3329^2). Just as with
863
 * the number theoretic transform, this leaves off the first step of the normal
864
 * iFFT to account for the fact that 3329 does not have a 512th root of unity,
865
 * using the precomputed 128 roots of unity stored in InverseNTTRoots.
866
 */
867
static void scalar_inverse_ntt(scalar *s)
868
1.35k
{
869
1.35k
    const uint16_t *roots = kInverseNTTRoots;
870
1.35k
    uint16_t *end = s->c + DEGREE;
871
1.35k
    int offset = 2;
872
873
9.47k
    do {
874
9.47k
        uint16_t *curr = s->c, *peer;
875
876
171k
        do {
877
171k
            uint16_t *pause = curr + offset, even, odd;
878
171k
            uint32_t zeta = *++roots;
879
880
171k
            peer = pause;
881
1.21M
            do {
882
1.21M
                even = *curr;
883
1.21M
                odd = *peer;
884
1.21M
                *peer++ = reduce(zeta * (even - odd + kPrime));
885
1.21M
                *curr++ = reduce_once(odd + even);
886
1.21M
            } while (curr < pause);
887
171k
        } while ((curr = peer) < end);
888
9.47k
    } while ((offset <<= 1) < DEGREE);
889
1.35k
    scalar_mult_const(s, kInverseDegree);
890
1.35k
}
891
892
/* Addition updating the LHS scalar in-place. */
893
static void scalar_add(scalar *lhs, const scalar *rhs)
894
1.21k
{
895
1.21k
    int i;
896
897
311k
    for (i = 0; i < DEGREE; i++)
898
310k
        lhs->c[i] = reduce_once(lhs->c[i] + rhs->c[i]);
899
1.21k
}
900
901
/* Subtraction updating the LHS scalar in-place. */
902
static void scalar_sub(scalar *lhs, const scalar *rhs)
903
141
{
904
141
    int i;
905
906
36.2k
    for (i = 0; i < DEGREE; i++)
907
36.0k
        lhs->c[i] = reduce_once(lhs->c[i] - rhs->c[i] + kPrime);
908
141
}
909
910
/*
911
 * Multiplying two scalars in the number theoretically transformed state. Since
912
 * 3329 does not have a 512th root of unity, this means we have to interpret
913
 * the 2*ith and (2*i+1)th entries of the scalar as elements of
914
 * GF(3329)[X]/(X^2 - 17^(2*bitreverse(i)+1)).
915
 *
916
 * The value of 17^(2*bitreverse(i)+1) mod 3329 is stored in the precomputed
917
 * ModRoots table. Note that our Barrett transform only allows us to multiply
918
 * two reduced numbers together, so we need some intermediate reduction steps,
919
 * even if an uint64_t could hold 3 multiplied numbers.
920
 */
921
static void scalar_mult(scalar *out, const scalar *lhs,
922
    const scalar *rhs)
923
1.35k
{
924
1.35k
    uint16_t *curr = out->c, *end = curr + DEGREE;
925
1.35k
    const uint16_t *lc = lhs->c, *rc = rhs->c;
926
1.35k
    const uint16_t *roots = kModRoots;
927
928
173k
    do {
929
173k
        uint32_t l0 = *lc++, r0 = *rc++;
930
173k
        uint32_t l1 = *lc++, r1 = *rc++;
931
173k
        uint32_t zetapow = *roots++;
932
933
173k
        *curr++ = reduce(l0 * r0 + reduce(l1 * r1) * zetapow);
934
173k
        *curr++ = reduce(l0 * r1 + l1 * r0);
935
173k
    } while (curr < end);
936
1.35k
}
937
938
/* Above, but add the result to an existing scalar */
939
static ossl_inline void scalar_mult_add(scalar *out, const scalar *lhs,
940
    const scalar *rhs)
941
464k
{
942
464k
    uint16_t *curr = out->c, *end = curr + DEGREE;
943
464k
    const uint16_t *lc = lhs->c, *rc = rhs->c;
944
464k
    const uint16_t *roots = kModRoots;
945
946
59.4M
    do {
947
59.4M
        uint32_t l0 = *lc++, r0 = *rc++;
948
59.4M
        uint32_t l1 = *lc++, r1 = *rc++;
949
59.4M
        uint16_t *c0 = curr++;
950
59.4M
        uint16_t *c1 = curr++;
951
59.4M
        uint32_t zetapow = *roots++;
952
953
59.4M
        *c0 = reduce(*c0 + l0 * r0 + reduce(l1 * r1) * zetapow);
954
59.4M
        *c1 = reduce(*c1 + l0 * r1 + l1 * r0);
955
59.4M
    } while (curr < end);
956
464k
}
957
958
/*-
959
 * FIPS 203, Section 4.2.1, Algorithm 5: "ByteEncode_d", for 2<=d<=12.
960
 * Here |bits| is |d|.  For efficiency, we handle the d=1 case separately.
961
 */
962
static void scalar_encode(uint8_t *out, const scalar *s, int bits)
963
309k
{
964
309k
    const uint16_t *curr = s->c, *end = curr + DEGREE;
965
309k
    uint64_t accum = 0, element;
966
309k
    int used = 0;
967
968
79.1M
    do {
969
79.1M
        element = *curr++;
970
79.1M
        if (used + bits < 64) {
971
64.3M
            accum |= element << used;
972
64.3M
            used += bits;
973
64.3M
        } else if (used + bits > 64) {
974
9.89M
            out = OPENSSL_store_u64_le(out, accum | (element << used));
975
9.89M
            accum = element >> (64 - used);
976
9.89M
            used = (used + bits) - 64;
977
9.89M
        } else {
978
4.94M
            out = OPENSSL_store_u64_le(out, accum | (element << used));
979
4.94M
            accum = 0;
980
4.94M
            used = 0;
981
4.94M
        }
982
79.1M
    } while (curr < end);
983
309k
}
984
985
/*
986
 * scalar_encode_1 is |scalar_encode| specialised for |bits| == 1.
987
 */
988
static void scalar_encode_1(uint8_t out[DEGREE / 8], const scalar *s)
989
141
{
990
141
    int i, j;
991
141
    uint8_t out_byte;
992
993
4.65k
    for (i = 0; i < DEGREE; i += 8) {
994
4.51k
        out_byte = 0;
995
40.6k
        for (j = 0; j < 8; j++)
996
36.0k
            out_byte |= bit0(s->c[i + j]) << j;
997
4.51k
        *out = out_byte;
998
4.51k
        out++;
999
4.51k
    }
1000
141
}
1001
1002
/*-
1003
 * FIPS 203, Section 4.2.1, Algorithm 6: "ByteDecode_d", for 2<=d<12.
1004
 * Here |bits| is |d|.  For efficiency, we handle the d=1 and d=12 cases
1005
 * separately.
1006
 *
1007
 * scalar_decode parses |DEGREE * bits| bits from |in| into |DEGREE| values in
1008
 * |out|.
1009
 */
1010
static void scalar_decode(scalar *out, const uint8_t *in, int bits)
1011
550
{
1012
550
    uint16_t *curr = out->c, *end = curr + DEGREE;
1013
550
    uint64_t accum = 0;
1014
550
    int accum_bits = 0, todo = bits;
1015
550
    uint16_t bitmask = (((uint16_t)1) << bits) - 1, mask = bitmask;
1016
550
    uint16_t element = 0;
1017
1018
155k
    do {
1019
155k
        if (accum_bits == 0) {
1020
19.4k
            in = OPENSSL_load_u64_le(&accum, in);
1021
19.4k
            accum_bits = 64;
1022
19.4k
        }
1023
155k
        if (todo == bits && accum_bits >= bits) {
1024
            /* No partial "element", and all the required bits available */
1025
125k
            *curr++ = ((uint16_t)accum) & mask;
1026
125k
            accum >>= bits;
1027
125k
            accum_bits -= bits;
1028
125k
        } else if (accum_bits >= todo) {
1029
            /* A partial "element", and all the required bits available */
1030
15.0k
            *curr++ = element | ((((uint16_t)accum) & mask) << (bits - todo));
1031
15.0k
            accum >>= todo;
1032
15.0k
            accum_bits -= todo;
1033
15.0k
            element = 0;
1034
15.0k
            todo = bits;
1035
15.0k
            mask = bitmask;
1036
15.0k
        } else {
1037
            /*
1038
             * Only some of the requisite bits accumulated, store |accum_bits|
1039
             * of these in |element|.  The accumulated bitcount becomes 0, but
1040
             * as soon as we have more bits we'll want to merge accum_bits
1041
             * fewer of them into the final |element|.
1042
             *
1043
             * Note that with a 64-bit accumulator and |bits| always 12 or
1044
             * less, if we're here, the previous iteration had all the
1045
             * requisite bits, and so there are no kept bits in |element|.
1046
             */
1047
15.0k
            element = ((uint16_t)accum) & mask;
1048
15.0k
            todo -= accum_bits;
1049
15.0k
            mask = bitmask >> accum_bits;
1050
15.0k
            accum_bits = 0;
1051
15.0k
        }
1052
155k
    } while (curr < end);
1053
550
}
1054
1055
static __owur int scalar_decode_12(scalar *out, const uint8_t in[3 * DEGREE / 2])
1056
1.22k
{
1057
1.22k
    int i;
1058
1.22k
    uint16_t *c = out->c;
1059
1060
127k
    for (i = 0; i < DEGREE / 2; ++i) {
1061
126k
        uint8_t b1 = *in++;
1062
126k
        uint8_t b2 = *in++;
1063
126k
        uint8_t b3 = *in++;
1064
126k
        int outOfRange1 = (*c++ = b1 | ((b2 & 0x0f) << 8)) >= kPrime;
1065
126k
        int outOfRange2 = (*c++ = (b2 >> 4) | (b3 << 4)) >= kPrime;
1066
1067
126k
        if (outOfRange1 | outOfRange2)
1068
286
            return 0;
1069
126k
    }
1070
940
    return 1;
1071
1.22k
}
1072
1073
/*-
1074
 * scalar_decode_decompress_add is a combination of decoding and decompression
1075
 * both specialised for |bits| == 1, with the result added (and sum reduced) to
1076
 * the output scalar.
1077
 *
1078
 * NOTE: this function MUST not leak an input-data-depedennt timing signal.
1079
 * A timing leak in a related function in the reference Kyber implementation
1080
 * made the "clangover" attack (CVE-2024-37880) possible, giving key recovery
1081
 * for ML-KEM-512 in minutes, provided the attacker has access to precise
1082
 * timing of a CPU performing chosen-ciphertext decap.  Admittedly this is only
1083
 * a risk when private keys are reused (perhaps KEMTLS servers).
1084
 */
1085
static void
1086
scalar_decode_decompress_add(scalar *out, const uint8_t in[DEGREE / 8])
1087
310
{
1088
310
    static const uint16_t half_q_plus_1 = (ML_KEM_PRIME >> 1) + 1;
1089
310
    uint16_t *curr = out->c, *end = curr + DEGREE;
1090
310
    uint16_t mask;
1091
310
    uint8_t b;
1092
1093
    /*
1094
     * Add |half_q_plus_1| if the bit is set, without exposing a side-channel,
1095
     * avoiding the "clangover" attack.  See |constish_time_non_zero| for a
1096
     * discussion on why the value barrier is by default omitted.
1097
     */
1098
310
#define decode_decompress_add_bit                        \
1099
79.3k
    mask = constish_time_non_zero(bit0(b));              \
1100
79.3k
    *curr = reduce_once(*curr + (mask & half_q_plus_1)); \
1101
79.3k
    curr++;                                              \
1102
79.3k
    b >>= 1
1103
1104
    /* Unrolled to process each byte in one iteration */
1105
9.92k
    do {
1106
9.92k
        b = *in++;
1107
9.92k
        decode_decompress_add_bit;
1108
9.92k
        decode_decompress_add_bit;
1109
9.92k
        decode_decompress_add_bit;
1110
9.92k
        decode_decompress_add_bit;
1111
1112
9.92k
        decode_decompress_add_bit;
1113
9.92k
        decode_decompress_add_bit;
1114
9.92k
        decode_decompress_add_bit;
1115
9.92k
        decode_decompress_add_bit;
1116
9.92k
    } while (curr < end);
1117
310
#undef decode_decompress_add_bit
1118
310
}
1119
1120
/*
1121
 * FIPS 203, Section 4.2.1, Equation (4.7): Compress_d.
1122
 *
1123
 * Compresses (lossily) an input |x| mod 3329 into |bits| many bits by grouping
1124
 * numbers close to each other together. The formula used is
1125
 * round(2^|bits|/kPrime*x) mod 2^|bits|.
1126
 * Uses Barrett reduction to achieve constant time. Since we need both the
1127
 * remainder (for rounding) and the quotient (as the result), we cannot use
1128
 * |reduce| here, but need to do the Barrett reduction directly.
1129
 */
1130
static __owur uint16_t compress(uint16_t x, int bits)
1131
346k
{
1132
346k
    uint32_t shifted = (uint32_t)x << bits;
1133
346k
    uint64_t product = (uint64_t)shifted * kBarrettMultiplier;
1134
346k
    uint32_t quotient = (uint32_t)(product >> kBarrettShift);
1135
346k
    uint32_t remainder = shifted - quotient * kPrime;
1136
1137
    /*
1138
     * Adjust the quotient to round correctly:
1139
     *   0 <= remainder <= kHalfPrime round to 0
1140
     *   kHalfPrime < remainder <= kPrime + kHalfPrime round to 1
1141
     *   kPrime + kHalfPrime < remainder < 2 * kPrime round to 2
1142
     */
1143
346k
    quotient += 1 & constant_time_lt_32(kHalfPrime, remainder);
1144
346k
    quotient += 1 & constant_time_lt_32(kPrime + kHalfPrime, remainder);
1145
346k
    return quotient & ((1 << bits) - 1);
1146
346k
}
1147
1148
/*
1149
 * FIPS 203, Section 4.2.1, Equation (4.8): Decompress_d.
1150
1151
 * Decompresses |x| by using a close equi-distant representative. The formula
1152
 * is round(kPrime/2^|bits|*x). Note that 2^|bits| being the divisor allows us
1153
 * to implement this logic using only bit operations.
1154
 */
1155
static __owur uint16_t decompress(uint16_t x, int bits)
1156
140k
{
1157
140k
    uint32_t product = (uint32_t)x * kPrime;
1158
140k
    uint32_t power = 1 << bits;
1159
    /* This is |product| % power, since |power| is a power of 2. */
1160
140k
    uint32_t remainder = product & (power - 1);
1161
    /* This is |product| / power, since |power| is a power of 2. */
1162
140k
    uint32_t lower = product >> bits;
1163
1164
    /*
1165
     * The rounding logic works since the first half of numbers mod |power|
1166
     * have a 0 as first bit, and the second half has a 1 as first bit, since
1167
     * |power| is a power of 2. As a 12 bit number, |remainder| is always
1168
     * positive, so we will shift in 0s for a right shift.
1169
     */
1170
140k
    return lower + (remainder >> (bits - 1));
1171
140k
}
1172
1173
/*-
1174
 * FIPS 203, Section 4.2.1, Equation (4.7): "Compress_d".
1175
 * In-place lossy rounding of scalars to 2^d bits.
1176
 */
1177
static void scalar_compress(scalar *s, int bits)
1178
1.35k
{
1179
1.35k
    int i;
1180
1181
347k
    for (i = 0; i < DEGREE; i++)
1182
346k
        s->c[i] = compress(s->c[i], bits);
1183
1.35k
}
1184
1185
/*
1186
 * FIPS 203, Section 4.2.1, Equation (4.8): "Decompress_d".
1187
 * In-place approximate recovery of scalars from 2^d bit compression.
1188
 */
1189
static void scalar_decompress(scalar *s, int bits)
1190
550
{
1191
550
    int i;
1192
1193
141k
    for (i = 0; i < DEGREE; i++)
1194
140k
        s->c[i] = decompress(s->c[i], bits);
1195
550
}
1196
1197
/* Addition updating the LHS vector in-place. */
1198
static void vector_add(scalar *lhs, const scalar *rhs, int rank)
1199
310
{
1200
902
    do {
1201
902
        scalar_add(lhs++, rhs++);
1202
902
    } while (--rank > 0);
1203
310
}
1204
1205
/*
1206
 * Encodes an entire vector into 32*|rank|*|bits| bytes. Note that since 256
1207
 * (DEGREE) is divisible by 8, the individual vector entries will always fill a
1208
 * whole number of bytes, so we do not need to worry about bit packing here.
1209
 */
1210
static void vector_encode(uint8_t *out, const scalar *a, int bits, int rank)
1211
51.7k
{
1212
51.7k
    int stride = bits * DEGREE / 8;
1213
1214
206k
    for (; rank-- > 0; out += stride)
1215
155k
        scalar_encode(out, a++, bits);
1216
51.7k
}
1217
1218
/*
1219
 * Decodes 32*|rank|*|bits| bytes from |in| into |out|. It returns early
1220
 * if any parsed value is >= |ML_KEM_PRIME|.  The resulting scalars are
1221
 * then decompressed and transformed via the NTT.
1222
 *
1223
 * Note: Used only in decrypt_cpa(), which returns void and so does not check
1224
 * the return value of this function.  Side-channels are fine when the input
1225
 * ciphertext to decap() is simply syntactically invalid.
1226
 */
1227
static void
1228
vector_decode_decompress_ntt(scalar *out, const uint8_t *in, int bits, int rank)
1229
141
{
1230
141
    int stride = bits * DEGREE / 8;
1231
1232
550
    for (; rank-- > 0; in += stride, ++out) {
1233
409
        scalar_decode(out, in, bits);
1234
409
        scalar_decompress(out, bits);
1235
409
        scalar_ntt(out);
1236
409
    }
1237
141
}
1238
1239
/* vector_decode(), specialised to bits == 12. */
1240
static __owur int vector_decode_12(scalar *out, const uint8_t in[3 * DEGREE / 2], int rank)
1241
625
{
1242
625
    int stride = 3 * DEGREE / 2;
1243
1244
1.56k
    for (; rank-- > 0; in += stride)
1245
1.22k
        if (!scalar_decode_12(out++, in))
1246
286
            return 0;
1247
339
    return 1;
1248
625
}
1249
1250
/* In-place compression of each scalar component */
1251
static void vector_compress(scalar *a, int bits, int rank)
1252
310
{
1253
902
    do {
1254
902
        scalar_compress(a++, bits);
1255
902
    } while (--rank > 0);
1256
310
}
1257
1258
/* The output scalar must not overlap with the inputs */
1259
static void inner_product(scalar *out, const scalar *lhs, const scalar *rhs,
1260
    int rank)
1261
451
{
1262
451
    scalar_mult(out, lhs, rhs);
1263
1.31k
    while (--rank > 0)
1264
860
        scalar_mult_add(out, ++lhs, ++rhs);
1265
451
}
1266
1267
/*
1268
 * Here, the output vector must not overlap with the inputs, the result is
1269
 * directly subjected to inverse NTT.
1270
 */
1271
static void
1272
matrix_mult_intt(scalar *out, const scalar *m, const scalar *a, int rank)
1273
310
{
1274
310
    const scalar *ar;
1275
310
    int i, j;
1276
1277
1.21k
    for (i = rank; i-- > 0; ++out) {
1278
902
        scalar_mult(out, m++, ar = a);
1279
2.81k
        for (j = rank - 1; j > 0; --j)
1280
1.90k
            scalar_mult_add(out, m++, ++ar);
1281
902
        scalar_inverse_ntt(out);
1282
902
    }
1283
310
}
1284
1285
/* Here, the output vector must not overlap with the inputs */
1286
static void
1287
matrix_mult_transpose_add(scalar *out, const scalar *m, const scalar *a, int rank)
1288
51.3k
{
1289
51.3k
    const scalar *mc = m, *mr, *ar;
1290
51.3k
    int i, j;
1291
1292
205k
    for (i = rank; i-- > 0; ++out) {
1293
153k
        scalar_mult_add(out, mr = mc++, ar = a);
1294
461k
        for (j = rank; --j > 0;)
1295
307k
            scalar_mult_add(out, (mr += rank), ++ar);
1296
153k
    }
1297
51.3k
}
1298
1299
/*-
1300
 * Expands the matrix from a seed for key generation and for encaps-CPA.
1301
 * NOTE: FIPS 203 matrix "A" is the transpose of this matrix, computed
1302
 * by appending the (i,j) indices to the seed in the opposite order!
1303
 *
1304
 * Where FIPS 203 computes t = A * s + e, we use the transpose of "m".
1305
 */
1306
static __owur int matrix_expand(EVP_MD_CTX *mdctx, ML_KEM_KEY *key)
1307
51.6k
{
1308
51.6k
    scalar *out = key->m;
1309
51.6k
    uint8_t input[ML_KEM_RANDOM_BYTES + 2];
1310
51.6k
    int rank = key->vinfo->rank;
1311
51.6k
    int i, j;
1312
1313
51.6k
    memcpy(input, key->rho, ML_KEM_RANDOM_BYTES);
1314
206k
    for (i = 0; i < rank; i++) {
1315
618k
        for (j = 0; j < rank; j++) {
1316
464k
            input[ML_KEM_RANDOM_BYTES] = i;
1317
464k
            input[ML_KEM_RANDOM_BYTES + 1] = j;
1318
464k
            if (!EVP_DigestInit_ex(mdctx, key->shake128_md, NULL)
1319
464k
                || !EVP_DigestUpdate(mdctx, input, sizeof(input))
1320
464k
                || !sample_scalar(out++, mdctx))
1321
0
                return 0;
1322
464k
        }
1323
154k
    }
1324
51.6k
    return 1;
1325
51.6k
}
1326
1327
/*
1328
 * Algorithm 7 from the spec, with eta fixed to two and the PRF call
1329
 * included. Creates binominally distributed elements by sampling 2*|eta| bits,
1330
 * and setting the coefficient to the count of the first bits minus the count of
1331
 * the second bits, resulting in a centered binomial distribution. Since eta is
1332
 * two this gives -2/2 with a probability of 1/16, -1/1 with probability 1/4,
1333
 * and 0 with probability 3/8.
1334
 */
1335
static __owur int cbd_2(scalar *out, uint8_t in[ML_KEM_RANDOM_BYTES + 1],
1336
    EVP_MD_CTX *mdctx, const ML_KEM_KEY *key)
1337
308k
{
1338
308k
    uint16_t *curr = out->c, *end = curr + DEGREE;
1339
308k
    uint8_t randbuf[4 * DEGREE / 8], *r = randbuf; /* 64 * eta slots */
1340
308k
    uint16_t value, mask;
1341
308k
    uint8_t b;
1342
1343
308k
    if (!prf(randbuf, sizeof(randbuf), in, mdctx, key))
1344
0
        return 0;
1345
1346
39.5M
    do {
1347
39.5M
        b = *r++;
1348
1349
        /*
1350
         * Add |kPrime| if |value| underflowed.  See |constish_time_non_zero|
1351
         * for a discussion on why the value barrier is by default omitted.
1352
         * While this could have been written reduce_once(value + kPrime), this
1353
         * is one extra addition and small range of |value| tempts some
1354
         * versions of Clang to emit a branch.
1355
         */
1356
39.5M
        value = bit0(b) + bitn(1, b);
1357
39.5M
        value -= bitn(2, b) + bitn(3, b);
1358
39.5M
        mask = constish_time_non_zero(value >> 15);
1359
39.5M
        *curr++ = value + (kPrime & mask);
1360
1361
39.5M
        value = bitn(4, b) + bitn(5, b);
1362
39.5M
        value -= bitn(6, b) + bitn(7, b);
1363
39.5M
        mask = constish_time_non_zero(value >> 15);
1364
39.5M
        *curr++ = value + (kPrime & mask);
1365
39.5M
    } while (curr < end);
1366
308k
    return 1;
1367
308k
}
1368
1369
/*
1370
 * Algorithm 7 from the spec, with eta fixed to three and the PRF call
1371
 * included. Creates binominally distributed elements by sampling 3*|eta| bits,
1372
 * and setting the coefficient to the count of the first bits minus the count of
1373
 * the second bits, resulting in a centered binomial distribution.
1374
 */
1375
static __owur int cbd_3(scalar *out, uint8_t in[ML_KEM_RANDOM_BYTES + 1],
1376
    EVP_MD_CTX *mdctx, const ML_KEM_KEY *key)
1377
1.31k
{
1378
1.31k
    uint16_t *curr = out->c, *end = curr + DEGREE;
1379
1.31k
    uint8_t randbuf[6 * DEGREE / 8], *r = randbuf; /* 64 * eta slots */
1380
1.31k
    uint8_t b1, b2, b3;
1381
1.31k
    uint16_t value, mask;
1382
1383
1.31k
    if (!prf(randbuf, sizeof(randbuf), in, mdctx, key))
1384
0
        return 0;
1385
1386
83.9k
    do {
1387
83.9k
        b1 = *r++;
1388
83.9k
        b2 = *r++;
1389
83.9k
        b3 = *r++;
1390
1391
        /*
1392
         * Add |kPrime| if |value| underflowed.  See |constish_time_non_zero|
1393
         * for a discussion on why the value barrier is by default omitted.
1394
         * While this could have been written reduce_once(value + kPrime), this
1395
         * is one extra addition and small range of |value| tempts some
1396
         * versions of Clang to emit a branch.
1397
         */
1398
83.9k
        value = bit0(b1) + bitn(1, b1) + bitn(2, b1);
1399
83.9k
        value -= bitn(3, b1) + bitn(4, b1) + bitn(5, b1);
1400
83.9k
        mask = constish_time_non_zero(value >> 15);
1401
83.9k
        *curr++ = value + (kPrime & mask);
1402
1403
83.9k
        value = bitn(6, b1) + bitn(7, b1) + bit0(b2);
1404
83.9k
        value -= bitn(1, b2) + bitn(2, b2) + bitn(3, b2);
1405
83.9k
        mask = constish_time_non_zero(value >> 15);
1406
83.9k
        *curr++ = value + (kPrime & mask);
1407
1408
83.9k
        value = bitn(4, b2) + bitn(5, b2) + bitn(6, b2);
1409
83.9k
        value -= bitn(7, b2) + bit0(b3) + bitn(1, b3);
1410
83.9k
        mask = constish_time_non_zero(value >> 15);
1411
83.9k
        *curr++ = value + (kPrime & mask);
1412
1413
83.9k
        value = bitn(2, b3) + bitn(3, b3) + bitn(4, b3);
1414
83.9k
        value -= bitn(5, b3) + bitn(6, b3) + bitn(7, b3);
1415
83.9k
        mask = constish_time_non_zero(value >> 15);
1416
83.9k
        *curr++ = value + (kPrime & mask);
1417
83.9k
    } while (curr < end);
1418
1.31k
    return 1;
1419
1.31k
}
1420
1421
/*
1422
 * Generates a secret vector by using |cbd| with the given seed to generate
1423
 * scalar elements and incrementing |counter| for each slot of the vector.
1424
 */
1425
static __owur int gencbd_vector(scalar *out, CBD_FUNC cbd, uint8_t *counter,
1426
    const uint8_t seed[ML_KEM_RANDOM_BYTES], int rank,
1427
    EVP_MD_CTX *mdctx, const ML_KEM_KEY *key)
1428
310
{
1429
310
    uint8_t input[ML_KEM_RANDOM_BYTES + 1];
1430
1431
310
    memcpy(input, seed, ML_KEM_RANDOM_BYTES);
1432
902
    do {
1433
902
        input[ML_KEM_RANDOM_BYTES] = (*counter)++;
1434
902
        if (!cbd(out++, input, mdctx, key))
1435
0
            return 0;
1436
902
    } while (--rank > 0);
1437
310
    return 1;
1438
310
}
1439
1440
/*
1441
 * As above plus NTT transform.
1442
 */
1443
static __owur int gencbd_vector_ntt(scalar *out, CBD_FUNC cbd, uint8_t *counter,
1444
    const uint8_t seed[ML_KEM_RANDOM_BYTES], int rank,
1445
    EVP_MD_CTX *mdctx, const ML_KEM_KEY *key)
1446
103k
{
1447
103k
    uint8_t input[ML_KEM_RANDOM_BYTES + 1];
1448
1449
103k
    memcpy(input, seed, ML_KEM_RANDOM_BYTES);
1450
308k
    do {
1451
308k
        input[ML_KEM_RANDOM_BYTES] = (*counter)++;
1452
308k
        if (!cbd(out, input, mdctx, key))
1453
0
            return 0;
1454
308k
        scalar_ntt(out++);
1455
308k
    } while (--rank > 0);
1456
103k
    return 1;
1457
103k
}
1458
1459
/* The |ETA1| value for ML-KEM-512 is 3, the rest and all ETA2 values are 2. */
1460
35.3k
#define CBD1(evp_type) ((evp_type) == EVP_PKEY_ML_KEM_512 ? cbd_3 : cbd_2)
1461
1462
/*
1463
 * FIPS 203, Section 5.2, Algorithm 14: K-PKE.Encrypt.
1464
 *
1465
 * Encrypts a message with given randomness to the ciphertext in |out|. Without
1466
 * applying the Fujisaki-Okamoto transform this would not result in a CCA
1467
 * secure scheme, since lattice schemes are vulnerable to decryption failure
1468
 * oracles.
1469
 *
1470
 * The steps are re-ordered to make more efficient/localised use of storage.
1471
 *
1472
 * Note also that the input public key is assumed to hold a precomputed matrix
1473
 * |A| (our key->m, with the public key holding an expanded (16-bit per scalar
1474
 * coefficient) key->t vector).
1475
 *
1476
 * Caller passes storage in |tmp| for for two temporary vectors.
1477
 */
1478
static __owur int encrypt_cpa(uint8_t out[ML_KEM_SHARED_SECRET_BYTES],
1479
    const uint8_t message[DEGREE / 8],
1480
    const uint8_t r[ML_KEM_RANDOM_BYTES], scalar *tmp,
1481
    EVP_MD_CTX *mdctx, const ML_KEM_KEY *key)
1482
310
{
1483
310
    const ML_KEM_VINFO *vinfo = key->vinfo;
1484
310
    CBD_FUNC cbd_1 = CBD1(vinfo->evp_type);
1485
310
    int rank = vinfo->rank;
1486
    /* We can use tmp[0..rank-1] as storage for |y|, then |e1|, ... */
1487
310
    scalar *y = &tmp[0], *e1 = y, *e2 = y;
1488
    /* We can use tmp[rank]..tmp[2*rank - 1] for |u| */
1489
310
    scalar *u = &tmp[rank];
1490
310
    scalar v;
1491
310
    uint8_t input[ML_KEM_RANDOM_BYTES + 1];
1492
310
    uint8_t counter = 0;
1493
310
    int du = vinfo->du;
1494
310
    int dv = vinfo->dv;
1495
1496
    /* FIPS 203 "y" vector */
1497
310
    if (!gencbd_vector_ntt(y, cbd_1, &counter, r, rank, mdctx, key))
1498
0
        return 0;
1499
    /* FIPS 203 "v" scalar */
1500
310
    inner_product(&v, key->t, y, rank);
1501
310
    scalar_inverse_ntt(&v);
1502
    /* FIPS 203 "u" vector */
1503
310
    matrix_mult_intt(u, key->m, y, rank);
1504
1505
    /* All done with |y|, now free to reuse tmp[0] for FIPS 203 |e1| */
1506
310
    if (!gencbd_vector(e1, cbd_2, &counter, r, rank, mdctx, key))
1507
0
        return 0;
1508
310
    vector_add(u, e1, rank);
1509
310
    vector_compress(u, du, rank);
1510
310
    vector_encode(out, u, du, rank);
1511
1512
    /* All done with |e1|, now free to reuse tmp[0] for FIPS 203 |e2| */
1513
310
    memcpy(input, r, ML_KEM_RANDOM_BYTES);
1514
310
    input[ML_KEM_RANDOM_BYTES] = counter;
1515
310
    if (!cbd_2(e2, input, mdctx, key))
1516
0
        return 0;
1517
310
    scalar_add(&v, e2);
1518
1519
    /* Combine message with |v| */
1520
310
    scalar_decode_decompress_add(&v, message);
1521
310
    scalar_compress(&v, dv);
1522
310
    scalar_encode(out + vinfo->u_vector_bytes, &v, dv);
1523
310
    return 1;
1524
310
}
1525
1526
/*
1527
 * FIPS 203, Section 5.3, Algorithm 15: K-PKE.Decrypt.
1528
 */
1529
static void
1530
decrypt_cpa(uint8_t out[ML_KEM_SHARED_SECRET_BYTES],
1531
    const uint8_t *ctext, scalar *u, const ML_KEM_KEY *key)
1532
141
{
1533
141
    const ML_KEM_VINFO *vinfo = key->vinfo;
1534
141
    scalar v, mask;
1535
141
    int rank = vinfo->rank;
1536
141
    int du = vinfo->du;
1537
141
    int dv = vinfo->dv;
1538
1539
141
    vector_decode_decompress_ntt(u, ctext, du, rank);
1540
141
    scalar_decode(&v, ctext + vinfo->u_vector_bytes, dv);
1541
141
    scalar_decompress(&v, dv);
1542
141
    inner_product(&mask, key->s, u, rank);
1543
141
    scalar_inverse_ntt(&mask);
1544
141
    scalar_sub(&v, &mask);
1545
141
    scalar_compress(&v, 1);
1546
141
    scalar_encode_1(out, &v);
1547
141
}
1548
1549
/*-
1550
 * FIPS 203, Section 7.1, Algorithm 19: "ML-KEM.KeyGen".
1551
 * FIPS 203, Section 7.2, Algorithm 20: "ML-KEM.Encaps".
1552
 *
1553
 * Fills the |out| buffer with the |ek| output of "ML-KEM.KeyGen", or,
1554
 * equivalently, the |ek| input of "ML-KEM.Encaps", i.e. returns the
1555
 * wire-format of an ML-KEM public key.
1556
 */
1557
static void encode_pubkey(uint8_t *out, const ML_KEM_KEY *key)
1558
51.2k
{
1559
51.2k
    const uint8_t *rho = key->rho;
1560
51.2k
    const ML_KEM_VINFO *vinfo = key->vinfo;
1561
1562
51.2k
    vector_encode(out, key->t, 12, vinfo->rank);
1563
51.2k
    memcpy(out + vinfo->vector_bytes, rho, ML_KEM_RANDOM_BYTES);
1564
51.2k
}
1565
1566
/*-
1567
 * FIPS 203, Section 7.1, Algorithm 19: "ML-KEM.KeyGen".
1568
 *
1569
 * Fills the |out| buffer with the |dk| output of "ML-KEM.KeyGen".
1570
 * This matches the input format of parse_prvkey() below.
1571
 */
1572
static void encode_prvkey(uint8_t *out, const ML_KEM_KEY *key)
1573
132
{
1574
132
    const ML_KEM_VINFO *vinfo = key->vinfo;
1575
1576
132
    vector_encode(out, key->s, 12, vinfo->rank);
1577
132
    out += vinfo->vector_bytes;
1578
132
    encode_pubkey(out, key);
1579
132
    out += vinfo->pubkey_bytes;
1580
132
    memcpy(out, key->pkhash, ML_KEM_PKHASH_BYTES);
1581
132
    out += ML_KEM_PKHASH_BYTES;
1582
132
    memcpy(out, key->z, ML_KEM_RANDOM_BYTES);
1583
132
}
1584
1585
/*-
1586
 * FIPS 203, Section 7.1, Algorithm 19: "ML-KEM.KeyGen".
1587
 * FIPS 203, Section 7.2, Algorithm 20: "ML-KEM.Encaps".
1588
 *
1589
 * This function parses the |in| buffer as the |ek| output of "ML-KEM.KeyGen",
1590
 * or, equivalently, the |ek| input of "ML-KEM.Encaps", i.e. decodes the
1591
 * wire-format of the ML-KEM public key.
1592
 */
1593
static int parse_pubkey(const uint8_t *in, EVP_MD_CTX *mdctx, ML_KEM_KEY *key)
1594
545
{
1595
545
    const ML_KEM_VINFO *vinfo = key->vinfo;
1596
1597
    /* Decode and check |t| */
1598
545
    if (!vector_decode_12(key->t, in, vinfo->rank)) {
1599
220
        ERR_raise_data(ERR_LIB_PROV, PROV_R_INVALID_KEY,
1600
220
            "%s invalid public 't' vector",
1601
220
            vinfo->algorithm_name);
1602
220
        return 0;
1603
220
    }
1604
    /* Save the matrix |m| recovery seed |rho| */
1605
325
    memcpy(key->rho, in + vinfo->vector_bytes, ML_KEM_RANDOM_BYTES);
1606
    /*
1607
     * Pre-compute the public key hash, needed for both encap and decap.
1608
     * Also pre-compute the matrix expansion, stored with the public key.
1609
     */
1610
325
    if (!hash_h(key->pkhash, in, vinfo->pubkey_bytes, mdctx, key)
1611
325
        || !matrix_expand(mdctx, key)) {
1612
0
        ERR_raise_data(ERR_LIB_CRYPTO, ERR_R_INTERNAL_ERROR,
1613
0
            "internal error while parsing %s public key",
1614
0
            vinfo->algorithm_name);
1615
0
        return 0;
1616
0
    }
1617
325
    return 1;
1618
325
}
1619
1620
/*
1621
 * FIPS 203, Section 7.1, Algorithm 19: "ML-KEM.KeyGen".
1622
 *
1623
 * Parses the |in| buffer as a |dk| output of "ML-KEM.KeyGen".
1624
 * This matches the output format of encode_prvkey() above.
1625
 */
1626
static int parse_prvkey(const uint8_t *in, EVP_MD_CTX *mdctx, ML_KEM_KEY *key)
1627
80
{
1628
80
    const ML_KEM_VINFO *vinfo = key->vinfo;
1629
1630
    /* Decode and check |s|. */
1631
80
    if (!vector_decode_12(key->s, in, vinfo->rank)) {
1632
66
        ERR_raise_data(ERR_LIB_PROV, PROV_R_INVALID_KEY,
1633
66
            "%s invalid private 's' vector",
1634
66
            vinfo->algorithm_name);
1635
66
        return 0;
1636
66
    }
1637
14
    in += vinfo->vector_bytes;
1638
1639
14
    if (!parse_pubkey(in, mdctx, key))
1640
11
        return 0;
1641
3
    in += vinfo->pubkey_bytes;
1642
1643
    /* Check public key hash. */
1644
3
    if (memcmp(key->pkhash, in, ML_KEM_PKHASH_BYTES) != 0) {
1645
3
        ERR_raise_data(ERR_LIB_PROV, PROV_R_INVALID_KEY,
1646
3
            "%s public key hash mismatch",
1647
3
            vinfo->algorithm_name);
1648
3
        return 0;
1649
3
    }
1650
0
    in += ML_KEM_PKHASH_BYTES;
1651
1652
0
    memcpy(key->z, in, ML_KEM_RANDOM_BYTES);
1653
0
    return 1;
1654
3
}
1655
1656
/*
1657
 * FIPS 203, Section 6.1, Algorithm 16: "ML-KEM.KeyGen_internal".
1658
 *
1659
 * The implementation of Section 5.1, Algorithm 13, "K-PKE.KeyGen(d)" is
1660
 * inlined.
1661
 *
1662
 * The caller MUST pass a pre-allocated digest context that is not shared with
1663
 * any concurrent computation.
1664
 *
1665
 * This function optionally outputs the serialised wire-form |ek| public key
1666
 * into the provided |pubenc| buffer, and generates the content of the |rho|,
1667
 * |pkhash|, |t|, |m|, |s| and |z| components of the private |key| (which must
1668
 * have preallocated space for these).
1669
 *
1670
 * Keys are computed from a 32-byte random |d| plus the 1 byte rank for
1671
 * domain separation.  These are concatenated and hashed to produce a pair of
1672
 * 32-byte seeds public "rho", used to generate the matrix, and private "sigma",
1673
 * used to generate the secret vector |s|.
1674
 *
1675
 * The second random input |z| is copied verbatim into the Fujisaki-Okamoto
1676
 * (FO) transform "implicit-rejection" secret (the |z| component of the private
1677
 * key), which thwarts chosen-ciphertext attacks, provided decap() runs in
1678
 * constant time, with no side channel leaks, on all well-formed (valid length,
1679
 * and correctly encoded) ciphertext inputs.
1680
 */
1681
static __owur int genkey(const uint8_t seed[ML_KEM_SEED_BYTES],
1682
    EVP_MD_CTX *mdctx, uint8_t *pubenc, ML_KEM_KEY *key)
1683
35.0k
{
1684
35.0k
    uint8_t hashed[2 * ML_KEM_RANDOM_BYTES];
1685
35.0k
    const uint8_t *const sigma = hashed + ML_KEM_RANDOM_BYTES;
1686
35.0k
    uint8_t augmented_seed[ML_KEM_RANDOM_BYTES + 1];
1687
35.0k
    const ML_KEM_VINFO *vinfo = key->vinfo;
1688
35.0k
    CBD_FUNC cbd_1 = CBD1(vinfo->evp_type);
1689
35.0k
    int rank = vinfo->rank;
1690
35.0k
    uint8_t counter = 0;
1691
35.0k
    int ret = 0;
1692
1693
    /*
1694
     * Use the "d" seed salted with the rank to derive the public and private
1695
     * seeds rho and sigma.
1696
     */
1697
35.0k
    memcpy(augmented_seed, seed, ML_KEM_RANDOM_BYTES);
1698
35.0k
    augmented_seed[ML_KEM_RANDOM_BYTES] = (uint8_t)rank;
1699
35.0k
    if (!hash_g(hashed, augmented_seed, sizeof(augmented_seed), mdctx, key))
1700
0
        goto end;
1701
35.0k
    memcpy(key->rho, hashed, ML_KEM_RANDOM_BYTES);
1702
    /* The |rho| matrix seed is public */
1703
35.0k
    CONSTTIME_DECLASSIFY(key->rho, ML_KEM_RANDOM_BYTES);
1704
1705
    /* FIPS 203 |e| vector is initial value of key->t */
1706
35.0k
    if (!matrix_expand(mdctx, key)
1707
35.0k
        || !gencbd_vector_ntt(key->s, cbd_1, &counter, sigma, rank, mdctx, key)
1708
35.0k
        || !gencbd_vector_ntt(key->t, cbd_1, &counter, sigma, rank, mdctx, key))
1709
0
        goto end;
1710
1711
    /* To |e| we now add the product of transpose |m| and |s|, giving |t|. */
1712
35.0k
    matrix_mult_transpose_add(key->t, key->m, key->s, rank);
1713
    /* The |t| vector is public */
1714
35.0k
    CONSTTIME_DECLASSIFY(key->t, vinfo->rank * sizeof(scalar));
1715
1716
35.0k
    if (pubenc == NULL) {
1717
        /* Incremental digest of public key without in-full serialisation. */
1718
35.0k
        if (!hash_h_pubkey(key->pkhash, mdctx, key))
1719
0
            goto end;
1720
35.0k
    } else {
1721
0
        encode_pubkey(pubenc, key);
1722
0
        if (!hash_h(key->pkhash, pubenc, vinfo->pubkey_bytes, mdctx, key))
1723
0
            goto end;
1724
0
    }
1725
1726
    /* Save |z| portion of seed for "implicit rejection" on failure. */
1727
35.0k
    memcpy(key->z, seed + ML_KEM_RANDOM_BYTES, ML_KEM_RANDOM_BYTES);
1728
1729
    /* Optionally save the |d| portion of the seed */
1730
35.0k
    key->d = key->z + ML_KEM_RANDOM_BYTES;
1731
35.0k
    if (key->prov_flags & ML_KEM_KEY_RETAIN_SEED) {
1732
35.0k
        memcpy(key->d, seed, ML_KEM_RANDOM_BYTES);
1733
35.0k
    } else {
1734
0
        OPENSSL_cleanse(key->d, ML_KEM_RANDOM_BYTES);
1735
0
        key->d = NULL;
1736
0
    }
1737
1738
35.0k
    ret = 1;
1739
35.0k
end:
1740
35.0k
    OPENSSL_cleanse((void *)augmented_seed, ML_KEM_RANDOM_BYTES);
1741
35.0k
    OPENSSL_cleanse((void *)sigma, ML_KEM_RANDOM_BYTES);
1742
35.0k
    if (ret == 0) {
1743
0
        ERR_raise_data(ERR_LIB_CRYPTO, ERR_R_INTERNAL_ERROR,
1744
0
            "internal error while generating %s private key",
1745
0
            vinfo->algorithm_name);
1746
0
    }
1747
35.0k
    return ret;
1748
35.0k
}
1749
1750
/*-
1751
 * FIPS 203, Section 6.2, Algorithm 17: "ML-KEM.Encaps_internal".
1752
 * This is the deterministic version with randomness supplied externally.
1753
 *
1754
 * The caller must pass space for two vectors in |tmp|.
1755
 * The |ctext| buffer have space for the ciphertext of the ML-KEM variant
1756
 * of the provided key.
1757
 */
1758
static int encap(uint8_t *ctext, uint8_t secret[ML_KEM_SHARED_SECRET_BYTES],
1759
    const uint8_t entropy[ML_KEM_RANDOM_BYTES],
1760
    scalar *tmp, EVP_MD_CTX *mdctx, const ML_KEM_KEY *key)
1761
169
{
1762
169
    uint8_t input[ML_KEM_RANDOM_BYTES + ML_KEM_PKHASH_BYTES];
1763
169
    uint8_t Kr[ML_KEM_SHARED_SECRET_BYTES + ML_KEM_RANDOM_BYTES];
1764
169
    uint8_t *r = Kr + ML_KEM_SHARED_SECRET_BYTES;
1765
169
    int ret;
1766
1767
169
    memcpy(input, entropy, ML_KEM_RANDOM_BYTES);
1768
169
    memcpy(input + ML_KEM_RANDOM_BYTES, key->pkhash, ML_KEM_PKHASH_BYTES);
1769
169
    ret = hash_g(Kr, input, sizeof(input), mdctx, key)
1770
169
        && encrypt_cpa(ctext, entropy, r, tmp, mdctx, key);
1771
169
    OPENSSL_cleanse((void *)input, sizeof(input));
1772
1773
169
    if (ret)
1774
169
        memcpy(secret, Kr, ML_KEM_SHARED_SECRET_BYTES);
1775
0
    else
1776
0
        ERR_raise_data(ERR_LIB_CRYPTO, ERR_R_INTERNAL_ERROR,
1777
0
            "internal error while performing %s encapsulation",
1778
0
            key->vinfo->algorithm_name);
1779
169
    return ret;
1780
169
}
1781
1782
/*
1783
 * FIPS 203, Section 6.3, Algorithm 18: ML-KEM.Decaps_internal
1784
 *
1785
 * Barring failure of the supporting SHA3/SHAKE primitives, this is fully
1786
 * deterministic, the randomness for the FO transform is extracted during
1787
 * private key generation.
1788
 *
1789
 * The caller must pass space for two vectors in |tmp|.
1790
 * The |ctext| and |tmp_ctext| buffers must each have space for the ciphertext
1791
 * of the key's ML-KEM variant.
1792
 */
1793
static int decap(uint8_t secret[ML_KEM_SHARED_SECRET_BYTES],
1794
    const uint8_t *ctext, uint8_t *tmp_ctext, scalar *tmp,
1795
    EVP_MD_CTX *mdctx, const ML_KEM_KEY *key)
1796
85
{
1797
85
    uint8_t decrypted[ML_KEM_SHARED_SECRET_BYTES + ML_KEM_PKHASH_BYTES];
1798
85
    uint8_t failure_key[ML_KEM_RANDOM_BYTES];
1799
85
    uint8_t Kr[ML_KEM_SHARED_SECRET_BYTES + ML_KEM_RANDOM_BYTES];
1800
85
    uint8_t *r = Kr + ML_KEM_SHARED_SECRET_BYTES;
1801
85
    const uint8_t *pkhash = key->pkhash;
1802
85
    const ML_KEM_VINFO *vinfo = key->vinfo;
1803
85
    int i;
1804
85
    uint8_t mask;
1805
1806
    /*
1807
     * If our KDF is unavailable, fail early! Otherwise, keep going ignoring
1808
     * any further errors, returning success, and whatever we got for a shared
1809
     * secret.  The decrypt_cpa() function is just arithmetic on secret data,
1810
     * so should not be subject to failure that makes its output predictable.
1811
     *
1812
     * We guard against "should never happen" catastrophic failure of the
1813
     * "pure" function |hash_g| by overwriting the shared secret with the
1814
     * content of the failure key and returning early, if nevertheless hash_g
1815
     * fails.  This is not constant-time, but a failure of |hash_g| already
1816
     * implies loss of side-channel resistance.
1817
     *
1818
     * The same action is taken, if also |encrypt_cpa| should catastrophically
1819
     * fail, due to failure of the |PRF| underlying the CBD functions.
1820
     */
1821
85
    if (!kdf(failure_key, key->z, ctext, vinfo->ctext_bytes, mdctx, key)) {
1822
0
        ERR_raise_data(ERR_LIB_CRYPTO, ERR_R_INTERNAL_ERROR,
1823
0
            "internal error while performing %s decapsulation",
1824
0
            vinfo->algorithm_name);
1825
0
        return 0;
1826
0
    }
1827
85
    decrypt_cpa(decrypted, ctext, tmp, key);
1828
85
    memcpy(decrypted + ML_KEM_SHARED_SECRET_BYTES, pkhash, ML_KEM_PKHASH_BYTES);
1829
85
    if (!hash_g(Kr, decrypted, sizeof(decrypted), mdctx, key)
1830
85
        || !encrypt_cpa(tmp_ctext, decrypted, r, tmp, mdctx, key)) {
1831
0
        memcpy(secret, failure_key, ML_KEM_SHARED_SECRET_BYTES);
1832
0
        OPENSSL_cleanse(decrypted, ML_KEM_SHARED_SECRET_BYTES);
1833
0
        return 1;
1834
0
    }
1835
85
    mask = constant_time_eq_int_8(0,
1836
85
        CRYPTO_memcmp(ctext, tmp_ctext, vinfo->ctext_bytes));
1837
2.80k
    for (i = 0; i < ML_KEM_SHARED_SECRET_BYTES; i++)
1838
2.72k
        secret[i] = constant_time_select_8(mask, Kr[i], failure_key[i]);
1839
85
    OPENSSL_cleanse(decrypted, ML_KEM_SHARED_SECRET_BYTES);
1840
85
    OPENSSL_cleanse(Kr, sizeof(Kr));
1841
85
    return 1;
1842
85
}
1843
1844
/*
1845
 * After allocating storage for public or private key data, update the key
1846
 * component pointers to reference that storage.
1847
 */
1848
static __owur int add_storage(scalar *p, int private, ML_KEM_KEY *key)
1849
17.4k
{
1850
17.4k
    int rank = key->vinfo->rank;
1851
1852
17.4k
    if (p == NULL)
1853
0
        return 0;
1854
1855
    /*
1856
     * We're adding key material, the seed buffer will now hold |rho| and
1857
     * |pkhash|.
1858
     */
1859
17.4k
    memset(key->seedbuf, 0, sizeof(key->seedbuf));
1860
17.4k
    key->rho = key->seedbuf;
1861
17.4k
    key->pkhash = key->seedbuf + ML_KEM_RANDOM_BYTES;
1862
17.4k
    key->d = key->z = NULL;
1863
1864
    /* A public key needs space for |t| and |m| */
1865
17.4k
    key->m = (key->t = p) + rank;
1866
1867
    /*
1868
     * A private key also needs space for |s| and |z|.
1869
     * The |z| buffer always includes additional space for |d|, but a key's |d|
1870
     * pointer is left NULL when parsed from the NIST format, which omits that
1871
     * information.  Only keys generated from a (d, z) seed pair will have a
1872
     * non-NULL |d| pointer.
1873
     */
1874
17.4k
    if (private)
1875
17.2k
        key->z = (uint8_t *)(rank + (key->s = key->m + rank * rank));
1876
17.4k
    return 1;
1877
17.4k
}
1878
1879
/*
1880
 * After freeing the storage associated with a key that failed to be
1881
 * constructed, reset the internal pointers back to NULL.
1882
 */
1883
void ossl_ml_kem_key_reset(ML_KEM_KEY *key)
1884
17.5k
{
1885
17.5k
    if (key->t == NULL)
1886
111
        return;
1887
    /*-
1888
     * Cleanse any sensitive data:
1889
     * - The private vector |s| is immediately followed by the FO failure
1890
     *   secret |z|, and seed |d|, we can cleanse all three in one call.
1891
     *
1892
     * - Otherwise, when key->d is set, cleanse the stashed seed.
1893
     */
1894
17.4k
    if (ossl_ml_kem_have_prvkey(key))
1895
17.2k
        OPENSSL_cleanse(key->s,
1896
17.2k
            key->vinfo->rank * sizeof(scalar) + 2 * ML_KEM_RANDOM_BYTES);
1897
17.4k
    OPENSSL_free(key->t);
1898
17.4k
    key->d = key->z = (uint8_t *)(key->s = key->m = key->t = NULL);
1899
17.4k
}
1900
1901
/*
1902
 * ----- API exported to the provider
1903
 *
1904
 * Parameters with an implicit fixed length in the internal static API of each
1905
 * variant have an explicit checked length argument at this layer.
1906
 */
1907
1908
/* Retrieve the parameters of one of the ML-KEM variants */
1909
const ML_KEM_VINFO *ossl_ml_kem_get_vinfo(int evp_type)
1910
279k
{
1911
279k
    switch (evp_type) {
1912
60.5k
    case EVP_PKEY_ML_KEM_512:
1913
60.5k
        return &vinfo_map[ML_KEM_512_VINFO];
1914
162k
    case EVP_PKEY_ML_KEM_768:
1915
162k
        return &vinfo_map[ML_KEM_768_VINFO];
1916
57.0k
    case EVP_PKEY_ML_KEM_1024:
1917
57.0k
        return &vinfo_map[ML_KEM_1024_VINFO];
1918
279k
    }
1919
0
    return NULL;
1920
279k
}
1921
1922
ML_KEM_KEY *ossl_ml_kem_key_new(OSSL_LIB_CTX *libctx, const char *properties,
1923
    int evp_type)
1924
35.4k
{
1925
35.4k
    const ML_KEM_VINFO *vinfo = ossl_ml_kem_get_vinfo(evp_type);
1926
35.4k
    ML_KEM_KEY *key;
1927
1928
35.4k
    if (vinfo == NULL) {
1929
0
        ERR_raise_data(ERR_LIB_CRYPTO, ERR_R_PASSED_INVALID_ARGUMENT,
1930
0
            "unsupported ML-KEM key type: %d", evp_type);
1931
0
        return NULL;
1932
0
    }
1933
1934
35.4k
    if ((key = OPENSSL_malloc(sizeof(*key))) == NULL)
1935
0
        return NULL;
1936
1937
35.4k
    key->vinfo = vinfo;
1938
35.4k
    key->libctx = libctx;
1939
35.4k
    key->prov_flags = ML_KEM_KEY_PROV_FLAGS_DEFAULT;
1940
35.4k
    key->shake128_md = EVP_MD_fetch(libctx, "SHAKE128", properties);
1941
35.4k
    key->shake256_md = EVP_MD_fetch(libctx, "SHAKE256", properties);
1942
35.4k
    key->sha3_256_md = EVP_MD_fetch(libctx, "SHA3-256", properties);
1943
35.4k
    key->sha3_512_md = EVP_MD_fetch(libctx, "SHA3-512", properties);
1944
35.4k
    key->d = key->z = key->rho = key->pkhash = key->encoded_dk = NULL;
1945
35.4k
    key->s = key->m = key->t = NULL;
1946
1947
35.4k
    if (key->shake128_md != NULL
1948
35.4k
        && key->shake256_md != NULL
1949
35.4k
        && key->sha3_256_md != NULL
1950
35.4k
        && key->sha3_512_md != NULL)
1951
35.4k
        return key;
1952
1953
0
    ossl_ml_kem_key_free(key);
1954
0
    ERR_raise_data(ERR_LIB_CRYPTO, ERR_R_INTERNAL_ERROR,
1955
0
        "missing SHA3 digest algorithms while creating %s key",
1956
0
        vinfo->algorithm_name);
1957
0
    return NULL;
1958
35.4k
}
1959
1960
ML_KEM_KEY *ossl_ml_kem_key_dup(const ML_KEM_KEY *key, int selection)
1961
27
{
1962
27
    int ok = 0;
1963
27
    ML_KEM_KEY *ret;
1964
1965
    /*
1966
     * Partially decoded keys, not yet imported or loaded, should never be
1967
     * duplicated.
1968
     */
1969
27
    if (ossl_ml_kem_decoded_key(key))
1970
0
        return NULL;
1971
1972
27
    if (key == NULL
1973
27
        || (ret = OPENSSL_memdup(key, sizeof(*key))) == NULL)
1974
0
        return NULL;
1975
27
    ret->d = ret->z = ret->rho = ret->pkhash = NULL;
1976
27
    ret->s = ret->m = ret->t = NULL;
1977
1978
    /* Clear selection bits we can't fulfill */
1979
27
    if (!ossl_ml_kem_have_pubkey(key))
1980
0
        selection = 0;
1981
27
    else if (!ossl_ml_kem_have_prvkey(key))
1982
27
        selection &= ~OSSL_KEYMGMT_SELECT_PRIVATE_KEY;
1983
1984
27
    switch (selection & OSSL_KEYMGMT_SELECT_KEYPAIR) {
1985
0
    case 0:
1986
0
        ok = 1;
1987
0
        break;
1988
27
    case OSSL_KEYMGMT_SELECT_PUBLIC_KEY:
1989
27
        ok = add_storage(OPENSSL_memdup(key->t, key->vinfo->puballoc), 0, ret);
1990
27
        ret->rho = ret->seedbuf;
1991
27
        ret->pkhash = ret->rho + ML_KEM_RANDOM_BYTES;
1992
27
        break;
1993
0
    case OSSL_KEYMGMT_SELECT_PRIVATE_KEY:
1994
0
        ok = add_storage(OPENSSL_memdup(key->t, key->vinfo->prvalloc), 1, ret);
1995
        /* Duplicated keys retain |d|, if available */
1996
0
        if (key->d != NULL)
1997
0
            ret->d = ret->z + ML_KEM_RANDOM_BYTES;
1998
0
        break;
1999
27
    }
2000
2001
27
    if (!ok) {
2002
0
        OPENSSL_free(ret);
2003
0
        return NULL;
2004
0
    }
2005
2006
27
    EVP_MD_up_ref(ret->shake128_md);
2007
27
    EVP_MD_up_ref(ret->shake256_md);
2008
27
    EVP_MD_up_ref(ret->sha3_256_md);
2009
27
    EVP_MD_up_ref(ret->sha3_512_md);
2010
2011
27
    return ret;
2012
27
}
2013
2014
void ossl_ml_kem_key_free(ML_KEM_KEY *key)
2015
89.8k
{
2016
89.8k
    if (key == NULL)
2017
72.3k
        return;
2018
2019
17.4k
    EVP_MD_free(key->shake128_md);
2020
17.4k
    EVP_MD_free(key->shake256_md);
2021
17.4k
    EVP_MD_free(key->sha3_256_md);
2022
17.4k
    EVP_MD_free(key->sha3_512_md);
2023
2024
17.4k
    if (ossl_ml_kem_decoded_key(key)) {
2025
0
        OPENSSL_cleanse(key->seedbuf, sizeof(key->seedbuf));
2026
0
        if (ossl_ml_kem_have_dkenc(key)) {
2027
0
            OPENSSL_cleanse(key->encoded_dk, key->vinfo->prvkey_bytes);
2028
0
            OPENSSL_free(key->encoded_dk);
2029
0
        }
2030
0
    }
2031
17.4k
    ossl_ml_kem_key_reset(key);
2032
17.4k
    OPENSSL_free(key);
2033
17.4k
}
2034
2035
/* Serialise the public component of an ML-KEM key */
2036
int ossl_ml_kem_encode_public_key(uint8_t *out, size_t len,
2037
    const ML_KEM_KEY *key)
2038
51.1k
{
2039
51.1k
    if (!ossl_ml_kem_have_pubkey(key)
2040
51.1k
        || len != key->vinfo->pubkey_bytes)
2041
0
        return 0;
2042
51.1k
    encode_pubkey(out, key);
2043
51.1k
    return 1;
2044
51.1k
}
2045
2046
/* Serialise an ML-KEM private key */
2047
int ossl_ml_kem_encode_private_key(uint8_t *out, size_t len,
2048
    const ML_KEM_KEY *key)
2049
132
{
2050
132
    if (!ossl_ml_kem_have_prvkey(key)
2051
132
        || len != key->vinfo->prvkey_bytes)
2052
0
        return 0;
2053
132
    encode_prvkey(out, key);
2054
132
    return 1;
2055
132
}
2056
2057
int ossl_ml_kem_encode_seed(uint8_t *out, size_t len,
2058
    const ML_KEM_KEY *key)
2059
220
{
2060
220
    if (key == NULL || key->d == NULL || len != ML_KEM_SEED_BYTES)
2061
30
        return 0;
2062
    /*
2063
     * Both in the seed buffer, and in the allocated storage, the |d| component
2064
     * of the seed is stored last, so we must copy each separately.
2065
     */
2066
190
    memcpy(out, key->d, ML_KEM_RANDOM_BYTES);
2067
190
    out += ML_KEM_RANDOM_BYTES;
2068
190
    memcpy(out, key->z, ML_KEM_RANDOM_BYTES);
2069
190
    return 1;
2070
220
}
2071
2072
/*
2073
 * Stash the seed without (yet) performing a keygen, used during decoding, to
2074
 * avoid an extra keygen if we're only going to export the key again to load
2075
 * into another provider.
2076
 */
2077
ML_KEM_KEY *ossl_ml_kem_set_seed(const uint8_t *seed, size_t seedlen, ML_KEM_KEY *key)
2078
30
{
2079
30
    if (key == NULL
2080
30
        || ossl_ml_kem_have_pubkey(key)
2081
30
        || ossl_ml_kem_have_seed(key)
2082
30
        || seedlen != ML_KEM_SEED_BYTES)
2083
0
        return NULL;
2084
    /*
2085
     * With no public or private key material on hand, we can use the seed
2086
     * buffer for |z| and |d|, in that order.
2087
     */
2088
30
    key->z = key->seedbuf;
2089
30
    key->d = key->z + ML_KEM_RANDOM_BYTES;
2090
30
    memcpy(key->d, seed, ML_KEM_RANDOM_BYTES);
2091
30
    seed += ML_KEM_RANDOM_BYTES;
2092
30
    memcpy(key->z, seed, ML_KEM_RANDOM_BYTES);
2093
30
    return key;
2094
30
}
2095
2096
/* Parse input as a public key */
2097
int ossl_ml_kem_parse_public_key(const uint8_t *in, size_t len, ML_KEM_KEY *key)
2098
531
{
2099
531
    EVP_MD_CTX *mdctx = NULL;
2100
531
    const ML_KEM_VINFO *vinfo;
2101
531
    int ret = 0;
2102
2103
    /* Keys with key material are immutable */
2104
531
    if (key == NULL
2105
531
        || ossl_ml_kem_have_pubkey(key)
2106
531
        || ossl_ml_kem_have_dkenc(key))
2107
0
        return 0;
2108
531
    vinfo = key->vinfo;
2109
2110
531
    if (len != vinfo->pubkey_bytes
2111
531
        || (mdctx = EVP_MD_CTX_new()) == NULL)
2112
0
        return 0;
2113
2114
531
    if (add_storage(OPENSSL_malloc(vinfo->puballoc), 0, key))
2115
531
        ret = parse_pubkey(in, mdctx, key);
2116
2117
531
    if (!ret)
2118
209
        ossl_ml_kem_key_reset(key);
2119
531
    EVP_MD_CTX_free(mdctx);
2120
531
    return ret;
2121
531
}
2122
2123
/* Parse input as a new private key */
2124
int ossl_ml_kem_parse_private_key(const uint8_t *in, size_t len,
2125
    ML_KEM_KEY *key)
2126
80
{
2127
80
    EVP_MD_CTX *mdctx = NULL;
2128
80
    const ML_KEM_VINFO *vinfo;
2129
80
    int ret = 0;
2130
2131
    /* Keys with key material are immutable */
2132
80
    if (key == NULL
2133
80
        || ossl_ml_kem_have_pubkey(key)
2134
80
        || ossl_ml_kem_have_dkenc(key))
2135
0
        return 0;
2136
80
    vinfo = key->vinfo;
2137
2138
80
    if (len != vinfo->prvkey_bytes
2139
80
        || (mdctx = EVP_MD_CTX_new()) == NULL)
2140
0
        return 0;
2141
2142
80
    if (add_storage(OPENSSL_malloc(vinfo->prvalloc), 1, key))
2143
80
        ret = parse_prvkey(in, mdctx, key);
2144
2145
80
    if (!ret)
2146
80
        ossl_ml_kem_key_reset(key);
2147
80
    EVP_MD_CTX_free(mdctx);
2148
80
    return ret;
2149
80
}
2150
2151
/*
2152
 * Generate a new keypair, either from the saved seed (when non-null), or from
2153
 * the RNG.
2154
 */
2155
int ossl_ml_kem_genkey(uint8_t *pubenc, size_t publen, ML_KEM_KEY *key)
2156
51.3k
{
2157
51.3k
    uint8_t seed[ML_KEM_SEED_BYTES];
2158
51.3k
    EVP_MD_CTX *mdctx = NULL;
2159
51.3k
    const ML_KEM_VINFO *vinfo;
2160
51.3k
    int ret = 0;
2161
2162
51.3k
    if (key == NULL
2163
51.3k
        || ossl_ml_kem_have_pubkey(key)
2164
51.3k
        || ossl_ml_kem_have_dkenc(key))
2165
0
        return 0;
2166
51.3k
    vinfo = key->vinfo;
2167
2168
51.3k
    if (pubenc != NULL && publen != vinfo->pubkey_bytes)
2169
0
        return 0;
2170
2171
51.3k
    if (ossl_ml_kem_have_seed(key)) {
2172
58
        if (!ossl_ml_kem_encode_seed(seed, sizeof(seed), key))
2173
0
            return 0;
2174
58
        key->d = key->z = NULL;
2175
51.3k
    } else if (RAND_priv_bytes_ex(key->libctx, seed, sizeof(seed),
2176
51.3k
                   key->vinfo->secbits)
2177
51.3k
        <= 0) {
2178
0
        return 0;
2179
0
    }
2180
2181
51.3k
    if ((mdctx = EVP_MD_CTX_new()) == NULL)
2182
0
        return 0;
2183
2184
    /*
2185
     * Data derived from (d, z) defaults secret, and to avoid side-channel
2186
     * leaks should not influence control flow.
2187
     */
2188
51.3k
    CONSTTIME_SECRET(seed, ML_KEM_SEED_BYTES);
2189
2190
51.3k
    if (add_storage(OPENSSL_malloc(vinfo->prvalloc), 1, key))
2191
51.3k
        ret = genkey(seed, mdctx, pubenc, key);
2192
51.3k
    OPENSSL_cleanse(seed, sizeof(seed));
2193
2194
    /* Declassify secret inputs and derived outputs before returning control */
2195
51.3k
    CONSTTIME_DECLASSIFY(seed, ML_KEM_SEED_BYTES);
2196
2197
51.3k
    EVP_MD_CTX_free(mdctx);
2198
51.3k
    if (!ret) {
2199
0
        ossl_ml_kem_key_reset(key);
2200
0
        return 0;
2201
0
    }
2202
2203
    /* The public components are already declassified */
2204
51.3k
    CONSTTIME_DECLASSIFY(key->s, vinfo->rank * sizeof(scalar));
2205
51.3k
    CONSTTIME_DECLASSIFY(key->z, 2 * ML_KEM_RANDOM_BYTES);
2206
51.3k
    return 1;
2207
51.3k
}
2208
2209
/*
2210
 * FIPS 203, Section 6.2, Algorithm 17: ML-KEM.Encaps_internal
2211
 * This is the deterministic version with randomness supplied externally.
2212
 */
2213
int ossl_ml_kem_encap_seed(uint8_t *ctext, size_t clen,
2214
    uint8_t *shared_secret, size_t slen,
2215
    const uint8_t *entropy, size_t elen,
2216
    const ML_KEM_KEY *key)
2217
169
{
2218
169
    const ML_KEM_VINFO *vinfo;
2219
169
    EVP_MD_CTX *mdctx;
2220
169
    int ret = 0;
2221
2222
169
    if (key == NULL || !ossl_ml_kem_have_pubkey(key))
2223
0
        return 0;
2224
169
    vinfo = key->vinfo;
2225
2226
169
    if (ctext == NULL || clen != vinfo->ctext_bytes
2227
169
        || shared_secret == NULL || slen != ML_KEM_SHARED_SECRET_BYTES
2228
169
        || entropy == NULL || elen != ML_KEM_RANDOM_BYTES
2229
169
        || (mdctx = EVP_MD_CTX_new()) == NULL)
2230
0
        return 0;
2231
    /*
2232
     * Data derived from the encap entropy defaults secret, and to avoid
2233
     * side-channel leaks should not influence control flow.
2234
     */
2235
169
    CONSTTIME_SECRET(entropy, elen);
2236
2237
    /*-
2238
     * This avoids the need to handle allocation failures for two (max 2KB
2239
     * each) vectors, that are never retained on return from this function.
2240
     * We stack-allocate these.
2241
     */
2242
169
#define case_encap_seed(bits)                                        \
2243
169
    case EVP_PKEY_ML_KEM_##bits: {                                   \
2244
169
        scalar tmp[2 * ML_KEM_##bits##_RANK];                        \
2245
169
                                                                     \
2246
169
        ret = encap(ctext, shared_secret, entropy, tmp, mdctx, key); \
2247
169
        OPENSSL_cleanse((void *)tmp, sizeof(tmp));                   \
2248
169
        break;                                                       \
2249
169
    }
2250
169
    switch (vinfo->evp_type) {
2251
54
        case_encap_seed(512);
2252
75
        case_encap_seed(768);
2253
40
        case_encap_seed(1024);
2254
169
    }
2255
169
#undef case_encap_seed
2256
2257
    /* Declassify secret inputs and derived outputs before returning control */
2258
169
    CONSTTIME_DECLASSIFY(entropy, elen);
2259
169
    CONSTTIME_DECLASSIFY(ctext, clen);
2260
169
    CONSTTIME_DECLASSIFY(shared_secret, slen);
2261
2262
169
    EVP_MD_CTX_free(mdctx);
2263
169
    return ret;
2264
169
}
2265
2266
int ossl_ml_kem_encap_rand(uint8_t *ctext, size_t clen,
2267
    uint8_t *shared_secret, size_t slen,
2268
    const ML_KEM_KEY *key)
2269
169
{
2270
169
    uint8_t r[ML_KEM_RANDOM_BYTES];
2271
2272
169
    if (key == NULL)
2273
0
        return 0;
2274
2275
169
    if (RAND_bytes_ex(key->libctx, r, ML_KEM_RANDOM_BYTES,
2276
169
            key->vinfo->secbits)
2277
169
        < 1)
2278
0
        return 0;
2279
2280
169
    return ossl_ml_kem_encap_seed(ctext, clen, shared_secret, slen,
2281
169
        r, sizeof(r), key);
2282
169
}
2283
2284
int ossl_ml_kem_decap(uint8_t *shared_secret, size_t slen,
2285
    const uint8_t *ctext, size_t clen,
2286
    const ML_KEM_KEY *key)
2287
141
{
2288
141
    const ML_KEM_VINFO *vinfo;
2289
141
    EVP_MD_CTX *mdctx;
2290
141
    int ret = 0;
2291
#if defined(OPENSSL_CONSTANT_TIME_VALIDATION)
2292
    int classify_bytes;
2293
#endif
2294
2295
    /* Need a private key here */
2296
141
    if (!ossl_ml_kem_have_prvkey(key))
2297
0
        return 0;
2298
141
    vinfo = key->vinfo;
2299
2300
141
    if (shared_secret == NULL || slen != ML_KEM_SHARED_SECRET_BYTES
2301
141
        || ctext == NULL || clen != vinfo->ctext_bytes
2302
141
        || (mdctx = EVP_MD_CTX_new()) == NULL) {
2303
0
        (void)RAND_bytes_ex(key->libctx, shared_secret,
2304
0
            ML_KEM_SHARED_SECRET_BYTES, vinfo->secbits);
2305
0
        return 0;
2306
0
    }
2307
#if defined(OPENSSL_CONSTANT_TIME_VALIDATION)
2308
    /*
2309
     * Data derived from |s| and |z| defaults secret, and to avoid side-channel
2310
     * leaks should not influence control flow.
2311
     */
2312
    classify_bytes = 2 * sizeof(scalar) + ML_KEM_RANDOM_BYTES;
2313
    CONSTTIME_SECRET(key->s, classify_bytes);
2314
#endif
2315
2316
    /*-
2317
     * This avoids the need to handle allocation failures for two (max 2KB
2318
     * each) vectors and an encoded ciphertext (max 1568 bytes), that are never
2319
     * retained on return from this function.
2320
     * We stack-allocate these.
2321
     */
2322
141
#define case_decap(bits)                                          \
2323
141
    case EVP_PKEY_ML_KEM_##bits: {                                \
2324
141
        uint8_t cbuf[CTEXT_BYTES(bits)];                          \
2325
141
        scalar tmp[2 * ML_KEM_##bits##_RANK];                     \
2326
141
                                                                  \
2327
141
        ret = decap(shared_secret, ctext, cbuf, tmp, mdctx, key); \
2328
141
        OPENSSL_cleanse((void *)tmp, sizeof(tmp));                \
2329
141
        break;                                                    \
2330
141
    }
2331
141
    switch (vinfo->evp_type) {
2332
54
        case_decap(512);
2333
47
        case_decap(768);
2334
40
        case_decap(1024);
2335
141
    }
2336
2337
    /* Declassify secret inputs and derived outputs before returning control */
2338
141
    CONSTTIME_DECLASSIFY(key->s, classify_bytes);
2339
141
    CONSTTIME_DECLASSIFY(shared_secret, slen);
2340
141
    EVP_MD_CTX_free(mdctx);
2341
2342
141
    return ret;
2343
141
#undef case_decap
2344
141
}
2345
2346
int ossl_ml_kem_pubkey_cmp(const ML_KEM_KEY *key1, const ML_KEM_KEY *key2)
2347
140
{
2348
    /*
2349
     * This handles any unexpected differences in the ML-KEM variant rank,
2350
     * giving different key component structures, barring SHA3-256 hash
2351
     * collisions, the keys are the same size.
2352
     */
2353
140
    if (ossl_ml_kem_have_pubkey(key1) && ossl_ml_kem_have_pubkey(key2))
2354
140
        return memcmp(key1->pkhash, key2->pkhash, ML_KEM_PKHASH_BYTES) == 0;
2355
2356
    /*
2357
     * No match if just one of the public keys is not available, otherwise both
2358
     * are unavailable, and for now such keys are considered equal.
2359
     */
2360
0
    return (!(ossl_ml_kem_have_pubkey(key1) ^ ossl_ml_kem_have_pubkey(key2)));
2361
140
}