Coverage Report

Created: 2026-05-24 07:14

next uncovered line (L), next uncovered region (R), next uncovered branch (B)
/src/openssl35/crypto/ml_kem/ml_kem.c
Line
Count
Source
1
/*
2
 * Copyright 2024-2025 The OpenSSL Project Authors. All Rights Reserved.
3
 *
4
 * Licensed under the Apache License 2.0 (the "License").  You may not use
5
 * this file except in compliance with the License.  You can obtain a copy
6
 * in the file LICENSE in the source distribution or at
7
 * https://www.openssl.org/source/license.html
8
 */
9
10
#include <openssl/byteorder.h>
11
#include <openssl/rand.h>
12
#include <openssl/proverr.h>
13
#include "crypto/ml_kem.h"
14
#include "internal/common.h"
15
#include "internal/constant_time.h"
16
#include "internal/sha3.h"
17
18
#if defined(OPENSSL_CONSTANT_TIME_VALIDATION)
19
#include <valgrind/memcheck.h>
20
#endif
21
22
#if ML_KEM_SEED_BYTES != ML_KEM_SHARED_SECRET_BYTES + ML_KEM_RANDOM_BYTES
23
#error "ML-KEM keygen seed length != shared secret + random bytes length"
24
#endif
25
#if ML_KEM_SHARED_SECRET_BYTES != ML_KEM_RANDOM_BYTES
26
#error "Invalid unequal lengths of ML-KEM shared secret and random inputs"
27
#endif
28
29
#if UINT_MAX < UINT32_MAX
30
#error "Unsupported compiler: sizeof(unsigned int) < sizeof(uint32_t)"
31
#endif
32
33
/* Handy function-like bit-extraction macros */
34
39.6M
#define bit0(b) ((b) & 1)
35
277M
#define bitn(n, b) (((b) >> n) & 1)
36
37
/*
38
 * 12 bits are sufficient to losslessly represent values in [0, q-1].
39
 * INVERSE_DEGREE is (n/2)^-1 mod q; used in inverse NTT.
40
 */
41
3.13M
#define DEGREE ML_KEM_DEGREE
42
#define INVERSE_DEGREE (ML_KEM_PRIME - 2 * 13)
43
#define LOG2PRIME 12
44
#define BARRETT_SHIFT (2 * LOG2PRIME)
45
46
#ifdef SHA3_BLOCKSIZE
47
#define SHAKE128_BLOCKSIZE SHA3_BLOCKSIZE(128)
48
#endif
49
50
/*
51
 * Return whether a value that can only be 0 or 1 is non-zero, in constant time
52
 * in practice!  The return value is a mask that is all ones if true, and all
53
 * zeros otherwise (twos-complement arithmetic assumed for unsigned values).
54
 *
55
 * Although this is used in constant-time selects, we omit a value barrier
56
 * here.  Value barriers impede auto-vectorization (likely because it forces
57
 * the value to transit through a general-purpose register). On AArch64, this
58
 * is a difference of 2x.
59
 *
60
 * We usually add value barriers to selects because Clang turns consecutive
61
 * selects with the same condition into a branch instead of CMOV/CSEL. This
62
 * condition does not occur in Kyber, so omitting it seems to be safe so far,
63
 * but see |cbd_2|, |cbd_3|, where reduction needs to be specialised to the
64
 * sign of the input, rather than adding |q| in advance, and using the generic
65
 * |reduce_once|.  (David Benjamin, Chromium)
66
 */
67
#if 0
68
#define constish_time_non_zero(b) (~constant_time_is_zero(b));
69
#else
70
1.08G
#define constish_time_non_zero(b) (0u - (b))
71
#endif
72
73
/*
74
 * The scalar rejection-sampling buffer size needs to be a multiple of 12, but
75
 * is otherwise arbitrary, the preferred block size matches the internal buffer
76
 * size of SHAKE128, avoiding internal buffering and copying in SHAKE128. That
77
 * block size of (1600 - 256)/8 bytes, or 168, just happens to divide by 12!
78
 *
79
 * If the blocksize is unknown, or is not divisible by 12, 168 is used as a
80
 * fallback.
81
 */
82
#if defined(SHAKE128_BLOCKSIZE) && (SHAKE128_BLOCKSIZE) % 12 == 0
83
#define SCALAR_SAMPLING_BUFSIZE (SHAKE128_BLOCKSIZE)
84
#else
85
#define SCALAR_SAMPLING_BUFSIZE 168
86
#endif
87
88
/*
89
 * Structure of keys
90
 */
91
typedef struct ossl_ml_kem_scalar_st {
92
    /* On every function entry and exit, 0 <= c[i] < ML_KEM_PRIME. */
93
    uint16_t c[ML_KEM_DEGREE];
94
} scalar;
95
96
/* Key material allocation layout */
97
#define DECLARE_ML_KEM_KEYDATA(name, rank, private_sz)               \
98
    struct name##_alloc {                                            \
99
        /* Public vector |t| */                                      \
100
        scalar tbuf[(rank)];                                         \
101
        /* Pre-computed matrix |m| (FIPS 203 |A| transpose) */       \
102
        scalar mbuf[(rank) * (rank)] /* optional private key data */ \
103
            private_sz                                               \
104
    }
