Coverage Report

Created: 2026-02-14 07:20

next uncovered line (L), next uncovered region (R), next uncovered branch (B)
/src/openssl35/crypto/ml_kem/ml_kem.c
Line
Count
Source
1
/*
2
 * Copyright 2024-2025 The OpenSSL Project Authors. All Rights Reserved.
3
 *
4
 * Licensed under the Apache License 2.0 (the "License").  You may not use
5
 * this file except in compliance with the License.  You can obtain a copy
6
 * in the file LICENSE in the source distribution or at
7
 * https://www.openssl.org/source/license.html
8
 */
9
10
#include <openssl/byteorder.h>
11
#include <openssl/rand.h>
12
#include <openssl/proverr.h>
13
#include "crypto/ml_kem.h"
14
#include "internal/common.h"
15
#include "internal/constant_time.h"
16
#include "internal/sha3.h"
17
18
#if defined(OPENSSL_CONSTANT_TIME_VALIDATION)
19
#include <valgrind/memcheck.h>
20
#endif
21
22
#if ML_KEM_SEED_BYTES != ML_KEM_SHARED_SECRET_BYTES + ML_KEM_RANDOM_BYTES
23
#error "ML-KEM keygen seed length != shared secret + random bytes length"
24
#endif
25
#if ML_KEM_SHARED_SECRET_BYTES != ML_KEM_RANDOM_BYTES
26
#error "Invalid unequal lengths of ML-KEM shared secret and random inputs"
27
#endif
28
29
#if UINT_MAX < UINT32_MAX
30
#error "Unsupported compiler: sizeof(unsigned int) < sizeof(uint32_t)"
31
#endif
32
33
/* Handy function-like bit-extraction macros */
34
38.8M
#define bit0(b) ((b) & 1)
35
271M
#define bitn(n, b) (((b) >> n) & 1)
36
37
/*
38
 * 12 bits are sufficient to losslessly represent values in [0, q-1].
39
 * INVERSE_DEGREE is (n/2)^-1 mod q; used in inverse NTT.
40
 */
41
3.02M
#define DEGREE ML_KEM_DEGREE
42
#define INVERSE_DEGREE (ML_KEM_PRIME - 2 * 13)
43
#define LOG2PRIME 12
44
#define BARRETT_SHIFT (2 * LOG2PRIME)
45
46
#ifdef SHA3_BLOCKSIZE
47
#define SHAKE128_BLOCKSIZE SHA3_BLOCKSIZE(128)
48
#endif
49
50
/*
51
 * Return whether a value that can only be 0 or 1 is non-zero, in constant time
52
 * in practice!  The return value is a mask that is all ones if true, and all
53
 * zeros otherwise (twos-complement arithmetic assumed for unsigned values).
54
 *
55
 * Although this is used in constant-time selects, we omit a value barrier
56
 * here.  Value barriers impede auto-vectorization (likely because it forces
57
 * the value to transit through a general-purpose register). On AArch64, this
58
 * is a difference of 2x.
59
 *
60
 * We usually add value barriers to selects because Clang turns consecutive
61
 * selects with the same condition into a branch instead of CMOV/CSEL. This
62
 * condition does not occur in Kyber, so omitting it seems to be safe so far,
63
 * but see |cbd_2|, |cbd_3|, where reduction needs to be specialised to the
64
 * sign of the input, rather than adding |q| in advance, and using the generic
65
 * |reduce_once|.  (David Benjamin, Chromium)
66
 */
67
#if 0
68
#define constish_time_non_zero(b) (~constant_time_is_zero(b));
69
#else
70
1.06G
#define constish_time_non_zero(b) (0u - (b))
71
#endif
72
73
/*
74
 * The scalar rejection-sampling buffer size needs to be a multiple of 12, but
75
 * is otherwise arbitrary, the preferred block size matches the internal buffer
76
 * size of SHAKE128, avoiding internal buffering and copying in SHAKE128. That
77
 * block size of (1600 - 256)/8 bytes, or 168, just happens to divide by 12!
78
 *
79
 * If the blocksize is unknown, or is not divisible by 12, 168 is used as a
80
 * fallback.
81
 */
82
#if defined(SHAKE128_BLOCKSIZE) && (SHAKE128_BLOCKSIZE) % 12 == 0
83
#define SCALAR_SAMPLING_BUFSIZE (SHAKE128_BLOCKSIZE)
84
#else
85
#define SCALAR_SAMPLING_BUFSIZE 168
86
#endif
87
88
/*
89
 * Structure of keys
90
 */
91
typedef struct ossl_ml_kem_scalar_st {
92
    /* On every function entry and exit, 0 <= c[i] < ML_KEM_PRIME. */
93
    uint16_t c[ML_KEM_DEGREE];
94
} scalar;
95
96
/* Key material allocation layout */
97
#define DECLARE_ML_KEM_KEYDATA(name, rank, private_sz)               \
98
    struct name##_alloc {                                            \
99
        /* Public vector |t| */                                      \
100
        scalar tbuf[(rank)];                                         \
101
        /* Pre-computed matrix |m| (FIPS 203 |A| transpose) */       \
102
        scalar mbuf[(rank) * (rank)] /* optional private key data */ \
103
            private_sz                                               \
104
    }
