Coverage Report

Created: 2026-04-09 06:50

next uncovered line (L), next uncovered region (R), next uncovered branch (B)
/src/openssl35/crypto/ml_kem/ml_kem.c
Line
Count
Source
1
/*
2
 * Copyright 2024-2025 The OpenSSL Project Authors. All Rights Reserved.
3
 *
4
 * Licensed under the Apache License 2.0 (the "License").  You may not use
5
 * this file except in compliance with the License.  You can obtain a copy
6
 * in the file LICENSE in the source distribution or at
7
 * https://www.openssl.org/source/license.html
8
 */
9
10
#include <openssl/byteorder.h>
11
#include <openssl/rand.h>
12
#include <openssl/proverr.h>
13
#include "crypto/ml_kem.h"
14
#include "internal/common.h"
15
#include "internal/constant_time.h"
16
#include "internal/sha3.h"
17
18
#if defined(OPENSSL_CONSTANT_TIME_VALIDATION)
19
#include <valgrind/memcheck.h>
20
#endif
21
22
#if ML_KEM_SEED_BYTES != ML_KEM_SHARED_SECRET_BYTES + ML_KEM_RANDOM_BYTES
23
#error "ML-KEM keygen seed length != shared secret + random bytes length"
24
#endif
25
#if ML_KEM_SHARED_SECRET_BYTES != ML_KEM_RANDOM_BYTES
26
#error "Invalid unequal lengths of ML-KEM shared secret and random inputs"
27
#endif
28
29
#if UINT_MAX < UINT32_MAX
30
#error "Unsupported compiler: sizeof(unsigned int) < sizeof(uint32_t)"
31
#endif
32
33
/* Handy function-like bit-extraction macros */
34
39.8M
#define bit0(b) ((b) & 1)
35
278M
#define bitn(n, b) (((b) >> n) & 1)
36
37
/*
38
 * 12 bits are sufficient to losslessly represent values in [0, q-1].
39
 * INVERSE_DEGREE is (n/2)^-1 mod q; used in inverse NTT.
40
 */
41
3.37M
#define DEGREE ML_KEM_DEGREE
42
#define INVERSE_DEGREE (ML_KEM_PRIME - 2 * 13)
43
#define LOG2PRIME 12
44
#define BARRETT_SHIFT (2 * LOG2PRIME)
45
46
#ifdef SHA3_BLOCKSIZE
47
#define SHAKE128_BLOCKSIZE SHA3_BLOCKSIZE(128)
48
#endif
49
50
/*
51
 * Return whether a value that can only be 0 or 1 is non-zero, in constant time
52
 * in practice!  The return value is a mask that is all ones if true, and all
53
 * zeros otherwise (twos-complement arithmetic assumed for unsigned values).
54
 *
55
 * Although this is used in constant-time selects, we omit a value barrier
56
 * here.  Value barriers impede auto-vectorization (likely because it forces
57
 * the value to transit through a general-purpose register). On AArch64, this
58
 * is a difference of 2x.
59
 *
60
 * We usually add value barriers to selects because Clang turns consecutive
61
 * selects with the same condition into a branch instead of CMOV/CSEL. This
62
 * condition does not occur in Kyber, so omitting it seems to be safe so far,
63
 * but see |cbd_2|, |cbd_3|, where reduction needs to be specialised to the
64
 * sign of the input, rather than adding |q| in advance, and using the generic
65
 * |reduce_once|.  (David Benjamin, Chromium)
66
 */
67
#if 0
68
#define constish_time_non_zero(b) (~constant_time_is_zero(b));
69
#else
70
1.09G
#define constish_time_non_zero(b) (0u - (b))
71
#endif
72
73
/*
74
 * The scalar rejection-sampling buffer size needs to be a multiple of 12, but
75
 * is otherwise arbitrary, the preferred block size matches the internal buffer
76
 * size of SHAKE128, avoiding internal buffering and copying in SHAKE128. That
77
 * block size of (1600 - 256)/8 bytes, or 168, just happens to divide by 12!
78
 *
79
 * If the blocksize is unknown, or is not divisible by 12, 168 is used as a
80
 * fallback.
81
 */
82
#if defined(SHAKE128_BLOCKSIZE) && (SHAKE128_BLOCKSIZE) % 12 == 0
83
#define SCALAR_SAMPLING_BUFSIZE (SHAKE128_BLOCKSIZE)
84
#else
85
#define SCALAR_SAMPLING_BUFSIZE 168
86
#endif
87
88
/*
89
 * Structure of keys
90
 */
91
typedef struct ossl_ml_kem_scalar_st {
92
    /* On every function entry and exit, 0 <= c[i] < ML_KEM_PRIME. */
93
    uint16_t c[ML_KEM_DEGREE];
94
} scalar;
95
96
/* Key material allocation layout */
97
#define DECLARE_ML_KEM_KEYDATA(name, rank, private_sz)               \
98
    struct name##_alloc {                                            \
99
        /* Public vector |t| */                                      \
100
        scalar tbuf[(rank)];                                         \
101
        /* Pre-computed matrix |m| (FIPS 203 |A| transpose) */       \
102
        scalar mbuf[(rank) * (rank)] /* optional private key data */ \
103
            private_sz                                               \
104
    }