105
106
/* Declare variant-specific public and private storage */
107
#define DECLARE_ML_KEM_VARIANT_KEYDATA(bits)                        \
108
    DECLARE_ML_KEM_KEYDATA(pubkey_##bits, ML_KEM_##bits##_RANK, ;); \
109
    DECLARE_ML_KEM_KEYDATA(prvkey_##bits, ML_KEM_##bits##_RANK, ; scalar sbuf[ML_KEM_##bits##_RANK]; uint8_t zbuf[2 * ML_KEM_RANDOM_BYTES];)
110
DECLARE_ML_KEM_VARIANT_KEYDATA(512);
111
DECLARE_ML_KEM_VARIANT_KEYDATA(768);
112
DECLARE_ML_KEM_VARIANT_KEYDATA(1024);
113
#undef DECLARE_ML_KEM_VARIANT_KEYDATA
114
#undef DECLARE_ML_KEM_KEYDATA
115
116
typedef __owur int (*CBD_FUNC)(scalar *out, uint8_t in[ML_KEM_RANDOM_BYTES + 1],
117
    EVP_MD_CTX *mdctx, const ML_KEM_KEY *key);
118
static void scalar_encode(uint8_t *out, const scalar *s, int bits);
119
120
/*
121
 * The wire-form of a losslessly encoded vector uses 12-bits per element.
122
 *
123
 * The wire-form public key consists of the lossless encoding of the public
124
 * vector |t|, followed by the public seed |rho|.
125
 *
126
 * Our serialised private key concatenates serialisations of the private vector
127
 * |s|, the public key, the public key hash, and the failure secret |z|.
128
 */
129
#define VECTOR_BYTES(b) ((3 * DEGREE / 2) * ML_KEM_##b##_RANK)
130
#define PUBKEY_BYTES(b) (VECTOR_BYTES(b) + ML_KEM_RANDOM_BYTES)
131
#define PRVKEY_BYTES(b) (2 * PUBKEY_BYTES(b) + ML_KEM_PKHASH_BYTES)
132
133
/*
134
 * Encapsulation produces a vector "u" and a scalar "v", whose coordinates
135
 * (numbers modulo the ML-KEM prime "q") are lossily encoded using as "du" and
136
 * "dv" bits, respectively.  This encoding is the ciphertext input for
137
 * decapsulation.
138
 */
139
#define U_VECTOR_BYTES(b) ((DEGREE / 8) * ML_KEM_##b##_DU * ML_KEM_##b##_RANK)
140
#define V_SCALAR_BYTES(b) ((DEGREE / 8) * ML_KEM_##b##_DV)
141
#define CTEXT_BYTES(b) (U_VECTOR_BYTES(b) + V_SCALAR_BYTES(b))
142
143
#if defined(OPENSSL_CONSTANT_TIME_VALIDATION)
144
145
/*
146
 * CONSTTIME_SECRET takes a pointer and a number of bytes and marks that region
147
 * of memory as secret. Secret data is tracked as it flows to registers and
148
 * other parts of a memory. If secret data is used as a condition for a branch,
149
 * or as a memory index, it will trigger warnings in valgrind.
150
 */
151
#define CONSTTIME_SECRET(ptr, len) VALGRIND_MAKE_MEM_UNDEFINED(ptr, len)
152
153
/*
154
 * CONSTTIME_DECLASSIFY takes a pointer and a number of bytes and marks that
155
 * region of memory as public. Public data is not subject to constant-time
156
 * rules.
157
 */
158
#define CONSTTIME_DECLASSIFY(ptr, len) VALGRIND_MAKE_MEM_DEFINED(ptr, len)
159
160
#else
161
162
#define CONSTTIME_SECRET(ptr, len)
163
#define CONSTTIME_DECLASSIFY(ptr, len)
164
165
#endif
166
167
/*
168
 * Indices of slots in the vinfo tables below
169
 */
170
60.2k
#define ML_KEM_512_VINFO 0
171
154k
#define ML_KEM_768_VINFO 1
172
56.3k
#define ML_KEM_1024_VINFO 2
173
174
/*
175
 * Per-variant fixed parameters
176
 */
177
static const ML_KEM_VINFO vinfo_map[3] = {
178
    { "ML-KEM-512",
179
        PRVKEY_BYTES(512),
180
        sizeof(struct prvkey_512_alloc),
181
        PUBKEY_BYTES(512),
182
        sizeof(struct pubkey_512_alloc),
183
        CTEXT_BYTES(512),
184
        VECTOR_BYTES(512),
185
        U_VECTOR_BYTES(512),
186
        EVP_PKEY_ML_KEM_512,
187
        ML_KEM_512_BITS,
188
        ML_KEM_512_RANK,
189
        ML_KEM_512_DU,
190
        ML_KEM_512_DV,
191
        ML_KEM_512_SECBITS },
192
    { "ML-KEM-768",
193
        PRVKEY_BYTES(768),
194
        sizeof(struct prvkey_768_alloc),
195
        PUBKEY_BYTES(768),
196
        sizeof(struct pubkey_768_alloc),
197
        CTEXT_BYTES(768),
198
        VECTOR_BYTES(768),
199
        U_VECTOR_BYTES(768),
200
        EVP_PKEY_ML_KEM_768,
201
        ML_KEM_768_BITS,
202
        ML_KEM_768_RANK,
203
        ML_KEM_768_DU,
204
        ML_KEM_768_DV,
205
        ML_KEM_768_SECBITS },
206
    { "ML-KEM-1024",
207
        PRVKEY_BYTES(1024),
208
        sizeof(struct prvkey_1024_alloc),
209
        PUBKEY_BYTES(1024),
210
        sizeof(struct pubkey_1024_alloc),
211
        CTEXT_BYTES(1024),
212
        VECTOR_BYTES(1024),
213
        U_VECTOR_BYTES(1024),
214
        EVP_PKEY_ML_KEM_1024,
215
        ML_KEM_1024_BITS,
216
        ML_KEM_1024_RANK,
217
        ML_KEM_1024_DU,
218
        ML_KEM_1024_DV,
219
        ML_KEM_1024_SECBITS }
220
};
221
222
/*
223
 * Remainders modulo `kPrime`, for sufficiently small inputs, are computed in
224
 * constant time via Barrett reduction, and a final call to reduce_once(),
225
 * which reduces inputs that are at most 2*kPrime and is also constant-time.
226
 */
227
static const int kPrime = ML_KEM_PRIME;
228
static const unsigned int kBarrettShift = BARRETT_SHIFT;
229
static const size_t kBarrettMultiplier = (1 << BARRETT_SHIFT) / ML_KEM_PRIME;
230
static const uint16_t kHalfPrime = (ML_KEM_PRIME - 1) / 2;
231
static const uint16_t kInverseDegree = INVERSE_DEGREE;
232
233
/*
234
 * Python helper:
235
 *
236
 * p = 3329
237
 * def bitreverse(i):
238
 *     ret = 0
239
 *     for n in range(7):
240
 *         bit = i & 1
241
 *         ret <<= 1
242
 *         ret |= bit
243
 *         i >>= 1
244
 *     return ret
245
 */
246
247
/*-
248
 * First precomputed array from Appendix A of FIPS 203, or else Python:
249
 * kNTTRoots = [pow(17, bitreverse(i), p) for i in range(128)]
250
 */
251
static const uint16_t kNTTRoots[128] = {
252
    1,
253
    1729,
254
    2580,
255
    3289,
256
    2642,
257
    630,
258
    1897,
259
    848,
260
    1062,
261
    1919,
262
    193,
263
    797,
264
    2786,
265
    3260,
266
    569,
267
    1746,
268
    296,
269
    2447,
270
    1339,
271
    1476,
272
    3046,
273
    56,
274
    2240,
275
    1333,
276
    1426,
277
    2094,
278
    535,
279
    2882,
280
    2393,
281
    2879,
282
    1974,
283
    821,
284
    289,
285
    331,
286
    3253,
287
    1756,
288
    1197,
289
    2304,
290
    2277,
291
    2055,
292
    650,
293
    1977,
294
    2513,
295
    632,
296
    2865,
297
    33,
298
    1320,
299
    1915,
300
    2319,
301
    1435,
302
    807,
303
    452,
304
    1438,
305
    2868,
306
    1534,
307
    2402,
308
    2647,
309
    2617,
310
    1481,
311
    648,
312
    2474,
313
    3110,
314
    1227,
315
    910,
316
    17,
317
    2761,
318
    583,
319
    2649,
320
    1637,
321
    723,
322
    2288,
323
    1100,
324
    1409,
325
    2662,
326
    3281,
327
    233,
328
    756,
329
    2156,
330
    3015,
331
    3050,
332
    1703,
333
    1651,
334
    2789,
335
    1789,
336
    1847,
337
    952,
338
    1461,
339
    2687,
340
    939,
341
    2308,
342
    2437,
343
    2388,
344
    733,
345
    2337,
346
    268,
347
    641,
348
    1584,
349
    2298,
350
    2037,
351
    3220,
352
    375,
353
    2549,
354
    2090,
355
    1645,
356
    1063,
357
    319,
358
    2773,
359
    757,
360
    2099,
361
    561,
362
    2466,
363
    2594,
364
    2804,
365
    1092,
366
    403,
367
    1026,
368
    1143,
369
    2150,
370
    2775,
371
    886,
372
    1722,
373
    1212,
374
    1874,
375
    1029,
376
    2110,
377
    2935,
378
    885,
379
    2154,
380
};
381
382
/*
383
 * InverseNTTRoots = [pow(17, -bitreverse(i), p) for i in range(128)]
384
 * Listed in order of use in the inverse NTT loop (index 0 is skipped):
385
 *
386
 *  0, 64, 65, ..., 127, 32, 33, ..., 63, 16, 17, ..., 31, 8, 9, ...
387
 */
388
static const uint16_t kInverseNTTRoots[128] = {
389
    1,
390
    1175,
391
    2444,
392
    394,
393
    1219,
394
    2300,
395
    1455,
396
    2117,
397
    1607,
398
    2443,
399
    554,
400
    1179,
401
    2186,
402
    2303,
403
    2926,
404
    2237,
405
    525,
406
    735,
407
    863,
408
    2768,
409
    1230,
410
    2572,
411
    556,
412
    3010,
413
    2266,
414
    1684,
415
    1239,
416
    780,
417
    2954,
418
    109,
419
    1292,
420
    1031,
421
    1745,
422
    2688,
423
    3061,
424
    992,
425
    2596,
426
    941,
427
    892,
428
    1021,
429
    2390,
430
    642,
431
    1868,
432
    2377,
433
    1482,
434
    1540,
435
    540,
436
    1678,
437
    1626,
438
    279,
439
    314,
440
    1173,
441
    2573,
442
    3096,
443
    48,
444
    667,
445
    1920,
446
    2229,
447
    1041,
448
    2606,
449
    1692,
450
    680,
451
    2746,
452
    568,
453
    3312,
454
    2419,
455
    2102,
456
    219,
457
    855,
458
    2681,
459
    1848,
460
    712,
461
    682,
462
    927,
463
    1795,
464
    461,
465
    1891,
466
    2877,
467
    2522,
468
    1894,
469
    1010,
470
    1414,
471
    2009,
472
    3296,
473
    464,
474
    2697,
475
    816,
476
    1352,
477
    2679,
478
    1274,
479
    1052,
480
    1025,
481
    2132,
482
    1573,
483
    76,
484
    2998,
485
    3040,
486
    2508,
487
    1355,
488
    450,
489
    936,
490
    447,
491
    2794,
492
    1235,
493
    1903,
494
    1996,
495
    1089,
496
    3273,
497
    283,
498
    1853,
499
    1990,
500
    882,
501
    3033,
502
    1583,
503
    2760,
504
    69,
505
    543,
506
    2532,
507
    3136,
508
    1410,
509
    2267,
510
    2481,
511
    1432,
512
    2699,
513
    687,
514
    40,
515
    749,
516
    1600,
517
};
518
519
/*
520
 * Second precomputed array from Appendix A of FIPS 203 (normalised positive),
521
 * or else Python:
522
 * ModRoots = [pow(17, 2*bitreverse(i) + 1, p) for i in range(128)]
523
 */
524
static const uint16_t kModRoots[128] = {
525
    17,
526
    3312,
527
    2761,
528
    568,
529
    583,
530
    2746,
531
    2649,
532
    680,
533
    1637,
534
    1692,
535
    723,
536
    2606,
537
    2288,
538
    1041,
539
    1100,
540
    2229,
541
    1409,
542
    1920,
543
    2662,
544
    667,
545
    3281,
546
    48,
547
    233,
548
    3096,
549
    756,
550
    2573,
551
    2156,
552
    1173,
553
    3015,
554
    314,
555
    3050,
556
    279,
557
    1703,
558
    1626,
559
    1651,
560
    1678,
561
    2789,
562
    540,
563
    1789,
564
    1540,
565
    1847,
566
    1482,
567
    952,
568
    2377,
569
    1461,
570
    1868,
571
    2687,
572
    642,
573
    939,
574
    2390,
575
    2308,
576
    1021,
577
    2437,
578
    892,
579
    2388,
580
    941,
581
    733,
582
    2596,
583
    2337,
584
    992,
585
    268,
586
    3061,
587
    641,
588
    2688,
589
    1584,
590
    1745,
591
    2298,
592
    1031,
593
    2037,
594
    1292,
595
    3220,
596
    109,
597
    375,
598
    2954,
599
    2549,
600
    780,
601
    2090,
602
    1239,
603
    1645,
604
    1684,
605
    1063,
606
    2266,
607
    319,
608
    3010,
609
    2773,
610
    556,
611
    757,
612
    2572,
613
    2099,
614
    1230,
615
    561,
616
    2768,
617
    2466,
618
    863,
619
    2594,
620
    735,
621
    2804,
622
    525,
623
    1092,
624
    2237,
625
    403,
626
    2926,
627
    1026,
628
    2303,
629
    1143,
630
    2186,
631
    2150,
632
    1179,
633
    2775,
634
    554,
635
    886,
636
    2443,
637
    1722,
638
    1607,
639
    1212,
640
    2117,
641
    1874,
642
    1455,
643
    1029,
644
    2300,
645
    2110,
646
    1219,
647
    2935,
648
    394,
649
    885,
650
    2444,
651
    2154,
652
    1175,
653
};
654
655
/*
656
 * single_keccak hashes |inlen| bytes from |in| and writes |outlen| bytes of
657
 * output to |out|. If the |md| specifies a fixed-output function, like
658
 * SHA3-256, then |outlen| must be the correct length for that function.
659
 */
660
static __owur int single_keccak(uint8_t *out, size_t outlen, const uint8_t *in, size_t inlen,
661
    EVP_MD_CTX *mdctx)
662
360k
{
663
360k
    unsigned int sz = (unsigned int)outlen;
664
665
360k
    if (!EVP_DigestUpdate(mdctx, in, inlen))
666
0
        return 0;
667
360k
    if (EVP_MD_xof(EVP_MD_CTX_get0_md(mdctx)))
668
308k
        return EVP_DigestFinalXOF(mdctx, out, outlen);
669
51.7k
    return EVP_DigestFinal_ex(mdctx, out, &sz)
670
51.7k
        && ossl_assert((size_t)sz == outlen);
671
360k
}
672
673
/*
674
 * FIPS 203, Section 4.1, equation (4.3): PRF. Takes 32+1 input bytes, and uses
675
 * SHAKE256 to produce the input to SamplePolyCBD_eta: FIPS 203, algorithm 8.
676
 */
677
static __owur int prf(uint8_t *out, size_t len, const uint8_t in[ML_KEM_RANDOM_BYTES + 1],
678
    EVP_MD_CTX *mdctx, const ML_KEM_KEY *key)
679
308k
{
680
308k
    return EVP_DigestInit_ex(mdctx, key->shake256_md, NULL)
681
308k
        && single_keccak(out, len, in, ML_KEM_RANDOM_BYTES + 1, mdctx);
682
308k
}
683
684
/*
685
 * FIPS 203, Section 4.1, equation (4.4): H.  SHA3-256 hash of a variable
686
 * length input, producing 32 bytes of output.
687
 */
688
static __owur int hash_h(uint8_t out[ML_KEM_PKHASH_BYTES], const uint8_t *in, size_t len,
689
    EVP_MD_CTX *mdctx, const ML_KEM_KEY *key)
690
335
{
691
335
    return EVP_DigestInit_ex(mdctx, key->sha3_256_md, NULL)
692
335
        && single_keccak(out, ML_KEM_PKHASH_BYTES, in, len, mdctx);
693
335
}
694
695
/* Incremental hash_h of expanded public key */
696
static int
697
hash_h_pubkey(uint8_t pkhash[ML_KEM_PKHASH_BYTES],
698
    EVP_MD_CTX *mdctx, ML_KEM_KEY *key)
699
51.1k
{
700
51.1k
    const ML_KEM_VINFO *vinfo = key->vinfo;
701
51.1k
    const scalar *t = key->t, *end = t + vinfo->rank;
702
51.1k
    unsigned int sz;
703
704
51.1k
    if (!EVP_DigestInit_ex(mdctx, key->sha3_256_md, NULL))
705
0
        return 0;
706
707
153k
    do {
708
153k
        uint8_t buf[3 * DEGREE / 2];
709
710
153k
        scalar_encode(buf, t++, 12);
711
153k
        if (!EVP_DigestUpdate(mdctx, buf, sizeof(buf)))
712
0
            return 0;
713
153k
    } while (t < end);
714
715
51.1k
    if (!EVP_DigestUpdate(mdctx, key->rho, ML_KEM_RANDOM_BYTES))
716
0
        return 0;
717
51.1k
    return EVP_DigestFinal_ex(mdctx, pkhash, &sz)
718
51.1k
        && ossl_assert(sz == ML_KEM_PKHASH_BYTES);
719
51.1k
}
720
721
/*
722
 * FIPS 203, Section 4.1, equation (4.5): G.  SHA3-512 hash of a variable
723
 * length input, producing 64 bytes of output, in particular the seeds
724
 * (d,z) for key generation.
725
 */
726
static __owur int hash_g(uint8_t out[ML_KEM_SEED_BYTES], const uint8_t *in, size_t len,
727
    EVP_MD_CTX *mdctx, const ML_KEM_KEY *key)
728
51.4k
{
729
51.4k
    return EVP_DigestInit_ex(mdctx, key->sha3_512_md, NULL)
730
51.4k
        && single_keccak(out, ML_KEM_SEED_BYTES, in, len, mdctx);
731
51.4k
}
732
733
/*
734
 * FIPS 203, Section 4.1, equation (4.4): J. SHAKE256 taking a variable length
735
 * input to compute a 32-byte implicit rejection shared secret, of the same
736
 * length as the expected shared secret.  (Computed even on success to avoid
737
 * side-channel leaks).
738
 */
739
static __owur int kdf(uint8_t out[ML_KEM_SHARED_SECRET_BYTES],
740
    const uint8_t z[ML_KEM_RANDOM_BYTES],
741
    const uint8_t *ctext, size_t len,
742
    EVP_MD_CTX *mdctx, const ML_KEM_KEY *key)
743
132
{
744
132
    return EVP_DigestInit_ex(mdctx, key->shake256_md, NULL)
745
132
        && EVP_DigestUpdate(mdctx, z, ML_KEM_RANDOM_BYTES)
746
132
        && EVP_DigestUpdate(mdctx, ctext, len)
747
132
        && EVP_DigestFinalXOF(mdctx, out, ML_KEM_SHARED_SECRET_BYTES);
748
132
}
749
750
/*
751
 * FIPS 203, Section 4.2.2, Algorithm 7: "SampleNTT" (steps 3-17, steps 1, 2
752
 * are performed by the caller). Rejection-samples a Keccak stream to get
753
 * uniformly distributed elements in the range [0,q). This is used for matrix
754
 * expansion and only operates on public inputs.
755
 */
756
static __owur int sample_scalar(scalar *out, EVP_MD_CTX *mdctx)
757
462k
{
758
462k
    uint16_t *curr = out->c, *endout = curr + DEGREE;
759
462k
    uint8_t buf[SCALAR_SAMPLING_BUFSIZE], *in;
760
462k
    uint8_t *endin = buf + sizeof(buf);
761
462k
    uint16_t d;
762
462k
    uint8_t b1, b2, b3;
763
764
1.38M
    do {
765
1.38M
        if (!EVP_DigestSqueeze(mdctx, in = buf, sizeof(buf)))
766
0
            return 0;
767
72.5M
        do {
768
72.5M
            b1 = *in++;
769
72.5M
            b2 = *in++;
770
72.5M
            b3 = *in++;
771
772
72.5M
            if (curr >= endout)
773
155k
                break;
774
72.4M
            if ((d = ((b2 & 0x0f) << 8) + b1) < kPrime)
775
59.7M
                *curr++ = d;
776
72.4M
            if (curr >= endout)
777
307k
                break;
778
72.0M
            if ((d = (b3 << 4) + (b2 >> 4)) < kPrime)
779
58.7M
                *curr++ = d;
780
72.0M
        } while (in < endin);
781
1.38M
    } while (curr < endout);
782
462k
    return 1;
783
462k
}
784
785
/*-
786
 * reduce_once reduces 0 <= x < 2*kPrime, mod kPrime.
787
 *
788
 * Subtract |q| if the input is larger, without exposing a side-channel,
789
 * avoiding the "clangover" attack.  See |constish_time_non_zero| for a
790
 * discussion on why the value barrier is by default omitted.
791
 */
792
static __owur uint16_t reduce_once(uint16_t x)
793
1.00G
{
794
1.00G
    const uint16_t subtracted = x - kPrime;
795
1.00G
    uint16_t mask = constish_time_non_zero(subtracted >> 15);
796
797
1.00G
    return (mask & x) | (~mask & subtracted);
798
1.00G
}
799
800
/*
801
 * Constant-time reduce x mod kPrime using Barrett reduction. x must be less
802
 * than kPrime + 2 * kPrime^2.  This is sufficient to reduce a product of
803
 * two already reduced u_int16 values, in fact it is sufficient for each
804
 * to be less than 2^12, because (kPrime * (2 * kPrime + 1)) > 2^24.
805
 */
806
static __owur uint16_t reduce(uint32_t x)
807
455M
{
808
455M
    uint64_t product = (uint64_t)x * kBarrettMultiplier;
809
455M
    uint32_t quotient = (uint32_t)(product >> kBarrettShift);
810
455M
    uint32_t remainder = x - quotient * kPrime;
811
812
455M
    return reduce_once(remainder);
813
455M
}
814
815
/* Multiply a scalar by a constant. */
816
static void scalar_mult_const(scalar *s, uint16_t a)
817
1.25k
{
818
1.25k
    uint16_t *curr = s->c, *end = curr + DEGREE, tmp;
819
820
320k
    do {
821
320k
        tmp = reduce(*curr * a);
822
320k
        *curr++ = tmp;
823
320k
    } while (curr < end);
824
1.25k
}
825
826
/*-
827
 * FIPS 203, Section 4.3, Algorithm 9: "NTT".
828
 * In-place number theoretic transform of a given scalar.  Note that ML-KEM's
829
 * kPrime 3329 does not have a 512th root of unity, so this transform leaves
830
 * off the last iteration of the usual FFT code, with the 128 relevant roots of
831
 * unity being stored in NTTRoots.  This means the output should be seen as 128
832
 * elements in GF(3329^2), with the coefficients of the elements being
833
 * consecutive entries in |s->c|.
834
 */
835
static void scalar_ntt(scalar *s)
836
308k
{
837
308k
    const uint16_t *roots = kNTTRoots;
838
308k
    uint16_t *end = s->c + DEGREE;
839
308k
    int offset = DEGREE / 2;
840
841
2.15M
    do {
842
2.15M
        uint16_t *curr = s->c, *peer;
843
844
39.1M
        do {
845
39.1M
            uint16_t *pause = curr + offset, even, odd;
846
39.1M
            uint32_t zeta = *++roots;
847
848
39.1M
            peer = pause;
849
276M
            do {
850
276M
                even = *curr;
851
276M
                odd = reduce(*peer * zeta);
852
276M
                *peer++ = reduce_once(even - odd + kPrime);
853
276M
                *curr++ = reduce_once(odd + even);
854
276M
            } while (curr < pause);
855
39.1M
        } while ((curr = peer) < end);
856
2.15M
    } while ((offset >>= 1) >= 2);
857
308k
}
858
859
/*-
860
 * FIPS 203, Section 4.3, Algorithm 10: "NTT^(-1)".
861
 * In-place inverse number theoretic transform of a given scalar, with pairs of
862
 * entries of s->v being interpreted as elements of GF(3329^2). Just as with
863
 * the number theoretic transform, this leaves off the first step of the normal
864
 * iFFT to account for the fact that 3329 does not have a 512th root of unity,
865
 * using the precomputed 128 roots of unity stored in InverseNTTRoots.
866
 */
867
static void scalar_inverse_ntt(scalar *s)
868
1.25k
{
869
1.25k
    const uint16_t *roots = kInverseNTTRoots;
870
1.25k
    uint16_t *end = s->c + DEGREE;
871
1.25k
    int offset = 2;
872
873
8.75k
    do {
874
8.75k
        uint16_t *curr = s->c, *peer;
875
876
158k
        do {
877
158k
            uint16_t *pause = curr + offset, even, odd;
878
158k
            uint32_t zeta = *++roots;
879
880
158k
            peer = pause;
881
1.12M
            do {
882
1.12M
                even = *curr;
883
1.12M
                odd = *peer;
884
1.12M
                *peer++ = reduce(zeta * (even - odd + kPrime));
885
1.12M
                *curr++ = reduce_once(odd + even);
886
1.12M
            } while (curr < pause);
887
158k
        } while ((curr = peer) < end);
888
8.75k
    } while ((offset <<= 1) < DEGREE);
889
1.25k
    scalar_mult_const(s, kInverseDegree);
890
1.25k
}
891
892
/* Addition updating the LHS scalar in-place. */
893
static void scalar_add(scalar *lhs, const scalar *rhs)
894
1.11k
{
895
1.11k
    int i;
896
897
287k
    for (i = 0; i < DEGREE; i++)
898
286k
        lhs->c[i] = reduce_once(lhs->c[i] + rhs->c[i]);
899
1.11k
}
900
901
/* Subtraction updating the LHS scalar in-place. */
902
static void scalar_sub(scalar *lhs, const scalar *rhs)
903
132
{
904
132
    int i;
905
906
33.9k
    for (i = 0; i < DEGREE; i++)
907
33.7k
        lhs->c[i] = reduce_once(lhs->c[i] - rhs->c[i] + kPrime);
908
132
}
909
910
/*
911
 * Multiplying two scalars in the number theoretically transformed state. Since
912
 * 3329 does not have a 512th root of unity, this means we have to interpret
913
 * the 2*ith and (2*i+1)th entries of the scalar as elements of
914
 * GF(3329)[X]/(X^2 - 17^(2*bitreverse(i)+1)).
915
 *
916
 * The value of 17^(2*bitreverse(i)+1) mod 3329 is stored in the precomputed
917
 * ModRoots table. Note that our Barrett transform only allows us to multiply
918
 * two reduced numbers together, so we need some intermediate reduction steps,
919
 * even if an uint64_t could hold 3 multiplied numbers.
920
 */
921
static void scalar_mult(scalar *out, const scalar *lhs,
922
    const scalar *rhs)
923
1.25k
{
924
1.25k
    uint16_t *curr = out->c, *end = curr + DEGREE;
925
1.25k
    const uint16_t *lc = lhs->c, *rc = rhs->c;
926
1.25k
    const uint16_t *roots = kModRoots;
927
928
160k
    do {
929
160k
        uint32_t l0 = *lc++, r0 = *rc++;
930
160k
        uint32_t l1 = *lc++, r1 = *rc++;
931
160k
        uint32_t zetapow = *roots++;
932
933
160k
        *curr++ = reduce(l0 * r0 + reduce(l1 * r1) * zetapow);
934
160k
        *curr++ = reduce(l0 * r1 + l1 * r0);
935
160k
    } while (curr < end);
936
1.25k
}
937
938
/* Above, but add the result to an existing scalar */
939
static ossl_inline void scalar_mult_add(scalar *out, const scalar *lhs,
940
    const scalar *rhs)
941
462k
{
942
462k
    uint16_t *curr = out->c, *end = curr + DEGREE;
943
462k
    const uint16_t *lc = lhs->c, *rc = rhs->c;
944
462k
    const uint16_t *roots = kModRoots;
945
946
59.2M
    do {
947
59.2M
        uint32_t l0 = *lc++, r0 = *rc++;
948
59.2M
        uint32_t l1 = *lc++, r1 = *rc++;
949
59.2M
        uint16_t *c0 = curr++;
950
59.2M
        uint16_t *c1 = curr++;
951
59.2M
        uint32_t zetapow = *roots++;
952
953
59.2M
        *c0 = reduce(*c0 + l0 * r0 + reduce(l1 * r1) * zetapow);
954
59.2M
        *c1 = reduce(*c1 + l0 * r1 + l1 * r0);
955
59.2M
    } while (curr < end);
956
462k
}
957
958
/*-
959
 * FIPS 203, Section 4.2.1, Algorithm 5: "ByteEncode_d", for 2<=d<=12.
960
 * Here |bits| is |d|.  For efficiency, we handle the d=1 case separately.
961
 */
962
static void scalar_encode(uint8_t *out, const scalar *s, int bits)
963
308k
{
964
308k
    const uint16_t *curr = s->c, *end = curr + DEGREE;
965
308k
    uint64_t accum = 0, element;
966
308k
    int used = 0;
967
968
78.9M
    do {
969
78.9M
        element = *curr++;
970
78.9M
        if (used + bits < 64) {
971
64.1M
            accum |= element << used;
972
64.1M
            used += bits;
973
64.1M
        } else if (used + bits > 64) {
974
9.85M
            out = OPENSSL_store_u64_le(out, accum | (element << used));
975
9.85M
            accum = element >> (64 - used);
976
9.85M
            used = (used + bits) - 64;
977
9.85M
        } else {
978
4.92M
            out = OPENSSL_store_u64_le(out, accum | (element << used));
979
4.92M
            accum = 0;
980
4.92M
            used = 0;
981
4.92M
        }
982
78.9M
    } while (curr < end);
983
308k
}
984
985
/*
986
 * scalar_encode_1 is |scalar_encode| specialised for |bits| == 1.
987
 */
988
static void scalar_encode_1(uint8_t out[DEGREE / 8], const scalar *s)
989
132
{
990
132
    int i, j;
991
132
    uint8_t out_byte;
992
993
4.35k
    for (i = 0; i < DEGREE; i += 8) {
994
4.22k
        out_byte = 0;
995
38.0k
        for (j = 0; j < 8; j++)
996
33.7k
            out_byte |= bit0(s->c[i + j]) << j;
997
4.22k
        *out = out_byte;
998
4.22k
        out++;
999
4.22k
    }
1000
132
}
1001
1002
/*-
1003
 * FIPS 203, Section 4.2.1, Algorithm 6: "ByteDecode_d", for 2<=d<12.
1004
 * Here |bits| is |d|.  For efficiency, we handle the d=1 and d=12 cases
1005
 * separately.
1006
 *
1007
 * scalar_decode parses |DEGREE * bits| bits from |in| into |DEGREE| values in
1008
 * |out|.
1009
 */
1010
static void scalar_decode(scalar *out, const uint8_t *in, int bits)
1011
511
{
1012
511
    uint16_t *curr = out->c, *end = curr + DEGREE;
1013
511
    uint64_t accum = 0;
1014
511
    int accum_bits = 0, todo = bits;
1015
511
    uint16_t bitmask = (((uint16_t)1) << bits) - 1, mask = bitmask;
1016
511
    uint16_t element = 0;
1017
1018
144k
    do {
1019
144k
        if (accum_bits == 0) {
1020
18.0k
            in = OPENSSL_load_u64_le(&accum, in);
1021
18.0k
            accum_bits = 64;
1022
18.0k
        }
1023
144k
        if (todo == bits && accum_bits >= bits) {
1024
            /* No partial "element", and all the required bits available */
1025
116k
            *curr++ = ((uint16_t)accum) & mask;
1026
116k
            accum >>= bits;
1027
116k
            accum_bits -= bits;
1028
116k
        } else if (accum_bits >= todo) {
1029
            /* A partial "element", and all the required bits available */
1030
13.9k
            *curr++ = element | ((((uint16_t)accum) & mask) << (bits - todo));
1031
13.9k
            accum >>= todo;
1032
13.9k
            accum_bits -= todo;
1033
13.9k
            element = 0;
1034
13.9k
            todo = bits;
1035
13.9k
            mask = bitmask;
1036
13.9k
        } else {
1037
            /*
1038
             * Only some of the requisite bits accumulated, store |accum_bits|
1039
             * of these in |element|.  The accumulated bitcount becomes 0, but
1040
             * as soon as we have more bits we'll want to merge accum_bits
1041
             * fewer of them into the final |element|.
1042
             *
1043
             * Note that with a 64-bit accumulator and |bits| always 12 or
1044
             * less, if we're here, the previous iteration had all the
1045
             * requisite bits, and so there are no kept bits in |element|.
1046
             */
1047
13.9k
            element = ((uint16_t)accum) & mask;
1048
13.9k
            todo -= accum_bits;
1049
13.9k
            mask = bitmask >> accum_bits;
1050
13.9k
            accum_bits = 0;
1051
13.9k
        }
1052
144k
    } while (curr < end);
1053
511
}
1054
1055
static __owur int scalar_decode_12(scalar *out, const uint8_t in[3 * DEGREE / 2])
1056
1.30k
{
1057
1.30k
    int i;
1058
1.30k
    uint16_t *c = out->c;
1059
1060
134k
    for (i = 0; i < DEGREE / 2; ++i) {
1061
133k
        uint8_t b1 = *in++;
1062
133k
        uint8_t b2 = *in++;
1063
133k
        uint8_t b3 = *in++;
1064
133k
        int outOfRange1 = (*c++ = b1 | ((b2 & 0x0f) << 8)) >= kPrime;
1065
133k
        int outOfRange2 = (*c++ = (b2 >> 4) | (b3 << 4)) >= kPrime;
1066
1067
133k
        if (outOfRange1 | outOfRange2)
1068
314
            return 0;
1069
133k
    }
1070
990
    return 1;
1071
1.30k
}
1072
1073
/*-
1074
 * scalar_decode_decompress_add is a combination of decoding and decompression
1075
 * both specialised for |bits| == 1, with the result added (and sum reduced) to
1076
 * the output scalar.
1077
 *
1078
 * NOTE: this function MUST not leak an input-data-depedennt timing signal.
1079
 * A timing leak in a related function in the reference Kyber implementation
1080
 * made the "clangover" attack (CVE-2024-37880) possible, giving key recovery
1081
 * for ML-KEM-512 in minutes, provided the attacker has access to precise
1082
 * timing of a CPU performing chosen-ciphertext decap.  Admittedly this is only
1083
 * a risk when private keys are reused (perhaps KEMTLS servers).
1084
 */
1085
static void
1086
scalar_decode_decompress_add(scalar *out, const uint8_t in[DEGREE / 8])
1087
288
{
1088
288
    static const uint16_t half_q_plus_1 = (ML_KEM_PRIME >> 1) + 1;
1089
288
    uint16_t *curr = out->c, *end = curr + DEGREE;
1090
288
    uint16_t mask;
1091
288
    uint8_t b;
1092
1093
    /*
1094
     * Add |half_q_plus_1| if the bit is set, without exposing a side-channel,
1095
     * avoiding the "clangover" attack.  See |constish_time_non_zero| for a
1096
     * discussion on why the value barrier is by default omitted.
1097
     */
1098
288
#define decode_decompress_add_bit                        \
1099
73.7k
    mask = constish_time_non_zero(bit0(b));              \
1100
73.7k
    *curr = reduce_once(*curr + (mask & half_q_plus_1)); \
1101
73.7k
    curr++;                                              \
1102
73.7k
    b >>= 1
1103
1104
    /* Unrolled to process each byte in one iteration */
1105
9.21k
    do {
1106
9.21k
        b = *in++;
1107
9.21k
        decode_decompress_add_bit;
1108
9.21k
        decode_decompress_add_bit;
1109
9.21k
        decode_decompress_add_bit;
1110
9.21k
        decode_decompress_add_bit;
1111
1112
9.21k
        decode_decompress_add_bit;
1113
9.21k
        decode_decompress_add_bit;
1114
9.21k
        decode_decompress_add_bit;
1115
9.21k
        decode_decompress_add_bit;
1116
9.21k
    } while (curr < end);
1117
288
#undef decode_decompress_add_bit
1118
288
}
1119
1120
/*
1121
 * FIPS 203, Section 4.2.1, Equation (4.7): Compress_d.
1122
 *
1123
 * Compresses (lossily) an input |x| mod 3329 into |bits| many bits by grouping
1124
 * numbers close to each other together. The formula used is
1125
 * round(2^|bits|/kPrime*x) mod 2^|bits|.
1126
 * Uses Barrett reduction to achieve constant time. Since we need both the
1127
 * remainder (for rounding) and the quotient (as the result), we cannot use
1128
 * |reduce| here, but need to do the Barrett reduction directly.
1129
 */
1130
static __owur uint16_t compress(uint16_t x, int bits)
1131
320k
{
1132
320k
    uint32_t shifted = (uint32_t)x << bits;
1133
320k
    uint64_t product = (uint64_t)shifted * kBarrettMultiplier;
1134
320k
    uint32_t quotient = (uint32_t)(product >> kBarrettShift);
1135
320k
    uint32_t remainder = shifted - quotient * kPrime;
1136
1137
    /*
1138
     * Adjust the quotient to round correctly:
1139
     *   0 <= remainder <= kHalfPrime round to 0
1140
     *   kHalfPrime < remainder <= kPrime + kHalfPrime round to 1
1141
     *   kPrime + kHalfPrime < remainder < 2 * kPrime round to 2
1142
     */
1143
320k
    quotient += 1 & constant_time_lt_32(kHalfPrime, remainder);
1144
320k
    quotient += 1 & constant_time_lt_32(kPrime + kHalfPrime, remainder);
1145
320k
    return quotient & ((1 << bits) - 1);
1146
320k
}
1147
1148
/*
1149
 * FIPS 203, Section 4.2.1, Equation (4.8): Decompress_d.
1150
1151
 * Decompresses |x| by using a close equi-distant representative. The formula
1152
 * is round(kPrime/2^|bits|*x). Note that 2^|bits| being the divisor allows us
1153
 * to implement this logic using only bit operations.
1154
 */
1155
static __owur uint16_t decompress(uint16_t x, int bits)
1156
130k
{
1157
130k
    uint32_t product = (uint32_t)x * kPrime;
1158
130k
    uint32_t power = 1 << bits;
1159
    /* This is |product| % power, since |power| is a power of 2. */
1160
130k
    uint32_t remainder = product & (power - 1);
1161
    /* This is |product| / power, since |power| is a power of 2. */
1162
130k
    uint32_t lower = product >> bits;
1163
1164
    /*
1165
     * The rounding logic works since the first half of numbers mod |power|
1166
     * have a 0 as first bit, and the second half has a 1 as first bit, since
1167
     * |power| is a power of 2. As a 12 bit number, |remainder| is always
1168
     * positive, so we will shift in 0s for a right shift.
1169
     */
1170
130k
    return lower + (remainder >> (bits - 1));
1171
130k
}
1172
1173
/*-
1174
 * FIPS 203, Section 4.2.1, Equation (4.7): "Compress_d".
1175
 * In-place lossy rounding of scalars to 2^d bits.
1176
 */
1177
static void scalar_compress(scalar *s, int bits)
1178
1.25k
{
1179
1.25k
    int i;
1180
1181
321k
    for (i = 0; i < DEGREE; i++)
1182
320k
        s->c[i] = compress(s->c[i], bits);
1183
1.25k
}
1184
1185
/*
1186
 * FIPS 203, Section 4.2.1, Equation (4.8): "Decompress_d".
1187
 * In-place approximate recovery of scalars from 2^d bit compression.
1188
 */
1189
static void scalar_decompress(scalar *s, int bits)
1190
511
{
1191
511
    int i;
1192
1193
131k
    for (i = 0; i < DEGREE; i++)
1194
130k
        s->c[i] = decompress(s->c[i], bits);
1195
511
}
1196
1197
/* Addition updating the LHS vector in-place. */
1198
static void vector_add(scalar *lhs, const scalar *rhs, int rank)
1199
288
{
1200
830
    do {
1201
830
        scalar_add(lhs++, rhs++);
1202
830
    } while (--rank > 0);
1203
288
}
1204
1205
/*
1206
 * Encodes an entire vector into 32*|rank|*|bits| bytes. Note that since 256
1207
 * (DEGREE) is divisible by 8, the individual vector entries will always fill a
1208
 * whole number of bytes, so we do not need to worry about bit packing here.
1209
 */
1210
static void vector_encode(uint8_t *out, const scalar *a, int bits, int rank)
1211
51.5k
{
1212
51.5k
    int stride = bits * DEGREE / 8;
1213
1214
206k
    for (; rank-- > 0; out += stride)
1215
154k
        scalar_encode(out, a++, bits);
1216
51.5k
}
1217
1218
/*
1219
 * Decodes 32*|rank|*|bits| bytes from |in| into |out|. It returns early
1220
 * if any parsed value is >= |ML_KEM_PRIME|.  The resulting scalars are
1221
 * then decompressed and transformed via the NTT.
1222
 *
1223
 * Note: Used only in decrypt_cpa(), which returns void and so does not check
1224
 * the return value of this function.  Side-channels are fine when the input
1225
 * ciphertext to decap() is simply syntactically invalid.
1226
 */
1227
static void
1228
vector_decode_decompress_ntt(scalar *out, const uint8_t *in, int bits, int rank)
1229
132
{
1230
132
    int stride = bits * DEGREE / 8;
1231
1232
511
    for (; rank-- > 0; in += stride, ++out) {
1233
379
        scalar_decode(out, in, bits);
1234
379
        scalar_decompress(out, bits);
1235
379
        scalar_ntt(out);
1236
379
    }
1237
132
}
1238
1239
/* vector_decode(), specialised to bits == 12. */
1240
static __owur int vector_decode_12(scalar *out, const uint8_t in[3 * DEGREE / 2], int rank)
1241
671
{
1242
671
    int stride = 3 * DEGREE / 2;
1243
1244
1.66k
    for (; rank-- > 0; in += stride)
1245
1.30k
        if (!scalar_decode_12(out++, in))
1246
314
            return 0;
1247
357
    return 1;
1248
671
}
1249
1250
/* In-place compression of each scalar component */
1251
static void vector_compress(scalar *a, int bits, int rank)
1252
288
{
1253
830
    do {
1254
830
        scalar_compress(a++, bits);
1255
830
    } while (--rank > 0);
1256
288
}
1257
1258
/* The output scalar must not overlap with the inputs */
1259
static void inner_product(scalar *out, const scalar *lhs, const scalar *rhs,
1260
    int rank)
1261
420
{
1262
420
    scalar_mult(out, lhs, rhs);
1263
1.20k
    while (--rank > 0)
1264
789
        scalar_mult_add(out, ++lhs, ++rhs);
1265
420
}
1266
1267
/*
1268
 * Here, the output vector must not overlap with the inputs, the result is
1269
 * directly subjected to inverse NTT.
1270
 */
1271
static void
1272
matrix_mult_intt(scalar *out, const scalar *m, const scalar *a, int rank)
1273
288
{
1274
288
    const scalar *ar;
1275
288
    int i, j;
1276
1277
1.11k
    for (i = rank; i-- > 0; ++out) {
1278
830
        scalar_mult(out, m++, ar = a);
1279
2.57k
        for (j = rank - 1; j > 0; --j)
1280
1.74k
            scalar_mult_add(out, m++, ++ar);
1281
830
        scalar_inverse_ntt(out);
1282
830
    }
1283
288
}
1284
1285
/* Here, the output vector must not overlap with the inputs */
1286
static void
1287
matrix_mult_transpose_add(scalar *out, const scalar *m, const scalar *a, int rank)
1288
51.1k
{
1289
51.1k
    const scalar *mc = m, *mr, *ar;
1290
51.1k
    int i, j;
1291
1292
204k
    for (i = rank; i-- > 0; ++out) {
1293
153k
        scalar_mult_add(out, mr = mc++, ar = a);
1294
460k
        for (j = rank; --j > 0;)
1295
306k
            scalar_mult_add(out, (mr += rank), ++ar);
1296
153k
    }
1297
51.1k
}
1298
1299
/*-
1300
 * Expands the matrix from a seed for key generation and for encaps-CPA.
1301
 * NOTE: FIPS 203 matrix "A" is the transpose of this matrix, computed
1302
 * by appending the (i,j) indices to the seed in the opposite order!
1303
 *
1304
 * Where FIPS 203 computes t = A * s + e, we use the transpose of "m".
1305
 */
1306
static __owur int matrix_expand(EVP_MD_CTX *mdctx, ML_KEM_KEY *key)
1307
51.5k
{
1308
51.5k
    scalar *out = key->m;
1309
51.5k
    uint8_t input[ML_KEM_RANDOM_BYTES + 2];
1310
51.5k
    int rank = key->vinfo->rank;
1311
51.5k
    int i, j;
1312
1313
51.5k
    memcpy(input, key->rho, ML_KEM_RANDOM_BYTES);
1314
205k
    for (i = 0; i < rank; i++) {
1315
617k
        for (j = 0; j < rank; j++) {
1316
462k
            input[ML_KEM_RANDOM_BYTES] = i;
1317
462k
            input[ML_KEM_RANDOM_BYTES + 1] = j;
1318
462k
            if (!EVP_DigestInit_ex(mdctx, key->shake128_md, NULL)
1319
462k
                || !EVP_DigestUpdate(mdctx, input, sizeof(input))
1320
462k
                || !sample_scalar(out++, mdctx))
1321
0
                return 0;
1322
462k
        }
1323
154k
    }
1324
51.5k
    return 1;
1325
51.5k
}
1326
1327
/*
1328
 * Algorithm 7 from the spec, with eta fixed to two and the PRF call
1329
 * included. Creates binominally distributed elements by sampling 2*|eta| bits,
1330
 * and setting the coefficient to the count of the first bits minus the count of
1331
 * the second bits, resulting in a centered binomial distribution. Since eta is
1332
 * two this gives -2/2 with a probability of 1/16, -1/1 with probability 1/4,
1333
 * and 0 with probability 3/8.
1334
 */
1335
static __owur int cbd_2(scalar *out, uint8_t in[ML_KEM_RANDOM_BYTES + 1],
1336
    EVP_MD_CTX *mdctx, const ML_KEM_KEY *key)
1337
307k
{
1338
307k
    uint16_t *curr = out->c, *end = curr + DEGREE;
1339
307k
    uint8_t randbuf[4 * DEGREE / 8], *r = randbuf; /* 64 * eta slots */
1340
307k
    uint16_t value, mask;
1341
307k
    uint8_t b;
1342
1343
307k
    if (!prf(randbuf, sizeof(randbuf), in, mdctx, key))
1344
0
        return 0;
1345
1346
39.3M
    do {
1347
39.3M
        b = *r++;
1348
1349
        /*
1350
         * Add |kPrime| if |value| underflowed.  See |constish_time_non_zero|
1351
         * for a discussion on why the value barrier is by default omitted.
1352
         * While this could have been written reduce_once(value + kPrime), this
1353
         * is one extra addition and small range of |value| tempts some
1354
         * versions of Clang to emit a branch.
1355
         */
1356
39.3M
        value = bit0(b) + bitn(1, b);
1357
39.3M
        value -= bitn(2, b) + bitn(3, b);
1358
39.3M
        mask = constish_time_non_zero(value >> 15);
1359
39.3M
        *curr++ = value + (kPrime & mask);
1360
1361
39.3M
        value = bitn(4, b) + bitn(5, b);
1362
39.3M
        value -= bitn(6, b) + bitn(7, b);
1363
39.3M
        mask = constish_time_non_zero(value >> 15);
1364
39.3M
        *curr++ = value + (kPrime & mask);
1365
39.3M
    } while (curr < end);
1366
307k
    return 1;
1367
307k
}
1368
1369
/*
1370
 * Algorithm 7 from the spec, with eta fixed to three and the PRF call
1371
 * included. Creates binominally distributed elements by sampling 3*|eta| bits,
1372
 * and setting the coefficient to the count of the first bits minus the count of
1373
 * the second bits, resulting in a centered binomial distribution.
1374
 */
1375
static __owur int cbd_3(scalar *out, uint8_t in[ML_KEM_RANDOM_BYTES + 1],
1376
    EVP_MD_CTX *mdctx, const ML_KEM_KEY *key)
1377
1.19k
{
1378
1.19k
    uint16_t *curr = out->c, *end = curr + DEGREE;
1379
1.19k
    uint8_t randbuf[6 * DEGREE / 8], *r = randbuf; /* 64 * eta slots */
1380
1.19k
    uint8_t b1, b2, b3;
1381
1.19k
    uint16_t value, mask;
1382
1383
1.19k
    if (!prf(randbuf, sizeof(randbuf), in, mdctx, key))
1384
0
        return 0;
1385
1386
76.2k
    do {
1387
76.2k
        b1 = *r++;
1388
76.2k
        b2 = *r++;
1389
76.2k
        b3 = *r++;
1390
1391
        /*
1392
         * Add |kPrime| if |value| underflowed.  See |constish_time_non_zero|
1393
         * for a discussion on why the value barrier is by default omitted.
1394
         * While this could have been written reduce_once(value + kPrime), this
1395
         * is one extra addition and small range of |value| tempts some
1396
         * versions of Clang to emit a branch.
1397
         */
1398
76.2k
        value = bit0(b1) + bitn(1, b1) + bitn(2, b1);
1399
76.2k
        value -= bitn(3, b1) + bitn(4, b1) + bitn(5, b1);
1400
76.2k
        mask = constish_time_non_zero(value >> 15);
1401
76.2k
        *curr++ = value + (kPrime & mask);
1402
1403
76.2k
        value = bitn(6, b1) + bitn(7, b1) + bit0(b2);
1404
76.2k
        value -= bitn(1, b2) + bitn(2, b2) + bitn(3, b2);
1405
76.2k
        mask = constish_time_non_zero(value >> 15);
1406
76.2k
        *curr++ = value + (kPrime & mask);
1407
1408
76.2k
        value = bitn(4, b2) + bitn(5, b2) + bitn(6, b2);
1409
76.2k
        value -= bitn(7, b2) + bit0(b3) + bitn(1, b3);
1410
76.2k
        mask = constish_time_non_zero(value >> 15);
1411
76.2k
        *curr++ = value + (kPrime & mask);
1412
1413
76.2k
        value = bitn(2, b3) + bitn(3, b3) + bitn(4, b3);
1414
76.2k
        value -= bitn(5, b3) + bitn(6, b3) + bitn(7, b3);
1415
76.2k
        mask = constish_time_non_zero(value >> 15);
1416
76.2k
        *curr++ = value + (kPrime & mask);
1417
76.2k
    } while (curr < end);
1418
1.19k
    return 1;
1419
1.19k
}
1420
1421
/*
1422
 * Generates a secret vector by using |cbd| with the given seed to generate
1423
 * scalar elements and incrementing |counter| for each slot of the vector.
1424
 */
1425
static __owur int gencbd_vector(scalar *out, CBD_FUNC cbd, uint8_t *counter,
1426
    const uint8_t seed[ML_KEM_RANDOM_BYTES], int rank,
1427
    EVP_MD_CTX *mdctx, const ML_KEM_KEY *key)
1428
288
{
1429
288
    uint8_t input[ML_KEM_RANDOM_BYTES + 1];
1430
1431
288
    memcpy(input, seed, ML_KEM_RANDOM_BYTES);
1432
830
    do {
1433
830
        input[ML_KEM_RANDOM_BYTES] = (*counter)++;
1434
830
        if (!cbd(out++, input, mdctx, key))
1435
0
            return 0;
1436
830
    } while (--rank > 0);
1437
288
    return 1;
1438
288
}
1439
1440
/*
1441
 * As above plus NTT transform.
1442
 */
1443
static __owur int gencbd_vector_ntt(scalar *out, CBD_FUNC cbd, uint8_t *counter,
1444
    const uint8_t seed[ML_KEM_RANDOM_BYTES], int rank,
1445
    EVP_MD_CTX *mdctx, const ML_KEM_KEY *key)
1446
102k
{
1447
102k
    uint8_t input[ML_KEM_RANDOM_BYTES + 1];
1448
1449
102k
    memcpy(input, seed, ML_KEM_RANDOM_BYTES);
1450
307k
    do {
1451
307k
        input[ML_KEM_RANDOM_BYTES] = (*counter)++;
1452
307k
        if (!cbd(out, input, mdctx, key))
1453
0
            return 0;
1454
307k
        scalar_ntt(out++);
1455
307k
    } while (--rank > 0);
1456
102k
    return 1;
1457
102k
}
1458
1459
/* The |ETA1| value for ML-KEM-512 is 3, the rest and all ETA2 values are 2. */
1460
34.8k
#define CBD1(evp_type) ((evp_type) == EVP_PKEY_ML_KEM_512 ? cbd_3 : cbd_2)
1461
1462
/*
1463
 * FIPS 203, Section 5.2, Algorithm 14: K-PKE.Encrypt.
1464
 *
1465
 * Encrypts a message with given randomness to the ciphertext in |out|. Without
1466
 * applying the Fujisaki-Okamoto transform this would not result in a CCA
1467
 * secure scheme, since lattice schemes are vulnerable to decryption failure
1468
 * oracles.
1469
 *
1470
 * The steps are re-ordered to make more efficient/localised use of storage.
1471
 *
1472
 * Note also that the input public key is assumed to hold a precomputed matrix
1473
 * |A| (our key->m, with the public key holding an expanded (16-bit per scalar
1474
 * coefficient) key->t vector).
1475
 *
1476
 * Caller passes storage in |tmp| for for two temporary vectors.
1477
 */
1478
static __owur int encrypt_cpa(uint8_t out[ML_KEM_SHARED_SECRET_BYTES],
1479
    const uint8_t message[DEGREE / 8],
1480
    const uint8_t r[ML_KEM_RANDOM_BYTES], scalar *tmp,
1481
    EVP_MD_CTX *mdctx, const ML_KEM_KEY *key)
1482
288
{
1483
288
    const ML_KEM_VINFO *vinfo = key->vinfo;
1484
288
    CBD_FUNC cbd_1 = CBD1(vinfo->evp_type);
1485
288
    int rank = vinfo->rank;
1486
    /* We can use tmp[0..rank-1] as storage for |y|, then |e1|, ... */
1487
288
    scalar *y = &tmp[0], *e1 = y, *e2 = y;
1488
    /* We can use tmp[rank]..tmp[2*rank - 1] for |u| */
1489
288
    scalar *u = &tmp[rank];
1490
288
    scalar v;
1491
288
    uint8_t input[ML_KEM_RANDOM_BYTES + 1];
1492
288
    uint8_t counter = 0;
1493
288
    int du = vinfo->du;
1494
288
    int dv = vinfo->dv;
1495
1496
    /* FIPS 203 "y" vector */
1497
288
    if (!gencbd_vector_ntt(y, cbd_1, &counter, r, rank, mdctx, key))
1498
0
        return 0;
1499
    /* FIPS 203 "v" scalar */
1500
288
    inner_product(&v, key->t, y, rank);
1501
288
    scalar_inverse_ntt(&v);
1502
    /* FIPS 203 "u" vector */
1503
288
    matrix_mult_intt(u, key->m, y, rank);
1504
1505
    /* All done with |y|, now free to reuse tmp[0] for FIPS 203 |e1| */
1506
288
    if (!gencbd_vector(e1, cbd_2, &counter, r, rank, mdctx, key))
1507
0
        return 0;
1508
288
    vector_add(u, e1, rank);
1509
288
    vector_compress(u, du, rank);
1510
288
    vector_encode(out, u, du, rank);
1511
1512
    /* All done with |e1|, now free to reuse tmp[0] for FIPS 203 |e2| */
1513
288
    memcpy(input, r, ML_KEM_RANDOM_BYTES);
1514
288
    input[ML_KEM_RANDOM_BYTES] = counter;
1515
288
    if (!cbd_2(e2, input, mdctx, key))
1516
0
        return 0;
1517
288
    scalar_add(&v, e2);
1518
1519
    /* Combine message with |v| */
1520
288
    scalar_decode_decompress_add(&v, message);
1521
288
    scalar_compress(&v, dv);
1522
288
    scalar_encode(out + vinfo->u_vector_bytes, &v, dv);
1523
288
    return 1;
1524
288
}
1525
1526
/*
1527
 * FIPS 203, Section 5.3, Algorithm 15: K-PKE.Decrypt.
1528
 */
1529
static void
1530
decrypt_cpa(uint8_t out[ML_KEM_SHARED_SECRET_BYTES],
1531
    const uint8_t *ctext, scalar *u, const ML_KEM_KEY *key)
1532
132
{
1533
132
    const ML_KEM_VINFO *vinfo = key->vinfo;
1534
132
    scalar v, mask;
1535
132
    int rank = vinfo->rank;
1536
132
    int du = vinfo->du;
1537
132
    int dv = vinfo->dv;
1538
1539
132
    vector_decode_decompress_ntt(u, ctext, du, rank);
1540
132
    scalar_decode(&v, ctext + vinfo->u_vector_bytes, dv);
1541
132
    scalar_decompress(&v, dv);
1542
132
    inner_product(&mask, key->s, u, rank);
1543
132
    scalar_inverse_ntt(&mask);
1544
132
    scalar_sub(&v, &mask);
1545
132
    scalar_compress(&v, 1);
1546
132
    scalar_encode_1(out, &v);
1547
132
}
1548
1549
/*-
1550
 * FIPS 203, Section 7.1, Algorithm 19: "ML-KEM.KeyGen".
1551
 * FIPS 203, Section 7.2, Algorithm 20: "ML-KEM.Encaps".
1552
 *
1553
 * Fills the |out| buffer with the |ek| output of "ML-KEM.KeyGen", or,
1554
 * equivalently, the |ek| input of "ML-KEM.Encaps", i.e. returns the
1555
 * wire-format of an ML-KEM public key.
1556
 */
1557
static void encode_pubkey(uint8_t *out, const ML_KEM_KEY *key)
1558
51.1k
{
1559
51.1k
    const uint8_t *rho = key->rho;
1560
51.1k
    const ML_KEM_VINFO *vinfo = key->vinfo;
1561
1562
51.1k
    vector_encode(out, key->t, 12, vinfo->rank);
1563
51.1k
    memcpy(out + vinfo->vector_bytes, rho, ML_KEM_RANDOM_BYTES);
1564
51.1k
}
1565
1566
/*-
1567
 * FIPS 203, Section 7.1, Algorithm 19: "ML-KEM.KeyGen".
1568
 *
1569
 * Fills the |out| buffer with the |dk| output of "ML-KEM.KeyGen".
1570
 * This matches the input format of parse_prvkey() below.
1571
 */
1572
static void encode_prvkey(uint8_t *out, const ML_KEM_KEY *key)
1573
116
{
1574
116
    const ML_KEM_VINFO *vinfo = key->vinfo;
1575
1576
116
    vector_encode(out, key->s, 12, vinfo->rank);
1577
116
    out += vinfo->vector_bytes;
1578
116
    encode_pubkey(out, key);
1579
116
    out += vinfo->pubkey_bytes;
1580
116
    memcpy(out, key->pkhash, ML_KEM_PKHASH_BYTES);
1581
116
    out += ML_KEM_PKHASH_BYTES;
1582
116
    memcpy(out, key->z, ML_KEM_RANDOM_BYTES);
1583
116
}
1584
1585
/*-
1586
 * FIPS 203, Section 7.1, Algorithm 19: "ML-KEM.KeyGen".
1587
 * FIPS 203, Section 7.2, Algorithm 20: "ML-KEM.Encaps".
1588
 *
1589
 * This function parses the |in| buffer as the |ek| output of "ML-KEM.KeyGen",
1590
 * or, equivalently, the |ek| input of "ML-KEM.Encaps", i.e. decodes the
1591
 * wire-format of the ML-KEM public key.
1592
 */
1593
static int parse_pubkey(const uint8_t *in, EVP_MD_CTX *mdctx, ML_KEM_KEY *key)
1594
562
{
1595
562
    const ML_KEM_VINFO *vinfo = key->vinfo;
1596
1597
    /* Decode and check |t| */
1598
562
    if (!vector_decode_12(key->t, in, vinfo->rank)) {
1599
227
        ERR_raise_data(ERR_LIB_PROV, PROV_R_INVALID_KEY,
1600
227
            "%s invalid public 't' vector",
1601
227
            vinfo->algorithm_name);
1602
227
        return 0;
1603
227
    }
1604
    /* Save the matrix |m| recovery seed |rho| */
1605
335
    memcpy(key->rho, in + vinfo->vector_bytes, ML_KEM_RANDOM_BYTES);
1606
    /*
1607
     * Pre-compute the public key hash, needed for both encap and decap.
1608
     * Also pre-compute the matrix expansion, stored with the public key.
1609
     */
1610
335
    if (!hash_h(key->pkhash, in, vinfo->pubkey_bytes, mdctx, key)
1611
335
        || !matrix_expand(mdctx, key)) {
1612
0
        ERR_raise_data(ERR_LIB_CRYPTO, ERR_R_INTERNAL_ERROR,
1613
0
            "internal error while parsing %s public key",
1614
0
            vinfo->algorithm_name);
1615
0
        return 0;
1616
0
    }
1617
335
    return 1;
1618
335
}
1619
1620
/*
1621
 * FIPS 203, Section 7.1, Algorithm 19: "ML-KEM.KeyGen".
1622
 *
1623
 * Parses the |in| buffer as a |dk| output of "ML-KEM.KeyGen".
1624
 * This matches the output format of encode_prvkey() above.
1625
 */
1626
static int parse_prvkey(const uint8_t *in, EVP_MD_CTX *mdctx, ML_KEM_KEY *key)
1627
109
{
1628
109
    const ML_KEM_VINFO *vinfo = key->vinfo;
1629
1630
    /* Decode and check |s|. */
1631
109
    if (!vector_decode_12(key->s, in, vinfo->rank)) {
1632
87
        ERR_raise_data(ERR_LIB_PROV, PROV_R_INVALID_KEY,
1633
87
            "%s invalid private 's' vector",
1634
87
            vinfo->algorithm_name);
1635
87
        return 0;
1636
87
    }
1637
22
    in += vinfo->vector_bytes;
1638
1639
22
    if (!parse_pubkey(in, mdctx, key))
1640
13
        return 0;
1641
9
    in += vinfo->pubkey_bytes;
1642
1643
    /* Check public key hash. */
1644
9
    if (memcmp(key->pkhash, in, ML_KEM_PKHASH_BYTES) != 0) {
1645
9
        ERR_raise_data(ERR_LIB_PROV, PROV_R_INVALID_KEY,
1646
9
            "%s public key hash mismatch",
1647
9
            vinfo->algorithm_name);
1648
9
        return 0;
1649
9
    }
1650
0
    in += ML_KEM_PKHASH_BYTES;
1651
1652
0
    memcpy(key->z, in, ML_KEM_RANDOM_BYTES);
1653
0
    return 1;
1654
9
}
1655
1656
/*
1657
 * FIPS 203, Section 6.1, Algorithm 16: "ML-KEM.KeyGen_internal".
1658
 *
1659
 * The implementation of Section 5.1, Algorithm 13, "K-PKE.KeyGen(d)" is
1660
 * inlined.
1661
 *
1662
 * The caller MUST pass a pre-allocated digest context that is not shared with
1663
 * any concurrent computation.
1664
 *
1665
 * This function optionally outputs the serialised wire-form |ek| public key
1666
 * into the provided |pubenc| buffer, and generates the content of the |rho|,
1667
 * |pkhash|, |t|, |m|, |s| and |z| components of the private |key| (which must
1668
 * have preallocated space for these).
1669
 *
1670
 * Keys are computed from a 32-byte random |d| plus the 1 byte rank for
1671
 * domain separation.  These are concatenated and hashed to produce a pair of
1672
 * 32-byte seeds public "rho", used to generate the matrix, and private "sigma",
1673
 * used to generate the secret vector |s|.
1674
 *
1675
 * The second random input |z| is copied verbatim into the Fujisaki-Okamoto
1676
 * (FO) transform "implicit-rejection" secret (the |z| component of the private
1677
 * key), which thwarts chosen-ciphertext attacks, provided decap() runs in
1678
 * constant time, with no side channel leaks, on all well-formed (valid length,
1679
 * and correctly encoded) ciphertext inputs.
1680
 */
1681
static __owur int genkey(const uint8_t seed[ML_KEM_SEED_BYTES],
1682
    EVP_MD_CTX *mdctx, uint8_t *pubenc, ML_KEM_KEY *key)
1683
34.6k
{
1684
34.6k
    uint8_t hashed[2 * ML_KEM_RANDOM_BYTES];
1685
34.6k
    const uint8_t *const sigma = hashed + ML_KEM_RANDOM_BYTES;
1686
34.6k
    uint8_t augmented_seed[ML_KEM_RANDOM_BYTES + 1];
1687
34.6k
    const ML_KEM_VINFO *vinfo = key->vinfo;
1688
34.6k
    CBD_FUNC cbd_1 = CBD1(vinfo->evp_type);
1689
34.6k
    int rank = vinfo->rank;
1690
34.6k
    uint8_t counter = 0;
1691
34.6k
    int ret = 0;
1692
1693
    /*
1694
     * Use the "d" seed salted with the rank to derive the public and private
1695
     * seeds rho and sigma.
1696
     */
1697
34.6k
    memcpy(augmented_seed, seed, ML_KEM_RANDOM_BYTES);
1698
34.6k
    augmented_seed[ML_KEM_RANDOM_BYTES] = (uint8_t)rank;
1699
34.6k
    if (!hash_g(hashed, augmented_seed, sizeof(augmented_seed), mdctx, key))
1700
0
        goto end;
1701
34.6k
    memcpy(key->rho, hashed, ML_KEM_RANDOM_BYTES);
1702
    /* The |rho| matrix seed is public */
1703
34.6k
    CONSTTIME_DECLASSIFY(key->rho, ML_KEM_RANDOM_BYTES);
1704
1705
    /* FIPS 203 |e| vector is initial value of key->t */
1706
34.6k
    if (!matrix_expand(mdctx, key)
1707
34.6k
        || !gencbd_vector_ntt(key->s, cbd_1, &counter, sigma, rank, mdctx, key)
1708
34.6k
        || !gencbd_vector_ntt(key->t, cbd_1, &counter, sigma, rank, mdctx, key))
1709
0
        goto end;
1710
1711
    /* To |e| we now add the product of transpose |m| and |s|, giving |t|. */
1712
34.6k
    matrix_mult_transpose_add(key->t, key->m, key->s, rank);
1713
    /* The |t| vector is public */
1714
34.6k
    CONSTTIME_DECLASSIFY(key->t, vinfo->rank * sizeof(scalar));
1715
1716
34.6k
    if (pubenc == NULL) {
1717
        /* Incremental digest of public key without in-full serialisation. */
1718
34.6k
        if (!hash_h_pubkey(key->pkhash, mdctx, key))
1719
0
            goto end;
1720
34.6k
    } else {
1721
0
        encode_pubkey(pubenc, key);
1722
0
        if (!hash_h(key->pkhash, pubenc, vinfo->pubkey_bytes, mdctx, key))
1723
0
            goto end;
1724
0
    }
1725
1726
    /* Save |z| portion of seed for "implicit rejection" on failure. */
1727
34.6k
    memcpy(key->z, seed + ML_KEM_RANDOM_BYTES, ML_KEM_RANDOM_BYTES);
1728
1729
    /* Optionally save the |d| portion of the seed */
1730
34.6k
    key->d = key->z + ML_KEM_RANDOM_BYTES;
1731
34.6k
    if (key->prov_flags & ML_KEM_KEY_RETAIN_SEED) {
1732
34.6k
        memcpy(key->d, seed, ML_KEM_RANDOM_BYTES);
1733
34.6k
    } else {
1734
0
        OPENSSL_cleanse(key->d, ML_KEM_RANDOM_BYTES);
1735
0
        key->d = NULL;
1736
0
    }
1737
1738
34.6k
    ret = 1;
1739
34.6k
end:
1740
34.6k
    OPENSSL_cleanse((void *)augmented_seed, ML_KEM_RANDOM_BYTES);
1741
34.6k
    OPENSSL_cleanse((void *)sigma, ML_KEM_RANDOM_BYTES);
1742
34.6k
    if (ret == 0) {
1743
0
        ERR_raise_data(ERR_LIB_CRYPTO, ERR_R_INTERNAL_ERROR,
1744
0
            "internal error while generating %s private key",
1745
0
            vinfo->algorithm_name);
1746
0
    }
1747
34.6k
    return ret;
1748
34.6k
}
1749
1750
/*-
1751
 * FIPS 203, Section 6.2, Algorithm 17: "ML-KEM.Encaps_internal".
1752
 * This is the deterministic version with randomness supplied externally.
1753
 *
1754
 * The caller must pass space for two vectors in |tmp|.
1755
 * The |ctext| buffer have space for the ciphertext of the ML-KEM variant
1756
 * of the provided key.
1757
 */
1758
static int encap(uint8_t *ctext, uint8_t secret[ML_KEM_SHARED_SECRET_BYTES],
1759
    const uint8_t entropy[ML_KEM_RANDOM_BYTES],
1760
    scalar *tmp, EVP_MD_CTX *mdctx, const ML_KEM_KEY *key)
1761
156
{
1762
156
    uint8_t input[ML_KEM_RANDOM_BYTES + ML_KEM_PKHASH_BYTES];
1763
156
    uint8_t Kr[ML_KEM_SHARED_SECRET_BYTES + ML_KEM_RANDOM_BYTES];
1764
156
    uint8_t *r = Kr + ML_KEM_SHARED_SECRET_BYTES;
1765
156
    int ret;
1766
1767
156
    memcpy(input, entropy, ML_KEM_RANDOM_BYTES);
1768
156
    memcpy(input + ML_KEM_RANDOM_BYTES, key->pkhash, ML_KEM_PKHASH_BYTES);
1769
156
    ret = hash_g(Kr, input, sizeof(input), mdctx, key)
1770
156
        && encrypt_cpa(ctext, entropy, r, tmp, mdctx, key);
1771
156
    OPENSSL_cleanse((void *)input, sizeof(input));
1772
1773
156
    if (ret)
1774
156
        memcpy(secret, Kr, ML_KEM_SHARED_SECRET_BYTES);
1775
0
    else
1776
0
        ERR_raise_data(ERR_LIB_CRYPTO, ERR_R_INTERNAL_ERROR,
1777
0
            "internal error while performing %s encapsulation",
1778
0
            key->vinfo->algorithm_name);
1779
156
    return ret;
1780
156
}
1781
1782
/*
1783
 * FIPS 203, Section 6.3, Algorithm 18: ML-KEM.Decaps_internal
1784
 *
1785
 * Barring failure of the supporting SHA3/SHAKE primitives, this is fully
1786
 * deterministic, the randomness for the FO transform is extracted during
1787
 * private key generation.
1788
 *
1789
 * The caller must pass space for two vectors in |tmp|.
1790
 * The |ctext| and |tmp_ctext| buffers must each have space for the ciphertext
1791
 * of the key's ML-KEM variant.
1792
 */
1793
static int decap(uint8_t secret[ML_KEM_SHARED_SECRET_BYTES],
1794
    const uint8_t *ctext, uint8_t *tmp_ctext, scalar *tmp,
1795
    EVP_MD_CTX *mdctx, const ML_KEM_KEY *key)
1796
86
{
1797
86
    uint8_t decrypted[ML_KEM_SHARED_SECRET_BYTES + ML_KEM_PKHASH_BYTES];
1798
86
    uint8_t failure_key[ML_KEM_RANDOM_BYTES];
1799
86
    uint8_t Kr[ML_KEM_SHARED_SECRET_BYTES + ML_KEM_RANDOM_BYTES];
1800
86
    uint8_t *r = Kr + ML_KEM_SHARED_SECRET_BYTES;
1801
86
    const uint8_t *pkhash = key->pkhash;
1802
86
    const ML_KEM_VINFO *vinfo = key->vinfo;
1803
86
    int i;
1804
86
    uint8_t mask;
1805
1806
    /*
1807
     * If our KDF is unavailable, fail early! Otherwise, keep going ignoring
1808
     * any further errors, returning success, and whatever we got for a shared
1809
     * secret.  The decrypt_cpa() function is just arithmetic on secret data,
1810
     * so should not be subject to failure that makes its output predictable.
1811
     *
1812
     * We guard against "should never happen" catastrophic failure of the
1813
     * "pure" function |hash_g| by overwriting the shared secret with the
1814
     * content of the failure key and returning early, if nevertheless hash_g
1815
     * fails.  This is not constant-time, but a failure of |hash_g| already
1816
     * implies loss of side-channel resistance.
1817
     *
1818
     * The same action is taken, if also |encrypt_cpa| should catastrophically
1819
     * fail, due to failure of the |PRF| underlying the CBD functions.
1820
     */
1821
86
    if (!kdf(failure_key, key->z, ctext, vinfo->ctext_bytes, mdctx, key)) {
1822
0
        ERR_raise_data(ERR_LIB_CRYPTO, ERR_R_INTERNAL_ERROR,
1823
0
            "internal error while performing %s decapsulation",
1824
0
            vinfo->algorithm_name);
1825
0
        return 0;
1826
0
    }
1827
86
    decrypt_cpa(decrypted, ctext, tmp, key);
1828
86
    memcpy(decrypted + ML_KEM_SHARED_SECRET_BYTES, pkhash, ML_KEM_PKHASH_BYTES);
1829
86
    if (!hash_g(Kr, decrypted, sizeof(decrypted), mdctx, key)
1830
86
        || !encrypt_cpa(tmp_ctext, decrypted, r, tmp, mdctx, key)) {
1831
0
        memcpy(secret, failure_key, ML_KEM_SHARED_SECRET_BYTES);
1832
0
        OPENSSL_cleanse(decrypted, ML_KEM_SHARED_SECRET_BYTES);
1833
0
        return 1;
1834
0
    }
1835
86
    mask = constant_time_eq_int_8(0,
1836
86
        CRYPTO_memcmp(ctext, tmp_ctext, vinfo->ctext_bytes));
1837
2.83k
    for (i = 0; i < ML_KEM_SHARED_SECRET_BYTES; i++)
1838
2.75k
        secret[i] = constant_time_select_8(mask, Kr[i], failure_key[i]);
1839
86
    OPENSSL_cleanse(decrypted, ML_KEM_SHARED_SECRET_BYTES);
1840
86
    OPENSSL_cleanse(Kr, sizeof(Kr));
1841
86
    return 1;
1842
86
}
1843
1844
/*
1845
 * After allocating storage for public or private key data, update the key
1846
 * component pointers to reference that storage.
1847
 */
1848
static __owur int add_storage(scalar *p, int private, ML_KEM_KEY *key)
1849
17.5k
{
1850
17.5k
    int rank = key->vinfo->rank;
1851
1852
17.5k
    if (p == NULL)
1853
0
        return 0;
1854
1855
    /*
1856
     * We're adding key material, the seed buffer will now hold |rho| and
1857
     * |pkhash|.
1858
     */
1859
17.5k
    memset(key->seedbuf, 0, sizeof(key->seedbuf));
1860
17.5k
    key->rho = key->seedbuf;
1861
17.5k
    key->pkhash = key->seedbuf + ML_KEM_RANDOM_BYTES;
1862
17.5k
    key->d = key->z = NULL;
1863
1864
    /* A public key needs space for |t| and |m| */
1865
17.5k
    key->m = (key->t = p) + rank;
1866
1867
    /*
1868
     * A private key also needs space for |s| and |z|.
1869
     * The |z| buffer always includes additional space for |d|, but a key's |d|
1870
     * pointer is left NULL when parsed from the NIST format, which omits that
1871
     * information.  Only keys generated from a (d, z) seed pair will have a
1872
     * non-NULL |d| pointer.
1873
     */
1874
17.5k
    if (private)
1875
17.3k
        key->z = (uint8_t *)(rank + (key->s = key->m + rank * rank));
1876
17.5k
    return 1;
1877
17.5k
}
1878
1879
/*
1880
 * After freeing the storage associated with a key that failed to be
1881
 * constructed, reset the internal pointers back to NULL.
1882
 */
1883
void ossl_ml_kem_key_reset(ML_KEM_KEY *key)
1884
17.6k
{
1885
17.6k
    if (key->t == NULL)
1886
144
        return;
1887
    /*-
1888
     * Cleanse any sensitive data:
1889
     * - The private vector |s| is immediately followed by the FO failure
1890
     *   secret |z|, and seed |d|, we can cleanse all three in one call.
1891
     *
1892
     * - Otherwise, when key->d is set, cleanse the stashed seed.
1893
     */
1894
17.5k
    if (ossl_ml_kem_have_prvkey(key))
1895
17.3k
        OPENSSL_cleanse(key->s,
1896
17.3k
            key->vinfo->rank * sizeof(scalar) + 2 * ML_KEM_RANDOM_BYTES);
1897
17.5k
    OPENSSL_free(key->t);
1898
17.5k
    key->d = key->z = (uint8_t *)(key->s = key->m = key->t = NULL);
1899
17.5k
}
1900
1901
/*
1902
 * ----- API exported to the provider
1903
 *
1904
 * Parameters with an implicit fixed length in the internal static API of each
1905
 * variant have an explicit checked length argument at this layer.
1906
 */
1907
1908
/* Retrieve the parameters of one of the ML-KEM variants */
1909
const ML_KEM_VINFO *ossl_ml_kem_get_vinfo(int evp_type)
1910
271k
{
1911
271k
    switch (evp_type) {
1912
60.2k
    case EVP_PKEY_ML_KEM_512:
1913
60.2k
        return &vinfo_map[ML_KEM_512_VINFO];
1914
154k
    case EVP_PKEY_ML_KEM_768:
1915
154k
        return &vinfo_map[ML_KEM_768_VINFO];
1916
56.3k
    case EVP_PKEY_ML_KEM_1024:
1917
56.3k
        return &vinfo_map[ML_KEM_1024_VINFO];
1918
271k
    }
1919
0
    return NULL;
1920
271k
}
1921
1922
ML_KEM_KEY *ossl_ml_kem_key_new(OSSL_LIB_CTX *libctx, const char *properties,
1923
    int evp_type)
1924
35.0k
{
1925
35.0k
    const ML_KEM_VINFO *vinfo = ossl_ml_kem_get_vinfo(evp_type);
1926
35.0k
    ML_KEM_KEY *key;
1927
1928
35.0k
    if (vinfo == NULL) {
1929
0
        ERR_raise_data(ERR_LIB_CRYPTO, ERR_R_PASSED_INVALID_ARGUMENT,
1930
0
            "unsupported ML-KEM key type: %d", evp_type);
1931
0
        return NULL;
1932
0
    }
1933
1934
35.0k
    if ((key = OPENSSL_malloc(sizeof(*key))) == NULL)
1935
0
        return NULL;
1936
1937
35.0k
    key->vinfo = vinfo;
1938
35.0k
    key->libctx = libctx;
1939
35.0k
    key->prov_flags = ML_KEM_KEY_PROV_FLAGS_DEFAULT;
1940
35.0k
    key->shake128_md = EVP_MD_fetch(libctx, "SHAKE128", properties);
1941
35.0k
    key->shake256_md = EVP_MD_fetch(libctx, "SHAKE256", properties);
1942
35.0k
    key->sha3_256_md = EVP_MD_fetch(libctx, "SHA3-256", properties);
1943
35.0k
    key->sha3_512_md = EVP_MD_fetch(libctx, "SHA3-512", properties);
1944
35.0k
    key->d = key->z = key->rho = key->pkhash = key->encoded_dk = NULL;
1945
35.0k
    key->s = key->m = key->t = NULL;
1946
1947
35.0k
    if (key->shake128_md != NULL
1948
35.0k
        && key->shake256_md != NULL
1949
35.0k
        && key->sha3_256_md != NULL
1950
35.0k
        && key->sha3_512_md != NULL)
1951
35.0k
        return key;
1952
1953
0
    ossl_ml_kem_key_free(key);
1954
0
    ERR_raise_data(ERR_LIB_CRYPTO, ERR_R_INTERNAL_ERROR,
1955
0
        "missing SHA3 digest algorithms while creating %s key",
1956
0
        vinfo->algorithm_name);
1957
0
    return NULL;
1958
35.0k
}
1959
1960
ML_KEM_KEY *ossl_ml_kem_key_dup(const ML_KEM_KEY *key, int selection)
1961
27
{
1962
27
    int ok = 0;
1963
27
    ML_KEM_KEY *ret;
1964
1965
    /*
1966
     * Partially decoded keys, not yet imported or loaded, should never be
1967
     * duplicated.
1968
     */
1969
27
    if (ossl_ml_kem_decoded_key(key))
1970
0
        return NULL;
1971
1972
27
    if (key == NULL
1973
27
        || (ret = OPENSSL_memdup(key, sizeof(*key))) == NULL)
1974
0
        return NULL;
1975
27
    ret->d = ret->z = ret->rho = ret->pkhash = NULL;
1976
27
    ret->s = ret->m = ret->t = NULL;
1977
1978
    /* Clear selection bits we can't fulfill */
1979
27
    if (!ossl_ml_kem_have_pubkey(key))
1980
0
        selection = 0;
1981
27
    else if (!ossl_ml_kem_have_prvkey(key))
1982
27
        selection &= ~OSSL_KEYMGMT_SELECT_PRIVATE_KEY;
1983
1984
27
    switch (selection & OSSL_KEYMGMT_SELECT_KEYPAIR) {
1985
0
    case 0:
1986
0
        ok = 1;
1987
0
        break;
1988
27
    case OSSL_KEYMGMT_SELECT_PUBLIC_KEY:
1989
27
        ok = add_storage(OPENSSL_memdup(key->t, key->vinfo->puballoc), 0, ret);
1990
27
        ret->rho = ret->seedbuf;
1991
27
        ret->pkhash = ret->rho + ML_KEM_RANDOM_BYTES;
1992
27
        break;
1993
0
    case OSSL_KEYMGMT_SELECT_PRIVATE_KEY:
1994
0
        ok = add_storage(OPENSSL_memdup(key->t, key->vinfo->prvalloc), 1, ret);
1995
        /* Duplicated keys retain |d|, if available */
1996
0
        if (key->d != NULL)
1997
0
            ret->d = ret->z + ML_KEM_RANDOM_BYTES;
1998
0
        break;
1999
27
    }
2000
2001
27
    if (!ok) {
2002
0
        OPENSSL_free(ret);
2003
0
        return NULL;
2004
0
    }
2005
2006
27
    EVP_MD_up_ref(ret->shake128_md);
2007
27
    EVP_MD_up_ref(ret->shake256_md);
2008
27
    EVP_MD_up_ref(ret->sha3_256_md);
2009
27
    EVP_MD_up_ref(ret->sha3_512_md);
2010
2011
27
    return ret;
2012
27
}
2013
2014
void ossl_ml_kem_key_free(ML_KEM_KEY *key)
2015
89.6k
{
2016
89.6k
    if (key == NULL)
2017
72.0k
        return;
2018
2019
17.5k
    EVP_MD_free(key->shake128_md);
2020
17.5k
    EVP_MD_free(key->shake256_md);
2021
17.5k
    EVP_MD_free(key->sha3_256_md);
2022
17.5k
    EVP_MD_free(key->sha3_512_md);
2023
2024
17.5k
    if (ossl_ml_kem_decoded_key(key)) {
2025
0
        OPENSSL_cleanse(key->seedbuf, sizeof(key->seedbuf));
2026
0
        if (ossl_ml_kem_have_dkenc(key)) {
2027
0
            OPENSSL_cleanse(key->encoded_dk, key->vinfo->prvkey_bytes);
2028
0
            OPENSSL_free(key->encoded_dk);
2029
0
        }
2030
0
    }
2031
17.5k
    ossl_ml_kem_key_reset(key);
2032
17.5k
    OPENSSL_free(key);
2033
17.5k
}
2034
2035
/* Serialise the public component of an ML-KEM key */
2036
int ossl_ml_kem_encode_public_key(uint8_t *out, size_t len,
2037
    const ML_KEM_KEY *key)
2038
51.0k
{
2039
51.0k
    if (!ossl_ml_kem_have_pubkey(key)
2040
51.0k
        || len != key->vinfo->pubkey_bytes)
2041
0
        return 0;
2042
51.0k
    encode_pubkey(out, key);
2043
51.0k
    return 1;
2044
51.0k
}
2045
2046
/* Serialise an ML-KEM private key */
2047
int ossl_ml_kem_encode_private_key(uint8_t *out, size_t len,
2048
    const ML_KEM_KEY *key)
2049
116
{
2050
116
    if (!ossl_ml_kem_have_prvkey(key)
2051
116
        || len != key->vinfo->prvkey_bytes)
2052
0
        return 0;
2053
116
    encode_prvkey(out, key);
2054
116
    return 1;
2055
116
}
2056
2057
int ossl_ml_kem_encode_seed(uint8_t *out, size_t len,
2058
    const ML_KEM_KEY *key)
2059
214
{
2060
214
    if (key == NULL || key->d == NULL || len != ML_KEM_SEED_BYTES)
2061
46
        return 0;
2062
    /*
2063
     * Both in the seed buffer, and in the allocated storage, the |d| component
2064
     * of the seed is stored last, so we must copy each separately.
2065
     */
2066
168
    memcpy(out, key->d, ML_KEM_RANDOM_BYTES);
2067
168
    out += ML_KEM_RANDOM_BYTES;
2068
168
    memcpy(out, key->z, ML_KEM_RANDOM_BYTES);
2069
168
    return 1;
2070
214
}
2071
2072
/*
2073
 * Stash the seed without (yet) performing a keygen, used during decoding, to
2074
 * avoid an extra keygen if we're only going to export the key again to load
2075
 * into another provider.
2076
 */
2077
ML_KEM_KEY *ossl_ml_kem_set_seed(const uint8_t *seed, size_t seedlen, ML_KEM_KEY *key)
2078
28
{
2079
28
    if (key == NULL
2080
28
        || ossl_ml_kem_have_pubkey(key)
2081
28
        || ossl_ml_kem_have_seed(key)
2082
28
        || seedlen != ML_KEM_SEED_BYTES)
2083
0
        return NULL;
2084
    /*
2085
     * With no public or private key material on hand, we can use the seed
2086
     * buffer for |z| and |d|, in that order.
2087
     */
2088
28
    key->z = key->seedbuf;
2089
28
    key->d = key->z + ML_KEM_RANDOM_BYTES;
2090
28
    memcpy(key->d, seed, ML_KEM_RANDOM_BYTES);
2091
28
    seed += ML_KEM_RANDOM_BYTES;
2092
28
    memcpy(key->z, seed, ML_KEM_RANDOM_BYTES);
2093
28
    return key;
2094
28
}
2095
2096
/* Parse input as a public key */
2097
int ossl_ml_kem_parse_public_key(const uint8_t *in, size_t len, ML_KEM_KEY *key)
2098
540
{
2099
540
    EVP_MD_CTX *mdctx = NULL;
2100
540
    const ML_KEM_VINFO *vinfo;
2101
540
    int ret = 0;
2102
2103
    /* Keys with key material are immutable */
2104
540
    if (key == NULL
2105
540
        || ossl_ml_kem_have_pubkey(key)
2106
540
        || ossl_ml_kem_have_dkenc(key))
2107
0
        return 0;
2108
540
    vinfo = key->vinfo;
2109
2110
540
    if (len != vinfo->pubkey_bytes
2111
540
        || (mdctx = EVP_MD_CTX_new()) == NULL)
2112
0
        return 0;
2113
2114
540
    if (add_storage(OPENSSL_malloc(vinfo->puballoc), 0, key))
2115
540
        ret = parse_pubkey(in, mdctx, key);
2116
2117
540
    if (!ret)
2118
214
        ossl_ml_kem_key_reset(key);
2119
540
    EVP_MD_CTX_free(mdctx);
2120
540
    return ret;
2121
540
}
2122
2123
/* Parse input as a new private key */
2124
int ossl_ml_kem_parse_private_key(const uint8_t *in, size_t len,
2125
    ML_KEM_KEY *key)
2126
109
{
2127
109
    EVP_MD_CTX *mdctx = NULL;
2128
109
    const ML_KEM_VINFO *vinfo;
2129
109
    int ret = 0;
2130
2131
    /* Keys with key material are immutable */
2132
109
    if (key == NULL
2133
109
        || ossl_ml_kem_have_pubkey(key)
2134
109
        || ossl_ml_kem_have_dkenc(key))
2135
0
        return 0;
2136
109
    vinfo = key->vinfo;
2137
2138
109
    if (len != vinfo->prvkey_bytes
2139
109
        || (mdctx = EVP_MD_CTX_new()) == NULL)
2140
0
        return 0;
2141
2142
109
    if (add_storage(OPENSSL_malloc(vinfo->prvalloc), 1, key))
2143
109
        ret = parse_prvkey(in, mdctx, key);
2144
2145
109
    if (!ret)
2146
109
        ossl_ml_kem_key_reset(key);
2147
109
    EVP_MD_CTX_free(mdctx);
2148
109
    return ret;
2149
109
}
2150
2151
/*
2152
 * Generate a new keypair, either from the saved seed (when non-null), or from
2153
 * the RNG.
2154
 */
2155
int ossl_ml_kem_genkey(uint8_t *pubenc, size_t publen, ML_KEM_KEY *key)
2156
51.1k
{
2157
51.1k
    uint8_t seed[ML_KEM_SEED_BYTES];
2158
51.1k
    EVP_MD_CTX *mdctx = NULL;
2159
51.1k
    const ML_KEM_VINFO *vinfo;
2160
51.1k
    int ret = 0;
2161
2162
51.1k
    if (key == NULL
2163
51.1k
        || ossl_ml_kem_have_pubkey(key)
2164
51.1k
        || ossl_ml_kem_have_dkenc(key))
2165
0
        return 0;
2166
51.1k
    vinfo = key->vinfo;
2167
2168
51.1k
    if (pubenc != NULL && publen != vinfo->pubkey_bytes)
2169
0
        return 0;
2170
2171
51.1k
    if (ossl_ml_kem_have_seed(key)) {
2172
52
        if (!ossl_ml_kem_encode_seed(seed, sizeof(seed), key))
2173
0
            return 0;
2174
52
        key->d = key->z = NULL;
2175
51.1k
    } else if (RAND_priv_bytes_ex(key->libctx, seed, sizeof(seed),
2176
51.1k
                   key->vinfo->secbits)
2177
51.1k
        <= 0) {
2178
0
        return 0;
2179
0
    }
2180
2181
51.1k
    if ((mdctx = EVP_MD_CTX_new()) == NULL)
2182
0
        return 0;
2183
2184
    /*
2185
     * Data derived from (d, z) defaults secret, and to avoid side-channel
2186
     * leaks should not influence control flow.
2187
     */
2188
51.1k
    CONSTTIME_SECRET(seed, ML_KEM_SEED_BYTES);
2189
2190
51.1k
    if (add_storage(OPENSSL_malloc(vinfo->prvalloc), 1, key))
2191
51.1k
        ret = genkey(seed, mdctx, pubenc, key);
2192
51.1k
    OPENSSL_cleanse(seed, sizeof(seed));
2193
2194
    /* Declassify secret inputs and derived outputs before returning control */
2195
51.1k
    CONSTTIME_DECLASSIFY(seed, ML_KEM_SEED_BYTES);
2196
2197
51.1k
    EVP_MD_CTX_free(mdctx);
2198
51.1k
    if (!ret) {
2199
0
        ossl_ml_kem_key_reset(key);
2200
0
        return 0;
2201
0
    }
2202
2203
    /* The public components are already declassified */
2204
51.1k
    CONSTTIME_DECLASSIFY(key->s, vinfo->rank * sizeof(scalar));
2205
51.1k
    CONSTTIME_DECLASSIFY(key->z, 2 * ML_KEM_RANDOM_BYTES);
2206
51.1k
    return 1;
2207
51.1k
}
2208
2209
/*
2210
 * FIPS 203, Section 6.2, Algorithm 17: ML-KEM.Encaps_internal
2211
 * This is the deterministic version with randomness supplied externally.
2212
 */
2213
int ossl_ml_kem_encap_seed(uint8_t *ctext, size_t clen,
2214
    uint8_t *shared_secret, size_t slen,
2215
    const uint8_t *entropy, size_t elen,
2216
    const ML_KEM_KEY *key)
2217
156
{
2218
156
    const ML_KEM_VINFO *vinfo;
2219
156
    EVP_MD_CTX *mdctx;
2220
156
    int ret = 0;
2221
2222
156
    if (key == NULL || !ossl_ml_kem_have_pubkey(key))
2223
0
        return 0;
2224
156
    vinfo = key->vinfo;
2225
2226
156
    if (ctext == NULL || clen != vinfo->ctext_bytes
2227
156
        || shared_secret == NULL || slen != ML_KEM_SHARED_SECRET_BYTES
2228
156
        || entropy == NULL || elen != ML_KEM_RANDOM_BYTES
2229
156
        || (mdctx = EVP_MD_CTX_new()) == NULL)
2230
0
        return 0;
2231
    /*
2232
     * Data derived from the encap entropy defaults secret, and to avoid
2233
     * side-channel leaks should not influence control flow.
2234
     */
2235
156
    CONSTTIME_SECRET(entropy, elen);
2236
2237
    /*-
2238
     * This avoids the need to handle allocation failures for two (max 2KB
2239
     * each) vectors, that are never retained on return from this function.
2240
     * We stack-allocate these.
2241
     */
2242
156
#define case_encap_seed(bits)                                        \
2243
156
    case EVP_PKEY_ML_KEM_##bits: {                                   \
2244
156
        scalar tmp[2 * ML_KEM_##bits##_RANK];                        \
2245
156
                                                                     \
2246
156
        ret = encap(ctext, shared_secret, entropy, tmp, mdctx, key); \
2247
156
        OPENSSL_cleanse((void *)tmp, sizeof(tmp));                   \
2248
156
        break;                                                       \
2249
156
    }
2250
156
    switch (vinfo->evp_type) {
2251
55
        case_encap_seed(512);
2252
63
        case_encap_seed(768);
2253
38
        case_encap_seed(1024);
2254
156
    }
2255
156
#undef case_encap_seed
2256
2257
    /* Declassify secret inputs and derived outputs before returning control */
2258
156
    CONSTTIME_DECLASSIFY(entropy, elen);
2259
156
    CONSTTIME_DECLASSIFY(ctext, clen);
2260
156
    CONSTTIME_DECLASSIFY(shared_secret, slen);
2261
2262
156
    EVP_MD_CTX_free(mdctx);
2263
156
    return ret;
2264
156
}
2265
2266
int ossl_ml_kem_encap_rand(uint8_t *ctext, size_t clen,
2267
    uint8_t *shared_secret, size_t slen,
2268
    const ML_KEM_KEY *key)
2269
156
{
2270
156
    uint8_t r[ML_KEM_RANDOM_BYTES];
2271
2272
156
    if (key == NULL)
2273
0
        return 0;
2274
2275
156
    if (RAND_bytes_ex(key->libctx, r, ML_KEM_RANDOM_BYTES,
2276
156
            key->vinfo->secbits)
2277
156
        < 1)
2278
0
        return 0;
2279
2280
156
    return ossl_ml_kem_encap_seed(ctext, clen, shared_secret, slen,
2281
156
        r, sizeof(r), key);
2282
156
}
2283
2284
int ossl_ml_kem_decap(uint8_t *shared_secret, size_t slen,
2285
    const uint8_t *ctext, size_t clen,
2286
    const ML_KEM_KEY *key)
2287
132
{
2288
132
    const ML_KEM_VINFO *vinfo;
2289
132
    EVP_MD_CTX *mdctx;
2290
132
    int ret = 0;
2291
#if defined(OPENSSL_CONSTANT_TIME_VALIDATION)
2292
    int classify_bytes;
2293
#endif
2294
2295
    /* Need a private key here */
2296
132
    if (!ossl_ml_kem_have_prvkey(key))
2297
0
        return 0;
2298
132
    vinfo = key->vinfo;
2299
2300
132
    if (shared_secret == NULL || slen != ML_KEM_SHARED_SECRET_BYTES
2301
132
        || ctext == NULL || clen != vinfo->ctext_bytes
2302
132
        || (mdctx = EVP_MD_CTX_new()) == NULL) {
2303
0
        (void)RAND_bytes_ex(key->libctx, shared_secret,
2304
0
            ML_KEM_SHARED_SECRET_BYTES, vinfo->secbits);
2305
0
        return 0;
2306
0
    }
2307
#if defined(OPENSSL_CONSTANT_TIME_VALIDATION)
2308
    /*
2309
     * Data derived from |s| and |z| defaults secret, and to avoid side-channel
2310
     * leaks should not influence control flow.
2311
     */
2312
    classify_bytes = 2 * sizeof(scalar) + ML_KEM_RANDOM_BYTES;
2313
    CONSTTIME_SECRET(key->s, classify_bytes);
2314
#endif
2315
2316
    /*-
2317
     * This avoids the need to handle allocation failures for two (max 2KB
2318
     * each) vectors and an encoded ciphertext (max 1568 bytes), that are never
2319
     * retained on return from this function.
2320
     * We stack-allocate these.
2321
     */
2322
132
#define case_decap(bits)                                          \
2323
132
    case EVP_PKEY_ML_KEM_##bits: {                                \
2324
132
        uint8_t cbuf[CTEXT_BYTES(bits)];                          \
2325
132
        scalar tmp[2 * ML_KEM_##bits##_RANK];                     \
2326
132
                                                                  \
2327
132
        ret = decap(shared_secret, ctext, cbuf, tmp, mdctx, key); \
2328
132
        OPENSSL_cleanse((void *)tmp, sizeof(tmp));                \
2329
132
        break;                                                    \
2330
132
    }
2331
132
    switch (vinfo->evp_type) {
2332
55
        case_decap(512);
2333
39
        case_decap(768);
2334
38
        case_decap(1024);
2335
132
    }
2336
2337
    /* Declassify secret inputs and derived outputs before returning control */
2338
132
    CONSTTIME_DECLASSIFY(key->s, classify_bytes);
2339
132
    CONSTTIME_DECLASSIFY(shared_secret, slen);
2340
132
    EVP_MD_CTX_free(mdctx);
2341
2342
132
    return ret;
2343
132
#undef case_decap
2344
132
}
2345
2346
int ossl_ml_kem_pubkey_cmp(const ML_KEM_KEY *key1, const ML_KEM_KEY *key2)
2347
138
{
2348
    /*
2349
     * This handles any unexpected differences in the ML-KEM variant rank,
2350
     * giving different key component structures, barring SHA3-256 hash
2351
     * collisions, the keys are the same size.
2352
     */
2353
138
    if (ossl_ml_kem_have_pubkey(key1) && ossl_ml_kem_have_pubkey(key2))
2354
138
        return memcmp(key1->pkhash, key2->pkhash, ML_KEM_PKHASH_BYTES) == 0;
2355
2356
    /*
2357
     * No match if just one of the public keys is not available, otherwise both
2358
     * are unavailable, and for now such keys are considered equal.
2359
     */
2360
0
    return (!(ossl_ml_kem_have_pubkey(key1) ^ ossl_ml_kem_have_pubkey(key2)));
2361
138
}