105
106
/* Declare variant-specific public and private storage */
107
#define DECLARE_ML_KEM_VARIANT_KEYDATA(bits)                        \
108
    DECLARE_ML_KEM_KEYDATA(pubkey_##bits, ML_KEM_##bits##_RANK, ;); \
109
    DECLARE_ML_KEM_KEYDATA(prvkey_##bits, ML_KEM_##bits##_RANK, ; scalar sbuf[ML_KEM_##bits##_RANK]; uint8_t zbuf[2 * ML_KEM_RANDOM_BYTES];)
110
DECLARE_ML_KEM_VARIANT_KEYDATA(512);
111
DECLARE_ML_KEM_VARIANT_KEYDATA(768);
112
DECLARE_ML_KEM_VARIANT_KEYDATA(1024);
113
#undef DECLARE_ML_KEM_VARIANT_KEYDATA
114
#undef DECLARE_ML_KEM_KEYDATA
115
116
typedef __owur int (*CBD_FUNC)(scalar *out, uint8_t in[ML_KEM_RANDOM_BYTES + 1],
117
    EVP_MD_CTX *mdctx, const ML_KEM_KEY *key);
118
static void scalar_encode(uint8_t *out, const scalar *s, int bits);
119
120
/*
121
 * The wire-form of a losslessly encoded vector uses 12-bits per element.
122
 *
123
 * The wire-form public key consists of the lossless encoding of the public
124
 * vector |t|, followed by the public seed |rho|.
125
 *
126
 * Our serialised private key concatenates serialisations of the private vector
127
 * |s|, the public key, the public key hash, and the failure secret |z|.
128
 */
129
#define VECTOR_BYTES(b) ((3 * DEGREE / 2) * ML_KEM_##b##_RANK)
130
#define PUBKEY_BYTES(b) (VECTOR_BYTES(b) + ML_KEM_RANDOM_BYTES)
131
#define PRVKEY_BYTES(b) (2 * PUBKEY_BYTES(b) + ML_KEM_PKHASH_BYTES)
132
133
/*
134
 * Encapsulation produces a vector "u" and a scalar "v", whose coordinates
135
 * (numbers modulo the ML-KEM prime "q") are lossily encoded using as "du" and
136
 * "dv" bits, respectively.  This encoding is the ciphertext input for
137
 * decapsulation.
138
 */
139
#define U_VECTOR_BYTES(b) ((DEGREE / 8) * ML_KEM_##b##_DU * ML_KEM_##b##_RANK)
140
#define V_SCALAR_BYTES(b) ((DEGREE / 8) * ML_KEM_##b##_DV)
141
#define CTEXT_BYTES(b) (U_VECTOR_BYTES(b) + V_SCALAR_BYTES(b))
142
143
#if defined(OPENSSL_CONSTANT_TIME_VALIDATION)
144
145
/*
146
 * CONSTTIME_SECRET takes a pointer and a number of bytes and marks that region
147
 * of memory as secret. Secret data is tracked as it flows to registers and
148
 * other parts of a memory. If secret data is used as a condition for a branch,
149
 * or as a memory index, it will trigger warnings in valgrind.
150
 */
151
#define CONSTTIME_SECRET(ptr, len) VALGRIND_MAKE_MEM_UNDEFINED(ptr, len)
152
153
/*
154
 * CONSTTIME_DECLASSIFY takes a pointer and a number of bytes and marks that
155
 * region of memory as public. Public data is not subject to constant-time
156
 * rules.
157
 */
158
#define CONSTTIME_DECLASSIFY(ptr, len) VALGRIND_MAKE_MEM_DEFINED(ptr, len)
159
160
#else
161
162
#define CONSTTIME_SECRET(ptr, len)
163
#define CONSTTIME_DECLASSIFY(ptr, len)
164
165
#endif
166
167
/*
168
 * Indices of slots in the vinfo tables below
169
 */
170
56.9k
#define ML_KEM_512_VINFO 0
171
158k
#define ML_KEM_768_VINFO 1
172
60.3k
#define ML_KEM_1024_VINFO 2
173
174
/*
175
 * Per-variant fixed parameters
176
 */
177
static const ML_KEM_VINFO vinfo_map[3] = {
178
    { "ML-KEM-512",
179
        PRVKEY_BYTES(512),
180
        sizeof(struct prvkey_512_alloc),
181
        PUBKEY_BYTES(512),
182
        sizeof(struct pubkey_512_alloc),
183
        CTEXT_BYTES(512),
184
        VECTOR_BYTES(512),
185
        U_VECTOR_BYTES(512),
186
        EVP_PKEY_ML_KEM_512,
187
        ML_KEM_512_BITS,
188
        ML_KEM_512_RANK,
189
        ML_KEM_512_DU,
190
        ML_KEM_512_DV,
191
        ML_KEM_512_SECBITS },
192
    { "ML-KEM-768",
193
        PRVKEY_BYTES(768),
194
        sizeof(struct prvkey_768_alloc),
195
        PUBKEY_BYTES(768),
196
        sizeof(struct pubkey_768_alloc),
197
        CTEXT_BYTES(768),
198
        VECTOR_BYTES(768),
199
        U_VECTOR_BYTES(768),
200
        EVP_PKEY_ML_KEM_768,
201
        ML_KEM_768_BITS,
202
        ML_KEM_768_RANK,
203
        ML_KEM_768_DU,
204
        ML_KEM_768_DV,
205
        ML_KEM_768_SECBITS },
206
    { "ML-KEM-1024",
207
        PRVKEY_BYTES(1024),
208
        sizeof(struct prvkey_1024_alloc),
209
        PUBKEY_BYTES(1024),
210
        sizeof(struct pubkey_1024_alloc),
211
        CTEXT_BYTES(1024),
212
        VECTOR_BYTES(1024),
213
        U_VECTOR_BYTES(1024),
214
        EVP_PKEY_ML_KEM_1024,
215
        ML_KEM_1024_BITS,
216
        ML_KEM_1024_RANK,
217
        ML_KEM_1024_DU,
218
        ML_KEM_1024_DV,
219
        ML_KEM_1024_SECBITS }
220
};
221
222
/*
223
 * Remainders modulo `kPrime`, for sufficiently small inputs, are computed in
224
 * constant time via Barrett reduction, and a final call to reduce_once(),
225
 * which reduces inputs that are at most 2*kPrime and is also constant-time.
226
 */
227
static const int kPrime = ML_KEM_PRIME;
228
static const unsigned int kBarrettShift = BARRETT_SHIFT;
229
static const size_t kBarrettMultiplier = (1 << BARRETT_SHIFT) / ML_KEM_PRIME;
230
static const uint16_t kHalfPrime = (ML_KEM_PRIME - 1) / 2;
231
static const uint16_t kInverseDegree = INVERSE_DEGREE;
232
233
/*
234
 * Python helper:
235
 *
236
 * p = 3329
237
 * def bitreverse(i):
238
 *     ret = 0
239
 *     for n in range(7):
240
 *         bit = i & 1
241
 *         ret <<= 1
242
 *         ret |= bit
243
 *         i >>= 1
244
 *     return ret
245
 */
246
247
/*-
248
 * First precomputed array from Appendix A of FIPS 203, or else Python:
249
 * kNTTRoots = [pow(17, bitreverse(i), p) for i in range(128)]
250
 */
251
static const uint16_t kNTTRoots[128] = {
252
    1,
253
    1729,
254
    2580,
255
    3289,
256
    2642,
257
    630,
258
    1897,
259
    848,
260
    1062,
261
    1919,
262
    193,
263
    797,
264
    2786,
265
    3260,
266
    569,
267
    1746,
268
    296,
269
    2447,
270
    1339,
271
    1476,
272
    3046,
273
    56,
274
    2240,
275
    1333,
276
    1426,
277
    2094,
278
    535,
279
    2882,
280
    2393,
281
    2879,
282
    1974,
283
    821,
284
    289,
285
    331,
286
    3253,
287
    1756,
288
    1197,
289
    2304,
290
    2277,
291
    2055,
292
    650,
293
    1977,
294
    2513,
295
    632,
296
    2865,
297
    33,
298
    1320,
299
    1915,
300
    2319,
301
    1435,
302
    807,
303
    452,
304
    1438,
305
    2868,
306
    1534,
307
    2402,
308
    2647,
309
    2617,
310
    1481,
311
    648,
312
    2474,
313
    3110,
314
    1227,
315
    910,
316
    17,
317
    2761,
318
    583,
319
    2649,
320
    1637,
321
    723,
322
    2288,
323
    1100,
324
    1409,
325
    2662,
326
    3281,
327
    233,
328
    756,
329
    2156,
330
    3015,
331
    3050,
332
    1703,
333
    1651,
334
    2789,
335
    1789,
336
    1847,
337
    952,
338
    1461,
339
    2687,
340
    939,
341
    2308,
342
    2437,
343
    2388,
344
    733,
345
    2337,
346
    268,
347
    641,
348
    1584,
349
    2298,
350
    2037,
351
    3220,
352
    375,
353
    2549,
354
    2090,
355
    1645,
356
    1063,
357
    319,
358
    2773,
359
    757,
360
    2099,
361
    561,
362
    2466,
363
    2594,
364
    2804,
365
    1092,
366
    403,
367
    1026,
368
    1143,
369
    2150,
370
    2775,
371
    886,
372
    1722,
373
    1212,
374
    1874,
375
    1029,
376
    2110,
377
    2935,
378
    885,
379
    2154,
380
};
381
382
/*
383
 * InverseNTTRoots = [pow(17, -bitreverse(i), p) for i in range(128)]
384
 * Listed in order of use in the inverse NTT loop (index 0 is skipped):
385
 *
386
 *  0, 64, 65, ..., 127, 32, 33, ..., 63, 16, 17, ..., 31, 8, 9, ...
387
 */
388
static const uint16_t kInverseNTTRoots[128] = {
389
    1,
390
    1175,
391
    2444,
392
    394,
393
    1219,
394
    2300,
395
    1455,
396
    2117,
397
    1607,
398
    2443,
399
    554,
400
    1179,
401
    2186,
402
    2303,
403
    2926,
404
    2237,
405
    525,
406
    735,
407
    863,
408
    2768,
409
    1230,
410
    2572,
411
    556,
412
    3010,
413
    2266,
414
    1684,
415
    1239,
416
    780,
417
    2954,
418
    109,
419
    1292,
420
    1031,
421
    1745,
422
    2688,
423
    3061,
424
    992,
425
    2596,
426
    941,
427
    892,
428
    1021,
429
    2390,
430
    642,
431
    1868,
432
    2377,
433
    1482,
434
    1540,
435
    540,
436
    1678,
437
    1626,
438
    279,
439
    314,
440
    1173,
441
    2573,
442
    3096,
443
    48,
444
    667,
445
    1920,
446
    2229,
447
    1041,
448
    2606,
449
    1692,
450
    680,
451
    2746,
452
    568,
453
    3312,
454
    2419,
455
    2102,
456
    219,
457
    855,
458
    2681,
459
    1848,
460
    712,
461
    682,
462
    927,
463
    1795,
464
    461,
465
    1891,
466
    2877,
467
    2522,
468
    1894,
469
    1010,
470
    1414,
471
    2009,
472
    3296,
473
    464,
474
    2697,
475
    816,
476
    1352,
477
    2679,
478
    1274,
479
    1052,
480
    1025,
481
    2132,
482
    1573,
483
    76,
484
    2998,
485
    3040,
486
    2508,
487
    1355,
488
    450,
489
    936,
490
    447,
491
    2794,
492
    1235,
493
    1903,
494
    1996,
495
    1089,
496
    3273,
497
    283,
498
    1853,
499
    1990,
500
    882,
501
    3033,
502
    1583,
503
    2760,
504
    69,
505
    543,
506
    2532,
507
    3136,
508
    1410,
509
    2267,
510
    2481,
511
    1432,
512
    2699,
513
    687,
514
    40,
515
    749,
516
    1600,
517
};
518
519
/*
520
 * Second precomputed array from Appendix A of FIPS 203 (normalised positive),
521
 * or else Python:
522
 * ModRoots = [pow(17, 2*bitreverse(i) + 1, p) for i in range(128)]
523
 */
524
static const uint16_t kModRoots[128] = {
525
    17,
526
    3312,
527
    2761,
528
    568,
529
    583,
530
    2746,
531
    2649,
532
    680,
533
    1637,
534
    1692,
535
    723,
536
    2606,
537
    2288,
538
    1041,
539
    1100,
540
    2229,
541
    1409,
542
    1920,
543
    2662,
544
    667,
545
    3281,
546
    48,
547
    233,
548
    3096,
549
    756,
550
    2573,
551
    2156,
552
    1173,
553
    3015,
554
    314,
555
    3050,
556
    279,
557
    1703,
558
    1626,
559
    1651,
560
    1678,
561
    2789,
562
    540,
563
    1789,
564
    1540,
565
    1847,
566
    1482,
567
    952,
568
    2377,
569
    1461,
570
    1868,
571
    2687,
572
    642,
573
    939,
574
    2390,
575
    2308,
576
    1021,
577
    2437,
578
    892,
579
    2388,
580
    941,
581
    733,
582
    2596,
583
    2337,
584
    992,
585
    268,
586
    3061,
587
    641,
588
    2688,
589
    1584,
590
    1745,
591
    2298,
592
    1031,
593
    2037,
594
    1292,
595
    3220,
596
    109,
597
    375,
598
    2954,
599
    2549,
600
    780,
601
    2090,
602
    1239,
603
    1645,
604
    1684,
605
    1063,
606
    2266,
607
    319,
608
    3010,
609
    2773,
610
    556,
611
    757,
612
    2572,
613
    2099,
614
    1230,
615
    561,
616
    2768,
617
    2466,
618
    863,
619
    2594,
620
    735,
621
    2804,
622
    525,
623
    1092,
624
    2237,
625
    403,
626
    2926,
627
    1026,
628
    2303,
629
    1143,
630
    2186,
631
    2150,
632
    1179,
633
    2775,
634
    554,
635
    886,
636
    2443,
637
    1722,
638
    1607,
639
    1212,
640
    2117,
641
    1874,
642
    1455,
643
    1029,
644
    2300,
645
    2110,
646
    1219,
647
    2935,
648
    394,
649
    885,
650
    2444,
651
    2154,
652
    1175,
653
};
654
655
/*
656
 * single_keccak hashes |inlen| bytes from |in| and writes |outlen| bytes of
657
 * output to |out|. If the |md| specifies a fixed-output function, like
658
 * SHA3-256, then |outlen| must be the correct length for that function.
659
 */
660
static __owur int single_keccak(uint8_t *out, size_t outlen, const uint8_t *in, size_t inlen,
661
    EVP_MD_CTX *mdctx)
662
353k
{
663
353k
    unsigned int sz = (unsigned int)outlen;
664
665
353k
    if (!EVP_DigestUpdate(mdctx, in, inlen))
666
0
        return 0;
667
353k
    if (EVP_MD_xof(EVP_MD_CTX_get0_md(mdctx)))
668
302k
        return EVP_DigestFinalXOF(mdctx, out, outlen);
669
50.7k
    return EVP_DigestFinal_ex(mdctx, out, &sz)
670
50.7k
        && ossl_assert((size_t)sz == outlen);
671
353k
}
672
673
/*
674
 * FIPS 203, Section 4.1, equation (4.3): PRF. Takes 32+1 input bytes, and uses
675
 * SHAKE256 to produce the input to SamplePolyCBD_eta: FIPS 203, algorithm 8.
676
 */
677
static __owur int prf(uint8_t *out, size_t len, const uint8_t in[ML_KEM_RANDOM_BYTES + 1],
678
    EVP_MD_CTX *mdctx, const ML_KEM_KEY *key)
679
302k
{
680
302k
    return EVP_DigestInit_ex(mdctx, key->shake256_md, NULL)
681
302k
        && single_keccak(out, len, in, ML_KEM_RANDOM_BYTES + 1, mdctx);
682
302k
}
683
684
/*
685
 * FIPS 203, Section 4.1, equation (4.4): H.  SHA3-256 hash of a variable
686
 * length input, producing 32 bytes of output.
687
 */
688
static __owur int hash_h(uint8_t out[ML_KEM_PKHASH_BYTES], const uint8_t *in, size_t len,
689
    EVP_MD_CTX *mdctx, const ML_KEM_KEY *key)
690
311
{
691
311
    return EVP_DigestInit_ex(mdctx, key->sha3_256_md, NULL)
692
311
        && single_keccak(out, ML_KEM_PKHASH_BYTES, in, len, mdctx);
693
311
}
694
695
/* Incremental hash_h of expanded public key */
696
static int
697
hash_h_pubkey(uint8_t pkhash[ML_KEM_PKHASH_BYTES],
698
    EVP_MD_CTX *mdctx, ML_KEM_KEY *key)
699
50.1k
{
700
50.1k
    const ML_KEM_VINFO *vinfo = key->vinfo;
701
50.1k
    const scalar *t = key->t, *end = t + vinfo->rank;
702
50.1k
    unsigned int sz;
703
704
50.1k
    if (!EVP_DigestInit_ex(mdctx, key->sha3_256_md, NULL))
705
0
        return 0;
706
707
150k
    do {
708
150k
        uint8_t buf[3 * DEGREE / 2];
709
710
150k
        scalar_encode(buf, t++, 12);
711
150k
        if (!EVP_DigestUpdate(mdctx, buf, sizeof(buf)))
712
0
            return 0;
713
150k
    } while (t < end);
714
715
50.1k
    if (!EVP_DigestUpdate(mdctx, key->rho, ML_KEM_RANDOM_BYTES))
716
0
        return 0;
717
50.1k
    return EVP_DigestFinal_ex(mdctx, pkhash, &sz)
718
50.1k
        && ossl_assert(sz == ML_KEM_PKHASH_BYTES);
719
50.1k
}
720
721
/*
722
 * FIPS 203, Section 4.1, equation (4.5): G.  SHA3-512 hash of a variable
723
 * length input, producing 64 bytes of output, in particular the seeds
724
 * (d,z) for key generation.
725
 */
726
static __owur int hash_g(uint8_t out[ML_KEM_SEED_BYTES], const uint8_t *in, size_t len,
727
    EVP_MD_CTX *mdctx, const ML_KEM_KEY *key)
728
50.4k
{
729
50.4k
    return EVP_DigestInit_ex(mdctx, key->sha3_512_md, NULL)
730
50.4k
        && single_keccak(out, ML_KEM_SEED_BYTES, in, len, mdctx);
731
50.4k
}
732
733
/*
734
 * FIPS 203, Section 4.1, equation (4.4): J. SHAKE256 taking a variable length
735
 * input to compute a 32-byte implicit rejection shared secret, of the same
736
 * length as the expected shared secret.  (Computed even on success to avoid
737
 * side-channel leaks).
738
 */
739
static __owur int kdf(uint8_t out[ML_KEM_SHARED_SECRET_BYTES],
740
    const uint8_t z[ML_KEM_RANDOM_BYTES],
741
    const uint8_t *ctext, size_t len,
742
    EVP_MD_CTX *mdctx, const ML_KEM_KEY *key)
743
123
{
744
123
    return EVP_DigestInit_ex(mdctx, key->shake256_md, NULL)
745
123
        && EVP_DigestUpdate(mdctx, z, ML_KEM_RANDOM_BYTES)
746
123
        && EVP_DigestUpdate(mdctx, ctext, len)
747
123
        && EVP_DigestFinalXOF(mdctx, out, ML_KEM_SHARED_SECRET_BYTES);
748
123
}
749
750
/*
751
 * FIPS 203, Section 4.2.2, Algorithm 7: "SampleNTT" (steps 3-17, steps 1, 2
752
 * are performed by the caller). Rejection-samples a Keccak stream to get
753
 * uniformly distributed elements in the range [0,q). This is used for matrix
754
 * expansion and only operates on public inputs.
755
 */
756
static __owur int sample_scalar(scalar *out, EVP_MD_CTX *mdctx)
757
453k
{
758
453k
    uint16_t *curr = out->c, *endout = curr + DEGREE;
759
453k
    uint8_t buf[SCALAR_SAMPLING_BUFSIZE], *in;
760
453k
    uint8_t *endin = buf + sizeof(buf);
761
453k
    uint16_t d;
762
453k
    uint8_t b1, b2, b3;
763
764
1.36M
    do {
765
1.36M
        if (!EVP_DigestSqueeze(mdctx, in = buf, sizeof(buf)))
766
0
            return 0;
767
71.1M
        do {
768
71.1M
            b1 = *in++;
769
71.1M
            b2 = *in++;
770
71.1M
            b3 = *in++;
771
772
71.1M
            if (curr >= endout)
773
152k
                break;
774
71.0M
            if ((d = ((b2 & 0x0f) << 8) + b1) < kPrime)
775
58.5M
                *curr++ = d;
776
71.0M
            if (curr >= endout)
777
301k
                break;
778
70.7M
            if ((d = (b3 << 4) + (b2 >> 4)) < kPrime)
779
57.5M
                *curr++ = d;
780
70.7M
        } while (in < endin);
781
1.36M
    } while (curr < endout);
782
453k
    return 1;
783
453k
}
784
785
/*-
786
 * reduce_once reduces 0 <= x < 2*kPrime, mod kPrime.
787
 *
788
 * Subtract |q| if the input is larger, without exposing a side-channel,
789
 * avoiding the "clangover" attack.  See |constish_time_non_zero| for a
790
 * discussion on why the value barrier is by default omitted.
791
 */
792
static __owur uint16_t reduce_once(uint16_t x)
793
989M
{
794
989M
    const uint16_t subtracted = x - kPrime;
795
989M
    uint16_t mask = constish_time_non_zero(subtracted >> 15);
796
797
989M
    return (mask & x) | (~mask & subtracted);
798
989M
}
799
800
/*
801
 * Constant-time reduce x mod kPrime using Barrett reduction. x must be less
802
 * than kPrime + 2 * kPrime^2.  This is sufficient to reduce a product of
803
 * two already reduced u_int16 values, in fact it is sufficient for each
804
 * to be less than 2^12, because (kPrime * (2 * kPrime + 1)) > 2^24.
805
 */
806
static __owur uint16_t reduce(uint32_t x)
807
446M
{
808
446M
    uint64_t product = (uint64_t)x * kBarrettMultiplier;
809
446M
    uint32_t quotient = (uint32_t)(product >> kBarrettShift);
810
446M
    uint32_t remainder = x - quotient * kPrime;
811
812
446M
    return reduce_once(remainder);
813
446M
}
814
815
/* Multiply a scalar by a constant. */
816
static void scalar_mult_const(scalar *s, uint16_t a)
817
1.15k
{
818
1.15k
    uint16_t *curr = s->c, *end = curr + DEGREE, tmp;
819
820
296k
    do {
821
296k
        tmp = reduce(*curr * a);
822
296k
        *curr++ = tmp;
823
296k
    } while (curr < end);
824
1.15k
}
825
826
/*-
827
 * FIPS 203, Section 4.3, Algorithm 9: "NTT".
828
 * In-place number theoretic transform of a given scalar.  Note that ML-KEM's
829
 * kPrime 3329 does not have a 512th root of unity, so this transform leaves
830
 * off the last iteration of the usual FFT code, with the 128 relevant roots of
831
 * unity being stored in NTTRoots.  This means the output should be seen as 128
832
 * elements in GF(3329^2), with the coefficients of the elements being
833
 * consecutive entries in |s->c|.
834
 */
835
static void scalar_ntt(scalar *s)
836
302k
{
837
302k
    const uint16_t *roots = kNTTRoots;
838
302k
    uint16_t *end = s->c + DEGREE;
839
302k
    int offset = DEGREE / 2;
840
841
2.11M
    do {
842
2.11M
        uint16_t *curr = s->c, *peer;
843
844
38.3M
        do {
845
38.3M
            uint16_t *pause = curr + offset, even, odd;
846
38.3M
            uint32_t zeta = *++roots;
847
848
38.3M
            peer = pause;
849
270M
            do {
850
270M
                even = *curr;
851
270M
                odd = reduce(*peer * zeta);
852
270M
                *peer++ = reduce_once(even - odd + kPrime);
853
270M
                *curr++ = reduce_once(odd + even);
854
270M
            } while (curr < pause);
855
38.3M
        } while ((curr = peer) < end);
856
2.11M
    } while ((offset >>= 1) >= 2);
857
302k
}
858
859
/*-
860
 * FIPS 203, Section 4.3, Algorithm 10: "NTT^(-1)".
861
 * In-place inverse number theoretic transform of a given scalar, with pairs of
862
 * entries of s->v being interpreted as elements of GF(3329^2). Just as with
863
 * the number theoretic transform, this leaves off the first step of the normal
864
 * iFFT to account for the fact that 3329 does not have a 512th root of unity,
865
 * using the precomputed 128 roots of unity stored in InverseNTTRoots.
866
 */
867
static void scalar_inverse_ntt(scalar *s)
868
1.15k
{
869
1.15k
    const uint16_t *roots = kInverseNTTRoots;
870
1.15k
    uint16_t *end = s->c + DEGREE;
871
1.15k
    int offset = 2;
872
873
8.09k
    do {
874
8.09k
        uint16_t *curr = s->c, *peer;
875
876
146k
        do {
877
146k
            uint16_t *pause = curr + offset, even, odd;
878
146k
            uint32_t zeta = *++roots;
879
880
146k
            peer = pause;
881
1.03M
            do {
882
1.03M
                even = *curr;
883
1.03M
                odd = *peer;
884
1.03M
                *peer++ = reduce(zeta * (even - odd + kPrime));
885
1.03M
                *curr++ = reduce_once(odd + even);
886
1.03M
            } while (curr < pause);
887
146k
        } while ((curr = peer) < end);
888
8.09k
    } while ((offset <<= 1) < DEGREE);
889
1.15k
    scalar_mult_const(s, kInverseDegree);
890
1.15k
}
891
892
/* Addition updating the LHS scalar in-place. */
893
static void scalar_add(scalar *lhs, const scalar *rhs)
894
1.03k
{
895
1.03k
    int i;
896
897
265k
    for (i = 0; i < DEGREE; i++)
898
264k
        lhs->c[i] = reduce_once(lhs->c[i] + rhs->c[i]);
899
1.03k
}
900
901
/* Subtraction updating the LHS scalar in-place. */
902
static void scalar_sub(scalar *lhs, const scalar *rhs)
903
123
{
904
123
    int i;
905
906
31.6k
    for (i = 0; i < DEGREE; i++)
907
31.4k
        lhs->c[i] = reduce_once(lhs->c[i] - rhs->c[i] + kPrime);
908
123
}
909
910
/*
911
 * Multiplying two scalars in the number theoretically transformed state. Since
912
 * 3329 does not have a 512th root of unity, this means we have to interpret
913
 * the 2*ith and (2*i+1)th entries of the scalar as elements of
914
 * GF(3329)[X]/(X^2 - 17^(2*bitreverse(i)+1)).
915
 *
916
 * The value of 17^(2*bitreverse(i)+1) mod 3329 is stored in the precomputed
917
 * ModRoots table. Note that our Barrett transform only allows us to multiply
918
 * two reduced numbers together, so we need some intermediate reduction steps,
919
 * even if an uint64_t could hold 3 multiplied numbers.
920
 */
921
static void scalar_mult(scalar *out, const scalar *lhs,
922
    const scalar *rhs)
923
1.15k
{
924
1.15k
    uint16_t *curr = out->c, *end = curr + DEGREE;
925
1.15k
    const uint16_t *lc = lhs->c, *rc = rhs->c;
926
1.15k
    const uint16_t *roots = kModRoots;
927
928
148k
    do {
929
148k
        uint32_t l0 = *lc++, r0 = *rc++;
930
148k
        uint32_t l1 = *lc++, r1 = *rc++;
931
148k
        uint32_t zetapow = *roots++;
932
933
148k
        *curr++ = reduce(l0 * r0 + reduce(l1 * r1) * zetapow);
934
148k
        *curr++ = reduce(l0 * r1 + l1 * r0);
935
148k
    } while (curr < end);
936
1.15k
}
937
938
/* Above, but add the result to an existing scalar */
939
static ossl_inline void scalar_mult_add(scalar *out, const scalar *lhs,
940
    const scalar *rhs)
941
453k
{
942
453k
    uint16_t *curr = out->c, *end = curr + DEGREE;
943
453k
    const uint16_t *lc = lhs->c, *rc = rhs->c;
944
453k
    const uint16_t *roots = kModRoots;
945
946
58.0M
    do {
947
58.0M
        uint32_t l0 = *lc++, r0 = *rc++;
948
58.0M
        uint32_t l1 = *lc++, r1 = *rc++;
949
58.0M
        uint16_t *c0 = curr++;
950
58.0M
        uint16_t *c1 = curr++;
951
58.0M
        uint32_t zetapow = *roots++;
952
953
58.0M
        *c0 = reduce(*c0 + l0 * r0 + reduce(l1 * r1) * zetapow);
954
58.0M
        *c1 = reduce(*c1 + l0 * r1 + l1 * r0);
955
58.0M
    } while (curr < end);
956
453k
}
957
958
/*-
959
 * FIPS 203, Section 4.2.1, Algorithm 5: "ByteEncode_d", for 2<=d<=12.
960
 * Here |bits| is |d|.  For efficiency, we handle the d=1 case separately.
961
 */
962
static void scalar_encode(uint8_t *out, const scalar *s, int bits)
963
302k
{
964
302k
    const uint16_t *curr = s->c, *end = curr + DEGREE;
965
302k
    uint64_t accum = 0, element;
966
302k
    int used = 0;
967
968
77.3M
    do {
969
77.3M
        element = *curr++;
970
77.3M
        if (used + bits < 64) {
971
62.8M
            accum |= element << used;
972
62.8M
            used += bits;
973
62.8M
        } else if (used + bits > 64) {
974
9.66M
            out = OPENSSL_store_u64_le(out, accum | (element << used));
975
9.66M
            accum = element >> (64 - used);
976
9.66M
            used = (used + bits) - 64;
977
9.66M
        } else {
978
4.82M
            out = OPENSSL_store_u64_le(out, accum | (element << used));
979
4.82M
            accum = 0;
980
4.82M
            used = 0;
981
4.82M
        }
982
77.3M
    } while (curr < end);
983
302k
}
984
985
/*
986
 * scalar_encode_1 is |scalar_encode| specialised for |bits| == 1.
987
 */
988
static void scalar_encode_1(uint8_t out[DEGREE / 8], const scalar *s)
989
123
{
990
123
    int i, j;
991
123
    uint8_t out_byte;
992
993
4.05k
    for (i = 0; i < DEGREE; i += 8) {
994
3.93k
        out_byte = 0;
995
35.4k
        for (j = 0; j < 8; j++)
996
31.4k
            out_byte |= bit0(s->c[i + j]) << j;
997
3.93k
        *out = out_byte;
998
3.93k
        out++;
999
3.93k
    }
1000
123
}
1001
1002
/*-
1003
 * FIPS 203, Section 4.2.1, Algorithm 6: "ByteDecode_d", for 2<=d<12.
1004
 * Here |bits| is |d|.  For efficiency, we handle the d=1 and d=12 cases
1005
 * separately.
1006
 *
1007
 * scalar_decode parses |DEGREE * bits| bits from |in| into |DEGREE| values in
1008
 * |out|.
1009
 */
1010
static void scalar_decode(scalar *out, const uint8_t *in, int bits)
1011
479
{
1012
479
    uint16_t *curr = out->c, *end = curr + DEGREE;
1013
479
    uint64_t accum = 0;
1014
479
    int accum_bits = 0, todo = bits;
1015
479
    uint16_t bitmask = (((uint16_t)1) << bits) - 1, mask = bitmask;
1016
479
    uint16_t element = 0;
1017
1018
135k
    do {
1019
135k
        if (accum_bits == 0) {
1020
16.9k
            in = OPENSSL_load_u64_le(&accum, in);
1021
16.9k
            accum_bits = 64;
1022
16.9k
        }
1023
135k
        if (todo == bits && accum_bits >= bits) {
1024
            /* No partial "element", and all the required bits available */
1025
109k
            *curr++ = ((uint16_t)accum) & mask;
1026
109k
            accum >>= bits;
1027
109k
            accum_bits -= bits;
1028
109k
        } else if (accum_bits >= todo) {
1029
            /* A partial "element", and all the required bits available */
1030
13.2k
            *curr++ = element | ((((uint16_t)accum) & mask) << (bits - todo));
1031
13.2k
            accum >>= todo;
1032
13.2k
            accum_bits -= todo;
1033
13.2k
            element = 0;
1034
13.2k
            todo = bits;
1035
13.2k
            mask = bitmask;
1036
13.2k
        } else {
1037
            /*
1038
             * Only some of the requisite bits accumulated, store |accum_bits|
1039
             * of these in |element|.  The accumulated bitcount becomes 0, but
1040
             * as soon as we have more bits we'll want to merge accum_bits
1041
             * fewer of them into the final |element|.
1042
             *
1043
             * Note that with a 64-bit accumulator and |bits| always 12 or
1044
             * less, if we're here, the previous iteration had all the
1045
             * requisite bits, and so there are no kept bits in |element|.
1046
             */
1047
13.2k
            element = ((uint16_t)accum) & mask;
1048
13.2k
            todo -= accum_bits;
1049
13.2k
            mask = bitmask >> accum_bits;
1050
13.2k
            accum_bits = 0;
1051
13.2k
        }
1052
135k
    } while (curr < end);
1053
479
}
1054
1055
static __owur int scalar_decode_12(scalar *out, const uint8_t in[3 * DEGREE / 2])
1056
1.20k
{
1057
1.20k
    int i;
1058
1.20k
    uint16_t *c = out->c;
1059
1060
124k
    for (i = 0; i < DEGREE / 2; ++i) {
1061
123k
        uint8_t b1 = *in++;
1062
123k
        uint8_t b2 = *in++;
1063
123k
        uint8_t b3 = *in++;
1064
123k
        int outOfRange1 = (*c++ = b1 | ((b2 & 0x0f) << 8)) >= kPrime;
1065
123k
        int outOfRange2 = (*c++ = (b2 >> 4) | (b3 << 4)) >= kPrime;
1066
1067
123k
        if (outOfRange1 | outOfRange2)
1068
292
            return 0;
1069
123k
    }
1070
912
    return 1;
1071
1.20k
}
1072
1073
/*-
1074
 * scalar_decode_decompress_add is a combination of decoding and decompression
1075
 * both specialised for |bits| == 1, with the result added (and sum reduced) to
1076
 * the output scalar.
1077
 *
1078
 * NOTE: this function MUST not leak an input-data-depedennt timing signal.
1079
 * A timing leak in a related function in the reference Kyber implementation
1080
 * made the "clangover" attack (CVE-2024-37880) possible, giving key recovery
1081
 * for ML-KEM-512 in minutes, provided the attacker has access to precise
1082
 * timing of a CPU performing chosen-ciphertext decap.  Admittedly this is only
1083
 * a risk when private keys are reused (perhaps KEMTLS servers).
1084
 */
1085
static void
1086
scalar_decode_decompress_add(scalar *out, const uint8_t in[DEGREE / 8])
1087
265
{
1088
265
    static const uint16_t half_q_plus_1 = (ML_KEM_PRIME >> 1) + 1;
1089
265
    uint16_t *curr = out->c, *end = curr + DEGREE;
1090
265
    uint16_t mask;
1091
265
    uint8_t b;
1092
1093
    /*
1094
     * Add |half_q_plus_1| if the bit is set, without exposing a side-channel,
1095
     * avoiding the "clangover" attack.  See |constish_time_non_zero| for a
1096
     * discussion on why the value barrier is by default omitted.
1097
     */
1098
265
#define decode_decompress_add_bit                        \
1099
67.8k
    mask = constish_time_non_zero(bit0(b));              \
1100
67.8k
    *curr = reduce_once(*curr + (mask & half_q_plus_1)); \
1101
67.8k
    curr++;                                              \
1102
67.8k
    b >>= 1
1103
1104
    /* Unrolled to process each byte in one iteration */
1105
8.48k
    do {
1106
8.48k
        b = *in++;
1107
8.48k
        decode_decompress_add_bit;
1108
8.48k
        decode_decompress_add_bit;
1109
8.48k
        decode_decompress_add_bit;
1110
8.48k
        decode_decompress_add_bit;
1111
1112
8.48k
        decode_decompress_add_bit;
1113
8.48k
        decode_decompress_add_bit;
1114
8.48k
        decode_decompress_add_bit;
1115
8.48k
        decode_decompress_add_bit;
1116
8.48k
    } while (curr < end);
1117
265
#undef decode_decompress_add_bit
1118
265
}
1119
1120
/*
1121
 * FIPS 203, Section 4.2.1, Equation (4.7): Compress_d.
1122
 *
1123
 * Compresses (lossily) an input |x| mod 3329 into |bits| many bits by grouping
1124
 * numbers close to each other together. The formula used is
1125
 * round(2^|bits|/kPrime*x) mod 2^|bits|.
1126
 * Uses Barrett reduction to achieve constant time. Since we need both the
1127
 * remainder (for rounding) and the quotient (as the result), we cannot use
1128
 * |reduce| here, but need to do the Barrett reduction directly.
1129
 */
1130
static __owur uint16_t compress(uint16_t x, int bits)
1131
296k
{
1132
296k
    uint32_t shifted = (uint32_t)x << bits;
1133
296k
    uint64_t product = (uint64_t)shifted * kBarrettMultiplier;
1134
296k
    uint32_t quotient = (uint32_t)(product >> kBarrettShift);
1135
296k
    uint32_t remainder = shifted - quotient * kPrime;
1136
1137
    /*
1138
     * Adjust the quotient to round correctly:
1139
     *   0 <= remainder <= kHalfPrime round to 0
1140
     *   kHalfPrime < remainder <= kPrime + kHalfPrime round to 1
1141
     *   kPrime + kHalfPrime < remainder < 2 * kPrime round to 2
1142
     */
1143
296k
    quotient += 1 & constant_time_lt_32(kHalfPrime, remainder);
1144
296k
    quotient += 1 & constant_time_lt_32(kPrime + kHalfPrime, remainder);
1145
296k
    return quotient & ((1 << bits) - 1);
1146
296k
}
1147
1148
/*
1149
 * FIPS 203, Section 4.2.1, Equation (4.8): Decompress_d.
1150
1151
 * Decompresses |x| by using a close equi-distant representative. The formula
1152
 * is round(kPrime/2^|bits|*x). Note that 2^|bits| being the divisor allows us
1153
 * to implement this logic using only bit operations.
1154
 */
1155
static __owur uint16_t decompress(uint16_t x, int bits)
1156
122k
{
1157
122k
    uint32_t product = (uint32_t)x * kPrime;
1158
122k
    uint32_t power = 1 << bits;
1159
    /* This is |product| % power, since |power| is a power of 2. */
1160
122k
    uint32_t remainder = product & (power - 1);
1161
    /* This is |product| / power, since |power| is a power of 2. */
1162
122k
    uint32_t lower = product >> bits;
1163
1164
    /*
1165
     * The rounding logic works since the first half of numbers mod |power|
1166
     * have a 0 as first bit, and the second half has a 1 as first bit, since
1167
     * |power| is a power of 2. As a 12 bit number, |remainder| is always
1168
     * positive, so we will shift in 0s for a right shift.
1169
     */
1170
122k
    return lower + (remainder >> (bits - 1));
1171
122k
}
1172
1173
/*-
1174
 * FIPS 203, Section 4.2.1, Equation (4.7): "Compress_d".
1175
 * In-place lossy rounding of scalars to 2^d bits.
1176
 */
1177
static void scalar_compress(scalar *s, int bits)
1178
1.15k
{
1179
1.15k
    int i;
1180
1181
297k
    for (i = 0; i < DEGREE; i++)
1182
296k
        s->c[i] = compress(s->c[i], bits);
1183
1.15k
}
1184
1185
/*
1186
 * FIPS 203, Section 4.2.1, Equation (4.8): "Decompress_d".
1187
 * In-place approximate recovery of scalars from 2^d bit compression.
1188
 */
1189
static void scalar_decompress(scalar *s, int bits)
1190
479
{
1191
479
    int i;
1192
1193
123k
    for (i = 0; i < DEGREE; i++)
1194
122k
        s->c[i] = decompress(s->c[i], bits);
1195
479
}
1196
1197
/* Addition updating the LHS vector in-place. */
1198
static void vector_add(scalar *lhs, const scalar *rhs, int rank)
1199
265
{
1200
769
    do {
1201
769
        scalar_add(lhs++, rhs++);
1202
769
    } while (--rank > 0);
1203
265
}
1204
1205
/*
1206
 * Encodes an entire vector into 32*|rank|*|bits| bytes. Note that since 256
1207
 * (DEGREE) is divisible by 8, the individual vector entries will always fill a
1208
 * whole number of bytes, so we do not need to worry about bit packing here.
1209
 */
1210
static void vector_encode(uint8_t *out, const scalar *a, int bits, int rank)
1211
50.5k
{
1212
50.5k
    int stride = bits * DEGREE / 8;
1213
1214
202k
    for (; rank-- > 0; out += stride)
1215
151k
        scalar_encode(out, a++, bits);
1216
50.5k
}
1217
1218
/*
1219
 * Decodes 32*|rank|*|bits| bytes from |in| into |out|. It returns early
1220
 * if any parsed value is >= |ML_KEM_PRIME|.  The resulting scalars are
1221
 * then decompressed and transformed via the NTT.
1222
 *
1223
 * Note: Used only in decrypt_cpa(), which returns void and so does not check
1224
 * the return value of this function.  Side-channels are fine when the input
1225
 * ciphertext to decap() is simply syntactically invalid.
1226
 */
1227
static void
1228
vector_decode_decompress_ntt(scalar *out, const uint8_t *in, int bits, int rank)
1229
123
{
1230
123
    int stride = bits * DEGREE / 8;
1231
1232
479
    for (; rank-- > 0; in += stride, ++out) {
1233
356
        scalar_decode(out, in, bits);
1234
356
        scalar_decompress(out, bits);
1235
356
        scalar_ntt(out);
1236
356
    }
1237
123
}
1238
1239
/* vector_decode(), specialised to bits == 12. */
1240
static __owur int vector_decode_12(scalar *out, const uint8_t in[3 * DEGREE / 2], int rank)
1241
615
{
1242
615
    int stride = 3 * DEGREE / 2;
1243
1244
1.52k
    for (; rank-- > 0; in += stride)
1245
1.20k
        if (!scalar_decode_12(out++, in))
1246
292
            return 0;
1247
323
    return 1;
1248
615
}
1249
1250
/* In-place compression of each scalar component */
1251
static void vector_compress(scalar *a, int bits, int rank)
1252
265
{
1253
769
    do {
1254
769
        scalar_compress(a++, bits);
1255
769
    } while (--rank > 0);
1256
265
}
1257
1258
/* The output scalar must not overlap with the inputs */
1259
static void inner_product(scalar *out, const scalar *lhs, const scalar *rhs,
1260
    int rank)
1261
388
{
1262
388
    scalar_mult(out, lhs, rhs);
1263
1.12k
    while (--rank > 0)
1264
737
        scalar_mult_add(out, ++lhs, ++rhs);
1265
388
}
1266
1267
/*
1268
 * Here, the output vector must not overlap with the inputs, the result is
1269
 * directly subjected to inverse NTT.
1270
 */
1271
static void
1272
matrix_mult_intt(scalar *out, const scalar *m, const scalar *a, int rank)
1273
265
{
1274
265
    const scalar *ar;
1275
265
    int i, j;
1276
1277
1.03k
    for (i = rank; i-- > 0; ++out) {
1278
769
        scalar_mult(out, m++, ar = a);
1279
2.40k
        for (j = rank - 1; j > 0; --j)
1280
1.63k
            scalar_mult_add(out, m++, ++ar);
1281
769
        scalar_inverse_ntt(out);
1282
769
    }
1283
265
}
1284
1285
/* Here, the output vector must not overlap with the inputs */
1286
static void
1287
matrix_mult_transpose_add(scalar *out, const scalar *m, const scalar *a, int rank)
1288
50.1k
{
1289
50.1k
    const scalar *mc = m, *mr, *ar;
1290
50.1k
    int i, j;
1291
1292
200k
    for (i = rank; i-- > 0; ++out) {
1293
150k
        scalar_mult_add(out, mr = mc++, ar = a);
1294
451k
        for (j = rank; --j > 0;)
1295
301k
            scalar_mult_add(out, (mr += rank), ++ar);
1296
150k
    }
1297
50.1k
}
1298
1299
/*-
1300
 * Expands the matrix from a seed for key generation and for encaps-CPA.
1301
 * NOTE: FIPS 203 matrix "A" is the transpose of this matrix, computed
1302
 * by appending the (i,j) indices to the seed in the opposite order!
1303
 *
1304
 * Where FIPS 203 computes t = A * s + e, we use the transpose of "m".
1305
 */
1306
static __owur int matrix_expand(EVP_MD_CTX *mdctx, ML_KEM_KEY *key)
1307
50.5k
{
1308
50.5k
    scalar *out = key->m;
1309
50.5k
    uint8_t input[ML_KEM_RANDOM_BYTES + 2];
1310
50.5k
    int rank = key->vinfo->rank;
1311
50.5k
    int i, j;
1312
1313
50.5k
    memcpy(input, key->rho, ML_KEM_RANDOM_BYTES);
1314
201k
    for (i = 0; i < rank; i++) {
1315
605k
        for (j = 0; j < rank; j++) {
1316
453k
            input[ML_KEM_RANDOM_BYTES] = i;
1317
453k
            input[ML_KEM_RANDOM_BYTES + 1] = j;
1318
453k
            if (!EVP_DigestInit_ex(mdctx, key->shake128_md, NULL)
1319
453k
                || !EVP_DigestUpdate(mdctx, input, sizeof(input))
1320
453k
                || !sample_scalar(out++, mdctx))
1321
0
                return 0;
1322
453k
        }
1323
151k
    }
1324
50.5k
    return 1;
1325
50.5k
}
1326
1327
/*
1328
 * Algorithm 7 from the spec, with eta fixed to two and the PRF call
1329
 * included. Creates binominally distributed elements by sampling 2*|eta| bits,
1330
 * and setting the coefficient to the count of the first bits minus the count of
1331
 * the second bits, resulting in a centered binomial distribution. Since eta is
1332
 * two this gives -2/2 with a probability of 1/16, -1/1 with probability 1/4,
1333
 * and 0 with probability 3/8.
1334
 */
1335
static __owur int cbd_2(scalar *out, uint8_t in[ML_KEM_RANDOM_BYTES + 1],
1336
    EVP_MD_CTX *mdctx, const ML_KEM_KEY *key)
1337
301k
{
1338
301k
    uint16_t *curr = out->c, *end = curr + DEGREE;
1339
301k
    uint8_t randbuf[4 * DEGREE / 8], *r = randbuf; /* 64 * eta slots */
1340
301k
    uint16_t value, mask;
1341
301k
    uint8_t b;
1342
1343
301k
    if (!prf(randbuf, sizeof(randbuf), in, mdctx, key))
1344
0
        return 0;
1345
1346
38.5M
    do {
1347
38.5M
        b = *r++;
1348
1349
        /*
1350
         * Add |kPrime| if |value| underflowed.  See |constish_time_non_zero|
1351
         * for a discussion on why the value barrier is by default omitted.
1352
         * While this could have been written reduce_once(value + kPrime), this
1353
         * is one extra addition and small range of |value| tempts some
1354
         * versions of Clang to emit a branch.
1355
         */
1356
38.5M
        value = bit0(b) + bitn(1, b);
1357
38.5M
        value -= bitn(2, b) + bitn(3, b);
1358
38.5M
        mask = constish_time_non_zero(value >> 15);
1359
38.5M
        *curr++ = value + (kPrime & mask);
1360
1361
38.5M
        value = bitn(4, b) + bitn(5, b);
1362
38.5M
        value -= bitn(6, b) + bitn(7, b);
1363
38.5M
        mask = constish_time_non_zero(value >> 15);
1364
38.5M
        *curr++ = value + (kPrime & mask);
1365
38.5M
    } while (curr < end);
1366
301k
    return 1;
1367
301k
}
1368
1369
/*
1370
 * Algorithm 7 from the spec, with eta fixed to three and the PRF call
1371
 * included. Creates binominally distributed elements by sampling 3*|eta| bits,
1372
 * and setting the coefficient to the count of the first bits minus the count of
1373
 * the second bits, resulting in a centered binomial distribution.
1374
 */
1375
static __owur int cbd_3(scalar *out, uint8_t in[ML_KEM_RANDOM_BYTES + 1],
1376
    EVP_MD_CTX *mdctx, const ML_KEM_KEY *key)
1377
1.21k
{
1378
1.21k
    uint16_t *curr = out->c, *end = curr + DEGREE;
1379
1.21k
    uint8_t randbuf[6 * DEGREE / 8], *r = randbuf; /* 64 * eta slots */
1380
1.21k
    uint8_t b1, b2, b3;
1381
1.21k
    uint16_t value, mask;
1382
1383
1.21k
    if (!prf(randbuf, sizeof(randbuf), in, mdctx, key))
1384
0
        return 0;
1385
1386
77.5k
    do {
1387
77.5k
        b1 = *r++;
1388
77.5k
        b2 = *r++;
1389
77.5k
        b3 = *r++;
1390
1391
        /*
1392
         * Add |kPrime| if |value| underflowed.  See |constish_time_non_zero|
1393
         * for a discussion on why the value barrier is by default omitted.
1394
         * While this could have been written reduce_once(value + kPrime), this
1395
         * is one extra addition and small range of |value| tempts some
1396
         * versions of Clang to emit a branch.
1397
         */
1398
77.5k
        value = bit0(b1) + bitn(1, b1) + bitn(2, b1);
1399
77.5k
        value -= bitn(3, b1) + bitn(4, b1) + bitn(5, b1);
1400
77.5k
        mask = constish_time_non_zero(value >> 15);
1401
77.5k
        *curr++ = value + (kPrime & mask);
1402
1403
77.5k
        value = bitn(6, b1) + bitn(7, b1) + bit0(b2);
1404
77.5k
        value -= bitn(1, b2) + bitn(2, b2) + bitn(3, b2);
1405
77.5k
        mask = constish_time_non_zero(value >> 15);
1406
77.5k
        *curr++ = value + (kPrime & mask);
1407
1408
77.5k
        value = bitn(4, b2) + bitn(5, b2) + bitn(6, b2);
1409
77.5k
        value -= bitn(7, b2) + bit0(b3) + bitn(1, b3);
1410
77.5k
        mask = constish_time_non_zero(value >> 15);
1411
77.5k
        *curr++ = value + (kPrime & mask);
1412
1413
77.5k
        value = bitn(2, b3) + bitn(3, b3) + bitn(4, b3);
1414
77.5k
        value -= bitn(5, b3) + bitn(6, b3) + bitn(7, b3);
1415
77.5k
        mask = constish_time_non_zero(value >> 15);
1416
77.5k
        *curr++ = value + (kPrime & mask);
1417
77.5k
    } while (curr < end);
1418
1.21k
    return 1;
1419
1.21k
}
1420
1421
/*
1422
 * Generates a secret vector by using |cbd| with the given seed to generate
1423
 * scalar elements and incrementing |counter| for each slot of the vector.
1424
 */
1425
static __owur int gencbd_vector(scalar *out, CBD_FUNC cbd, uint8_t *counter,
1426
    const uint8_t seed[ML_KEM_RANDOM_BYTES], int rank,
1427
    EVP_MD_CTX *mdctx, const ML_KEM_KEY *key)
1428
265
{
1429
265
    uint8_t input[ML_KEM_RANDOM_BYTES + 1];
1430
1431
265
    memcpy(input, seed, ML_KEM_RANDOM_BYTES);
1432
769
    do {
1433
769
        input[ML_KEM_RANDOM_BYTES] = (*counter)++;
1434
769
        if (!cbd(out++, input, mdctx, key))
1435
0
            return 0;
1436
769
    } while (--rank > 0);
1437
265
    return 1;
1438
265
}
1439
1440
/*
1441
 * As above plus NTT transform.
1442
 */
1443
static __owur int gencbd_vector_ntt(scalar *out, CBD_FUNC cbd, uint8_t *counter,
1444
    const uint8_t seed[ML_KEM_RANDOM_BYTES], int rank,
1445
    EVP_MD_CTX *mdctx, const ML_KEM_KEY *key)
1446
100k
{
1447
100k
    uint8_t input[ML_KEM_RANDOM_BYTES + 1];
1448
1449
100k
    memcpy(input, seed, ML_KEM_RANDOM_BYTES);
1450
301k
    do {
1451
301k
        input[ML_KEM_RANDOM_BYTES] = (*counter)++;
1452
301k
        if (!cbd(out, input, mdctx, key))
1453
0
            return 0;
1454
301k
        scalar_ntt(out++);
1455
301k
    } while (--rank > 0);
1456
100k
    return 1;
1457
100k
}
1458
1459
/* The |ETA1| value for ML-KEM-512 is 3, the rest and all ETA2 values are 2. */
1460
34.6k
#define CBD1(evp_type) ((evp_type) == EVP_PKEY_ML_KEM_512 ? cbd_3 : cbd_2)
1461
1462
/*
1463
 * FIPS 203, Section 5.2, Algorithm 14: K-PKE.Encrypt.
1464
 *
1465
 * Encrypts a message with given randomness to the ciphertext in |out|. Without
1466
 * applying the Fujisaki-Okamoto transform this would not result in a CCA
1467
 * secure scheme, since lattice schemes are vulnerable to decryption failure
1468
 * oracles.
1469
 *
1470
 * The steps are re-ordered to make more efficient/localised use of storage.
1471
 *
1472
 * Note also that the input public key is assumed to hold a precomputed matrix
1473
 * |A| (our key->m, with the public key holding an expanded (16-bit per scalar
1474
 * coefficient) key->t vector).
1475
 *
1476
 * Caller passes storage in |tmp| for for two temporary vectors.
1477
 */
1478
static __owur int encrypt_cpa(uint8_t out[ML_KEM_SHARED_SECRET_BYTES],
1479
    const uint8_t message[DEGREE / 8],
1480
    const uint8_t r[ML_KEM_RANDOM_BYTES], scalar *tmp,
1481
    EVP_MD_CTX *mdctx, const ML_KEM_KEY *key)
1482
265
{
1483
265
    const ML_KEM_VINFO *vinfo = key->vinfo;
1484
265
    CBD_FUNC cbd_1 = CBD1(vinfo->evp_type);
1485
265
    int rank = vinfo->rank;
1486
    /* We can use tmp[0..rank-1] as storage for |y|, then |e1|, ... */
1487
265
    scalar *y = &tmp[0], *e1 = y, *e2 = y;
1488
    /* We can use tmp[rank]..tmp[2*rank - 1] for |u| */
1489
265
    scalar *u = &tmp[rank];
1490
265
    scalar v;
1491
265
    uint8_t input[ML_KEM_RANDOM_BYTES + 1];
1492
265
    uint8_t counter = 0;
1493
265
    int du = vinfo->du;
1494
265
    int dv = vinfo->dv;
1495
1496
    /* FIPS 203 "y" vector */
1497
265
    if (!gencbd_vector_ntt(y, cbd_1, &counter, r, rank, mdctx, key))
1498
0
        return 0;
1499
    /* FIPS 203 "v" scalar */
1500
265
    inner_product(&v, key->t, y, rank);
1501
265
    scalar_inverse_ntt(&v);
1502
    /* FIPS 203 "u" vector */
1503
265
    matrix_mult_intt(u, key->m, y, rank);
1504
1505
    /* All done with |y|, now free to reuse tmp[0] for FIPS 203 |e1| */
1506
265
    if (!gencbd_vector(e1, cbd_2, &counter, r, rank, mdctx, key))
1507
0
        return 0;
1508
265
    vector_add(u, e1, rank);
1509
265
    vector_compress(u, du, rank);
1510
265
    vector_encode(out, u, du, rank);
1511
1512
    /* All done with |e1|, now free to reuse tmp[0] for FIPS 203 |e2| */
1513
265
    memcpy(input, r, ML_KEM_RANDOM_BYTES);
1514
265
    input[ML_KEM_RANDOM_BYTES] = counter;
1515
265
    if (!cbd_2(e2, input, mdctx, key))
1516
0
        return 0;
1517
265
    scalar_add(&v, e2);
1518
1519
    /* Combine message with |v| */
1520
265
    scalar_decode_decompress_add(&v, message);
1521
265
    scalar_compress(&v, dv);
1522
265
    scalar_encode(out + vinfo->u_vector_bytes, &v, dv);
1523
265
    return 1;
1524
265
}
1525
1526
/*
1527
 * FIPS 203, Section 5.3, Algorithm 15: K-PKE.Decrypt.
1528
 */
1529
static void
1530
decrypt_cpa(uint8_t out[ML_KEM_SHARED_SECRET_BYTES],
1531
    const uint8_t *ctext, scalar *u, const ML_KEM_KEY *key)
1532
123
{
1533
123
    const ML_KEM_VINFO *vinfo = key->vinfo;
1534
123
    scalar v, mask;
1535
123
    int rank = vinfo->rank;
1536
123
    int du = vinfo->du;
1537
123
    int dv = vinfo->dv;
1538
1539
123
    vector_decode_decompress_ntt(u, ctext, du, rank);
1540
123
    scalar_decode(&v, ctext + vinfo->u_vector_bytes, dv);
1541
123
    scalar_decompress(&v, dv);
1542
123
    inner_product(&mask, key->s, u, rank);
1543
123
    scalar_inverse_ntt(&mask);
1544
123
    scalar_sub(&v, &mask);
1545
123
    scalar_compress(&v, 1);
1546
123
    scalar_encode_1(out, &v);
1547
123
}
1548
1549
/*-
1550
 * FIPS 203, Section 7.1, Algorithm 19: "ML-KEM.KeyGen".
1551
 * FIPS 203, Section 7.2, Algorithm 20: "ML-KEM.Encaps".
1552
 *
1553
 * Fills the |out| buffer with the |ek| output of "ML-KEM.KeyGen", or,
1554
 * equivalently, the |ek| input of "ML-KEM.Encaps", i.e. returns the
1555
 * wire-format of an ML-KEM public key.
1556
 */
1557
static void encode_pubkey(uint8_t *out, const ML_KEM_KEY *key)
1558
50.1k
{
1559
50.1k
    const uint8_t *rho = key->rho;
1560
50.1k
    const ML_KEM_VINFO *vinfo = key->vinfo;
1561
1562
50.1k
    vector_encode(out, key->t, 12, vinfo->rank);
1563
50.1k
    memcpy(out + vinfo->vector_bytes, rho, ML_KEM_RANDOM_BYTES);
1564
50.1k
}
1565
1566
/*-
1567
 * FIPS 203, Section 7.1, Algorithm 19: "ML-KEM.KeyGen".
1568
 *
1569
 * Fills the |out| buffer with the |dk| output of "ML-KEM.KeyGen".
1570
 * This matches the input format of parse_prvkey() below.
1571
 */
1572
static void encode_prvkey(uint8_t *out, const ML_KEM_KEY *key)
1573
119
{
1574
119
    const ML_KEM_VINFO *vinfo = key->vinfo;
1575
1576
119
    vector_encode(out, key->s, 12, vinfo->rank);
1577
119
    out += vinfo->vector_bytes;
1578
119
    encode_pubkey(out, key);
1579
119
    out += vinfo->pubkey_bytes;
1580
119
    memcpy(out, key->pkhash, ML_KEM_PKHASH_BYTES);
1581
119
    out += ML_KEM_PKHASH_BYTES;
1582
119
    memcpy(out, key->z, ML_KEM_RANDOM_BYTES);
1583
119
}
1584
1585
/*-
1586
 * FIPS 203, Section 7.1, Algorithm 19: "ML-KEM.KeyGen".
1587
 * FIPS 203, Section 7.2, Algorithm 20: "ML-KEM.Encaps".
1588
 *
1589
 * This function parses the |in| buffer as the |ek| output of "ML-KEM.KeyGen",
1590
 * or, equivalently, the |ek| input of "ML-KEM.Encaps", i.e. decodes the
1591
 * wire-format of the ML-KEM public key.
1592
 */
1593
static int parse_pubkey(const uint8_t *in, EVP_MD_CTX *mdctx, ML_KEM_KEY *key)
1594
522
{
1595
522
    const ML_KEM_VINFO *vinfo = key->vinfo;
1596
1597
    /* Decode and check |t| */
1598
522
    if (!vector_decode_12(key->t, in, vinfo->rank)) {
1599
211
        ERR_raise_data(ERR_LIB_PROV, PROV_R_INVALID_KEY,
1600
211
            "%s invalid public 't' vector",
1601
211
            vinfo->algorithm_name);
1602
211
        return 0;
1603
211
    }
1604
    /* Save the matrix |m| recovery seed |rho| */
1605
311
    memcpy(key->rho, in + vinfo->vector_bytes, ML_KEM_RANDOM_BYTES);
1606
    /*
1607
     * Pre-compute the public key hash, needed for both encap and decap.
1608
     * Also pre-compute the matrix expansion, stored with the public key.
1609
     */
1610
311
    if (!hash_h(key->pkhash, in, vinfo->pubkey_bytes, mdctx, key)
1611
311
        || !matrix_expand(mdctx, key)) {
1612
0
        ERR_raise_data(ERR_LIB_CRYPTO, ERR_R_INTERNAL_ERROR,
1613
0
            "internal error while parsing %s public key",
1614
0
            vinfo->algorithm_name);
1615
0
        return 0;
1616
0
    }
1617
311
    return 1;
1618
311
}
1619
1620
/*
1621
 * FIPS 203, Section 7.1, Algorithm 19: "ML-KEM.KeyGen".
1622
 *
1623
 * Parses the |in| buffer as a |dk| output of "ML-KEM.KeyGen".
1624
 * This matches the output format of encode_prvkey() above.
1625
 */
1626
static int parse_prvkey(const uint8_t *in, EVP_MD_CTX *mdctx, ML_KEM_KEY *key)
1627
93
{
1628
93
    const ML_KEM_VINFO *vinfo = key->vinfo;
1629
1630
    /* Decode and check |s|. */
1631
93
    if (!vector_decode_12(key->s, in, vinfo->rank)) {
1632
81
        ERR_raise_data(ERR_LIB_PROV, PROV_R_INVALID_KEY,
1633
81
            "%s invalid private 's' vector",
1634
81
            vinfo->algorithm_name);
1635
81
        return 0;
1636
81
    }
1637
12
    in += vinfo->vector_bytes;
1638
1639
12
    if (!parse_pubkey(in, mdctx, key))
1640
10
        return 0;
1641
2
    in += vinfo->pubkey_bytes;
1642
1643
    /* Check public key hash. */
1644
2
    if (memcmp(key->pkhash, in, ML_KEM_PKHASH_BYTES) != 0) {
1645
2
        ERR_raise_data(ERR_LIB_PROV, PROV_R_INVALID_KEY,
1646
2
            "%s public key hash mismatch",
1647
2
            vinfo->algorithm_name);
1648
2
        return 0;
1649
2
    }
1650
0
    in += ML_KEM_PKHASH_BYTES;
1651
1652
0
    memcpy(key->z, in, ML_KEM_RANDOM_BYTES);
1653
0
    return 1;
1654
2
}
1655
1656
/*
1657
 * FIPS 203, Section 6.1, Algorithm 16: "ML-KEM.KeyGen_internal".
1658
 *
1659
 * The implementation of Section 5.1, Algorithm 13, "K-PKE.KeyGen(d)" is
1660
 * inlined.
1661
 *
1662
 * The caller MUST pass a pre-allocated digest context that is not shared with
1663
 * any concurrent computation.
1664
 *
1665
 * This function optionally outputs the serialised wire-form |ek| public key
1666
 * into the provided |pubenc| buffer, and generates the content of the |rho|,
1667
 * |pkhash|, |t|, |m|, |s| and |z| components of the private |key| (which must
1668
 * have preallocated space for these).
1669
 *
1670
 * Keys are computed from a 32-byte random |d| plus the 1 byte rank for
1671
 * domain separation.  These are concatenated and hashed to produce a pair of
1672
 * 32-byte seeds public "rho", used to generate the matrix, and private "sigma",
1673
 * used to generate the secret vector |s|.
1674
 *
1675
 * The second random input |z| is copied verbatim into the Fujisaki-Okamoto
1676
 * (FO) transform "implicit-rejection" secret (the |z| component of the private
1677
 * key), which thwarts chosen-ciphertext attacks, provided decap() runs in
1678
 * constant time, with no side channel leaks, on all well-formed (valid length,
1679
 * and correctly encoded) ciphertext inputs.
1680
 */
1681
static __owur int genkey(const uint8_t seed[ML_KEM_SEED_BYTES],
1682
    EVP_MD_CTX *mdctx, uint8_t *pubenc, ML_KEM_KEY *key)
1683
34.3k
{
1684
34.3k
    uint8_t hashed[2 * ML_KEM_RANDOM_BYTES];
1685
34.3k
    const uint8_t *const sigma = hashed + ML_KEM_RANDOM_BYTES;
1686
34.3k
    uint8_t augmented_seed[ML_KEM_RANDOM_BYTES + 1];
1687
34.3k
    const ML_KEM_VINFO *vinfo = key->vinfo;
1688
34.3k
    CBD_FUNC cbd_1 = CBD1(vinfo->evp_type);
1689
34.3k
    int rank = vinfo->rank;
1690
34.3k
    uint8_t counter = 0;
1691
34.3k
    int ret = 0;
1692
1693
    /*
1694
     * Use the "d" seed salted with the rank to derive the public and private
1695
     * seeds rho and sigma.
1696
     */
1697
34.3k
    memcpy(augmented_seed, seed, ML_KEM_RANDOM_BYTES);
1698
34.3k
    augmented_seed[ML_KEM_RANDOM_BYTES] = (uint8_t)rank;
1699
34.3k
    if (!hash_g(hashed, augmented_seed, sizeof(augmented_seed), mdctx, key))
1700
0
        goto end;
1701
34.3k
    memcpy(key->rho, hashed, ML_KEM_RANDOM_BYTES);
1702
    /* The |rho| matrix seed is public */
1703
34.3k
    CONSTTIME_DECLASSIFY(key->rho, ML_KEM_RANDOM_BYTES);
1704
1705
    /* FIPS 203 |e| vector is initial value of key->t */
1706
34.3k
    if (!matrix_expand(mdctx, key)
1707
34.3k
        || !gencbd_vector_ntt(key->s, cbd_1, &counter, sigma, rank, mdctx, key)
1708
34.3k
        || !gencbd_vector_ntt(key->t, cbd_1, &counter, sigma, rank, mdctx, key))
1709
0
        goto end;
1710
1711
    /* To |e| we now add the product of transpose |m| and |s|, giving |t|. */
1712
34.3k
    matrix_mult_transpose_add(key->t, key->m, key->s, rank);
1713
    /* The |t| vector is public */
1714
34.3k
    CONSTTIME_DECLASSIFY(key->t, vinfo->rank * sizeof(scalar));
1715
1716
34.3k
    if (pubenc == NULL) {
1717
        /* Incremental digest of public key without in-full serialisation. */
1718
34.3k
        if (!hash_h_pubkey(key->pkhash, mdctx, key))
1719
0
            goto end;
1720
34.3k
    } else {
1721
0
        encode_pubkey(pubenc, key);
1722
0
        if (!hash_h(key->pkhash, pubenc, vinfo->pubkey_bytes, mdctx, key))
1723
0
            goto end;
1724
0
    }
1725
1726
    /* Save |z| portion of seed for "implicit rejection" on failure. */
1727
34.3k
    memcpy(key->z, seed + ML_KEM_RANDOM_BYTES, ML_KEM_RANDOM_BYTES);
1728
1729
    /* Optionally save the |d| portion of the seed */
1730
34.3k
    key->d = key->z + ML_KEM_RANDOM_BYTES;
1731
34.3k
    if (key->prov_flags & ML_KEM_KEY_RETAIN_SEED) {
1732
34.3k
        memcpy(key->d, seed, ML_KEM_RANDOM_BYTES);
1733
34.3k
    } else {
1734
0
        OPENSSL_cleanse(key->d, ML_KEM_RANDOM_BYTES);
1735
0
        key->d = NULL;
1736
0
    }
1737
1738
34.3k
    ret = 1;
1739
34.3k
end:
1740
34.3k
    OPENSSL_cleanse((void *)augmented_seed, ML_KEM_RANDOM_BYTES);
1741
34.3k
    OPENSSL_cleanse((void *)sigma, ML_KEM_RANDOM_BYTES);
1742
34.3k
    if (ret == 0) {
1743
0
        ERR_raise_data(ERR_LIB_CRYPTO, ERR_R_INTERNAL_ERROR,
1744
0
            "internal error while generating %s private key",
1745
0
            vinfo->algorithm_name);
1746
0
    }
1747
34.3k
    return ret;
1748
34.3k
}
1749
1750
/*-
1751
 * FIPS 203, Section 6.2, Algorithm 17: "ML-KEM.Encaps_internal".
1752
 * This is the deterministic version with randomness supplied externally.
1753
 *
1754
 * The caller must pass space for two vectors in |tmp|.
1755
 * The |ctext| buffer have space for the ciphertext of the ML-KEM variant
1756
 * of the provided key.
1757
 */
1758
static int encap(uint8_t *ctext, uint8_t secret[ML_KEM_SHARED_SECRET_BYTES],
1759
    const uint8_t entropy[ML_KEM_RANDOM_BYTES],
1760
    scalar *tmp, EVP_MD_CTX *mdctx, const ML_KEM_KEY *key)
1761
142
{
1762
142
    uint8_t input[ML_KEM_RANDOM_BYTES + ML_KEM_PKHASH_BYTES];
1763
142
    uint8_t Kr[ML_KEM_SHARED_SECRET_BYTES + ML_KEM_RANDOM_BYTES];
1764
142
    uint8_t *r = Kr + ML_KEM_SHARED_SECRET_BYTES;
1765
142
    int ret;
1766
1767
142
    memcpy(input, entropy, ML_KEM_RANDOM_BYTES);
1768
142
    memcpy(input + ML_KEM_RANDOM_BYTES, key->pkhash, ML_KEM_PKHASH_BYTES);
1769
142
    ret = hash_g(Kr, input, sizeof(input), mdctx, key)
1770
142
        && encrypt_cpa(ctext, entropy, r, tmp, mdctx, key);
1771
142
    OPENSSL_cleanse((void *)input, sizeof(input));
1772
1773
142
    if (ret)
1774
142
        memcpy(secret, Kr, ML_KEM_SHARED_SECRET_BYTES);
1775
0
    else
1776
0
        ERR_raise_data(ERR_LIB_CRYPTO, ERR_R_INTERNAL_ERROR,
1777
0
            "internal error while performing %s encapsulation",
1778
0
            key->vinfo->algorithm_name);
1779
142
    return ret;
1780
142
}
1781
1782
/*
1783
 * FIPS 203, Section 6.3, Algorithm 18: ML-KEM.Decaps_internal
1784
 *
1785
 * Barring failure of the supporting SHA3/SHAKE primitives, this is fully
1786
 * deterministic, the randomness for the FO transform is extracted during
1787
 * private key generation.
1788
 *
1789
 * The caller must pass space for two vectors in |tmp|.
1790
 * The |ctext| and |tmp_ctext| buffers must each have space for the ciphertext
1791
 * of the key's ML-KEM variant.
1792
 */
1793
static int decap(uint8_t secret[ML_KEM_SHARED_SECRET_BYTES],
1794
    const uint8_t *ctext, uint8_t *tmp_ctext, scalar *tmp,
1795
    EVP_MD_CTX *mdctx, const ML_KEM_KEY *key)
1796
123
{
1797
123
    uint8_t decrypted[ML_KEM_SHARED_SECRET_BYTES + ML_KEM_PKHASH_BYTES];
1798
123
    uint8_t failure_key[ML_KEM_RANDOM_BYTES];
1799
123
    uint8_t Kr[ML_KEM_SHARED_SECRET_BYTES + ML_KEM_RANDOM_BYTES];
1800
123
    uint8_t *r = Kr + ML_KEM_SHARED_SECRET_BYTES;
1801
123
    const uint8_t *pkhash = key->pkhash;
1802
123
    const ML_KEM_VINFO *vinfo = key->vinfo;
1803
123
    int i;
1804
123
    uint8_t mask;
1805
1806
    /*
1807
     * If our KDF is unavailable, fail early! Otherwise, keep going ignoring
1808
     * any further errors, returning success, and whatever we got for a shared
1809
     * secret.  The decrypt_cpa() function is just arithmetic on secret data,
1810
     * so should not be subject to failure that makes its output predictable.
1811
     *
1812
     * We guard against "should never happen" catastrophic failure of the
1813
     * "pure" function |hash_g| by overwriting the shared secret with the
1814
     * content of the failure key and returning early, if nevertheless hash_g
1815
     * fails.  This is not constant-time, but a failure of |hash_g| already
1816
     * implies loss of side-channel resistance.
1817
     *
1818
     * The same action is taken, if also |encrypt_cpa| should catastrophically
1819
     * fail, due to failure of the |PRF| underlying the CBD functions.
1820
     */
1821
123
    if (!kdf(failure_key, key->z, ctext, vinfo->ctext_bytes, mdctx, key)) {
1822
0
        ERR_raise_data(ERR_LIB_CRYPTO, ERR_R_INTERNAL_ERROR,
1823
0
            "internal error while performing %s decapsulation",
1824
0
            vinfo->algorithm_name);
1825
0
        return 0;
1826
0
    }
1827
123
    decrypt_cpa(decrypted, ctext, tmp, key);
1828
123
    memcpy(decrypted + ML_KEM_SHARED_SECRET_BYTES, pkhash, ML_KEM_PKHASH_BYTES);
1829
123
    if (!hash_g(Kr, decrypted, sizeof(decrypted), mdctx, key)
1830
123
        || !encrypt_cpa(tmp_ctext, decrypted, r, tmp, mdctx, key)) {
1831
0
        memcpy(secret, failure_key, ML_KEM_SHARED_SECRET_BYTES);
1832
0
        OPENSSL_cleanse(decrypted, ML_KEM_SHARED_SECRET_BYTES);
1833
0
        return 1;
1834
0
    }
1835
123
    mask = constant_time_eq_int_8(0,
1836
123
        CRYPTO_memcmp(ctext, tmp_ctext, vinfo->ctext_bytes));
1837
4.05k
    for (i = 0; i < ML_KEM_SHARED_SECRET_BYTES; i++)
1838
3.93k
        secret[i] = constant_time_select_8(mask, Kr[i], failure_key[i]);
1839
123
    OPENSSL_cleanse(decrypted, ML_KEM_SHARED_SECRET_BYTES);
1840
123
    OPENSSL_cleanse(Kr, sizeof(Kr));
1841
123
    return 1;
1842
123
}
1843
1844
/*
1845
 * After allocating storage for public or private key data, update the key
1846
 * component pointers to reference that storage.
1847
 */
1848
static __owur int add_storage(scalar *p, int private, ML_KEM_KEY *key)
1849
16.9k
{
1850
16.9k
    int rank = key->vinfo->rank;
1851
1852
16.9k
    if (p == NULL)
1853
0
        return 0;
1854
1855
    /*
1856
     * We're adding key material, the seed buffer will now hold |rho| and
1857
     * |pkhash|.
1858
     */
1859
16.9k
    memset(key->seedbuf, 0, sizeof(key->seedbuf));
1860
16.9k
    key->rho = key->seedbuf;
1861
16.9k
    key->pkhash = key->seedbuf + ML_KEM_RANDOM_BYTES;
1862
16.9k
    key->d = key->z = NULL;
1863
1864
    /* A public key needs space for |t| and |m| */
1865
16.9k
    key->m = (key->t = p) + rank;
1866
1867
    /*
1868
     * A private key also needs space for |s| and |z|.
1869
     * The |z| buffer always includes additional space for |d|, but a key's |d|
1870
     * pointer is left NULL when parsed from the NIST format, which omits that
1871
     * information.  Only keys generated from a (d, z) seed pair will have a
1872
     * non-NULL |d| pointer.
1873
     */
1874
16.9k
    if (private)
1875
16.7k
        key->z = (uint8_t *)(rank + (key->s = key->m + rank * rank));
1876
16.9k
    return 1;
1877
16.9k
}
1878
1879
/*
1880
 * After freeing the storage associated with a key that failed to be
1881
 * constructed, reset the internal pointers back to NULL.
1882
 */
1883
void ossl_ml_kem_key_reset(ML_KEM_KEY *key)
1884
17.0k
{
1885
17.0k
    if (key->t == NULL)
1886
127
        return;
1887
    /*-
1888
     * Cleanse any sensitive data:
1889
     * - The private vector |s| is immediately followed by the FO failure
1890
     *   secret |z|, and seed |d|, we can cleanse all three in one call.
1891
     *
1892
     * - Otherwise, when key->d is set, cleanse the stashed seed.
1893
     */
1894
16.9k
    if (ossl_ml_kem_have_prvkey(key))
1895
16.7k
        OPENSSL_cleanse(key->s,
1896
16.7k
            key->vinfo->rank * sizeof(scalar) + 2 * ML_KEM_RANDOM_BYTES);
1897
16.9k
    OPENSSL_free(key->t);
1898
16.9k
    key->d = key->z = (uint8_t *)(key->s = key->m = key->t = NULL);
1899
16.9k
}
1900
1901
/*
1902
 * ----- API exported to the provider
1903
 *
1904
 * Parameters with an implicit fixed length in the internal static API of each
1905
 * variant have an explicit checked length argument at this layer.
1906
 */
1907
1908
/* Retrieve the parameters of one of the ML-KEM variants */
1909
const ML_KEM_VINFO *ossl_ml_kem_get_vinfo(int evp_type)
1910
276k
{
1911
276k
    switch (evp_type) {
1912
56.9k
    case EVP_PKEY_ML_KEM_512:
1913
56.9k
        return &vinfo_map[ML_KEM_512_VINFO];
1914
158k
    case EVP_PKEY_ML_KEM_768:
1915
158k
        return &vinfo_map[ML_KEM_768_VINFO];
1916
60.3k
    case EVP_PKEY_ML_KEM_1024:
1917
60.3k
        return &vinfo_map[ML_KEM_1024_VINFO];
1918
276k
    }
1919
0
    return NULL;
1920
276k
}
1921
1922
ML_KEM_KEY *ossl_ml_kem_key_new(OSSL_LIB_CTX *libctx, const char *properties,
1923
    int evp_type)
1924
50.9k
{
1925
50.9k
    const ML_KEM_VINFO *vinfo = ossl_ml_kem_get_vinfo(evp_type);
1926
50.9k
    ML_KEM_KEY *key;
1927
1928
50.9k
    if (vinfo == NULL) {
1929
0
        ERR_raise_data(ERR_LIB_CRYPTO, ERR_R_PASSED_INVALID_ARGUMENT,
1930
0
            "unsupported ML-KEM key type: %d", evp_type);
1931
0
        return NULL;
1932
0
    }
1933
1934
50.9k
    if ((key = OPENSSL_malloc(sizeof(*key))) == NULL)
1935
0
        return NULL;
1936
1937
50.9k
    key->vinfo = vinfo;
1938
50.9k
    key->libctx = libctx;
1939
50.9k
    key->prov_flags = ML_KEM_KEY_PROV_FLAGS_DEFAULT;
1940
50.9k
    key->shake128_md = EVP_MD_fetch(libctx, "SHAKE128", properties);
1941
50.9k
    key->shake256_md = EVP_MD_fetch(libctx, "SHAKE256", properties);
1942
50.9k
    key->sha3_256_md = EVP_MD_fetch(libctx, "SHA3-256", properties);
1943
50.9k
    key->sha3_512_md = EVP_MD_fetch(libctx, "SHA3-512", properties);
1944
50.9k
    key->d = key->z = key->rho = key->pkhash = key->encoded_dk = NULL;
1945
50.9k
    key->s = key->m = key->t = NULL;
1946
1947
50.9k
    if (key->shake128_md != NULL
1948
50.9k
        && key->shake256_md != NULL
1949
50.9k
        && key->sha3_256_md != NULL
1950
50.9k
        && key->sha3_512_md != NULL)
1951
50.9k
        return key;
1952
1953
0
    ossl_ml_kem_key_free(key);
1954
0
    ERR_raise_data(ERR_LIB_CRYPTO, ERR_R_INTERNAL_ERROR,
1955
0
        "missing SHA3 digest algorithms while creating %s key",
1956
0
        vinfo->algorithm_name);
1957
0
    return NULL;
1958
50.9k
}
1959
1960
ML_KEM_KEY *ossl_ml_kem_key_dup(const ML_KEM_KEY *key, int selection)
1961
27
{
1962
27
    int ok = 0;
1963
27
    ML_KEM_KEY *ret;
1964
1965
    /*
1966
     * Partially decoded keys, not yet imported or loaded, should never be
1967
     * duplicated.
1968
     */
1969
27
    if (ossl_ml_kem_decoded_key(key))
1970
0
        return NULL;
1971
1972
27
    if (key == NULL
1973
27
        || (ret = OPENSSL_memdup(key, sizeof(*key))) == NULL)
1974
0
        return NULL;
1975
27
    ret->d = ret->z = ret->rho = ret->pkhash = NULL;
1976
27
    ret->s = ret->m = ret->t = NULL;
1977
1978
    /* Clear selection bits we can't fulfill */
1979
27
    if (!ossl_ml_kem_have_pubkey(key))
1980
0
        selection = 0;
1981
27
    else if (!ossl_ml_kem_have_prvkey(key))
1982
27
        selection &= ~OSSL_KEYMGMT_SELECT_PRIVATE_KEY;
1983
1984
27
    switch (selection & OSSL_KEYMGMT_SELECT_KEYPAIR) {
1985
0
    case 0:
1986
0
        ok = 1;
1987
0
        break;
1988
27
    case OSSL_KEYMGMT_SELECT_PUBLIC_KEY:
1989
27
        ok = add_storage(OPENSSL_memdup(key->t, key->vinfo->puballoc), 0, ret);
1990
27
        ret->rho = ret->seedbuf;
1991
27
        ret->pkhash = ret->rho + ML_KEM_RANDOM_BYTES;
1992
27
        break;
1993
0
    case OSSL_KEYMGMT_SELECT_PRIVATE_KEY:
1994
0
        ok = add_storage(OPENSSL_memdup(key->t, key->vinfo->prvalloc), 1, ret);
1995
        /* Duplicated keys retain |d|, if available */
1996
0
        if (key->d != NULL)
1997
0
            ret->d = ret->z + ML_KEM_RANDOM_BYTES;
1998
0
        break;
1999
27
    }
2000
2001
27
    if (!ok) {
2002
0
        OPENSSL_free(ret);
2003
0
        return NULL;
2004
0
    }
2005
2006
27
    EVP_MD_up_ref(ret->shake128_md);
2007
27
    EVP_MD_up_ref(ret->shake256_md);
2008
27
    EVP_MD_up_ref(ret->sha3_256_md);
2009
27
    EVP_MD_up_ref(ret->sha3_512_md);
2010
2011
27
    return ret;
2012
27
}
2013
2014
void ossl_ml_kem_key_free(ML_KEM_KEY *key)
2015
90.0k
{
2016
90.0k
    if (key == NULL)
2017
73.0k
        return;
2018
2019
16.9k
    EVP_MD_free(key->shake128_md);
2020
16.9k
    EVP_MD_free(key->shake256_md);
2021
16.9k
    EVP_MD_free(key->sha3_256_md);
2022
16.9k
    EVP_MD_free(key->sha3_512_md);
2023
2024
16.9k
    if (ossl_ml_kem_decoded_key(key)) {
2025
0
        OPENSSL_cleanse(key->seedbuf, sizeof(key->seedbuf));
2026
0
        if (ossl_ml_kem_have_dkenc(key)) {
2027
0
            OPENSSL_cleanse(key->encoded_dk, key->vinfo->prvkey_bytes);
2028
0
            OPENSSL_free(key->encoded_dk);
2029
0
        }
2030
0
    }
2031
16.9k
    ossl_ml_kem_key_reset(key);
2032
16.9k
    OPENSSL_free(key);
2033
16.9k
}
2034
2035
/* Serialise the public component of an ML-KEM key */
2036
int ossl_ml_kem_encode_public_key(uint8_t *out, size_t len,
2037
    const ML_KEM_KEY *key)
2038
50.0k
{
2039
50.0k
    if (!ossl_ml_kem_have_pubkey(key)
2040
50.0k
        || len != key->vinfo->pubkey_bytes)
2041
0
        return 0;
2042
50.0k
    encode_pubkey(out, key);
2043
50.0k
    return 1;
2044
50.0k
}
2045
2046
/* Serialise an ML-KEM private key */
2047
int ossl_ml_kem_encode_private_key(uint8_t *out, size_t len,
2048
    const ML_KEM_KEY *key)
2049
119
{
2050
119
    if (!ossl_ml_kem_have_prvkey(key)
2051
119
        || len != key->vinfo->prvkey_bytes)
2052
0
        return 0;
2053
119
    encode_prvkey(out, key);
2054
119
    return 1;
2055
119
}
2056
2057
int ossl_ml_kem_encode_seed(uint8_t *out, size_t len,
2058
    const ML_KEM_KEY *key)
2059
201
{
2060
201
    if (key == NULL || key->d == NULL || len != ML_KEM_SEED_BYTES)
2061
29
        return 0;
2062
    /*
2063
     * Both in the seed buffer, and in the allocated storage, the |d| component
2064
     * of the seed is stored last, so we must copy each separately.
2065
     */
2066
172
    memcpy(out, key->d, ML_KEM_RANDOM_BYTES);
2067
172
    out += ML_KEM_RANDOM_BYTES;
2068
172
    memcpy(out, key->z, ML_KEM_RANDOM_BYTES);
2069
172
    return 1;
2070
201
}
2071
2072
/*
2073
 * Stash the seed without (yet) performing a keygen, used during decoding, to
2074
 * avoid an extra keygen if we're only going to export the key again to load
2075
 * into another provider.
2076
 */
2077
ML_KEM_KEY *ossl_ml_kem_set_seed(const uint8_t *seed, size_t seedlen, ML_KEM_KEY *key)
2078
26
{
2079
26
    if (key == NULL
2080
26
        || ossl_ml_kem_have_pubkey(key)
2081
26
        || ossl_ml_kem_have_seed(key)
2082
26
        || seedlen != ML_KEM_SEED_BYTES)
2083
0
        return NULL;
2084
    /*
2085
     * With no public or private key material on hand, we can use the seed
2086
     * buffer for |z| and |d|, in that order.
2087
     */
2088
26
    key->z = key->seedbuf;
2089
26
    key->d = key->z + ML_KEM_RANDOM_BYTES;
2090
26
    memcpy(key->d, seed, ML_KEM_RANDOM_BYTES);
2091
26
    seed += ML_KEM_RANDOM_BYTES;
2092
26
    memcpy(key->z, seed, ML_KEM_RANDOM_BYTES);
2093
26
    return key;
2094
26
}
2095
2096
/* Parse input as a public key */
2097
int ossl_ml_kem_parse_public_key(const uint8_t *in, size_t len, ML_KEM_KEY *key)
2098
510
{
2099
510
    EVP_MD_CTX *mdctx = NULL;
2100
510
    const ML_KEM_VINFO *vinfo;
2101
510
    int ret = 0;
2102
2103
    /* Keys with key material are immutable */
2104
510
    if (key == NULL
2105
510
        || ossl_ml_kem_have_pubkey(key)
2106
510
        || ossl_ml_kem_have_dkenc(key))
2107
0
        return 0;
2108
510
    vinfo = key->vinfo;
2109
2110
510
    if (len != vinfo->pubkey_bytes
2111
510
        || (mdctx = EVP_MD_CTX_new()) == NULL)
2112
0
        return 0;
2113
2114
510
    if (add_storage(OPENSSL_malloc(vinfo->puballoc), 0, key))
2115
510
        ret = parse_pubkey(in, mdctx, key);
2116
2117
510
    if (!ret)
2118
201
        ossl_ml_kem_key_reset(key);
2119
510
    EVP_MD_CTX_free(mdctx);
2120
510
    return ret;
2121
510
}
2122
2123
/* Parse input as a new private key */
2124
int ossl_ml_kem_parse_private_key(const uint8_t *in, size_t len,
2125
    ML_KEM_KEY *key)
2126
93
{
2127
93
    EVP_MD_CTX *mdctx = NULL;
2128
93
    const ML_KEM_VINFO *vinfo;
2129
93
    int ret = 0;
2130
2131
    /* Keys with key material are immutable */
2132
93
    if (key == NULL
2133
93
        || ossl_ml_kem_have_pubkey(key)
2134
93
        || ossl_ml_kem_have_dkenc(key))
2135
0
        return 0;
2136
93
    vinfo = key->vinfo;
2137
2138
93
    if (len != vinfo->prvkey_bytes
2139
93
        || (mdctx = EVP_MD_CTX_new()) == NULL)
2140
0
        return 0;
2141
2142
93
    if (add_storage(OPENSSL_malloc(vinfo->prvalloc), 1, key))
2143
93
        ret = parse_prvkey(in, mdctx, key);
2144
2145
93
    if (!ret)
2146
93
        ossl_ml_kem_key_reset(key);
2147
93
    EVP_MD_CTX_free(mdctx);
2148
93
    return ret;
2149
93
}
2150
2151
/*
2152
 * Generate a new keypair, either from the saved seed (when non-null), or from
2153
 * the RNG.
2154
 */
2155
int ossl_ml_kem_genkey(uint8_t *pubenc, size_t publen, ML_KEM_KEY *key)
2156
50.1k
{
2157
50.1k
    uint8_t seed[ML_KEM_SEED_BYTES];
2158
50.1k
    EVP_MD_CTX *mdctx = NULL;
2159
50.1k
    const ML_KEM_VINFO *vinfo;
2160
50.1k
    int ret = 0;
2161
2162
50.1k
    if (key == NULL
2163
50.1k
        || ossl_ml_kem_have_pubkey(key)
2164
50.1k
        || ossl_ml_kem_have_dkenc(key))
2165
0
        return 0;
2166
50.1k
    vinfo = key->vinfo;
2167
2168
50.1k
    if (pubenc != NULL && publen != vinfo->pubkey_bytes)
2169
0
        return 0;
2170
2171
50.1k
    if (ossl_ml_kem_have_seed(key)) {
2172
53
        if (!ossl_ml_kem_encode_seed(seed, sizeof(seed), key))
2173
0
            return 0;
2174
53
        key->d = key->z = NULL;
2175
50.1k
    } else if (RAND_priv_bytes_ex(key->libctx, seed, sizeof(seed),
2176
50.1k
                   key->vinfo->secbits)
2177
50.1k
        <= 0) {
2178
0
        return 0;
2179
0
    }
2180
2181
50.1k
    if ((mdctx = EVP_MD_CTX_new()) == NULL)
2182
0
        return 0;
2183
2184
    /*
2185
     * Data derived from (d, z) defaults secret, and to avoid side-channel
2186
     * leaks should not influence control flow.
2187
     */
2188
50.1k
    CONSTTIME_SECRET(seed, ML_KEM_SEED_BYTES);
2189
2190
50.1k
    if (add_storage(OPENSSL_malloc(vinfo->prvalloc), 1, key))
2191
50.1k
        ret = genkey(seed, mdctx, pubenc, key);
2192
50.1k
    OPENSSL_cleanse(seed, sizeof(seed));
2193
2194
    /* Declassify secret inputs and derived outputs before returning control */
2195
50.1k
    CONSTTIME_DECLASSIFY(seed, ML_KEM_SEED_BYTES);
2196
2197
50.1k
    EVP_MD_CTX_free(mdctx);
2198
50.1k
    if (!ret) {
2199
0
        ossl_ml_kem_key_reset(key);
2200
0
        return 0;
2201
0
    }
2202
2203
    /* The public components are already declassified */
2204
50.1k
    CONSTTIME_DECLASSIFY(key->s, vinfo->rank * sizeof(scalar));
2205
50.1k
    CONSTTIME_DECLASSIFY(key->z, 2 * ML_KEM_RANDOM_BYTES);
2206
50.1k
    return 1;
2207
50.1k
}
2208
2209
/*
2210
 * FIPS 203, Section 6.2, Algorithm 17: ML-KEM.Encaps_internal
2211
 * This is the deterministic version with randomness supplied externally.
2212
 */
2213
int ossl_ml_kem_encap_seed(uint8_t *ctext, size_t clen,
2214
    uint8_t *shared_secret, size_t slen,
2215
    const uint8_t *entropy, size_t elen,
2216
    const ML_KEM_KEY *key)
2217
142
{
2218
142
    const ML_KEM_VINFO *vinfo;
2219
142
    EVP_MD_CTX *mdctx;
2220
142
    int ret = 0;
2221
2222
142
    if (key == NULL || !ossl_ml_kem_have_pubkey(key))
2223
0
        return 0;
2224
142
    vinfo = key->vinfo;
2225
2226
142
    if (ctext == NULL || clen != vinfo->ctext_bytes
2227
142
        || shared_secret == NULL || slen != ML_KEM_SHARED_SECRET_BYTES
2228
142
        || entropy == NULL || elen != ML_KEM_RANDOM_BYTES
2229
142
        || (mdctx = EVP_MD_CTX_new()) == NULL)
2230
0
        return 0;
2231
    /*
2232
     * Data derived from the encap entropy defaults secret, and to avoid
2233
     * side-channel leaks should not influence control flow.
2234
     */
2235
142
    CONSTTIME_SECRET(entropy, elen);
2236
2237
    /*-
2238
     * This avoids the need to handle allocation failures for two (max 2KB
2239
     * each) vectors, that are never retained on return from this function.
2240
     * We stack-allocate these.
2241
     */
2242
142
#define case_encap_seed(bits)                                        \
2243
142
    case EVP_PKEY_ML_KEM_##bits: {                                   \
2244
142
        scalar tmp[2 * ML_KEM_##bits##_RANK];                        \
2245
142
                                                                     \
2246
142
        ret = encap(ctext, shared_secret, entropy, tmp, mdctx, key); \
2247
142
        OPENSSL_cleanse((void *)tmp, sizeof(tmp));                   \
2248
142
        break;                                                       \
2249
142
    }
2250
142
    switch (vinfo->evp_type) {
2251
51
        case_encap_seed(512);
2252
53
        case_encap_seed(768);
2253
38
        case_encap_seed(1024);
2254
142
    }
2255
142
#undef case_encap_seed
2256
2257
    /* Declassify secret inputs and derived outputs before returning control */
2258
142
    CONSTTIME_DECLASSIFY(entropy, elen);
2259
142
    CONSTTIME_DECLASSIFY(ctext, clen);
2260
142
    CONSTTIME_DECLASSIFY(shared_secret, slen);
2261
2262
142
    EVP_MD_CTX_free(mdctx);
2263
142
    return ret;
2264
142
}
2265
2266
int ossl_ml_kem_encap_rand(uint8_t *ctext, size_t clen,
2267
    uint8_t *shared_secret, size_t slen,
2268
    const ML_KEM_KEY *key)
2269
142
{
2270
142
    uint8_t r[ML_KEM_RANDOM_BYTES];
2271
2272
142
    if (key == NULL)
2273
0
        return 0;
2274
2275
142
    if (RAND_bytes_ex(key->libctx, r, ML_KEM_RANDOM_BYTES,
2276
142
            key->vinfo->secbits)
2277
142
        < 1)
2278
0
        return 0;
2279
2280
142
    return ossl_ml_kem_encap_seed(ctext, clen, shared_secret, slen,
2281
142
        r, sizeof(r), key);
2282
142
}
2283
2284
int ossl_ml_kem_decap(uint8_t *shared_secret, size_t slen,
2285
    const uint8_t *ctext, size_t clen,
2286
    const ML_KEM_KEY *key)
2287
123
{
2288
123
    const ML_KEM_VINFO *vinfo;
2289
123
    EVP_MD_CTX *mdctx;
2290
123
    int ret = 0;
2291
#if defined(OPENSSL_CONSTANT_TIME_VALIDATION)
2292
    int classify_bytes;
2293
#endif
2294
2295
    /* Need a private key here */
2296
123
    if (!ossl_ml_kem_have_prvkey(key))
2297
0
        return 0;
2298
123
    vinfo = key->vinfo;
2299
2300
123
    if (shared_secret == NULL || slen != ML_KEM_SHARED_SECRET_BYTES
2301
123
        || ctext == NULL || clen != vinfo->ctext_bytes
2302
123
        || (mdctx = EVP_MD_CTX_new()) == NULL) {
2303
0
        (void)RAND_bytes_ex(key->libctx, shared_secret,
2304
0
            ML_KEM_SHARED_SECRET_BYTES, vinfo->secbits);
2305
0
        return 0;
2306
0
    }
2307
#if defined(OPENSSL_CONSTANT_TIME_VALIDATION)
2308
    /*
2309
     * Data derived from |s| and |z| defaults secret, and to avoid side-channel
2310
     * leaks should not influence control flow.
2311
     */
2312
    classify_bytes = 2 * sizeof(scalar) + ML_KEM_RANDOM_BYTES;
2313
    CONSTTIME_SECRET(key->s, classify_bytes);
2314
#endif
2315
2316
    /*-
2317
     * This avoids the need to handle allocation failures for two (max 2KB
2318
     * each) vectors and an encoded ciphertext (max 1568 bytes), that are never
2319
     * retained on return from this function.
2320
     * We stack-allocate these.
2321
     */
2322
123
#define case_decap(bits)                                          \
2323
123
    case EVP_PKEY_ML_KEM_##bits: {                                \
2324
123
        uint8_t cbuf[CTEXT_BYTES(bits)];                          \
2325
123
        scalar tmp[2 * ML_KEM_##bits##_RANK];                     \
2326
123
                                                                  \
2327
123
        ret = decap(shared_secret, ctext, cbuf, tmp, mdctx, key); \
2328
123
        OPENSSL_cleanse((void *)tmp, sizeof(tmp));                \
2329
123
        break;                                                    \
2330
123
    }
2331
123
    switch (vinfo->evp_type) {
2332
51
        case_decap(512);
2333
34
        case_decap(768);
2334
38
        case_decap(1024);
2335
123
    }
2336
2337
    /* Declassify secret inputs and derived outputs before returning control */
2338
123
    CONSTTIME_DECLASSIFY(key->s, classify_bytes);
2339
123
    CONSTTIME_DECLASSIFY(shared_secret, slen);
2340
123
    EVP_MD_CTX_free(mdctx);
2341
2342
123
    return ret;
2343
123
#undef case_decap
2344
123
}
2345
2346
int ossl_ml_kem_pubkey_cmp(const ML_KEM_KEY *key1, const ML_KEM_KEY *key2)
2347
145
{
2348
    /*
2349
     * This handles any unexpected differences in the ML-KEM variant rank,
2350
     * giving different key component structures, barring SHA3-256 hash
2351
     * collisions, the keys are the same size.
2352
     */
2353
145
    if (ossl_ml_kem_have_pubkey(key1) && ossl_ml_kem_have_pubkey(key2))
2354
145
        return memcmp(key1->pkhash, key2->pkhash, ML_KEM_PKHASH_BYTES) == 0;
2355
2356
    /*
2357
     * No match if just one of the public keys is not available, otherwise both
2358
     * are unavailable, and for now such keys are considered equal.
2359
     */
2360
0
    return (!(ossl_ml_kem_have_pubkey(key1) ^ ossl_ml_kem_have_pubkey(key2)));
2361
145
}