105
106
/* Declare variant-specific public and private storage */
107
#define DECLARE_ML_KEM_VARIANT_KEYDATA(bits)                        \
108
    DECLARE_ML_KEM_KEYDATA(pubkey_##bits, ML_KEM_##bits##_RANK, ;); \
109
    DECLARE_ML_KEM_KEYDATA(prvkey_##bits, ML_KEM_##bits##_RANK, ; scalar sbuf[ML_KEM_##bits##_RANK]; uint8_t zbuf[2 * ML_KEM_RANDOM_BYTES];)
110
DECLARE_ML_KEM_VARIANT_KEYDATA(512);
111
DECLARE_ML_KEM_VARIANT_KEYDATA(768);
112
DECLARE_ML_KEM_VARIANT_KEYDATA(1024);
113
#undef DECLARE_ML_KEM_VARIANT_KEYDATA
114
#undef DECLARE_ML_KEM_KEYDATA
115
116
typedef __owur int (*CBD_FUNC)(scalar *out, uint8_t in[ML_KEM_RANDOM_BYTES + 1],
117
    EVP_MD_CTX *mdctx, const ML_KEM_KEY *key);
118
static void scalar_encode(uint8_t *out, const scalar *s, int bits);
119
120
/*
121
 * The wire-form of a losslessly encoded vector uses 12-bits per element.
122
 *
123
 * The wire-form public key consists of the lossless encoding of the public
124
 * vector |t|, followed by the public seed |rho|.
125
 *
126
 * Our serialised private key concatenates serialisations of the private vector
127
 * |s|, the public key, the public key hash, and the failure secret |z|.
128
 */
129
#define VECTOR_BYTES(b) ((3 * DEGREE / 2) * ML_KEM_##b##_RANK)
130
#define PUBKEY_BYTES(b) (VECTOR_BYTES(b) + ML_KEM_RANDOM_BYTES)
131
#define PRVKEY_BYTES(b) (2 * PUBKEY_BYTES(b) + ML_KEM_PKHASH_BYTES)
132
133
/*
134
 * Encapsulation produces a vector "u" and a scalar "v", whose coordinates
135
 * (numbers modulo the ML-KEM prime "q") are lossily encoded using as "du" and
136
 * "dv" bits, respectively.  This encoding is the ciphertext input for
137
 * decapsulation.
138
 */
139
#define U_VECTOR_BYTES(b) ((DEGREE / 8) * ML_KEM_##b##_DU * ML_KEM_##b##_RANK)
140
#define V_SCALAR_BYTES(b) ((DEGREE / 8) * ML_KEM_##b##_DV)
141
#define CTEXT_BYTES(b) (U_VECTOR_BYTES(b) + V_SCALAR_BYTES(b))
142
143
#if defined(OPENSSL_CONSTANT_TIME_VALIDATION)
144
145
/*
146
 * CONSTTIME_SECRET takes a pointer and a number of bytes and marks that region
147
 * of memory as secret. Secret data is tracked as it flows to registers and
148
 * other parts of a memory. If secret data is used as a condition for a branch,
149
 * or as a memory index, it will trigger warnings in valgrind.
150
 */
151
#define CONSTTIME_SECRET(ptr, len) VALGRIND_MAKE_MEM_UNDEFINED(ptr, len)
152
153
/*
154
 * CONSTTIME_DECLASSIFY takes a pointer and a number of bytes and marks that
155
 * region of memory as public. Public data is not subject to constant-time
156
 * rules.
157
 */
158
#define CONSTTIME_DECLASSIFY(ptr, len) VALGRIND_MAKE_MEM_DEFINED(ptr, len)
159
160
#else
161
162
#define CONSTTIME_SECRET(ptr, len)
163
#define CONSTTIME_DECLASSIFY(ptr, len)
164
165
#endif
166
167
/*
168
 * Indices of slots in the vinfo tables below
169
 */
170
60.0k
#define ML_KEM_512_VINFO 0
171
161k
#define ML_KEM_768_VINFO 1
172
56.4k
#define ML_KEM_1024_VINFO 2
173
174
/*
175
 * Per-variant fixed parameters
176
 */
177
static const ML_KEM_VINFO vinfo_map[3] = {
178
    { "ML-KEM-512",
179
        PRVKEY_BYTES(512),
180
        sizeof(struct prvkey_512_alloc),
181
        PUBKEY_BYTES(512),
182
        sizeof(struct pubkey_512_alloc),
183
        CTEXT_BYTES(512),
184
        VECTOR_BYTES(512),
185
        U_VECTOR_BYTES(512),
186
        EVP_PKEY_ML_KEM_512,
187
        ML_KEM_512_BITS,
188
        ML_KEM_512_RANK,
189
        ML_KEM_512_DU,
190
        ML_KEM_512_DV,
191
        ML_KEM_512_SECBITS },
192
    { "ML-KEM-768",
193
        PRVKEY_BYTES(768),
194
        sizeof(struct prvkey_768_alloc),
195
        PUBKEY_BYTES(768),
196
        sizeof(struct pubkey_768_alloc),
197
        CTEXT_BYTES(768),
198
        VECTOR_BYTES(768),
199
        U_VECTOR_BYTES(768),
200
        EVP_PKEY_ML_KEM_768,
201
        ML_KEM_768_BITS,
202
        ML_KEM_768_RANK,
203
        ML_KEM_768_DU,
204
        ML_KEM_768_DV,
205
        ML_KEM_768_SECBITS },
206
    { "ML-KEM-1024",
207
        PRVKEY_BYTES(1024),
208
        sizeof(struct prvkey_1024_alloc),
209
        PUBKEY_BYTES(1024),
210
        sizeof(struct pubkey_1024_alloc),
211
        CTEXT_BYTES(1024),
212
        VECTOR_BYTES(1024),
213
        U_VECTOR_BYTES(1024),
214
        EVP_PKEY_ML_KEM_1024,
215
        ML_KEM_1024_BITS,
216
        ML_KEM_1024_RANK,
217
        ML_KEM_1024_DU,
218
        ML_KEM_1024_DV,
219
        ML_KEM_1024_SECBITS }
220
};
221
222
/*
223
 * Remainders modulo `kPrime`, for sufficiently small inputs, are computed in
224
 * constant time via Barrett reduction, and a final call to reduce_once(),
225
 * which reduces inputs that are at most 2*kPrime and is also constant-time.
226
 */
227
static const int kPrime = ML_KEM_PRIME;
228
static const unsigned int kBarrettShift = BARRETT_SHIFT;
229
static const size_t kBarrettMultiplier = (1 << BARRETT_SHIFT) / ML_KEM_PRIME;
230
static const uint16_t kHalfPrime = (ML_KEM_PRIME - 1) / 2;
231
static const uint16_t kInverseDegree = INVERSE_DEGREE;
232
233
/*
234
 * Python helper:
235
 *
236
 * p = 3329
237
 * def bitreverse(i):
238
 *     ret = 0
239
 *     for n in range(7):
240
 *         bit = i & 1
241
 *         ret <<= 1
242
 *         ret |= bit
243
 *         i >>= 1
244
 *     return ret
245
 */
246
247
/*-
248
 * First precomputed array from Appendix A of FIPS 203, or else Python:
249
 * kNTTRoots = [pow(17, bitreverse(i), p) for i in range(128)]
250
 */
251
static const uint16_t kNTTRoots[128] = {
252
    1,
253
    1729,
254
    2580,
255
    3289,
256
    2642,
257
    630,
258
    1897,
259
    848,
260
    1062,
261
    1919,
262
    193,
263
    797,
264
    2786,
265
    3260,
266
    569,
267
    1746,
268
    296,
269
    2447,
270
    1339,
271
    1476,
272
    3046,
273
    56,
274
    2240,
275
    1333,
276
    1426,
277
    2094,
278
    535,
279
    2882,
280
    2393,
281
    2879,
282
    1974,
283
    821,
284
    289,
285
    331,
286
    3253,
287
    1756,
288
    1197,
289
    2304,
290
    2277,
291
    2055,
292
    650,
293
    1977,
294
    2513,
295
    632,
296
    2865,
297
    33,
298
    1320,
299
    1915,
300
    2319,
301
    1435,
302
    807,
303
    452,
304
    1438,
305
    2868,
306
    1534,
307
    2402,
308
    2647,
309
    2617,
310
    1481,
311
    648,
312
    2474,
313
    3110,
314
    1227,
315
    910,
316
    17,
317
    2761,
318
    583,
319
    2649,
320
    1637,
321
    723,
322
    2288,
323
    1100,
324
    1409,
325
    2662,
326
    3281,
327
    233,
328
    756,
329
    2156,
330
    3015,
331
    3050,
332
    1703,
333
    1651,
334
    2789,
335
    1789,
336
    1847,
337
    952,
338
    1461,
339
    2687,
340
    939,
341
    2308,
342
    2437,
343
    2388,
344
    733,
345
    2337,
346
    268,
347
    641,
348
    1584,
349
    2298,
350
    2037,
351
    3220,
352
    375,
353
    2549,
354
    2090,
355
    1645,
356
    1063,
357
    319,
358
    2773,
359
    757,
360
    2099,
361
    561,
362
    2466,
363
    2594,
364
    2804,
365
    1092,
366
    403,
367
    1026,
368
    1143,
369
    2150,
370
    2775,
371
    886,
372
    1722,
373
    1212,
374
    1874,
375
    1029,
376
    2110,
377
    2935,
378
    885,
379
    2154,
380
};
381
382
/*
383
 * InverseNTTRoots = [pow(17, -bitreverse(i), p) for i in range(128)]
384
 * Listed in order of use in the inverse NTT loop (index 0 is skipped):
385
 *
386
 *  0, 64, 65, ..., 127, 32, 33, ..., 63, 16, 17, ..., 31, 8, 9, ...
387
 */
388
static const uint16_t kInverseNTTRoots[128] = {
389
    1,
390
    1175,
391
    2444,
392
    394,
393
    1219,
394
    2300,
395
    1455,
396
    2117,
397
    1607,
398
    2443,
399
    554,
400
    1179,
401
    2186,
402
    2303,
403
    2926,
404
    2237,
405
    525,
406
    735,
407
    863,
408
    2768,
409
    1230,
410
    2572,
411
    556,
412
    3010,
413
    2266,
414
    1684,
415
    1239,
416
    780,
417
    2954,
418
    109,
419
    1292,
420
    1031,
421
    1745,
422
    2688,
423
    3061,
424
    992,
425
    2596,
426
    941,
427
    892,
428
    1021,
429
    2390,
430
    642,
431
    1868,
432
    2377,
433
    1482,
434
    1540,
435
    540,
436
    1678,
437
    1626,
438
    279,
439
    314,
440
    1173,
441
    2573,
442
    3096,
443
    48,
444
    667,
445
    1920,
446
    2229,
447
    1041,
448
    2606,
449
    1692,
450
    680,
451
    2746,
452
    568,
453
    3312,
454
    2419,
455
    2102,
456
    219,
457
    855,
458
    2681,
459
    1848,
460
    712,
461
    682,
462
    927,
463
    1795,
464
    461,
465
    1891,
466
    2877,
467
    2522,
468
    1894,
469
    1010,
470
    1414,
471
    2009,
472
    3296,
473
    464,
474
    2697,
475
    816,
476
    1352,
477
    2679,
478
    1274,
479
    1052,
480
    1025,
481
    2132,
482
    1573,
483
    76,
484
    2998,
485
    3040,
486
    2508,
487
    1355,
488
    450,
489
    936,
490
    447,
491
    2794,
492
    1235,
493
    1903,
494
    1996,
495
    1089,
496
    3273,
497
    283,
498
    1853,
499
    1990,
500
    882,
501
    3033,
502
    1583,
503
    2760,
504
    69,
505
    543,
506
    2532,
507
    3136,
508
    1410,
509
    2267,
510
    2481,
511
    1432,
512
    2699,
513
    687,
514
    40,
515
    749,
516
    1600,
517
};
518
519
/*
520
 * Second precomputed array from Appendix A of FIPS 203 (normalised positive),
521
 * or else Python:
522
 * ModRoots = [pow(17, 2*bitreverse(i) + 1, p) for i in range(128)]
523
 */
524
static const uint16_t kModRoots[128] = {
525
    17,
526
    3312,
527
    2761,
528
    568,
529
    583,
530
    2746,
531
    2649,
532
    680,
533
    1637,
534
    1692,
535
    723,
536
    2606,
537
    2288,
538
    1041,
539
    1100,
540
    2229,
541
    1409,
542
    1920,
543
    2662,
544
    667,
545
    3281,
546
    48,
547
    233,
548
    3096,
549
    756,
550
    2573,
551
    2156,
552
    1173,
553
    3015,
554
    314,
555
    3050,
556
    279,
557
    1703,
558
    1626,
559
    1651,
560
    1678,
561
    2789,
562
    540,
563
    1789,
564
    1540,
565
    1847,
566
    1482,
567
    952,
568
    2377,
569
    1461,
570
    1868,
571
    2687,
572
    642,
573
    939,
574
    2390,
575
    2308,
576
    1021,
577
    2437,
578
    892,
579
    2388,
580
    941,
581
    733,
582
    2596,
583
    2337,
584
    992,
585
    268,
586
    3061,
587
    641,
588
    2688,
589
    1584,
590
    1745,
591
    2298,
592
    1031,
593
    2037,
594
    1292,
595
    3220,
596
    109,
597
    375,
598
    2954,
599
    2549,
600
    780,
601
    2090,
602
    1239,
603
    1645,
604
    1684,
605
    1063,
606
    2266,
607
    319,
608
    3010,
609
    2773,
610
    556,
611
    757,
612
    2572,
613
    2099,
614
    1230,
615
    561,
616
    2768,
617
    2466,
618
    863,
619
    2594,
620
    735,
621
    2804,
622
    525,
623
    1092,
624
    2237,
625
    403,
626
    2926,
627
    1026,
628
    2303,
629
    1143,
630
    2186,
631
    2150,
632
    1179,
633
    2775,
634
    554,
635
    886,
636
    2443,
637
    1722,
638
    1607,
639
    1212,
640
    2117,
641
    1874,
642
    1455,
643
    1029,
644
    2300,
645
    2110,
646
    1219,
647
    2935,
648
    394,
649
    885,
650
    2444,
651
    2154,
652
    1175,
653
};
654
655
/*
656
 * single_keccak hashes |inlen| bytes from |in| and writes |outlen| bytes of
657
 * output to |out|. If the |md| specifies a fixed-output function, like
658
 * SHA3-256, then |outlen| must be the correct length for that function.
659
 */
660
static __owur int single_keccak(uint8_t *out, size_t outlen, const uint8_t *in, size_t inlen,
661
    EVP_MD_CTX *mdctx)
662
362k
{
663
362k
    unsigned int sz = (unsigned int)outlen;
664
665
362k
    if (!EVP_DigestUpdate(mdctx, in, inlen))
666
0
        return 0;
667
362k
    if (EVP_MD_xof(EVP_MD_CTX_get0_md(mdctx)))
668
310k
        return EVP_DigestFinalXOF(mdctx, out, outlen);
669
52.0k
    return EVP_DigestFinal_ex(mdctx, out, &sz)
670
52.0k
        && ossl_assert((size_t)sz == outlen);
671
362k
}
672
673
/*
674
 * FIPS 203, Section 4.1, equation (4.3): PRF. Takes 32+1 input bytes, and uses
675
 * SHAKE256 to produce the input to SamplePolyCBD_eta: FIPS 203, algorithm 8.
676
 */
677
static __owur int prf(uint8_t *out, size_t len, const uint8_t in[ML_KEM_RANDOM_BYTES + 1],
678
    EVP_MD_CTX *mdctx, const ML_KEM_KEY *key)
679
310k
{
680
310k
    return EVP_DigestInit_ex(mdctx, key->shake256_md, NULL)
681
310k
        && single_keccak(out, len, in, ML_KEM_RANDOM_BYTES + 1, mdctx);
682
310k
}
683
684
/*
685
 * FIPS 203, Section 4.1, equation (4.4): H.  SHA3-256 hash of a variable
686
 * length input, producing 32 bytes of output.
687
 */
688
static __owur int hash_h(uint8_t out[ML_KEM_PKHASH_BYTES], const uint8_t *in, size_t len,
689
    EVP_MD_CTX *mdctx, const ML_KEM_KEY *key)
690
325
{
691
325
    return EVP_DigestInit_ex(mdctx, key->sha3_256_md, NULL)
692
325
        && single_keccak(out, ML_KEM_PKHASH_BYTES, in, len, mdctx);
693
325
}
694
695
/* Incremental hash_h of expanded public key */
696
static int
697
hash_h_pubkey(uint8_t pkhash[ML_KEM_PKHASH_BYTES],
698
    EVP_MD_CTX *mdctx, ML_KEM_KEY *key)
699
51.3k
{
700
51.3k
    const ML_KEM_VINFO *vinfo = key->vinfo;
701
51.3k
    const scalar *t = key->t, *end = t + vinfo->rank;
702
51.3k
    unsigned int sz;
703
704
51.3k
    if (!EVP_DigestInit_ex(mdctx, key->sha3_256_md, NULL))
705
0
        return 0;
706
707
154k
    do {
708
154k
        uint8_t buf[3 * DEGREE / 2];
709
710
154k
        scalar_encode(buf, t++, 12);
711
154k
        if (!EVP_DigestUpdate(mdctx, buf, sizeof(buf)))
712
0
            return 0;
713
154k
    } while (t < end);
714
715
51.3k
    if (!EVP_DigestUpdate(mdctx, key->rho, ML_KEM_RANDOM_BYTES))
716
0
        return 0;
717
51.3k
    return EVP_DigestFinal_ex(mdctx, pkhash, &sz)
718
51.3k
        && ossl_assert(sz == ML_KEM_PKHASH_BYTES);
719
51.3k
}
720
721
/*
722
 * FIPS 203, Section 4.1, equation (4.5): G.  SHA3-512 hash of a variable
723
 * length input, producing 64 bytes of output, in particular the seeds
724
 * (d,z) for key generation.
725
 */
726
static __owur int hash_g(uint8_t out[ML_KEM_SEED_BYTES], const uint8_t *in, size_t len,
727
    EVP_MD_CTX *mdctx, const ML_KEM_KEY *key)
728
51.7k
{
729
51.7k
    return EVP_DigestInit_ex(mdctx, key->sha3_512_md, NULL)
730
51.7k
        && single_keccak(out, ML_KEM_SEED_BYTES, in, len, mdctx);
731
51.7k
}
732
733
/*
734
 * FIPS 203, Section 4.1, equation (4.4): J. SHAKE256 taking a variable length
735
 * input to compute a 32-byte implicit rejection shared secret, of the same
736
 * length as the expected shared secret.  (Computed even on success to avoid
737
 * side-channel leaks).
738
 */
739
static __owur int kdf(uint8_t out[ML_KEM_SHARED_SECRET_BYTES],
740
    const uint8_t z[ML_KEM_RANDOM_BYTES],
741
    const uint8_t *ctext, size_t len,
742
    EVP_MD_CTX *mdctx, const ML_KEM_KEY *key)
743
173
{
744
173
    return EVP_DigestInit_ex(mdctx, key->shake256_md, NULL)
745
173
        && EVP_DigestUpdate(mdctx, z, ML_KEM_RANDOM_BYTES)
746
173
        && EVP_DigestUpdate(mdctx, ctext, len)
747
173
        && EVP_DigestFinalXOF(mdctx, out, ML_KEM_SHARED_SECRET_BYTES);
748
173
}
749
750
/*
751
 * FIPS 203, Section 4.2.2, Algorithm 7: "SampleNTT" (steps 3-17, steps 1, 2
752
 * are performed by the caller). Rejection-samples a Keccak stream to get
753
 * uniformly distributed elements in the range [0,q). This is used for matrix
754
 * expansion and only operates on public inputs.
755
 */
756
static __owur int sample_scalar(scalar *out, EVP_MD_CTX *mdctx)
757
464k
{
758
464k
    uint16_t *curr = out->c, *endout = curr + DEGREE;
759
464k
    uint8_t buf[SCALAR_SAMPLING_BUFSIZE], *in;
760
464k
    uint8_t *endin = buf + sizeof(buf);
761
464k
    uint16_t d;
762
464k
    uint8_t b1, b2, b3;
763
764
1.39M
    do {
765
1.39M
        if (!EVP_DigestSqueeze(mdctx, in = buf, sizeof(buf)))
766
0
            return 0;
767
72.8M
        do {
768
72.8M
            b1 = *in++;
769
72.8M
            b2 = *in++;
770
72.8M
            b3 = *in++;
771
772
72.8M
            if (curr >= endout)
773
156k
                break;
774
72.6M
            if ((d = ((b2 & 0x0f) << 8) + b1) < kPrime)
775
59.9M
                *curr++ = d;
776
72.6M
            if (curr >= endout)
777
308k
                break;
778
72.3M
            if ((d = (b3 << 4) + (b2 >> 4)) < kPrime)
779
58.9M
                *curr++ = d;
780
72.3M
        } while (in < endin);
781
1.39M
    } while (curr < endout);
782
464k
    return 1;
783
464k
}
784
785
/*-
786
 * reduce_once reduces 0 <= x < 2*kPrime, mod kPrime.
787
 *
788
 * Subtract |q| if the input is larger, without exposing a side-channel,
789
 * avoiding the "clangover" attack.  See |constish_time_non_zero| for a
790
 * discussion on why the value barrier is by default omitted.
791
 */
792
static __owur uint16_t reduce_once(uint16_t x)
793
1.01G
{
794
1.01G
    const uint16_t subtracted = x - kPrime;
795
1.01G
    uint16_t mask = constish_time_non_zero(subtracted >> 15);
796
797
1.01G
    return (mask & x) | (~mask & subtracted);
798
1.01G
}
799
800
/*
801
 * Constant-time reduce x mod kPrime using Barrett reduction. x must be less
802
 * than kPrime + 2 * kPrime^2.  This is sufficient to reduce a product of
803
 * two already reduced u_int16 values, in fact it is sufficient for each
804
 * to be less than 2^12, because (kPrime * (2 * kPrime + 1)) > 2^24.
805
 */
806
static __owur uint16_t reduce(uint32_t x)
807
458M
{
808
458M
    uint64_t product = (uint64_t)x * kBarrettMultiplier;
809
458M
    uint32_t quotient = (uint32_t)(product >> kBarrettShift);
810
458M
    uint32_t remainder = x - quotient * kPrime;
811
812
458M
    return reduce_once(remainder);
813
458M
}
814
815
/* Multiply a scalar by a constant. */
816
static void scalar_mult_const(scalar *s, uint16_t a)
817
1.61k
{
818
1.61k
    uint16_t *curr = s->c, *end = curr + DEGREE, tmp;
819
820
414k
    do {
821
414k
        tmp = reduce(*curr * a);
822
414k
        *curr++ = tmp;
823
414k
    } while (curr < end);
824
1.61k
}
825
826
/*-
827
 * FIPS 203, Section 4.3, Algorithm 9: "NTT".
828
 * In-place number theoretic transform of a given scalar.  Note that ML-KEM's
829
 * kPrime 3329 does not have a 512th root of unity, so this transform leaves
830
 * off the last iteration of the usual FFT code, with the 128 relevant roots of
831
 * unity being stored in NTTRoots.  This means the output should be seen as 128
832
 * elements in GF(3329^2), with the coefficients of the elements being
833
 * consecutive entries in |s->c|.
834
 */
835
static void scalar_ntt(scalar *s)
836
309k
{
837
309k
    const uint16_t *roots = kNTTRoots;
838
309k
    uint16_t *end = s->c + DEGREE;
839
309k
    int offset = DEGREE / 2;
840
841
2.16M
    do {
842
2.16M
        uint16_t *curr = s->c, *peer;
843
844
39.3M
        do {
845
39.3M
            uint16_t *pause = curr + offset, even, odd;
846
39.3M
            uint32_t zeta = *++roots;
847
848
39.3M
            peer = pause;
849
277M
            do {
850
277M
                even = *curr;
851
277M
                odd = reduce(*peer * zeta);
852
277M
                *peer++ = reduce_once(even - odd + kPrime);
853
277M
                *curr++ = reduce_once(odd + even);
854
277M
            } while (curr < pause);
855
39.3M
        } while ((curr = peer) < end);
856
2.16M
    } while ((offset >>= 1) >= 2);
857
309k
}
858
859
/*-
860
 * FIPS 203, Section 4.3, Algorithm 10: "NTT^(-1)".
861
 * In-place inverse number theoretic transform of a given scalar, with pairs of
862
 * entries of s->v being interpreted as elements of GF(3329^2). Just as with
863
 * the number theoretic transform, this leaves off the first step of the normal
864
 * iFFT to account for the fact that 3329 does not have a 512th root of unity,
865
 * using the precomputed 128 roots of unity stored in InverseNTTRoots.
866
 */
867
static void scalar_inverse_ntt(scalar *s)
868
1.61k
{
869
1.61k
    const uint16_t *roots = kInverseNTTRoots;
870
1.61k
    uint16_t *end = s->c + DEGREE;
871
1.61k
    int offset = 2;
872
873
11.3k
    do {
874
11.3k
        uint16_t *curr = s->c, *peer;
875
876
205k
        do {
877
205k
            uint16_t *pause = curr + offset, even, odd;
878
205k
            uint32_t zeta = *++roots;
879
880
205k
            peer = pause;
881
1.45M
            do {
882
1.45M
                even = *curr;
883
1.45M
                odd = *peer;
884
1.45M
                *peer++ = reduce(zeta * (even - odd + kPrime));
885
1.45M
                *curr++ = reduce_once(odd + even);
886
1.45M
            } while (curr < pause);
887
205k
        } while ((curr = peer) < end);
888
11.3k
    } while ((offset <<= 1) < DEGREE);
889
1.61k
    scalar_mult_const(s, kInverseDegree);
890
1.61k
}
891
892
/* Addition updating the LHS scalar in-place. */
893
static void scalar_add(scalar *lhs, const scalar *rhs)
894
1.44k
{
895
1.44k
    int i;
896
897
371k
    for (i = 0; i < DEGREE; i++)
898
370k
        lhs->c[i] = reduce_once(lhs->c[i] + rhs->c[i]);
899
1.44k
}
900
901
/* Subtraction updating the LHS scalar in-place. */
902
static void scalar_sub(scalar *lhs, const scalar *rhs)
903
173
{
904
173
    int i;
905
906
44.4k
    for (i = 0; i < DEGREE; i++)
907
44.2k
        lhs->c[i] = reduce_once(lhs->c[i] - rhs->c[i] + kPrime);
908
173
}
909
910
/*
911
 * Multiplying two scalars in the number theoretically transformed state. Since
912
 * 3329 does not have a 512th root of unity, this means we have to interpret
913
 * the 2*ith and (2*i+1)th entries of the scalar as elements of
914
 * GF(3329)[X]/(X^2 - 17^(2*bitreverse(i)+1)).
915
 *
916
 * The value of 17^(2*bitreverse(i)+1) mod 3329 is stored in the precomputed
917
 * ModRoots table. Note that our Barrett transform only allows us to multiply
918
 * two reduced numbers together, so we need some intermediate reduction steps,
919
 * even if an uint64_t could hold 3 multiplied numbers.
920
 */
921
static void scalar_mult(scalar *out, const scalar *lhs,
922
    const scalar *rhs)
923
1.61k
{
924
1.61k
    uint16_t *curr = out->c, *end = curr + DEGREE;
925
1.61k
    const uint16_t *lc = lhs->c, *rc = rhs->c;
926
1.61k
    const uint16_t *roots = kModRoots;
927
928
207k
    do {
929
207k
        uint32_t l0 = *lc++, r0 = *rc++;
930
207k
        uint32_t l1 = *lc++, r1 = *rc++;
931
207k
        uint32_t zetapow = *roots++;
932
933
207k
        *curr++ = reduce(l0 * r0 + reduce(l1 * r1) * zetapow);
934
207k
        *curr++ = reduce(l0 * r1 + l1 * r0);
935
207k
    } while (curr < end);
936
1.61k
}
937
938
/* Above, but add the result to an existing scalar */
939
static ossl_inline void scalar_mult_add(scalar *out, const scalar *lhs,
940
    const scalar *rhs)
941
465k
{
942
465k
    uint16_t *curr = out->c, *end = curr + DEGREE;
943
465k
    const uint16_t *lc = lhs->c, *rc = rhs->c;
944
465k
    const uint16_t *roots = kModRoots;
945
946
59.5M
    do {
947
59.5M
        uint32_t l0 = *lc++, r0 = *rc++;
948
59.5M
        uint32_t l1 = *lc++, r1 = *rc++;
949
59.5M
        uint16_t *c0 = curr++;
950
59.5M
        uint16_t *c1 = curr++;
951
59.5M
        uint32_t zetapow = *roots++;
952
953
59.5M
        *c0 = reduce(*c0 + l0 * r0 + reduce(l1 * r1) * zetapow);
954
59.5M
        *c1 = reduce(*c1 + l0 * r1 + l1 * r0);
955
59.5M
    } while (curr < end);
956
465k
}
957
958
/*-
959
 * FIPS 203, Section 4.2.1, Algorithm 5: "ByteEncode_d", for 2<=d<=12.
960
 * Here |bits| is |d|.  For efficiency, we handle the d=1 case separately.
961
 */
962
static void scalar_encode(uint8_t *out, const scalar *s, int bits)
963
309k
{
964
309k
    const uint16_t *curr = s->c, *end = curr + DEGREE;
965
309k
    uint64_t accum = 0, element;
966
309k
    int used = 0;
967
968
79.2M
    do {
969
79.2M
        element = *curr++;
970
79.2M
        if (used + bits < 64) {
971
64.4M
            accum |= element << used;
972
64.4M
            used += bits;
973
64.4M
        } else if (used + bits > 64) {
974
9.89M
            out = OPENSSL_store_u64_le(out, accum | (element << used));
975
9.89M
            accum = element >> (64 - used);
976
9.89M
            used = (used + bits) - 64;
977
9.89M
        } else {
978
4.94M
            out = OPENSSL_store_u64_le(out, accum | (element << used));
979
4.94M
            accum = 0;
980
4.94M
            used = 0;
981
4.94M
        }
982
79.2M
    } while (curr < end);
983
309k
}
984
985
/*
986
 * scalar_encode_1 is |scalar_encode| specialised for |bits| == 1.
987
 */
988
static void scalar_encode_1(uint8_t out[DEGREE / 8], const scalar *s)
989
173
{
990
173
    int i, j;
991
173
    uint8_t out_byte;
992
993
5.70k
    for (i = 0; i < DEGREE; i += 8) {
994
5.53k
        out_byte = 0;
995
49.8k
        for (j = 0; j < 8; j++)
996
44.2k
            out_byte |= bit0(s->c[i + j]) << j;
997
5.53k
        *out = out_byte;
998
5.53k
        out++;
999
5.53k
    }
1000
173
}
1001
1002
/*-
1003
 * FIPS 203, Section 4.2.1, Algorithm 6: "ByteDecode_d", for 2<=d<12.
1004
 * Here |bits| is |d|.  For efficiency, we handle the d=1 and d=12 cases
1005
 * separately.
1006
 *
1007
 * scalar_decode parses |DEGREE * bits| bits from |in| into |DEGREE| values in
1008
 * |out|.
1009
 */
1010
static void scalar_decode(scalar *out, const uint8_t *in, int bits)
1011
675
{
1012
675
    uint16_t *curr = out->c, *end = curr + DEGREE;
1013
675
    uint64_t accum = 0;
1014
675
    int accum_bits = 0, todo = bits;
1015
675
    uint16_t bitmask = (((uint16_t)1) << bits) - 1, mask = bitmask;
1016
675
    uint16_t element = 0;
1017
1018
191k
    do {
1019
191k
        if (accum_bits == 0) {
1020
23.8k
            in = OPENSSL_load_u64_le(&accum, in);
1021
23.8k
            accum_bits = 64;
1022
23.8k
        }
1023
191k
        if (todo == bits && accum_bits >= bits) {
1024
            /* No partial "element", and all the required bits available */
1025
154k
            *curr++ = ((uint16_t)accum) & mask;
1026
154k
            accum >>= bits;
1027
154k
            accum_bits -= bits;
1028
154k
        } else if (accum_bits >= todo) {
1029
            /* A partial "element", and all the required bits available */
1030
18.5k
            *curr++ = element | ((((uint16_t)accum) & mask) << (bits - todo));
1031
18.5k
            accum >>= todo;
1032
18.5k
            accum_bits -= todo;
1033
18.5k
            element = 0;
1034
18.5k
            todo = bits;
1035
18.5k
            mask = bitmask;
1036
18.5k
        } else {
1037
            /*
1038
             * Only some of the requisite bits accumulated, store |accum_bits|
1039
             * of these in |element|.  The accumulated bitcount becomes 0, but
1040
             * as soon as we have more bits we'll want to merge accum_bits
1041
             * fewer of them into the final |element|.
1042
             *
1043
             * Note that with a 64-bit accumulator and |bits| always 12 or
1044
             * less, if we're here, the previous iteration had all the
1045
             * requisite bits, and so there are no kept bits in |element|.
1046
             */
1047
18.5k
            element = ((uint16_t)accum) & mask;
1048
18.5k
            todo -= accum_bits;
1049
18.5k
            mask = bitmask >> accum_bits;
1050
18.5k
            accum_bits = 0;
1051
18.5k
        }
1052
191k
    } while (curr < end);
1053
675
}
1054
1055
static __owur int scalar_decode_12(scalar *out, const uint8_t in[3 * DEGREE / 2])
1056
1.26k
{
1057
1.26k
    int i;
1058
1.26k
    uint16_t *c = out->c;
1059
1060
127k
    for (i = 0; i < DEGREE / 2; ++i) {
1061
126k
        uint8_t b1 = *in++;
1062
126k
        uint8_t b2 = *in++;
1063
126k
        uint8_t b3 = *in++;
1064
126k
        int outOfRange1 = (*c++ = b1 | ((b2 & 0x0f) << 8)) >= kPrime;
1065
126k
        int outOfRange2 = (*c++ = (b2 >> 4) | (b3 << 4)) >= kPrime;
1066
1067
126k
        if (outOfRange1 | outOfRange2)
1068
332
            return 0;
1069
126k
    }
1070
936
    return 1;
1071
1.26k
}
1072
1073
/*-
1074
 * scalar_decode_decompress_add is a combination of decoding and decompression
1075
 * both specialised for |bits| == 1, with the result added (and sum reduced) to
1076
 * the output scalar.
1077
 *
1078
 * NOTE: this function MUST not leak an input-data-depedennt timing signal.
1079
 * A timing leak in a related function in the reference Kyber implementation
1080
 * made the "clangover" attack (CVE-2024-37880) possible, giving key recovery
1081
 * for ML-KEM-512 in minutes, provided the attacker has access to precise
1082
 * timing of a CPU performing chosen-ciphertext decap.  Admittedly this is only
1083
 * a risk when private keys are reused (perhaps KEMTLS servers).
1084
 */
1085
static void
1086
scalar_decode_decompress_add(scalar *out, const uint8_t in[DEGREE / 8])
1087
370
{
1088
370
    static const uint16_t half_q_plus_1 = (ML_KEM_PRIME >> 1) + 1;
1089
370
    uint16_t *curr = out->c, *end = curr + DEGREE;
1090
370
    uint16_t mask;
1091
370
    uint8_t b;
1092
1093
    /*
1094
     * Add |half_q_plus_1| if the bit is set, without exposing a side-channel,
1095
     * avoiding the "clangover" attack.  See |constish_time_non_zero| for a
1096
     * discussion on why the value barrier is by default omitted.
1097
     */
1098
370
#define decode_decompress_add_bit                        \
1099
94.7k
    mask = constish_time_non_zero(bit0(b));              \
1100
94.7k
    *curr = reduce_once(*curr + (mask & half_q_plus_1)); \
1101
94.7k
    curr++;                                              \
1102
94.7k
    b >>= 1
1103
1104
    /* Unrolled to process each byte in one iteration */
1105
11.8k
    do {
1106
11.8k
        b = *in++;
1107
11.8k
        decode_decompress_add_bit;
1108
11.8k
        decode_decompress_add_bit;
1109
11.8k
        decode_decompress_add_bit;
1110
11.8k
        decode_decompress_add_bit;
1111
1112
11.8k
        decode_decompress_add_bit;
1113
11.8k
        decode_decompress_add_bit;
1114
11.8k
        decode_decompress_add_bit;
1115
11.8k
        decode_decompress_add_bit;
1116
11.8k
    } while (curr < end);
1117
370
#undef decode_decompress_add_bit
1118
370
}
1119
1120
/*
1121
 * FIPS 203, Section 4.2.1, Equation (4.7): Compress_d.
1122
 *
1123
 * Compresses (lossily) an input |x| mod 3329 into |bits| many bits by grouping
1124
 * numbers close to each other together. The formula used is
1125
 * round(2^|bits|/kPrime*x) mod 2^|bits|.
1126
 * Uses Barrett reduction to achieve constant time. Since we need both the
1127
 * remainder (for rounding) and the quotient (as the result), we cannot use
1128
 * |reduce| here, but need to do the Barrett reduction directly.
1129
 */
1130
static __owur uint16_t compress(uint16_t x, int bits)
1131
414k
{
1132
414k
    uint32_t shifted = (uint32_t)x << bits;
1133
414k
    uint64_t product = (uint64_t)shifted * kBarrettMultiplier;
1134
414k
    uint32_t quotient = (uint32_t)(product >> kBarrettShift);
1135
414k
    uint32_t remainder = shifted - quotient * kPrime;
1136
1137
    /*
1138
     * Adjust the quotient to round correctly:
1139
     *   0 <= remainder <= kHalfPrime round to 0
1140
     *   kHalfPrime < remainder <= kPrime + kHalfPrime round to 1
1141
     *   kPrime + kHalfPrime < remainder < 2 * kPrime round to 2
1142
     */
1143
414k
    quotient += 1 & constant_time_lt_32(kHalfPrime, remainder);
1144
414k
    quotient += 1 & constant_time_lt_32(kPrime + kHalfPrime, remainder);
1145
414k
    return quotient & ((1 << bits) - 1);
1146
414k
}
1147
1148
/*
1149
 * FIPS 203, Section 4.2.1, Equation (4.8): Decompress_d.
1150
1151
 * Decompresses |x| by using a close equi-distant representative. The formula
1152
 * is round(kPrime/2^|bits|*x). Note that 2^|bits| being the divisor allows us
1153
 * to implement this logic using only bit operations.
1154
 */
1155
static __owur uint16_t decompress(uint16_t x, int bits)
1156
172k
{
1157
172k
    uint32_t product = (uint32_t)x * kPrime;
1158
172k
    uint32_t power = 1 << bits;
1159
    /* This is |product| % power, since |power| is a power of 2. */
1160
172k
    uint32_t remainder = product & (power - 1);
1161
    /* This is |product| / power, since |power| is a power of 2. */
1162
172k
    uint32_t lower = product >> bits;
1163
1164
    /*
1165
     * The rounding logic works since the first half of numbers mod |power|
1166
     * have a 0 as first bit, and the second half has a 1 as first bit, since
1167
     * |power| is a power of 2. As a 12 bit number, |remainder| is always
1168
     * positive, so we will shift in 0s for a right shift.
1169
     */
1170
172k
    return lower + (remainder >> (bits - 1));
1171
172k
}
1172
1173
/*-
1174
 * FIPS 203, Section 4.2.1, Equation (4.7): "Compress_d".
1175
 * In-place lossy rounding of scalars to 2^d bits.
1176
 */
1177
static void scalar_compress(scalar *s, int bits)
1178
1.61k
{
1179
1.61k
    int i;
1180
1181
416k
    for (i = 0; i < DEGREE; i++)
1182
414k
        s->c[i] = compress(s->c[i], bits);
1183
1.61k
}
1184
1185
/*
1186
 * FIPS 203, Section 4.2.1, Equation (4.8): "Decompress_d".
1187
 * In-place approximate recovery of scalars from 2^d bit compression.
1188
 */
1189
static void scalar_decompress(scalar *s, int bits)
1190
675
{
1191
675
    int i;
1192
1193
173k
    for (i = 0; i < DEGREE; i++)
1194
172k
        s->c[i] = decompress(s->c[i], bits);
1195
675
}
1196
1197
/* Addition updating the LHS vector in-place. */
1198
static void vector_add(scalar *lhs, const scalar *rhs, int rank)
1199
370
{
1200
1.07k
    do {
1201
1.07k
        scalar_add(lhs++, rhs++);
1202
1.07k
    } while (--rank > 0);
1203
370
}
1204
1205
/*
1206
 * Encodes an entire vector into 32*|rank|*|bits| bytes. Note that since 256
1207
 * (DEGREE) is divisible by 8, the individual vector entries will always fill a
1208
 * whole number of bytes, so we do not need to worry about bit packing here.
1209
 */
1210
static void vector_encode(uint8_t *out, const scalar *a, int bits, int rank)
1211
51.7k
{
1212
51.7k
    int stride = bits * DEGREE / 8;
1213
1214
206k
    for (; rank-- > 0; out += stride)
1215
155k
        scalar_encode(out, a++, bits);
1216
51.7k
}
1217
1218
/*
1219
 * Decodes 32*|rank|*|bits| bytes from |in| into |out|. It returns early
1220
 * if any parsed value is >= |ML_KEM_PRIME|.  The resulting scalars are
1221
 * then decompressed and transformed via the NTT.
1222
 *
1223
 * Note: Used only in decrypt_cpa(), which returns void and so does not check
1224
 * the return value of this function.  Side-channels are fine when the input
1225
 * ciphertext to decap() is simply syntactically invalid.
1226
 */
1227
static void
1228
vector_decode_decompress_ntt(scalar *out, const uint8_t *in, int bits, int rank)
1229
173
{
1230
173
    int stride = bits * DEGREE / 8;
1231
1232
675
    for (; rank-- > 0; in += stride, ++out) {
1233
502
        scalar_decode(out, in, bits);
1234
502
        scalar_decompress(out, bits);
1235
502
        scalar_ntt(out);
1236
502
    }
1237
173
}
1238
1239
/* vector_decode(), specialised to bits == 12. */
1240
static __owur int vector_decode_12(scalar *out, const uint8_t in[3 * DEGREE / 2], int rank)
1241
670
{
1242
670
    int stride = 3 * DEGREE / 2;
1243
1244
1.60k
    for (; rank-- > 0; in += stride)
1245
1.26k
        if (!scalar_decode_12(out++, in))
1246
332
            return 0;
1247
338
    return 1;
1248
670
}
1249
1250
/* In-place compression of each scalar component */
1251
static void vector_compress(scalar *a, int bits, int rank)
1252
370
{
1253
1.07k
    do {
1254
1.07k
        scalar_compress(a++, bits);
1255
1.07k
    } while (--rank > 0);
1256
370
}
1257
1258
/* The output scalar must not overlap with the inputs */
1259
static void inner_product(scalar *out, const scalar *lhs, const scalar *rhs,
1260
    int rank)
1261
543
{
1262
543
    scalar_mult(out, lhs, rhs);
1263
1.57k
    while (--rank > 0)
1264
1.03k
        scalar_mult_add(out, ++lhs, ++rhs);
1265
543
}
1266
1267
/*
1268
 * Here, the output vector must not overlap with the inputs, the result is
1269
 * directly subjected to inverse NTT.
1270
 */
1271
static void
1272
matrix_mult_intt(scalar *out, const scalar *m, const scalar *a, int rank)
1273
370
{
1274
370
    const scalar *ar;
1275
370
    int i, j;
1276
1277
1.44k
    for (i = rank; i-- > 0; ++out) {
1278
1.07k
        scalar_mult(out, m++, ar = a);
1279
3.36k
        for (j = rank - 1; j > 0; --j)
1280
2.29k
            scalar_mult_add(out, m++, ++ar);
1281
1.07k
        scalar_inverse_ntt(out);
1282
1.07k
    }
1283
370
}
1284
1285
/* Here, the output vector must not overlap with the inputs */
1286
static void
1287
matrix_mult_transpose_add(scalar *out, const scalar *m, const scalar *a, int rank)
1288
51.3k
{
1289
51.3k
    const scalar *mc = m, *mr, *ar;
1290
51.3k
    int i, j;
1291
1292
205k
    for (i = rank; i-- > 0; ++out) {
1293
154k
        scalar_mult_add(out, mr = mc++, ar = a);
1294
462k
        for (j = rank; --j > 0;)
1295
308k
            scalar_mult_add(out, (mr += rank), ++ar);
1296
154k
    }
1297
51.3k
}
1298
1299
/*-
1300
 * Expands the matrix from a seed for key generation and for encaps-CPA.
1301
 * NOTE: FIPS 203 matrix "A" is the transpose of this matrix, computed
1302
 * by appending the (i,j) indices to the seed in the opposite order!
1303
 *
1304
 * Where FIPS 203 computes t = A * s + e, we use the transpose of "m".
1305
 */
1306
static __owur int matrix_expand(EVP_MD_CTX *mdctx, ML_KEM_KEY *key)
1307
51.7k
{
1308
51.7k
    scalar *out = key->m;
1309
51.7k
    uint8_t input[ML_KEM_RANDOM_BYTES + 2];
1310
51.7k
    int rank = key->vinfo->rank;
1311
51.7k
    int i, j;
1312
1313
51.7k
    memcpy(input, key->rho, ML_KEM_RANDOM_BYTES);
1314
206k
    for (i = 0; i < rank; i++) {
1315
619k
        for (j = 0; j < rank; j++) {
1316
464k
            input[ML_KEM_RANDOM_BYTES] = i;
1317
464k
            input[ML_KEM_RANDOM_BYTES + 1] = j;
1318
464k
            if (!EVP_DigestInit_ex(mdctx, key->shake128_md, NULL)
1319
464k
                || !EVP_DigestUpdate(mdctx, input, sizeof(input))
1320
464k
                || !sample_scalar(out++, mdctx))
1321
0
                return 0;
1322
464k
        }
1323
154k
    }
1324
51.7k
    return 1;
1325
51.7k
}
1326
1327
/*
1328
 * Algorithm 7 from the spec, with eta fixed to two and the PRF call
1329
 * included. Creates binominally distributed elements by sampling 2*|eta| bits,
1330
 * and setting the coefficient to the count of the first bits minus the count of
1331
 * the second bits, resulting in a centered binomial distribution. Since eta is
1332
 * two this gives -2/2 with a probability of 1/16, -1/1 with probability 1/4,
1333
 * and 0 with probability 3/8.
1334
 */
1335
static __owur int cbd_2(scalar *out, uint8_t in[ML_KEM_RANDOM_BYTES + 1],
1336
    EVP_MD_CTX *mdctx, const ML_KEM_KEY *key)
1337
308k
{
1338
308k
    uint16_t *curr = out->c, *end = curr + DEGREE;
1339
308k
    uint8_t randbuf[4 * DEGREE / 8], *r = randbuf; /* 64 * eta slots */
1340
308k
    uint16_t value, mask;
1341
308k
    uint8_t b;
1342
1343
308k
    if (!prf(randbuf, sizeof(randbuf), in, mdctx, key))
1344
0
        return 0;
1345
1346
39.5M
    do {
1347
39.5M
        b = *r++;
1348
1349
        /*
1350
         * Add |kPrime| if |value| underflowed.  See |constish_time_non_zero|
1351
         * for a discussion on why the value barrier is by default omitted.
1352
         * While this could have been written reduce_once(value + kPrime), this
1353
         * is one extra addition and small range of |value| tempts some
1354
         * versions of Clang to emit a branch.
1355
         */
1356
39.5M
        value = bit0(b) + bitn(1, b);
1357
39.5M
        value -= bitn(2, b) + bitn(3, b);
1358
39.5M
        mask = constish_time_non_zero(value >> 15);
1359
39.5M
        *curr++ = value + (kPrime & mask);
1360
1361
39.5M
        value = bitn(4, b) + bitn(5, b);
1362
39.5M
        value -= bitn(6, b) + bitn(7, b);
1363
39.5M
        mask = constish_time_non_zero(value >> 15);
1364
39.5M
        *curr++ = value + (kPrime & mask);
1365
39.5M
    } while (curr < end);
1366
308k
    return 1;
1367
308k
}
1368
1369
/*
1370
 * Algorithm 7 from the spec, with eta fixed to three and the PRF call
1371
 * included. Creates binominally distributed elements by sampling 3*|eta| bits,
1372
 * and setting the coefficient to the count of the first bits minus the count of
1373
 * the second bits, resulting in a centered binomial distribution.
1374
 */
1375
static __owur int cbd_3(scalar *out, uint8_t in[ML_KEM_RANDOM_BYTES + 1],
1376
    EVP_MD_CTX *mdctx, const ML_KEM_KEY *key)
1377
1.63k
{
1378
1.63k
    uint16_t *curr = out->c, *end = curr + DEGREE;
1379
1.63k
    uint8_t randbuf[6 * DEGREE / 8], *r = randbuf; /* 64 * eta slots */
1380
1.63k
    uint8_t b1, b2, b3;
1381
1.63k
    uint16_t value, mask;
1382
1383
1.63k
    if (!prf(randbuf, sizeof(randbuf), in, mdctx, key))
1384
0
        return 0;
1385
1386
104k
    do {
1387
104k
        b1 = *r++;
1388
104k
        b2 = *r++;
1389
104k
        b3 = *r++;
1390
1391
        /*
1392
         * Add |kPrime| if |value| underflowed.  See |constish_time_non_zero|
1393
         * for a discussion on why the value barrier is by default omitted.
1394
         * While this could have been written reduce_once(value + kPrime), this
1395
         * is one extra addition and small range of |value| tempts some
1396
         * versions of Clang to emit a branch.
1397
         */
1398
104k
        value = bit0(b1) + bitn(1, b1) + bitn(2, b1);
1399
104k
        value -= bitn(3, b1) + bitn(4, b1) + bitn(5, b1);
1400
104k
        mask = constish_time_non_zero(value >> 15);
1401
104k
        *curr++ = value + (kPrime & mask);
1402
1403
104k
        value = bitn(6, b1) + bitn(7, b1) + bit0(b2);
1404
104k
        value -= bitn(1, b2) + bitn(2, b2) + bitn(3, b2);
1405
104k
        mask = constish_time_non_zero(value >> 15);
1406
104k
        *curr++ = value + (kPrime & mask);
1407
1408
104k
        value = bitn(4, b2) + bitn(5, b2) + bitn(6, b2);
1409
104k
        value -= bitn(7, b2) + bit0(b3) + bitn(1, b3);
1410
104k
        mask = constish_time_non_zero(value >> 15);
1411
104k
        *curr++ = value + (kPrime & mask);
1412
1413
104k
        value = bitn(2, b3) + bitn(3, b3) + bitn(4, b3);
1414
104k
        value -= bitn(5, b3) + bitn(6, b3) + bitn(7, b3);
1415
104k
        mask = constish_time_non_zero(value >> 15);
1416
104k
        *curr++ = value + (kPrime & mask);
1417
104k
    } while (curr < end);
1418
1.63k
    return 1;
1419
1.63k
}
1420
1421
/*
1422
 * Generates a secret vector by using |cbd| with the given seed to generate
1423
 * scalar elements and incrementing |counter| for each slot of the vector.
1424
 */
1425
static __owur int gencbd_vector(scalar *out, CBD_FUNC cbd, uint8_t *counter,
1426
    const uint8_t seed[ML_KEM_RANDOM_BYTES], int rank,
1427
    EVP_MD_CTX *mdctx, const ML_KEM_KEY *key)
1428
370
{
1429
370
    uint8_t input[ML_KEM_RANDOM_BYTES + 1];
1430
1431
370
    memcpy(input, seed, ML_KEM_RANDOM_BYTES);
1432
1.07k
    do {
1433
1.07k
        input[ML_KEM_RANDOM_BYTES] = (*counter)++;
1434
1.07k
        if (!cbd(out++, input, mdctx, key))
1435
0
            return 0;
1436
1.07k
    } while (--rank > 0);
1437
370
    return 1;
1438
370
}
1439
1440
/*
1441
 * As above plus NTT transform.
1442
 */
1443
static __owur int gencbd_vector_ntt(scalar *out, CBD_FUNC cbd, uint8_t *counter,
1444
    const uint8_t seed[ML_KEM_RANDOM_BYTES], int rank,
1445
    EVP_MD_CTX *mdctx, const ML_KEM_KEY *key)
1446
103k
{
1447
103k
    uint8_t input[ML_KEM_RANDOM_BYTES + 1];
1448
1449
103k
    memcpy(input, seed, ML_KEM_RANDOM_BYTES);
1450
309k
    do {
1451
309k
        input[ML_KEM_RANDOM_BYTES] = (*counter)++;
1452
309k
        if (!cbd(out, input, mdctx, key))
1453
0
            return 0;
1454
309k
        scalar_ntt(out++);
1455
309k
    } while (--rank > 0);
1456
103k
    return 1;
1457
103k
}
1458
1459
/* The |ETA1| value for ML-KEM-512 is 3, the rest and all ETA2 values are 2. */
1460
34.9k
#define CBD1(evp_type) ((evp_type) == EVP_PKEY_ML_KEM_512 ? cbd_3 : cbd_2)
1461
1462
/*
1463
 * FIPS 203, Section 5.2, Algorithm 14: K-PKE.Encrypt.
1464
 *
1465
 * Encrypts a message with given randomness to the ciphertext in |out|. Without
1466
 * applying the Fujisaki-Okamoto transform this would not result in a CCA
1467
 * secure scheme, since lattice schemes are vulnerable to decryption failure
1468
 * oracles.
1469
 *
1470
 * The steps are re-ordered to make more efficient/localised use of storage.
1471
 *
1472
 * Note also that the input public key is assumed to hold a precomputed matrix
1473
 * |A| (our key->m, with the public key holding an expanded (16-bit per scalar
1474
 * coefficient) key->t vector).
1475
 *
1476
 * Caller passes storage in |tmp| for for two temporary vectors.
1477
 */
1478
static __owur int encrypt_cpa(uint8_t out[ML_KEM_SHARED_SECRET_BYTES],
1479
    const uint8_t message[DEGREE / 8],
1480
    const uint8_t r[ML_KEM_RANDOM_BYTES], scalar *tmp,
1481
    EVP_MD_CTX *mdctx, const ML_KEM_KEY *key)
1482
370
{
1483
370
    const ML_KEM_VINFO *vinfo = key->vinfo;
1484
370
    CBD_FUNC cbd_1 = CBD1(vinfo->evp_type);
1485
370
    int rank = vinfo->rank;
1486
    /* We can use tmp[0..rank-1] as storage for |y|, then |e1|, ... */
1487
370
    scalar *y = &tmp[0], *e1 = y, *e2 = y;
1488
    /* We can use tmp[rank]..tmp[2*rank - 1] for |u| */
1489
370
    scalar *u = &tmp[rank];
1490
370
    scalar v;
1491
370
    uint8_t input[ML_KEM_RANDOM_BYTES + 1];
1492
370
    uint8_t counter = 0;
1493
370
    int du = vinfo->du;
1494
370
    int dv = vinfo->dv;
1495
1496
    /* FIPS 203 "y" vector */
1497
370
    if (!gencbd_vector_ntt(y, cbd_1, &counter, r, rank, mdctx, key))
1498
0
        return 0;
1499
    /* FIPS 203 "v" scalar */
1500
370
    inner_product(&v, key->t, y, rank);
1501
370
    scalar_inverse_ntt(&v);
1502
    /* FIPS 203 "u" vector */
1503
370
    matrix_mult_intt(u, key->m, y, rank);
1504
1505
    /* All done with |y|, now free to reuse tmp[0] for FIPS 203 |e1| */
1506
370
    if (!gencbd_vector(e1, cbd_2, &counter, r, rank, mdctx, key))
1507
0
        return 0;
1508
370
    vector_add(u, e1, rank);
1509
370
    vector_compress(u, du, rank);
1510
370
    vector_encode(out, u, du, rank);
1511
1512
    /* All done with |e1|, now free to reuse tmp[0] for FIPS 203 |e2| */
1513
370
    memcpy(input, r, ML_KEM_RANDOM_BYTES);
1514
370
    input[ML_KEM_RANDOM_BYTES] = counter;
1515
370
    if (!cbd_2(e2, input, mdctx, key))
1516
0
        return 0;
1517
370
    scalar_add(&v, e2);
1518
1519
    /* Combine message with |v| */
1520
370
    scalar_decode_decompress_add(&v, message);
1521
370
    scalar_compress(&v, dv);
1522
370
    scalar_encode(out + vinfo->u_vector_bytes, &v, dv);
1523
370
    return 1;
1524
370
}
1525
1526
/*
1527
 * FIPS 203, Section 5.3, Algorithm 15: K-PKE.Decrypt.
1528
 */
1529
static void
1530
decrypt_cpa(uint8_t out[ML_KEM_SHARED_SECRET_BYTES],
1531
    const uint8_t *ctext, scalar *u, const ML_KEM_KEY *key)
1532
173
{
1533
173
    const ML_KEM_VINFO *vinfo = key->vinfo;
1534
173
    scalar v, mask;
1535
173
    int rank = vinfo->rank;
1536
173
    int du = vinfo->du;
1537
173
    int dv = vinfo->dv;
1538
1539
173
    vector_decode_decompress_ntt(u, ctext, du, rank);
1540
173
    scalar_decode(&v, ctext + vinfo->u_vector_bytes, dv);
1541
173
    scalar_decompress(&v, dv);
1542
173
    inner_product(&mask, key->s, u, rank);
1543
173
    scalar_inverse_ntt(&mask);
1544
173
    scalar_sub(&v, &mask);
1545
173
    scalar_compress(&v, 1);
1546
173
    scalar_encode_1(out, &v);
1547
173
}
1548
1549
/*-
1550
 * FIPS 203, Section 7.1, Algorithm 19: "ML-KEM.KeyGen".
1551
 * FIPS 203, Section 7.2, Algorithm 20: "ML-KEM.Encaps".
1552
 *
1553
 * Fills the |out| buffer with the |ek| output of "ML-KEM.KeyGen", or,
1554
 * equivalently, the |ek| input of "ML-KEM.Encaps", i.e. returns the
1555
 * wire-format of an ML-KEM public key.
1556
 */
1557
static void encode_pubkey(uint8_t *out, const ML_KEM_KEY *key)
1558
51.2k
{
1559
51.2k
    const uint8_t *rho = key->rho;
1560
51.2k
    const ML_KEM_VINFO *vinfo = key->vinfo;
1561
1562
51.2k
    vector_encode(out, key->t, 12, vinfo->rank);
1563
51.2k
    memcpy(out + vinfo->vector_bytes, rho, ML_KEM_RANDOM_BYTES);
1564
51.2k
}
1565
1566
/*-
1567
 * FIPS 203, Section 7.1, Algorithm 19: "ML-KEM.KeyGen".
1568
 *
1569
 * Fills the |out| buffer with the |dk| output of "ML-KEM.KeyGen".
1570
 * This matches the input format of parse_prvkey() below.
1571
 */
1572
static void encode_prvkey(uint8_t *out, const ML_KEM_KEY *key)
1573
138
{
1574
138
    const ML_KEM_VINFO *vinfo = key->vinfo;
1575
1576
138
    vector_encode(out, key->s, 12, vinfo->rank);
1577
138
    out += vinfo->vector_bytes;
1578
138
    encode_pubkey(out, key);
1579
138
    out += vinfo->pubkey_bytes;
1580
138
    memcpy(out, key->pkhash, ML_KEM_PKHASH_BYTES);
1581
138
    out += ML_KEM_PKHASH_BYTES;
1582
138
    memcpy(out, key->z, ML_KEM_RANDOM_BYTES);
1583
138
}
1584
1585
/*-
1586
 * FIPS 203, Section 7.1, Algorithm 19: "ML-KEM.KeyGen".
1587
 * FIPS 203, Section 7.2, Algorithm 20: "ML-KEM.Encaps".
1588
 *
1589
 * This function parses the |in| buffer as the |ek| output of "ML-KEM.KeyGen",
1590
 * or, equivalently, the |ek| input of "ML-KEM.Encaps", i.e. decodes the
1591
 * wire-format of the ML-KEM public key.
1592
 */
1593
static int parse_pubkey(const uint8_t *in, EVP_MD_CTX *mdctx, ML_KEM_KEY *key)
1594
563
{
1595
563
    const ML_KEM_VINFO *vinfo = key->vinfo;
1596
1597
    /* Decode and check |t| */
1598
563
    if (!vector_decode_12(key->t, in, vinfo->rank)) {
1599
238
        ERR_raise_data(ERR_LIB_PROV, PROV_R_INVALID_KEY,
1600
238
            "%s invalid public 't' vector",
1601
238
            vinfo->algorithm_name);
1602
238
        return 0;
1603
238
    }
1604
    /* Save the matrix |m| recovery seed |rho| */
1605
325
    memcpy(key->rho, in + vinfo->vector_bytes, ML_KEM_RANDOM_BYTES);
1606
    /*
1607
     * Pre-compute the public key hash, needed for both encap and decap.
1608
     * Also pre-compute the matrix expansion, stored with the public key.
1609
     */
1610
325
    if (!hash_h(key->pkhash, in, vinfo->pubkey_bytes, mdctx, key)
1611
325
        || !matrix_expand(mdctx, key)) {
1612
0
        ERR_raise_data(ERR_LIB_CRYPTO, ERR_R_INTERNAL_ERROR,
1613
0
            "internal error while parsing %s public key",
1614
0
            vinfo->algorithm_name);
1615
0
        return 0;
1616
0
    }
1617
325
    return 1;
1618
325
}
1619
1620
/*
1621
 * FIPS 203, Section 7.1, Algorithm 19: "ML-KEM.KeyGen".
1622
 *
1623
 * Parses the |in| buffer as a |dk| output of "ML-KEM.KeyGen".
1624
 * This matches the output format of encode_prvkey() above.
1625
 */
1626
static int parse_prvkey(const uint8_t *in, EVP_MD_CTX *mdctx, ML_KEM_KEY *key)
1627
107
{
1628
107
    const ML_KEM_VINFO *vinfo = key->vinfo;
1629
1630
    /* Decode and check |s|. */
1631
107
    if (!vector_decode_12(key->s, in, vinfo->rank)) {
1632
94
        ERR_raise_data(ERR_LIB_PROV, PROV_R_INVALID_KEY,
1633
94
            "%s invalid private 's' vector",
1634
94
            vinfo->algorithm_name);
1635
94
        return 0;
1636
94
    }
1637
13
    in += vinfo->vector_bytes;
1638
1639
13
    if (!parse_pubkey(in, mdctx, key))
1640
9
        return 0;
1641
4
    in += vinfo->pubkey_bytes;
1642
1643
    /* Check public key hash. */
1644
4
    if (memcmp(key->pkhash, in, ML_KEM_PKHASH_BYTES) != 0) {
1645
4
        ERR_raise_data(ERR_LIB_PROV, PROV_R_INVALID_KEY,
1646
4
            "%s public key hash mismatch",
1647
4
            vinfo->algorithm_name);
1648
4
        return 0;
1649
4
    }
1650
0
    in += ML_KEM_PKHASH_BYTES;
1651
1652
0
    memcpy(key->z, in, ML_KEM_RANDOM_BYTES);
1653
0
    return 1;
1654
4
}
1655
1656
/*
1657
 * FIPS 203, Section 6.1, Algorithm 16: "ML-KEM.KeyGen_internal".
1658
 *
1659
 * The implementation of Section 5.1, Algorithm 13, "K-PKE.KeyGen(d)" is
1660
 * inlined.
1661
 *
1662
 * The caller MUST pass a pre-allocated digest context that is not shared with
1663
 * any concurrent computation.
1664
 *
1665
 * This function optionally outputs the serialised wire-form |ek| public key
1666
 * into the provided |pubenc| buffer, and generates the content of the |rho|,
1667
 * |pkhash|, |t|, |m|, |s| and |z| components of the private |key| (which must
1668
 * have preallocated space for these).
1669
 *
1670
 * Keys are computed from a 32-byte random |d| plus the 1 byte rank for
1671
 * domain separation.  These are concatenated and hashed to produce a pair of
1672
 * 32-byte seeds public "rho", used to generate the matrix, and private "sigma",
1673
 * used to generate the secret vector |s|.
1674
 *
1675
 * The second random input |z| is copied verbatim into the Fujisaki-Okamoto
1676
 * (FO) transform "implicit-rejection" secret (the |z| component of the private
1677
 * key), which thwarts chosen-ciphertext attacks, provided decap() runs in
1678
 * constant time, with no side channel leaks, on all well-formed (valid length,
1679
 * and correctly encoded) ciphertext inputs.
1680
 */
1681
static __owur int genkey(const uint8_t seed[ML_KEM_SEED_BYTES],
1682
    EVP_MD_CTX *mdctx, uint8_t *pubenc, ML_KEM_KEY *key)
1683
34.5k
{
1684
34.5k
    uint8_t hashed[2 * ML_KEM_RANDOM_BYTES];
1685
34.5k
    const uint8_t *const sigma = hashed + ML_KEM_RANDOM_BYTES;
1686
34.5k
    uint8_t augmented_seed[ML_KEM_RANDOM_BYTES + 1];
1687
34.5k
    const ML_KEM_VINFO *vinfo = key->vinfo;
1688
34.5k
    CBD_FUNC cbd_1 = CBD1(vinfo->evp_type);
1689
34.5k
    int rank = vinfo->rank;
1690
34.5k
    uint8_t counter = 0;
1691
34.5k
    int ret = 0;
1692
1693
    /*
1694
     * Use the "d" seed salted with the rank to derive the public and private
1695
     * seeds rho and sigma.
1696
     */
1697
34.5k
    memcpy(augmented_seed, seed, ML_KEM_RANDOM_BYTES);
1698
34.5k
    augmented_seed[ML_KEM_RANDOM_BYTES] = (uint8_t)rank;
1699
34.5k
    if (!hash_g(hashed, augmented_seed, sizeof(augmented_seed), mdctx, key))
1700
0
        goto end;
1701
34.5k
    memcpy(key->rho, hashed, ML_KEM_RANDOM_BYTES);
1702
    /* The |rho| matrix seed is public */
1703
34.5k
    CONSTTIME_DECLASSIFY(key->rho, ML_KEM_RANDOM_BYTES);
1704
1705
    /* FIPS 203 |e| vector is initial value of key->t */
1706
34.5k
    if (!matrix_expand(mdctx, key)
1707
34.5k
        || !gencbd_vector_ntt(key->s, cbd_1, &counter, sigma, rank, mdctx, key)
1708
34.5k
        || !gencbd_vector_ntt(key->t, cbd_1, &counter, sigma, rank, mdctx, key))
1709
0
        goto end;
1710
1711
    /* To |e| we now add the product of transpose |m| and |s|, giving |t|. */
1712
34.5k
    matrix_mult_transpose_add(key->t, key->m, key->s, rank);
1713
    /* The |t| vector is public */
1714
34.5k
    CONSTTIME_DECLASSIFY(key->t, vinfo->rank * sizeof(scalar));
1715
1716
34.5k
    if (pubenc == NULL) {
1717
        /* Incremental digest of public key without in-full serialisation. */
1718
34.5k
        if (!hash_h_pubkey(key->pkhash, mdctx, key))
1719
0
            goto end;
1720
34.5k
    } else {
1721
0
        encode_pubkey(pubenc, key);
1722
0
        if (!hash_h(key->pkhash, pubenc, vinfo->pubkey_bytes, mdctx, key))
1723
0
            goto end;
1724
0
    }
1725
1726
    /* Save |z| portion of seed for "implicit rejection" on failure. */
1727
34.5k
    memcpy(key->z, seed + ML_KEM_RANDOM_BYTES, ML_KEM_RANDOM_BYTES);
1728
1729
    /* Optionally save the |d| portion of the seed */
1730
34.5k
    key->d = key->z + ML_KEM_RANDOM_BYTES;
1731
34.5k
    if (key->prov_flags & ML_KEM_KEY_RETAIN_SEED) {
1732
34.5k
        memcpy(key->d, seed, ML_KEM_RANDOM_BYTES);
1733
34.5k
    } else {
1734
0
        OPENSSL_cleanse(key->d, ML_KEM_RANDOM_BYTES);
1735
0
        key->d = NULL;
1736
0
    }
1737
1738
34.5k
    ret = 1;
1739
34.5k
end:
1740
34.5k
    OPENSSL_cleanse((void *)augmented_seed, ML_KEM_RANDOM_BYTES);
1741
34.5k
    OPENSSL_cleanse((void *)sigma, ML_KEM_RANDOM_BYTES);
1742
34.5k
    if (ret == 0) {
1743
0
        ERR_raise_data(ERR_LIB_CRYPTO, ERR_R_INTERNAL_ERROR,
1744
0
            "internal error while generating %s private key",
1745
0
            vinfo->algorithm_name);
1746
0
    }
1747
34.5k
    return ret;
1748
34.5k
}
1749
1750
/*-
1751
 * FIPS 203, Section 6.2, Algorithm 17: "ML-KEM.Encaps_internal".
1752
 * This is the deterministic version with randomness supplied externally.
1753
 *
1754
 * The caller must pass space for two vectors in |tmp|.
1755
 * The |ctext| buffer have space for the ciphertext of the ML-KEM variant
1756
 * of the provided key.
1757
 */
1758
static int encap(uint8_t *ctext, uint8_t secret[ML_KEM_SHARED_SECRET_BYTES],
1759
    const uint8_t entropy[ML_KEM_RANDOM_BYTES],
1760
    scalar *tmp, EVP_MD_CTX *mdctx, const ML_KEM_KEY *key)
1761
197
{
1762
197
    uint8_t input[ML_KEM_RANDOM_BYTES + ML_KEM_PKHASH_BYTES];
1763
197
    uint8_t Kr[ML_KEM_SHARED_SECRET_BYTES + ML_KEM_RANDOM_BYTES];
1764
197
    uint8_t *r = Kr + ML_KEM_SHARED_SECRET_BYTES;
1765
197
    int ret;
1766
1767
197
    memcpy(input, entropy, ML_KEM_RANDOM_BYTES);
1768
197
    memcpy(input + ML_KEM_RANDOM_BYTES, key->pkhash, ML_KEM_PKHASH_BYTES);
1769
197
    ret = hash_g(Kr, input, sizeof(input), mdctx, key)
1770
197
        && encrypt_cpa(ctext, entropy, r, tmp, mdctx, key);
1771
197
    OPENSSL_cleanse((void *)input, sizeof(input));
1772
1773
197
    if (ret)
1774
197
        memcpy(secret, Kr, ML_KEM_SHARED_SECRET_BYTES);
1775
0
    else
1776
0
        ERR_raise_data(ERR_LIB_CRYPTO, ERR_R_INTERNAL_ERROR,
1777
0
            "internal error while performing %s encapsulation",
1778
0
            key->vinfo->algorithm_name);
1779
197
    return ret;
1780
197
}
1781
1782
/*
1783
 * FIPS 203, Section 6.3, Algorithm 18: ML-KEM.Decaps_internal
1784
 *
1785
 * Barring failure of the supporting SHA3/SHAKE primitives, this is fully
1786
 * deterministic, the randomness for the FO transform is extracted during
1787
 * private key generation.
1788
 *
1789
 * The caller must pass space for two vectors in |tmp|.
1790
 * The |ctext| and |tmp_ctext| buffers must each have space for the ciphertext
1791
 * of the key's ML-KEM variant.
1792
 */
1793
static int decap(uint8_t secret[ML_KEM_SHARED_SECRET_BYTES],
1794
    const uint8_t *ctext, uint8_t *tmp_ctext, scalar *tmp,
1795
    EVP_MD_CTX *mdctx, const ML_KEM_KEY *key)
1796
110
{
1797
110
    uint8_t decrypted[ML_KEM_SHARED_SECRET_BYTES + ML_KEM_PKHASH_BYTES];
1798
110
    uint8_t failure_key[ML_KEM_RANDOM_BYTES];
1799
110
    uint8_t Kr[ML_KEM_SHARED_SECRET_BYTES + ML_KEM_RANDOM_BYTES];
1800
110
    uint8_t *r = Kr + ML_KEM_SHARED_SECRET_BYTES;
1801
110
    const uint8_t *pkhash = key->pkhash;
1802
110
    const ML_KEM_VINFO *vinfo = key->vinfo;
1803
110
    int i;
1804
110
    uint8_t mask;
1805
1806
    /*
1807
     * If our KDF is unavailable, fail early! Otherwise, keep going ignoring
1808
     * any further errors, returning success, and whatever we got for a shared
1809
     * secret.  The decrypt_cpa() function is just arithmetic on secret data,
1810
     * so should not be subject to failure that makes its output predictable.
1811
     *
1812
     * We guard against "should never happen" catastrophic failure of the
1813
     * "pure" function |hash_g| by overwriting the shared secret with the
1814
     * content of the failure key and returning early, if nevertheless hash_g
1815
     * fails.  This is not constant-time, but a failure of |hash_g| already
1816
     * implies loss of side-channel resistance.
1817
     *
1818
     * The same action is taken, if also |encrypt_cpa| should catastrophically
1819
     * fail, due to failure of the |PRF| underlying the CBD functions.
1820
     */
1821
110
    if (!kdf(failure_key, key->z, ctext, vinfo->ctext_bytes, mdctx, key)) {
1822
0
        ERR_raise_data(ERR_LIB_CRYPTO, ERR_R_INTERNAL_ERROR,
1823
0
            "internal error while performing %s decapsulation",
1824
0
            vinfo->algorithm_name);
1825
0
        return 0;
1826
0
    }
1827
110
    decrypt_cpa(decrypted, ctext, tmp, key);
1828
110
    memcpy(decrypted + ML_KEM_SHARED_SECRET_BYTES, pkhash, ML_KEM_PKHASH_BYTES);
1829
110
    if (!hash_g(Kr, decrypted, sizeof(decrypted), mdctx, key)
1830
110
        || !encrypt_cpa(tmp_ctext, decrypted, r, tmp, mdctx, key)) {
1831
0
        memcpy(secret, failure_key, ML_KEM_SHARED_SECRET_BYTES);
1832
0
        OPENSSL_cleanse(decrypted, ML_KEM_SHARED_SECRET_BYTES);
1833
0
        return 1;
1834
0
    }
1835
110
    mask = constant_time_eq_int_8(0,
1836
110
        CRYPTO_memcmp(ctext, tmp_ctext, vinfo->ctext_bytes));
1837
3.63k
    for (i = 0; i < ML_KEM_SHARED_SECRET_BYTES; i++)
1838
3.52k
        secret[i] = constant_time_select_8(mask, Kr[i], failure_key[i]);
1839
110
    OPENSSL_cleanse(decrypted, ML_KEM_SHARED_SECRET_BYTES);
1840
110
    OPENSSL_cleanse(Kr, sizeof(Kr));
1841
110
    return 1;
1842
110
}
1843
1844
/*
1845
 * After allocating storage for public or private key data, update the key
1846
 * component pointers to reference that storage.
1847
 */
1848
static __owur int add_storage(scalar *p, int private, ML_KEM_KEY *key)
1849
17.1k
{
1850
17.1k
    int rank = key->vinfo->rank;
1851
1852
17.1k
    if (p == NULL)
1853
0
        return 0;
1854
1855
    /*
1856
     * We're adding key material, the seed buffer will now hold |rho| and
1857
     * |pkhash|.
1858
     */
1859
17.1k
    memset(key->seedbuf, 0, sizeof(key->seedbuf));
1860
17.1k
    key->rho = key->seedbuf;
1861
17.1k
    key->pkhash = key->seedbuf + ML_KEM_RANDOM_BYTES;
1862
17.1k
    key->d = key->z = NULL;
1863
1864
    /* A public key needs space for |t| and |m| */
1865
17.1k
    key->m = (key->t = p) + rank;
1866
1867
    /*
1868
     * A private key also needs space for |s| and |z|.
1869
     * The |z| buffer always includes additional space for |d|, but a key's |d|
1870
     * pointer is left NULL when parsed from the NIST format, which omits that
1871
     * information.  Only keys generated from a (d, z) seed pair will have a
1872
     * non-NULL |d| pointer.
1873
     */
1874
17.1k
    if (private)
1875
16.9k
        key->z = (uint8_t *)(rank + (key->s = key->m + rank * rank));
1876
17.1k
    return 1;
1877
17.1k
}
1878
1879
/*
1880
 * After freeing the storage associated with a key that failed to be
1881
 * constructed, reset the internal pointers back to NULL.
1882
 */
1883
void ossl_ml_kem_key_reset(ML_KEM_KEY *key)
1884
17.2k
{
1885
17.2k
    if (key->t == NULL)
1886
134
        return;
1887
    /*-
1888
     * Cleanse any sensitive data:
1889
     * - The private vector |s| is immediately followed by the FO failure
1890
     *   secret |z|, and seed |d|, we can cleanse all three in one call.
1891
     *
1892
     * - Otherwise, when key->d is set, cleanse the stashed seed.
1893
     */
1894
17.1k
    if (ossl_ml_kem_have_prvkey(key))
1895
16.9k
        OPENSSL_cleanse(key->s,
1896
16.9k
            key->vinfo->rank * sizeof(scalar) + 2 * ML_KEM_RANDOM_BYTES);
1897
17.1k
    OPENSSL_free(key->t);
1898
17.1k
    key->d = key->z = (uint8_t *)(key->s = key->m = key->t = NULL);
1899
17.1k
}
1900
1901
/*
1902
 * ----- API exported to the provider
1903
 *
1904
 * Parameters with an implicit fixed length in the internal static API of each
1905
 * variant have an explicit checked length argument at this layer.
1906
 */
1907
1908
/* Retrieve the parameters of one of the ML-KEM variants */
1909
const ML_KEM_VINFO *ossl_ml_kem_get_vinfo(int evp_type)
1910
277k
{
1911
277k
    switch (evp_type) {
1912
60.0k
    case EVP_PKEY_ML_KEM_512:
1913
60.0k
        return &vinfo_map[ML_KEM_512_VINFO];
1914
161k
    case EVP_PKEY_ML_KEM_768:
1915
161k
        return &vinfo_map[ML_KEM_768_VINFO];
1916
56.4k
    case EVP_PKEY_ML_KEM_1024:
1917
56.4k
        return &vinfo_map[ML_KEM_1024_VINFO];
1918
277k
    }
1919
0
    return NULL;
1920
277k
}
1921
1922
ML_KEM_KEY *ossl_ml_kem_key_new(OSSL_LIB_CTX *libctx, const char *properties,
1923
    int evp_type)
1924
35.0k
{
1925
35.0k
    const ML_KEM_VINFO *vinfo = ossl_ml_kem_get_vinfo(evp_type);
1926
35.0k
    ML_KEM_KEY *key;
1927
1928
35.0k
    if (vinfo == NULL) {
1929
0
        ERR_raise_data(ERR_LIB_CRYPTO, ERR_R_PASSED_INVALID_ARGUMENT,
1930
0
            "unsupported ML-KEM key type: %d", evp_type);
1931
0
        return NULL;
1932
0
    }
1933
1934
35.0k
    if ((key = OPENSSL_malloc(sizeof(*key))) == NULL)
1935
0
        return NULL;
1936
1937
35.0k
    key->vinfo = vinfo;
1938
35.0k
    key->libctx = libctx;
1939
35.0k
    key->prov_flags = ML_KEM_KEY_PROV_FLAGS_DEFAULT;
1940
35.0k
    key->shake128_md = EVP_MD_fetch(libctx, "SHAKE128", properties);
1941
35.0k
    key->shake256_md = EVP_MD_fetch(libctx, "SHAKE256", properties);
1942
35.0k
    key->sha3_256_md = EVP_MD_fetch(libctx, "SHA3-256", properties);
1943
35.0k
    key->sha3_512_md = EVP_MD_fetch(libctx, "SHA3-512", properties);
1944
35.0k
    key->d = key->z = key->rho = key->pkhash = key->encoded_dk = NULL;
1945
35.0k
    key->s = key->m = key->t = NULL;
1946
1947
35.0k
    if (key->shake128_md != NULL
1948
35.0k
        && key->shake256_md != NULL
1949
35.0k
        && key->sha3_256_md != NULL
1950
35.0k
        && key->sha3_512_md != NULL)
1951
35.0k
        return key;
1952
1953
0
    ossl_ml_kem_key_free(key);
1954
0
    ERR_raise_data(ERR_LIB_CRYPTO, ERR_R_INTERNAL_ERROR,
1955
0
        "missing SHA3 digest algorithms while creating %s key",
1956
0
        vinfo->algorithm_name);
1957
0
    return NULL;
1958
35.0k
}
1959
1960
ML_KEM_KEY *ossl_ml_kem_key_dup(const ML_KEM_KEY *key, int selection)
1961
27
{
1962
27
    int ok = 0;
1963
27
    ML_KEM_KEY *ret;
1964
1965
    /*
1966
     * Partially decoded keys, not yet imported or loaded, should never be
1967
     * duplicated.
1968
     */
1969
27
    if (ossl_ml_kem_decoded_key(key))
1970
0
        return NULL;
1971
1972
27
    if (key == NULL
1973
27
        || (ret = OPENSSL_memdup(key, sizeof(*key))) == NULL)
1974
0
        return NULL;
1975
27
    ret->d = ret->z = ret->rho = ret->pkhash = NULL;
1976
27
    ret->s = ret->m = ret->t = NULL;
1977
1978
    /* Clear selection bits we can't fulfill */
1979
27
    if (!ossl_ml_kem_have_pubkey(key))
1980
0
        selection = 0;
1981
27
    else if (!ossl_ml_kem_have_prvkey(key))
1982
27
        selection &= ~OSSL_KEYMGMT_SELECT_PRIVATE_KEY;
1983
1984
27
    switch (selection & OSSL_KEYMGMT_SELECT_KEYPAIR) {
1985
0
    case 0:
1986
0
        ok = 1;
1987
0
        break;
1988
27
    case OSSL_KEYMGMT_SELECT_PUBLIC_KEY:
1989
27
        ok = add_storage(OPENSSL_memdup(key->t, key->vinfo->puballoc), 0, ret);
1990
27
        ret->rho = ret->seedbuf;
1991
27
        ret->pkhash = ret->rho + ML_KEM_RANDOM_BYTES;
1992
27
        break;
1993
0
    case OSSL_KEYMGMT_SELECT_PRIVATE_KEY:
1994
0
        ok = add_storage(OPENSSL_memdup(key->t, key->vinfo->prvalloc), 1, ret);
1995
        /* Duplicated keys retain |d|, if available */
1996
0
        if (key->d != NULL)
1997
0
            ret->d = ret->z + ML_KEM_RANDOM_BYTES;
1998
0
        break;
1999
27
    }
2000
2001
27
    if (!ok) {
2002
0
        OPENSSL_free(ret);
2003
0
        return NULL;
2004
0
    }
2005
2006
27
    EVP_MD_up_ref(ret->shake128_md);
2007
27
    EVP_MD_up_ref(ret->shake256_md);
2008
27
    EVP_MD_up_ref(ret->sha3_256_md);
2009
27
    EVP_MD_up_ref(ret->sha3_512_md);
2010
2011
27
    return ret;
2012
27
}
2013
2014
void ossl_ml_kem_key_free(ML_KEM_KEY *key)
2015
89.7k
{
2016
89.7k
    if (key == NULL)
2017
72.6k
        return;
2018
2019
17.1k
    EVP_MD_free(key->shake128_md);
2020
17.1k
    EVP_MD_free(key->shake256_md);
2021
17.1k
    EVP_MD_free(key->sha3_256_md);
2022
17.1k
    EVP_MD_free(key->sha3_512_md);
2023
2024
17.1k
    if (ossl_ml_kem_decoded_key(key)) {
2025
0
        OPENSSL_cleanse(key->seedbuf, sizeof(key->seedbuf));
2026
0
        if (ossl_ml_kem_have_dkenc(key)) {
2027
0
            OPENSSL_cleanse(key->encoded_dk, key->vinfo->prvkey_bytes);
2028
0
            OPENSSL_free(key->encoded_dk);
2029
0
        }
2030
0
    }
2031
17.1k
    ossl_ml_kem_key_reset(key);
2032
17.1k
    OPENSSL_free(key);
2033
17.1k
}
2034
2035
/* Serialise the public component of an ML-KEM key */
2036
int ossl_ml_kem_encode_public_key(uint8_t *out, size_t len,
2037
    const ML_KEM_KEY *key)
2038
51.1k
{
2039
51.1k
    if (!ossl_ml_kem_have_pubkey(key)
2040
51.1k
        || len != key->vinfo->pubkey_bytes)
2041
0
        return 0;
2042
51.1k
    encode_pubkey(out, key);
2043
51.1k
    return 1;
2044
51.1k
}
2045
2046
/* Serialise an ML-KEM private key */
2047
int ossl_ml_kem_encode_private_key(uint8_t *out, size_t len,
2048
    const ML_KEM_KEY *key)
2049
138
{
2050
138
    if (!ossl_ml_kem_have_prvkey(key)
2051
138
        || len != key->vinfo->prvkey_bytes)
2052
0
        return 0;
2053
138
    encode_prvkey(out, key);
2054
138
    return 1;
2055
138
}
2056
2057
int ossl_ml_kem_encode_seed(uint8_t *out, size_t len,
2058
    const ML_KEM_KEY *key)
2059
226
{
2060
226
    if (key == NULL || key->d == NULL || len != ML_KEM_SEED_BYTES)
2061
28
        return 0;
2062
    /*
2063
     * Both in the seed buffer, and in the allocated storage, the |d| component
2064
     * of the seed is stored last, so we must copy each separately.
2065
     */
2066
198
    memcpy(out, key->d, ML_KEM_RANDOM_BYTES);
2067
198
    out += ML_KEM_RANDOM_BYTES;
2068
198
    memcpy(out, key->z, ML_KEM_RANDOM_BYTES);
2069
198
    return 1;
2070
226
}
2071
2072
/*
2073
 * Stash the seed without (yet) performing a keygen, used during decoding, to
2074
 * avoid an extra keygen if we're only going to export the key again to load
2075
 * into another provider.
2076
 */
2077
ML_KEM_KEY *ossl_ml_kem_set_seed(const uint8_t *seed, size_t seedlen, ML_KEM_KEY *key)
2078
32
{
2079
32
    if (key == NULL
2080
32
        || ossl_ml_kem_have_pubkey(key)
2081
32
        || ossl_ml_kem_have_seed(key)
2082
32
        || seedlen != ML_KEM_SEED_BYTES)
2083
0
        return NULL;
2084
    /*
2085
     * With no public or private key material on hand, we can use the seed
2086
     * buffer for |z| and |d|, in that order.
2087
     */
2088
32
    key->z = key->seedbuf;
2089
32
    key->d = key->z + ML_KEM_RANDOM_BYTES;
2090
32
    memcpy(key->d, seed, ML_KEM_RANDOM_BYTES);
2091
32
    seed += ML_KEM_RANDOM_BYTES;
2092
32
    memcpy(key->z, seed, ML_KEM_RANDOM_BYTES);
2093
32
    return key;
2094
32
}
2095
2096
/* Parse input as a public key */
2097
int ossl_ml_kem_parse_public_key(const uint8_t *in, size_t len, ML_KEM_KEY *key)
2098
550
{
2099
550
    EVP_MD_CTX *mdctx = NULL;
2100
550
    const ML_KEM_VINFO *vinfo;
2101
550
    int ret = 0;
2102
2103
    /* Keys with key material are immutable */
2104
550
    if (key == NULL
2105
550
        || ossl_ml_kem_have_pubkey(key)
2106
550
        || ossl_ml_kem_have_dkenc(key))
2107
0
        return 0;
2108
550
    vinfo = key->vinfo;
2109
2110
550
    if (len != vinfo->pubkey_bytes
2111
550
        || (mdctx = EVP_MD_CTX_new()) == NULL)
2112
0
        return 0;
2113
2114
550
    if (add_storage(OPENSSL_malloc(vinfo->puballoc), 0, key))
2115
550
        ret = parse_pubkey(in, mdctx, key);
2116
2117
550
    if (!ret)
2118
229
        ossl_ml_kem_key_reset(key);
2119
550
    EVP_MD_CTX_free(mdctx);
2120
550
    return ret;
2121
550
}
2122
2123
/* Parse input as a new private key */
2124
int ossl_ml_kem_parse_private_key(const uint8_t *in, size_t len,
2125
    ML_KEM_KEY *key)
2126
107
{
2127
107
    EVP_MD_CTX *mdctx = NULL;
2128
107
    const ML_KEM_VINFO *vinfo;
2129
107
    int ret = 0;
2130
2131
    /* Keys with key material are immutable */
2132
107
    if (key == NULL
2133
107
        || ossl_ml_kem_have_pubkey(key)
2134
107
        || ossl_ml_kem_have_dkenc(key))
2135
0
        return 0;
2136
107
    vinfo = key->vinfo;
2137
2138
107
    if (len != vinfo->prvkey_bytes
2139
107
        || (mdctx = EVP_MD_CTX_new()) == NULL)
2140
0
        return 0;
2141
2142
107
    if (add_storage(OPENSSL_malloc(vinfo->prvalloc), 1, key))
2143
107
        ret = parse_prvkey(in, mdctx, key);
2144
2145
107
    if (!ret)
2146
107
        ossl_ml_kem_key_reset(key);
2147
107
    EVP_MD_CTX_free(mdctx);
2148
107
    return ret;
2149
107
}
2150
2151
/*
2152
 * Generate a new keypair, either from the saved seed (when non-null), or from
2153
 * the RNG.
2154
 */
2155
int ossl_ml_kem_genkey(uint8_t *pubenc, size_t publen, ML_KEM_KEY *key)
2156
51.3k
{
2157
51.3k
    uint8_t seed[ML_KEM_SEED_BYTES];
2158
51.3k
    EVP_MD_CTX *mdctx = NULL;
2159
51.3k
    const ML_KEM_VINFO *vinfo;
2160
51.3k
    int ret = 0;
2161
2162
51.3k
    if (key == NULL
2163
51.3k
        || ossl_ml_kem_have_pubkey(key)
2164
51.3k
        || ossl_ml_kem_have_dkenc(key))
2165
0
        return 0;
2166
51.3k
    vinfo = key->vinfo;
2167
2168
51.3k
    if (pubenc != NULL && publen != vinfo->pubkey_bytes)
2169
0
        return 0;
2170
2171
51.3k
    if (ossl_ml_kem_have_seed(key)) {
2172
60
        if (!ossl_ml_kem_encode_seed(seed, sizeof(seed), key))
2173
0
            return 0;
2174
60
        key->d = key->z = NULL;
2175
51.3k
    } else if (RAND_priv_bytes_ex(key->libctx, seed, sizeof(seed),
2176
51.3k
                   key->vinfo->secbits)
2177
51.3k
        <= 0) {
2178
0
        return 0;
2179
0
    }
2180
2181
51.3k
    if ((mdctx = EVP_MD_CTX_new()) == NULL)
2182
0
        return 0;
2183
2184
    /*
2185
     * Data derived from (d, z) defaults secret, and to avoid side-channel
2186
     * leaks should not influence control flow.
2187
     */
2188
51.3k
    CONSTTIME_SECRET(seed, ML_KEM_SEED_BYTES);
2189
2190
51.3k
    if (add_storage(OPENSSL_malloc(vinfo->prvalloc), 1, key))
2191
51.3k
        ret = genkey(seed, mdctx, pubenc, key);
2192
51.3k
    OPENSSL_cleanse(seed, sizeof(seed));
2193
2194
    /* Declassify secret inputs and derived outputs before returning control */
2195
51.3k
    CONSTTIME_DECLASSIFY(seed, ML_KEM_SEED_BYTES);
2196
2197
51.3k
    EVP_MD_CTX_free(mdctx);
2198
51.3k
    if (!ret) {
2199
0
        ossl_ml_kem_key_reset(key);
2200
0
        return 0;
2201
0
    }
2202
2203
    /* The public components are already declassified */
2204
51.3k
    CONSTTIME_DECLASSIFY(key->s, vinfo->rank * sizeof(scalar));
2205
51.3k
    CONSTTIME_DECLASSIFY(key->z, 2 * ML_KEM_RANDOM_BYTES);
2206
51.3k
    return 1;
2207
51.3k
}
2208
2209
/*
2210
 * FIPS 203, Section 6.2, Algorithm 17: ML-KEM.Encaps_internal
2211
 * This is the deterministic version with randomness supplied externally.
2212
 */
2213
int ossl_ml_kem_encap_seed(uint8_t *ctext, size_t clen,
2214
    uint8_t *shared_secret, size_t slen,
2215
    const uint8_t *entropy, size_t elen,
2216
    const ML_KEM_KEY *key)
2217
197
{
2218
197
    const ML_KEM_VINFO *vinfo;
2219
197
    EVP_MD_CTX *mdctx;
2220
197
    int ret = 0;
2221
2222
197
    if (key == NULL || !ossl_ml_kem_have_pubkey(key))
2223
0
        return 0;
2224
197
    vinfo = key->vinfo;
2225
2226
197
    if (ctext == NULL || clen != vinfo->ctext_bytes
2227
197
        || shared_secret == NULL || slen != ML_KEM_SHARED_SECRET_BYTES
2228
197
        || entropy == NULL || elen != ML_KEM_RANDOM_BYTES
2229
197
        || (mdctx = EVP_MD_CTX_new()) == NULL)
2230
0
        return 0;
2231
    /*
2232
     * Data derived from the encap entropy defaults secret, and to avoid
2233
     * side-channel leaks should not influence control flow.
2234
     */
2235
197
    CONSTTIME_SECRET(entropy, elen);
2236
2237
    /*-
2238
     * This avoids the need to handle allocation failures for two (max 2KB
2239
     * each) vectors, that are never retained on return from this function.
2240
     * We stack-allocate these.
2241
     */
2242
197
#define case_encap_seed(bits)                                        \
2243
197
    case EVP_PKEY_ML_KEM_##bits: {                                   \
2244
197
        scalar tmp[2 * ML_KEM_##bits##_RANK];                        \
2245
197
                                                                     \
2246
197
        ret = encap(ctext, shared_secret, entropy, tmp, mdctx, key); \
2247
197
        OPENSSL_cleanse((void *)tmp, sizeof(tmp));                   \
2248
197
        break;                                                       \
2249
197
    }
2250
197
    switch (vinfo->evp_type) {
2251
69
        case_encap_seed(512);
2252
76
        case_encap_seed(768);
2253
52
        case_encap_seed(1024);
2254
197
    }
2255
197
#undef case_encap_seed
2256
2257
    /* Declassify secret inputs and derived outputs before returning control */
2258
197
    CONSTTIME_DECLASSIFY(entropy, elen);
2259
197
    CONSTTIME_DECLASSIFY(ctext, clen);
2260
197
    CONSTTIME_DECLASSIFY(shared_secret, slen);
2261
2262
197
    EVP_MD_CTX_free(mdctx);
2263
197
    return ret;
2264
197
}
2265
2266
int ossl_ml_kem_encap_rand(uint8_t *ctext, size_t clen,
2267
    uint8_t *shared_secret, size_t slen,
2268
    const ML_KEM_KEY *key)
2269
197
{
2270
197
    uint8_t r[ML_KEM_RANDOM_BYTES];
2271
2272
197
    if (key == NULL)
2273
0
        return 0;
2274
2275
197
    if (RAND_bytes_ex(key->libctx, r, ML_KEM_RANDOM_BYTES,
2276
197
            key->vinfo->secbits)
2277
197
        < 1)
2278
0
        return 0;
2279
2280
197
    return ossl_ml_kem_encap_seed(ctext, clen, shared_secret, slen,
2281
197
        r, sizeof(r), key);
2282
197
}
2283
2284
int ossl_ml_kem_decap(uint8_t *shared_secret, size_t slen,
2285
    const uint8_t *ctext, size_t clen,
2286
    const ML_KEM_KEY *key)
2287
173
{
2288
173
    const ML_KEM_VINFO *vinfo;
2289
173
    EVP_MD_CTX *mdctx;
2290
173
    int ret = 0;
2291
#if defined(OPENSSL_CONSTANT_TIME_VALIDATION)
2292
    int classify_bytes;
2293
#endif
2294
2295
    /* Need a private key here */
2296
173
    if (!ossl_ml_kem_have_prvkey(key))
2297
0
        return 0;
2298
173
    vinfo = key->vinfo;
2299
2300
173
    if (shared_secret == NULL || slen != ML_KEM_SHARED_SECRET_BYTES
2301
173
        || ctext == NULL || clen != vinfo->ctext_bytes
2302
173
        || (mdctx = EVP_MD_CTX_new()) == NULL) {
2303
0
        (void)RAND_bytes_ex(key->libctx, shared_secret,
2304
0
            ML_KEM_SHARED_SECRET_BYTES, vinfo->secbits);
2305
0
        return 0;
2306
0
    }
2307
#if defined(OPENSSL_CONSTANT_TIME_VALIDATION)
2308
    /*
2309
     * Data derived from |s| and |z| defaults secret, and to avoid side-channel
2310
     * leaks should not influence control flow.
2311
     */
2312
    classify_bytes = 2 * sizeof(scalar) + ML_KEM_RANDOM_BYTES;
2313
    CONSTTIME_SECRET(key->s, classify_bytes);
2314
#endif
2315
2316
    /*-
2317
     * This avoids the need to handle allocation failures for two (max 2KB
2318
     * each) vectors and an encoded ciphertext (max 1568 bytes), that are never
2319
     * retained on return from this function.
2320
     * We stack-allocate these.
2321
     */
2322
173
#define case_decap(bits)                                          \
2323
173
    case EVP_PKEY_ML_KEM_##bits: {                                \
2324
173
        uint8_t cbuf[CTEXT_BYTES(bits)];                          \
2325
173
        scalar tmp[2 * ML_KEM_##bits##_RANK];                     \
2326
173
                                                                  \
2327
173
        ret = decap(shared_secret, ctext, cbuf, tmp, mdctx, key); \
2328
173
        OPENSSL_cleanse((void *)tmp, sizeof(tmp));                \
2329
173
        break;                                                    \
2330
173
    }
2331
173
    switch (vinfo->evp_type) {
2332
69
        case_decap(512);
2333
52
        case_decap(768);
2334
52
        case_decap(1024);
2335
173
    }
2336
2337
    /* Declassify secret inputs and derived outputs before returning control */
2338
173
    CONSTTIME_DECLASSIFY(key->s, classify_bytes);
2339
173
    CONSTTIME_DECLASSIFY(shared_secret, slen);
2340
173
    EVP_MD_CTX_free(mdctx);
2341
2342
173
    return ret;
2343
173
#undef case_decap
2344
173
}
2345
2346
int ossl_ml_kem_pubkey_cmp(const ML_KEM_KEY *key1, const ML_KEM_KEY *key2)
2347
159
{
2348
    /*
2349
     * This handles any unexpected differences in the ML-KEM variant rank,
2350
     * giving different key component structures, barring SHA3-256 hash
2351
     * collisions, the keys are the same size.
2352
     */
2353
159
    if (ossl_ml_kem_have_pubkey(key1) && ossl_ml_kem_have_pubkey(key2))
2354
159
        return memcmp(key1->pkhash, key2->pkhash, ML_KEM_PKHASH_BYTES) == 0;
2355
2356
    /*
2357
     * No match if just one of the public keys is not available, otherwise both
2358
     * are unavailable, and for now such keys are considered equal.
2359
     */
2360
0
    return (!(ossl_ml_kem_have_pubkey(key1) ^ ossl_ml_kem_have_pubkey(key2)));
2361
159
}