Coverage Report

Created: 2025-12-31 06:58

next uncovered line (L), next uncovered region (R), next uncovered branch (B)
/src/openssl35/crypto/ml_kem/ml_kem.c
Line
Count
Source
1
/*
2
 * Copyright 2024-2025 The OpenSSL Project Authors. All Rights Reserved.
3
 *
4
 * Licensed under the Apache License 2.0 (the "License").  You may not use
5
 * this file except in compliance with the License.  You can obtain a copy
6
 * in the file LICENSE in the source distribution or at
7
 * https://www.openssl.org/source/license.html
8
 */
9
10
#include <openssl/byteorder.h>
11
#include <openssl/rand.h>
12
#include <openssl/proverr.h>
13
#include "crypto/ml_kem.h"
14
#include "internal/common.h"
15
#include "internal/constant_time.h"
16
#include "internal/sha3.h"
17
18
#if defined(OPENSSL_CONSTANT_TIME_VALIDATION)
19
#include <valgrind/memcheck.h>
20
#endif
21
22
#if ML_KEM_SEED_BYTES != ML_KEM_SHARED_SECRET_BYTES + ML_KEM_RANDOM_BYTES
23
#error "ML-KEM keygen seed length != shared secret + random bytes length"
24
#endif
25
#if ML_KEM_SHARED_SECRET_BYTES != ML_KEM_RANDOM_BYTES
26
#error "Invalid unequal lengths of ML-KEM shared secret and random inputs"
27
#endif
28
29
#if UINT_MAX < UINT32_MAX
30
#error "Unsupported compiler: sizeof(unsigned int) < sizeof(uint32_t)"
31
#endif
32
33
/* Handy function-like bit-extraction macros */
34
38.5M
#define bit0(b) ((b) & 1)
35
269M
#define bitn(n, b) (((b) >> n) & 1)
36
37
/*
38
 * 12 bits are sufficient to losslessly represent values in [0, q-1].
39
 * INVERSE_DEGREE is (n/2)^-1 mod q; used in inverse NTT.
40
 */
41
2.89M
#define DEGREE ML_KEM_DEGREE
42
#define INVERSE_DEGREE (ML_KEM_PRIME - 2 * 13)
43
#define LOG2PRIME 12
44
#define BARRETT_SHIFT (2 * LOG2PRIME)
45
46
#ifdef SHA3_BLOCKSIZE
47
#define SHAKE128_BLOCKSIZE SHA3_BLOCKSIZE(128)
48
#endif
49
50
/*
51
 * Return whether a value that can only be 0 or 1 is non-zero, in constant time
52
 * in practice!  The return value is a mask that is all ones if true, and all
53
 * zeros otherwise (twos-complement arithmetic assumed for unsigned values).
54
 *
55
 * Although this is used in constant-time selects, we omit a value barrier
56
 * here.  Value barriers impede auto-vectorization (likely because it forces
57
 * the value to transit through a general-purpose register). On AArch64, this
58
 * is a difference of 2x.
59
 *
60
 * We usually add value barriers to selects because Clang turns consecutive
61
 * selects with the same condition into a branch instead of CMOV/CSEL. This
62
 * condition does not occur in Kyber, so omitting it seems to be safe so far,
63
 * but see |cbd_2|, |cbd_3|, where reduction needs to be specialised to the
64
 * sign of the input, rather than adding |q| in advance, and using the generic
65
 * |reduce_once|.  (David Benjamin, Chromium)
66
 */
67
#if 0
68
#define constish_time_non_zero(b) (~constant_time_is_zero(b));
69
#else
70
1.05G
#define constish_time_non_zero(b) (0u - (b))
71
#endif
72
73
/*
74
 * The scalar rejection-sampling buffer size needs to be a multiple of 12, but
75
 * is otherwise arbitrary, the preferred block size matches the internal buffer
76
 * size of SHAKE128, avoiding internal buffering and copying in SHAKE128. That
77
 * block size of (1600 - 256)/8 bytes, or 168, just happens to divide by 12!
78
 *
79
 * If the blocksize is unknown, or is not divisible by 12, 168 is used as a
80
 * fallback.
81
 */
82
#if defined(SHAKE128_BLOCKSIZE) && (SHAKE128_BLOCKSIZE) % 12 == 0
83
#define SCALAR_SAMPLING_BUFSIZE (SHAKE128_BLOCKSIZE)
84
#else
85
#define SCALAR_SAMPLING_BUFSIZE 168
86
#endif
87
88
/*
89
 * Structure of keys
90
 */
91
typedef struct ossl_ml_kem_scalar_st {
92
    /* On every function entry and exit, 0 <= c[i] < ML_KEM_PRIME. */
93
    uint16_t c[ML_KEM_DEGREE];
94
} scalar;
95
96
/* Key material allocation layout */
97
#define DECLARE_ML_KEM_KEYDATA(name, rank, private_sz)               \
98
    struct name##_alloc {                                            \
99
        /* Public vector |t| */                                      \
100
        scalar tbuf[(rank)];                                         \
101
        /* Pre-computed matrix |m| (FIPS 203 |A| transpose) */       \
102
        scalar mbuf[(rank) * (rank)] /* optional private key data */ \
103
            private_sz                                               \
104
    }
105
106
/* Declare variant-specific public and private storage */
107
#define DECLARE_ML_KEM_VARIANT_KEYDATA(bits)                        \
108
    DECLARE_ML_KEM_KEYDATA(pubkey_##bits, ML_KEM_##bits##_RANK, ;); \
109
    DECLARE_ML_KEM_KEYDATA(prvkey_##bits, ML_KEM_##bits##_RANK, ; scalar sbuf[ML_KEM_##bits##_RANK]; uint8_t zbuf[2 * ML_KEM_RANDOM_BYTES];)
110
DECLARE_ML_KEM_VARIANT_KEYDATA(512);
111
DECLARE_ML_KEM_VARIANT_KEYDATA(768);
112
DECLARE_ML_KEM_VARIANT_KEYDATA(1024);
113
#undef DECLARE_ML_KEM_VARIANT_KEYDATA
114
#undef DECLARE_ML_KEM_KEYDATA
115
116
typedef __owur int (*CBD_FUNC)(scalar *out, uint8_t in[ML_KEM_RANDOM_BYTES + 1],
117
    EVP_MD_CTX *mdctx, const ML_KEM_KEY *key);
118
static void scalar_encode(uint8_t *out, const scalar *s, int bits);
119
120
/*
121
 * The wire-form of a losslessly encoded vector uses 12-bits per element.
122
 *
123
 * The wire-form public key consists of the lossless encoding of the public
124
 * vector |t|, followed by the public seed |rho|.
125
 *
126
 * Our serialised private key concatenates serialisations of the private vector
127
 * |s|, the public key, the public key hash, and the failure secret |z|.
128
 */
129
#define VECTOR_BYTES(b) ((3 * DEGREE / 2) * ML_KEM_##b##_RANK)
130
#define PUBKEY_BYTES(b) (VECTOR_BYTES(b) + ML_KEM_RANDOM_BYTES)
131
#define PRVKEY_BYTES(b) (2 * PUBKEY_BYTES(b) + ML_KEM_PKHASH_BYTES)
132
133
/*
134
 * Encapsulation produces a vector "u" and a scalar "v", whose coordinates
135
 * (numbers modulo the ML-KEM prime "q") are lossily encoded using as "du" and
136
 * "dv" bits, respectively.  This encoding is the ciphertext input for
137
 * decapsulation.
138
 */
139
#define U_VECTOR_BYTES(b) ((DEGREE / 8) * ML_KEM_##b##_DU * ML_KEM_##b##_RANK)
140
#define V_SCALAR_BYTES(b) ((DEGREE / 8) * ML_KEM_##b##_DV)
141
#define CTEXT_BYTES(b) (U_VECTOR_BYTES(b) + V_SCALAR_BYTES(b))
142
143
#if defined(OPENSSL_CONSTANT_TIME_VALIDATION)
144
145
/*
146
 * CONSTTIME_SECRET takes a pointer and a number of bytes and marks that region
147
 * of memory as secret. Secret data is tracked as it flows to registers and
148
 * other parts of a memory. If secret data is used as a condition for a branch,
149
 * or as a memory index, it will trigger warnings in valgrind.
150
 */
151
#define CONSTTIME_SECRET(ptr, len) VALGRIND_MAKE_MEM_UNDEFINED(ptr, len)
152
153
/*
154
 * CONSTTIME_DECLASSIFY takes a pointer and a number of bytes and marks that
155
 * region of memory as public. Public data is not subject to constant-time
156
 * rules.
157
 */
158
#define CONSTTIME_DECLASSIFY(ptr, len) VALGRIND_MAKE_MEM_DEFINED(ptr, len)
159
160
#else
161
162
#define CONSTTIME_SECRET(ptr, len)
163
#define CONSTTIME_DECLASSIFY(ptr, len)
164
165
#endif
166
167
/*
168
 * Indices of slots in the vinfo tables below
169
 */
170
57.3k
#define ML_KEM_512_VINFO 0
171
158k
#define ML_KEM_768_VINFO 1
172
59.7k
#define ML_KEM_1024_VINFO 2
173
174
/*
175
 * Per-variant fixed parameters
176
 */
177
static const ML_KEM_VINFO vinfo_map[3] = {
178
    { "ML-KEM-512",
179
        PRVKEY_BYTES(512),
180
        sizeof(struct prvkey_512_alloc),
181
        PUBKEY_BYTES(512),
182
        sizeof(struct pubkey_512_alloc),
183
        CTEXT_BYTES(512),
184
        VECTOR_BYTES(512),
185
        U_VECTOR_BYTES(512),
186
        EVP_PKEY_ML_KEM_512,
187
        ML_KEM_512_BITS,
188
        ML_KEM_512_RANK,
189
        ML_KEM_512_DU,
190
        ML_KEM_512_DV,
191
        ML_KEM_512_SECBITS },
192
    { "ML-KEM-768",
193
        PRVKEY_BYTES(768),
194
        sizeof(struct prvkey_768_alloc),
195
        PUBKEY_BYTES(768),
196
        sizeof(struct pubkey_768_alloc),
197
        CTEXT_BYTES(768),
198
        VECTOR_BYTES(768),
199
        U_VECTOR_BYTES(768),
200
        EVP_PKEY_ML_KEM_768,
201
        ML_KEM_768_BITS,
202
        ML_KEM_768_RANK,
203
        ML_KEM_768_DU,
204
        ML_KEM_768_DV,
205
        ML_KEM_768_SECBITS },
206
    { "ML-KEM-1024",
207
        PRVKEY_BYTES(1024),
208
        sizeof(struct prvkey_1024_alloc),
209
        PUBKEY_BYTES(1024),
210
        sizeof(struct pubkey_1024_alloc),
211
        CTEXT_BYTES(1024),
212
        VECTOR_BYTES(1024),
213
        U_VECTOR_BYTES(1024),
214
        EVP_PKEY_ML_KEM_1024,
215
        ML_KEM_1024_BITS,
216
        ML_KEM_1024_RANK,
217
        ML_KEM_1024_DU,
218
        ML_KEM_1024_DV,
219
        ML_KEM_1024_SECBITS }
220
};
221
222
/*
223
 * Remainders modulo `kPrime`, for sufficiently small inputs, are computed in
224
 * constant time via Barrett reduction, and a final call to reduce_once(),
225
 * which reduces inputs that are at most 2*kPrime and is also constant-time.
226
 */
227
static const int kPrime = ML_KEM_PRIME;
228
static const unsigned int kBarrettShift = BARRETT_SHIFT;
229
static const size_t kBarrettMultiplier = (1 << BARRETT_SHIFT) / ML_KEM_PRIME;
230
static const uint16_t kHalfPrime = (ML_KEM_PRIME - 1) / 2;
231
static const uint16_t kInverseDegree = INVERSE_DEGREE;
232
233
/*
234
 * Python helper:
235
 *
236
 * p = 3329
237
 * def bitreverse(i):
238
 *     ret = 0
239
 *     for n in range(7):
240
 *         bit = i & 1
241
 *         ret <<= 1
242
 *         ret |= bit
243
 *         i >>= 1
244
 *     return ret
245
 */
246
247
/*-
248
 * First precomputed array from Appendix A of FIPS 203, or else Python:
249
 * kNTTRoots = [pow(17, bitreverse(i), p) for i in range(128)]
250
 */
251
static const uint16_t kNTTRoots[128] = {
252
    1,
253
    1729,
254
    2580,
255
    3289,
256
    2642,
257
    630,
258
    1897,
259
    848,
260
    1062,
261
    1919,
262
    193,
263
    797,
264
    2786,
265
    3260,
266
    569,
267
    1746,
268
    296,
269
    2447,
270
    1339,
271
    1476,
272
    3046,
273
    56,
274
    2240,
275
    1333,
276
    1426,
277
    2094,
278
    535,
279
    2882,
280
    2393,
281
    2879,
282
    1974,
283
    821,
284
    289,
285
    331,
286
    3253,
287
    1756,
288
    1197,
289
    2304,
290
    2277,
291
    2055,
292
    650,
293
    1977,
294
    2513,
295
    632,
296
    2865,
297
    33,
298
    1320,
299
    1915,
300
    2319,
301
    1435,
302
    807,
303
    452,
304
    1438,
305
    2868,
306
    1534,
307
    2402,
308
    2647,
309
    2617,
310
    1481,
311
    648,
312
    2474,
313
    3110,
314
    1227,
315
    910,
316
    17,
317
    2761,
318
    583,
319
    2649,
320
    1637,
321
    723,
322
    2288,
323
    1100,
324
    1409,
325
    2662,
326
    3281,
327
    233,
328
    756,
329
    2156,
330
    3015,
331
    3050,
332
    1703,
333
    1651,
334
    2789,
335
    1789,
336
    1847,
337
    952,
338
    1461,
339
    2687,
340
    939,
341
    2308,
342
    2437,
343
    2388,
344
    733,
345
    2337,
346
    268,
347
    641,
348
    1584,
349
    2298,
350
    2037,
351
    3220,
352
    375,
353
    2549,
354
    2090,
355
    1645,
356
    1063,
357
    319,
358
    2773,
359
    757,
360
    2099,
361
    561,
362
    2466,
363
    2594,
364
    2804,
365
    1092,
366
    403,
367
    1026,
368
    1143,
369
    2150,
370
    2775,
371
    886,
372
    1722,
373
    1212,
374
    1874,
375
    1029,
376
    2110,
377
    2935,
378
    885,
379
    2154,
380
};
381
382
/*
383
 * InverseNTTRoots = [pow(17, -bitreverse(i), p) for i in range(128)]
384
 * Listed in order of use in the inverse NTT loop (index 0 is skipped):
385
 *
386
 *  0, 64, 65, ..., 127, 32, 33, ..., 63, 16, 17, ..., 31, 8, 9, ...
387
 */
388
static const uint16_t kInverseNTTRoots[128] = {
389
    1,
390
    1175,
391
    2444,
392
    394,
393
    1219,
394
    2300,
395
    1455,
396
    2117,
397
    1607,
398
    2443,
399
    554,
400
    1179,
401
    2186,
402
    2303,
403
    2926,
404
    2237,
405
    525,
406
    735,
407
    863,
408
    2768,
409
    1230,
410
    2572,
411
    556,
412
    3010,
413
    2266,
414
    1684,
415
    1239,
416
    780,
417
    2954,
418
    109,
419
    1292,
420
    1031,
421
    1745,
422
    2688,
423
    3061,
424
    992,
425
    2596,
426
    941,
427
    892,
428
    1021,
429
    2390,
430
    642,
431
    1868,
432
    2377,
433
    1482,
434
    1540,
435
    540,
436
    1678,
437
    1626,
438
    279,
439
    314,
440
    1173,
441
    2573,
442
    3096,
443
    48,
444
    667,
445
    1920,
446
    2229,
447
    1041,
448
    2606,
449
    1692,
450
    680,
451
    2746,
452
    568,
453
    3312,
454
    2419,
455
    2102,
456
    219,
457
    855,
458
    2681,
459
    1848,
460
    712,
461
    682,
462
    927,
463
    1795,
464
    461,
465
    1891,
466
    2877,
467
    2522,
468
    1894,
469
    1010,
470
    1414,
471
    2009,
472
    3296,
473
    464,
474
    2697,
475
    816,
476
    1352,
477
    2679,
478
    1274,
479
    1052,
480
    1025,
481
    2132,
482
    1573,
483
    76,
484
    2998,
485
    3040,
486
    2508,
487
    1355,
488
    450,
489
    936,
490
    447,
491
    2794,
492
    1235,
493
    1903,
494
    1996,
495
    1089,
496
    3273,
497
    283,
498
    1853,
499
    1990,
500
    882,
501
    3033,
502
    1583,
503
    2760,
504
    69,
505
    543,
506
    2532,
507
    3136,
508
    1410,
509
    2267,
510
    2481,
511
    1432,
512
    2699,
513
    687,
514
    40,
515
    749,
516
    1600,
517
};
518
519
/*
520
 * Second precomputed array from Appendix A of FIPS 203 (normalised positive),
521
 * or else Python:
522
 * ModRoots = [pow(17, 2*bitreverse(i) + 1, p) for i in range(128)]
523
 */
524
static const uint16_t kModRoots[128] = {
525
    17,
526
    3312,
527
    2761,
528
    568,
529
    583,
530
    2746,
531
    2649,
532
    680,
533
    1637,
534
    1692,
535
    723,
536
    2606,
537
    2288,
538
    1041,
539
    1100,
540
    2229,
541
    1409,
542
    1920,
543
    2662,
544
    667,
545
    3281,
546
    48,
547
    233,
548
    3096,
549
    756,
550
    2573,
551
    2156,
552
    1173,
553
    3015,
554
    314,
555
    3050,
556
    279,
557
    1703,
558
    1626,
559
    1651,
560
    1678,
561
    2789,
562
    540,
563
    1789,
564
    1540,
565
    1847,
566
    1482,
567
    952,
568
    2377,
569
    1461,
570
    1868,
571
    2687,
572
    642,
573
    939,
574
    2390,
575
    2308,
576
    1021,
577
    2437,
578
    892,
579
    2388,
580
    941,
581
    733,
582
    2596,
583
    2337,
584
    992,
585
    268,
586
    3061,
587
    641,
588
    2688,
589
    1584,
590
    1745,
591
    2298,
592
    1031,
593
    2037,
594
    1292,
595
    3220,
596
    109,
597
    375,
598
    2954,
599
    2549,
600
    780,
601
    2090,
602
    1239,
603
    1645,
604
    1684,
605
    1063,
606
    2266,
607
    319,
608
    3010,
609
    2773,
610
    556,
611
    757,
612
    2572,
613
    2099,
614
    1230,
615
    561,
616
    2768,
617
    2466,
618
    863,
619
    2594,
620
    735,
621
    2804,
622
    525,
623
    1092,
624
    2237,
625
    403,
626
    2926,
627
    1026,
628
    2303,
629
    1143,
630
    2186,
631
    2150,
632
    1179,
633
    2775,
634
    554,
635
    886,
636
    2443,
637
    1722,
638
    1607,
639
    1212,
640
    2117,
641
    1874,
642
    1455,
643
    1029,
644
    2300,
645
    2110,
646
    1219,
647
    2935,
648
    394,
649
    885,
650
    2444,
651
    2154,
652
    1175,
653
};
654
655
/*
656
 * single_keccak hashes |inlen| bytes from |in| and writes |outlen| bytes of
657
 * output to |out|. If the |md| specifies a fixed-output function, like
658
 * SHA3-256, then |outlen| must be the correct length for that function.
659
 */
660
static __owur int single_keccak(uint8_t *out, size_t outlen, const uint8_t *in, size_t inlen,
661
    EVP_MD_CTX *mdctx)
662
350k
{
663
350k
    unsigned int sz = (unsigned int)outlen;
664
665
350k
    if (!EVP_DigestUpdate(mdctx, in, inlen))
666
0
        return 0;
667
350k
    if (EVP_MD_xof(EVP_MD_CTX_get0_md(mdctx)))
668
300k
        return EVP_DigestFinalXOF(mdctx, out, outlen);
669
50.3k
    return EVP_DigestFinal_ex(mdctx, out, &sz)
670
50.3k
        && ossl_assert((size_t)sz == outlen);
671
350k
}
672
673
/*
674
 * FIPS 203, Section 4.1, equation (4.3): PRF. Takes 32+1 input bytes, and uses
675
 * SHAKE256 to produce the input to SamplePolyCBD_eta: FIPS 203, algorithm 8.
676
 */
677
static __owur int prf(uint8_t *out, size_t len, const uint8_t in[ML_KEM_RANDOM_BYTES + 1],
678
    EVP_MD_CTX *mdctx, const ML_KEM_KEY *key)
679
300k
{
680
300k
    return EVP_DigestInit_ex(mdctx, key->shake256_md, NULL)
681
300k
        && single_keccak(out, len, in, ML_KEM_RANDOM_BYTES + 1, mdctx);
682
300k
}
683
684
/*
685
 * FIPS 203, Section 4.1, equation (4.4): H.  SHA3-256 hash of a variable
686
 * length input, producing 32 bytes of output.
687
 */
688
static __owur int hash_h(uint8_t out[ML_KEM_PKHASH_BYTES], const uint8_t *in, size_t len,
689
    EVP_MD_CTX *mdctx, const ML_KEM_KEY *key)
690
293
{
691
293
    return EVP_DigestInit_ex(mdctx, key->sha3_256_md, NULL)
692
293
        && single_keccak(out, ML_KEM_PKHASH_BYTES, in, len, mdctx);
693
293
}
694
695
/* Incremental hash_h of expanded public key */
696
static int
697
hash_h_pubkey(uint8_t pkhash[ML_KEM_PKHASH_BYTES],
698
    EVP_MD_CTX *mdctx, ML_KEM_KEY *key)
699
49.8k
{
700
49.8k
    const ML_KEM_VINFO *vinfo = key->vinfo;
701
49.8k
    const scalar *t = key->t, *end = t + vinfo->rank;
702
49.8k
    unsigned int sz;
703
704
49.8k
    if (!EVP_DigestInit_ex(mdctx, key->sha3_256_md, NULL))
705
0
        return 0;
706
707
149k
    do {
708
149k
        uint8_t buf[3 * DEGREE / 2];
709
710
149k
        scalar_encode(buf, t++, 12);
711
149k
        if (!EVP_DigestUpdate(mdctx, buf, sizeof(buf)))
712
0
            return 0;
713
149k
    } while (t < end);
714
715
49.8k
    if (!EVP_DigestUpdate(mdctx, key->rho, ML_KEM_RANDOM_BYTES))
716
0
        return 0;
717
49.8k
    return EVP_DigestFinal_ex(mdctx, pkhash, &sz)
718
49.8k
        && ossl_assert(sz == ML_KEM_PKHASH_BYTES);
719
49.8k
}
720
721
/*
722
 * FIPS 203, Section 4.1, equation (4.5): G.  SHA3-512 hash of a variable
723
 * length input, producing 64 bytes of output, in particular the seeds
724
 * (d,z) for key generation.
725
 */
726
static __owur int hash_g(uint8_t out[ML_KEM_SEED_BYTES], const uint8_t *in, size_t len,
727
    EVP_MD_CTX *mdctx, const ML_KEM_KEY *key)
728
50.0k
{
729
50.0k
    return EVP_DigestInit_ex(mdctx, key->sha3_512_md, NULL)
730
50.0k
        && single_keccak(out, ML_KEM_SEED_BYTES, in, len, mdctx);
731
50.0k
}
732
733
/*
734
 * FIPS 203, Section 4.1, equation (4.4): J. SHAKE256 taking a variable length
735
 * input to compute a 32-byte implicit rejection shared secret, of the same
736
 * length as the expected shared secret.  (Computed even on success to avoid
737
 * side-channel leaks).
738
 */
739
static __owur int kdf(uint8_t out[ML_KEM_SHARED_SECRET_BYTES],
740
    const uint8_t z[ML_KEM_RANDOM_BYTES],
741
    const uint8_t *ctext, size_t len,
742
    EVP_MD_CTX *mdctx, const ML_KEM_KEY *key)
743
100
{
744
100
    return EVP_DigestInit_ex(mdctx, key->shake256_md, NULL)
745
100
        && EVP_DigestUpdate(mdctx, z, ML_KEM_RANDOM_BYTES)
746
100
        && EVP_DigestUpdate(mdctx, ctext, len)
747
100
        && EVP_DigestFinalXOF(mdctx, out, ML_KEM_SHARED_SECRET_BYTES);
748
100
}
749
750
/*
751
 * FIPS 203, Section 4.2.2, Algorithm 7: "SampleNTT" (steps 3-17, steps 1, 2
752
 * are performed by the caller). Rejection-samples a Keccak stream to get
753
 * uniformly distributed elements in the range [0,q). This is used for matrix
754
 * expansion and only operates on public inputs.
755
 */
756
static __owur int sample_scalar(scalar *out, EVP_MD_CTX *mdctx)
757
450k
{
758
450k
    uint16_t *curr = out->c, *endout = curr + DEGREE;
759
450k
    uint8_t buf[SCALAR_SAMPLING_BUFSIZE], *in;
760
450k
    uint8_t *endin = buf + sizeof(buf);
761
450k
    uint16_t d;
762
450k
    uint8_t b1, b2, b3;
763
764
1.35M
    do {
765
1.35M
        if (!EVP_DigestSqueeze(mdctx, in = buf, sizeof(buf)))
766
0
            return 0;
767
70.6M
        do {
768
70.6M
            b1 = *in++;
769
70.6M
            b2 = *in++;
770
70.6M
            b3 = *in++;
771
772
70.6M
            if (curr >= endout)
773
151k
                break;
774
70.4M
            if ((d = ((b2 & 0x0f) << 8) + b1) < kPrime)
775
58.1M
                *curr++ = d;
776
70.4M
            if (curr >= endout)
777
299k
                break;
778
70.1M
            if ((d = (b3 << 4) + (b2 >> 4)) < kPrime)
779
57.1M
                *curr++ = d;
780
70.1M
        } while (in < endin);
781
1.35M
    } while (curr < endout);
782
450k
    return 1;
783
450k
}
784
785
/*-
786
 * reduce_once reduces 0 <= x < 2*kPrime, mod kPrime.
787
 *
788
 * Subtract |q| if the input is larger, without exposing a side-channel,
789
 * avoiding the "clangover" attack.  See |constish_time_non_zero| for a
790
 * discussion on why the value barrier is by default omitted.
791
 */
792
static __owur uint16_t reduce_once(uint16_t x)
793
980M
{
794
980M
    const uint16_t subtracted = x - kPrime;
795
980M
    uint16_t mask = constish_time_non_zero(subtracted >> 15);
796
797
980M
    return (mask & x) | (~mask & subtracted);
798
980M
}
799
800
/*
801
 * Constant-time reduce x mod kPrime using Barrett reduction. x must be less
802
 * than kPrime + 2 * kPrime^2.  This is sufficient to reduce a product of
803
 * two already reduced u_int16 values, in fact it is sufficient for each
804
 * to be less than 2^12, because (kPrime * (2 * kPrime + 1)) > 2^24.
805
 */
806
static __owur uint16_t reduce(uint32_t x)
807
442M
{
808
442M
    uint64_t product = (uint64_t)x * kBarrettMultiplier;
809
442M
    uint32_t quotient = (uint32_t)(product >> kBarrettShift);
810
442M
    uint32_t remainder = x - quotient * kPrime;
811
812
442M
    return reduce_once(remainder);
813
442M
}
814
815
/* Multiply a scalar by a constant. */
816
static void scalar_mult_const(scalar *s, uint16_t a)
817
1.00k
{
818
1.00k
    uint16_t *curr = s->c, *end = curr + DEGREE, tmp;
819
820
256k
    do {
821
256k
        tmp = reduce(*curr * a);
822
256k
        *curr++ = tmp;
823
256k
    } while (curr < end);
824
1.00k
}
825
826
/*-
827
 * FIPS 203, Section 4.3, Algorithm 9: "NTT".
828
 * In-place number theoretic transform of a given scalar.  Note that ML-KEM's
829
 * kPrime 3329 does not have a 512th root of unity, so this transform leaves
830
 * off the last iteration of the usual FFT code, with the 128 relevant roots of
831
 * unity being stored in NTTRoots.  This means the output should be seen as 128
832
 * elements in GF(3329^2), with the coefficients of the elements being
833
 * consecutive entries in |s->c|.
834
 */
835
static void scalar_ntt(scalar *s)
836
299k
{
837
299k
    const uint16_t *roots = kNTTRoots;
838
299k
    uint16_t *end = s->c + DEGREE;
839
299k
    int offset = DEGREE / 2;
840
841
2.09M
    do {
842
2.09M
        uint16_t *curr = s->c, *peer;
843
844
38.0M
        do {
845
38.0M
            uint16_t *pause = curr + offset, even, odd;
846
38.0M
            uint32_t zeta = *++roots;
847
848
38.0M
            peer = pause;
849
268M
            do {
850
268M
                even = *curr;
851
268M
                odd = reduce(*peer * zeta);
852
268M
                *peer++ = reduce_once(even - odd + kPrime);
853
268M
                *curr++ = reduce_once(odd + even);
854
268M
            } while (curr < pause);
855
38.0M
        } while ((curr = peer) < end);
856
2.09M
    } while ((offset >>= 1) >= 2);
857
299k
}
858
859
/*-
860
 * FIPS 203, Section 4.3, Algorithm 10: "NTT^(-1)".
861
 * In-place inverse number theoretic transform of a given scalar, with pairs of
862
 * entries of s->v being interpreted as elements of GF(3329^2). Just as with
863
 * the number theoretic transform, this leaves off the first step of the normal
864
 * iFFT to account for the fact that 3329 does not have a 512th root of unity,
865
 * using the precomputed 128 roots of unity stored in InverseNTTRoots.
866
 */
867
static void scalar_inverse_ntt(scalar *s)
868
1.00k
{
869
1.00k
    const uint16_t *roots = kInverseNTTRoots;
870
1.00k
    uint16_t *end = s->c + DEGREE;
871
1.00k
    int offset = 2;
872
873
7.00k
    do {
874
7.00k
        uint16_t *curr = s->c, *peer;
875
876
127k
        do {
877
127k
            uint16_t *pause = curr + offset, even, odd;
878
127k
            uint32_t zeta = *++roots;
879
880
127k
            peer = pause;
881
896k
            do {
882
896k
                even = *curr;
883
896k
                odd = *peer;
884
896k
                *peer++ = reduce(zeta * (even - odd + kPrime));
885
896k
                *curr++ = reduce_once(odd + even);
886
896k
            } while (curr < pause);
887
127k
        } while ((curr = peer) < end);
888
7.00k
    } while ((offset <<= 1) < DEGREE);
889
1.00k
    scalar_mult_const(s, kInverseDegree);
890
1.00k
}
891
892
/* Addition updating the LHS scalar in-place. */
893
static void scalar_add(scalar *lhs, const scalar *rhs)
894
900
{
895
900
    int i;
896
897
231k
    for (i = 0; i < DEGREE; i++)
898
230k
        lhs->c[i] = reduce_once(lhs->c[i] + rhs->c[i]);
899
900
}
900
901
/* Subtraction updating the LHS scalar in-place. */
902
static void scalar_sub(scalar *lhs, const scalar *rhs)
903
100
{
904
100
    int i;
905
906
25.7k
    for (i = 0; i < DEGREE; i++)
907
25.6k
        lhs->c[i] = reduce_once(lhs->c[i] - rhs->c[i] + kPrime);
908
100
}
909
910
/*
911
 * Multiplying two scalars in the number theoretically transformed state. Since
912
 * 3329 does not have a 512th root of unity, this means we have to interpret
913
 * the 2*ith and (2*i+1)th entries of the scalar as elements of
914
 * GF(3329)[X]/(X^2 - 17^(2*bitreverse(i)+1)).
915
 *
916
 * The value of 17^(2*bitreverse(i)+1) mod 3329 is stored in the precomputed
917
 * ModRoots table. Note that our Barrett transform only allows us to multiply
918
 * two reduced numbers together, so we need some intermediate reduction steps,
919
 * even if an uint64_t could hold 3 multiplied numbers.
920
 */
921
static void scalar_mult(scalar *out, const scalar *lhs,
922
    const scalar *rhs)
923
1.00k
{
924
1.00k
    uint16_t *curr = out->c, *end = curr + DEGREE;
925
1.00k
    const uint16_t *lc = lhs->c, *rc = rhs->c;
926
1.00k
    const uint16_t *roots = kModRoots;
927
928
128k
    do {
929
128k
        uint32_t l0 = *lc++, r0 = *rc++;
930
128k
        uint32_t l1 = *lc++, r1 = *rc++;
931
128k
        uint32_t zetapow = *roots++;
932
933
128k
        *curr++ = reduce(l0 * r0 + reduce(l1 * r1) * zetapow);
934
128k
        *curr++ = reduce(l0 * r1 + l1 * r0);
935
128k
    } while (curr < end);
936
1.00k
}
937
938
/* Above, but add the result to an existing scalar */
939
static ossl_inline void scalar_mult_add(scalar *out, const scalar *lhs,
940
    const scalar *rhs)
941
450k
{
942
450k
    uint16_t *curr = out->c, *end = curr + DEGREE;
943
450k
    const uint16_t *lc = lhs->c, *rc = rhs->c;
944
450k
    const uint16_t *roots = kModRoots;
945
946
57.6M
    do {
947
57.6M
        uint32_t l0 = *lc++, r0 = *rc++;
948
57.6M
        uint32_t l1 = *lc++, r1 = *rc++;
949
57.6M
        uint16_t *c0 = curr++;
950
57.6M
        uint16_t *c1 = curr++;
951
57.6M
        uint32_t zetapow = *roots++;
952
953
57.6M
        *c0 = reduce(*c0 + l0 * r0 + reduce(l1 * r1) * zetapow);
954
57.6M
        *c1 = reduce(*c1 + l0 * r1 + l1 * r0);
955
57.6M
    } while (curr < end);
956
450k
}
957
958
/*-
959
 * FIPS 203, Section 4.2.1, Algorithm 5: "ByteEncode_d", for 2<=d<=12.
960
 * Here |bits| is |d|.  For efficiency, we handle the d=1 case separately.
961
 */
962
static void scalar_encode(uint8_t *out, const scalar *s, int bits)
963
299k
{
964
299k
    const uint16_t *curr = s->c, *end = curr + DEGREE;
965
299k
    uint64_t accum = 0, element;
966
299k
    int used = 0;
967
968
76.6M
    do {
969
76.6M
        element = *curr++;
970
76.6M
        if (used + bits < 64) {
971
62.2M
            accum |= element << used;
972
62.2M
            used += bits;
973
62.2M
        } else if (used + bits > 64) {
974
9.57M
            out = OPENSSL_store_u64_le(out, accum | (element << used));
975
9.57M
            accum = element >> (64 - used);
976
9.57M
            used = (used + bits) - 64;
977
9.57M
        } else {
978
4.78M
            out = OPENSSL_store_u64_le(out, accum | (element << used));
979
4.78M
            accum = 0;
980
4.78M
            used = 0;
981
4.78M
        }
982
76.6M
    } while (curr < end);
983
299k
}
984
985
/*
986
 * scalar_encode_1 is |scalar_encode| specialised for |bits| == 1.
987
 */
988
static void scalar_encode_1(uint8_t out[DEGREE / 8], const scalar *s)
989
100
{
990
100
    int i, j;
991
100
    uint8_t out_byte;
992
993
3.30k
    for (i = 0; i < DEGREE; i += 8) {
994
3.20k
        out_byte = 0;
995
28.8k
        for (j = 0; j < 8; j++)
996
25.6k
            out_byte |= bit0(s->c[i + j]) << j;
997
3.20k
        *out = out_byte;
998
3.20k
        out++;
999
3.20k
    }
1000
100
}
1001
1002
/*-
1003
 * FIPS 203, Section 4.2.1, Algorithm 6: "ByteDecode_d", for 2<=d<12.
1004
 * Here |bits| is |d|.  For efficiency, we handle the d=1 and d=12 cases
1005
 * separately.
1006
 *
1007
 * scalar_decode parses |DEGREE * bits| bits from |in| into |DEGREE| values in
1008
 * |out|.
1009
 */
1010
static void scalar_decode(scalar *out, const uint8_t *in, int bits)
1011
394
{
1012
394
    uint16_t *curr = out->c, *end = curr + DEGREE;
1013
394
    uint64_t accum = 0;
1014
394
    int accum_bits = 0, todo = bits;
1015
394
    uint16_t bitmask = (((uint16_t)1) << bits) - 1, mask = bitmask;
1016
394
    uint16_t element = 0;
1017
1018
111k
    do {
1019
111k
        if (accum_bits == 0) {
1020
14.0k
            in = OPENSSL_load_u64_le(&accum, in);
1021
14.0k
            accum_bits = 64;
1022
14.0k
        }
1023
111k
        if (todo == bits && accum_bits >= bits) {
1024
            /* No partial "element", and all the required bits available */
1025
89.7k
            *curr++ = ((uint16_t)accum) & mask;
1026
89.7k
            accum >>= bits;
1027
89.7k
            accum_bits -= bits;
1028
89.7k
        } else if (accum_bits >= todo) {
1029
            /* A partial "element", and all the required bits available */
1030
11.0k
            *curr++ = element | ((((uint16_t)accum) & mask) << (bits - todo));
1031
11.0k
            accum >>= todo;
1032
11.0k
            accum_bits -= todo;
1033
11.0k
            element = 0;
1034
11.0k
            todo = bits;
1035
11.0k
            mask = bitmask;
1036
11.0k
        } else {
1037
            /*
1038
             * Only some of the requisite bits accumulated, store |accum_bits|
1039
             * of these in |element|.  The accumulated bitcount becomes 0, but
1040
             * as soon as we have more bits we'll want to merge accum_bits
1041
             * fewer of them into the final |element|.
1042
             *
1043
             * Note that with a 64-bit accumulator and |bits| always 12 or
1044
             * less, if we're here, the previous iteration had all the
1045
             * requisite bits, and so there are no kept bits in |element|.
1046
             */
1047
11.0k
            element = ((uint16_t)accum) & mask;
1048
11.0k
            todo -= accum_bits;
1049
11.0k
            mask = bitmask >> accum_bits;
1050
11.0k
            accum_bits = 0;
1051
11.0k
        }
1052
111k
    } while (curr < end);
1053
394
}
1054
1055
static __owur int scalar_decode_12(scalar *out, const uint8_t in[3 * DEGREE / 2])
1056
1.12k
{
1057
1.12k
    int i;
1058
1.12k
    uint16_t *c = out->c;
1059
1060
117k
    for (i = 0; i < DEGREE / 2; ++i) {
1061
116k
        uint8_t b1 = *in++;
1062
116k
        uint8_t b2 = *in++;
1063
116k
        uint8_t b3 = *in++;
1064
116k
        int outOfRange1 = (*c++ = b1 | ((b2 & 0x0f) << 8)) >= kPrime;
1065
116k
        int outOfRange2 = (*c++ = (b2 >> 4) | (b3 << 4)) >= kPrime;
1066
1067
116k
        if (outOfRange1 | outOfRange2)
1068
259
            return 0;
1069
116k
    }
1070
865
    return 1;
1071
1.12k
}
1072
1073
/*-
1074
 * scalar_decode_decompress_add is a combination of decoding and decompression
1075
 * both specialised for |bits| == 1, with the result added (and sum reduced) to
1076
 * the output scalar.
1077
 *
1078
 * NOTE: this function MUST not leak an input-data-depedennt timing signal.
1079
 * A timing leak in a related function in the reference Kyber implementation
1080
 * made the "clangover" attack (CVE-2024-37880) possible, giving key recovery
1081
 * for ML-KEM-512 in minutes, provided the attacker has access to precise
1082
 * timing of a CPU performing chosen-ciphertext decap.  Admittedly this is only
1083
 * a risk when private keys are reused (perhaps KEMTLS servers).
1084
 */
1085
static void
1086
scalar_decode_decompress_add(scalar *out, const uint8_t in[DEGREE / 8])
1087
228
{
1088
228
    static const uint16_t half_q_plus_1 = (ML_KEM_PRIME >> 1) + 1;
1089
228
    uint16_t *curr = out->c, *end = curr + DEGREE;
1090
228
    uint16_t mask;
1091
228
    uint8_t b;
1092
1093
    /*
1094
     * Add |half_q_plus_1| if the bit is set, without exposing a side-channel,
1095
     * avoiding the "clangover" attack.  See |constish_time_non_zero| for a
1096
     * discussion on why the value barrier is by default omitted.
1097
     */
1098
228
#define decode_decompress_add_bit                        \
1099
58.3k
    mask = constish_time_non_zero(bit0(b));              \
1100
58.3k
    *curr = reduce_once(*curr + (mask & half_q_plus_1)); \
1101
58.3k
    curr++;                                              \
1102
58.3k
    b >>= 1
1103
1104
    /* Unrolled to process each byte in one iteration */
1105
7.29k
    do {
1106
7.29k
        b = *in++;
1107
7.29k
        decode_decompress_add_bit;
1108
7.29k
        decode_decompress_add_bit;
1109
7.29k
        decode_decompress_add_bit;
1110
7.29k
        decode_decompress_add_bit;
1111
1112
7.29k
        decode_decompress_add_bit;
1113
7.29k
        decode_decompress_add_bit;
1114
7.29k
        decode_decompress_add_bit;
1115
7.29k
        decode_decompress_add_bit;
1116
7.29k
    } while (curr < end);
1117
228
#undef decode_decompress_add_bit
1118
228
}
1119
1120
/*
1121
 * FIPS 203, Section 4.2.1, Equation (4.7): Compress_d.
1122
 *
1123
 * Compresses (lossily) an input |x| mod 3329 into |bits| many bits by grouping
1124
 * numbers close to each other together. The formula used is
1125
 * round(2^|bits|/kPrime*x) mod 2^|bits|.
1126
 * Uses Barrett reduction to achieve constant time. Since we need both the
1127
 * remainder (for rounding) and the quotient (as the result), we cannot use
1128
 * |reduce| here, but need to do the Barrett reduction directly.
1129
 */
1130
static __owur uint16_t compress(uint16_t x, int bits)
1131
256k
{
1132
256k
    uint32_t shifted = (uint32_t)x << bits;
1133
256k
    uint64_t product = (uint64_t)shifted * kBarrettMultiplier;
1134
256k
    uint32_t quotient = (uint32_t)(product >> kBarrettShift);
1135
256k
    uint32_t remainder = shifted - quotient * kPrime;
1136
1137
    /*
1138
     * Adjust the quotient to round correctly:
1139
     *   0 <= remainder <= kHalfPrime round to 0
1140
     *   kHalfPrime < remainder <= kPrime + kHalfPrime round to 1
1141
     *   kPrime + kHalfPrime < remainder < 2 * kPrime round to 2
1142
     */
1143
256k
    quotient += 1 & constant_time_lt_32(kHalfPrime, remainder);
1144
256k
    quotient += 1 & constant_time_lt_32(kPrime + kHalfPrime, remainder);
1145
256k
    return quotient & ((1 << bits) - 1);
1146
256k
}
1147
1148
/*
1149
 * FIPS 203, Section 4.2.1, Equation (4.8): Decompress_d.
1150
1151
 * Decompresses |x| by using a close equi-distant representative. The formula
1152
 * is round(kPrime/2^|bits|*x). Note that 2^|bits| being the divisor allows us
1153
 * to implement this logic using only bit operations.
1154
 */
1155
static __owur uint16_t decompress(uint16_t x, int bits)
1156
100k
{
1157
100k
    uint32_t product = (uint32_t)x * kPrime;
1158
100k
    uint32_t power = 1 << bits;
1159
    /* This is |product| % power, since |power| is a power of 2. */
1160
100k
    uint32_t remainder = product & (power - 1);
1161
    /* This is |product| / power, since |power| is a power of 2. */
1162
100k
    uint32_t lower = product >> bits;
1163
1164
    /*
1165
     * The rounding logic works since the first half of numbers mod |power|
1166
     * have a 0 as first bit, and the second half has a 1 as first bit, since
1167
     * |power| is a power of 2. As a 12 bit number, |remainder| is always
1168
     * positive, so we will shift in 0s for a right shift.
1169
     */
1170
100k
    return lower + (remainder >> (bits - 1));
1171
100k
}
1172
1173
/*-
1174
 * FIPS 203, Section 4.2.1, Equation (4.7): "Compress_d".
1175
 * In-place lossy rounding of scalars to 2^d bits.
1176
 */
1177
static void scalar_compress(scalar *s, int bits)
1178
1.00k
{
1179
1.00k
    int i;
1180
1181
257k
    for (i = 0; i < DEGREE; i++)
1182
256k
        s->c[i] = compress(s->c[i], bits);
1183
1.00k
}
1184
1185
/*
1186
 * FIPS 203, Section 4.2.1, Equation (4.8): "Decompress_d".
1187
 * In-place approximate recovery of scalars from 2^d bit compression.
1188
 */
1189
static void scalar_decompress(scalar *s, int bits)
1190
394
{
1191
394
    int i;
1192
1193
101k
    for (i = 0; i < DEGREE; i++)
1194
100k
        s->c[i] = decompress(s->c[i], bits);
1195
394
}
1196
1197
/* Addition updating the LHS vector in-place. */
1198
static void vector_add(scalar *lhs, const scalar *rhs, int rank)
1199
228
{
1200
672
    do {
1201
672
        scalar_add(lhs++, rhs++);
1202
672
    } while (--rank > 0);
1203
228
}
1204
1205
/*
1206
 * Encodes an entire vector into 32*|rank|*|bits| bytes. Note that since 256
1207
 * (DEGREE) is divisible by 8, the individual vector entries will always fill a
1208
 * whole number of bytes, so we do not need to worry about bit packing here.
1209
 */
1210
static void vector_encode(uint8_t *out, const scalar *a, int bits, int rank)
1211
50.0k
{
1212
50.0k
    int stride = bits * DEGREE / 8;
1213
1214
199k
    for (; rank-- > 0; out += stride)
1215
149k
        scalar_encode(out, a++, bits);
1216
50.0k
}
1217
1218
/*
1219
 * Decodes 32*|rank|*|bits| bytes from |in| into |out|. It returns early
1220
 * if any parsed value is >= |ML_KEM_PRIME|.  The resulting scalars are
1221
 * then decompressed and transformed via the NTT.
1222
 *
1223
 * Note: Used only in decrypt_cpa(), which returns void and so does not check
1224
 * the return value of this function.  Side-channels are fine when the input
1225
 * ciphertext to decap() is simply syntactically invalid.
1226
 */
1227
static void
1228
vector_decode_decompress_ntt(scalar *out, const uint8_t *in, int bits, int rank)
1229
100
{
1230
100
    int stride = bits * DEGREE / 8;
1231
1232
394
    for (; rank-- > 0; in += stride, ++out) {
1233
294
        scalar_decode(out, in, bits);
1234
294
        scalar_decompress(out, bits);
1235
294
        scalar_ntt(out);
1236
294
    }
1237
100
}
1238
1239
/* vector_decode(), specialised to bits == 12. */
1240
static __owur int vector_decode_12(scalar *out, const uint8_t in[3 * DEGREE / 2], int rank)
1241
552
{
1242
552
    int stride = 3 * DEGREE / 2;
1243
1244
1.41k
    for (; rank-- > 0; in += stride)
1245
1.12k
        if (!scalar_decode_12(out++, in))
1246
259
            return 0;
1247
293
    return 1;
1248
552
}
1249
1250
/* In-place compression of each scalar component */
1251
static void vector_compress(scalar *a, int bits, int rank)
1252
228
{
1253
672
    do {
1254
672
        scalar_compress(a++, bits);
1255
672
    } while (--rank > 0);
1256
228
}
1257
1258
/* The output scalar must not overlap with the inputs */
1259
static void inner_product(scalar *out, const scalar *lhs, const scalar *rhs,
1260
    int rank)
1261
328
{
1262
328
    scalar_mult(out, lhs, rhs);
1263
966
    while (--rank > 0)
1264
638
        scalar_mult_add(out, ++lhs, ++rhs);
1265
328
}
1266
1267
/*
1268
 * Here, the output vector must not overlap with the inputs, the result is
1269
 * directly subjected to inverse NTT.
1270
 */
1271
static void
1272
matrix_mult_intt(scalar *out, const scalar *m, const scalar *a, int rank)
1273
228
{
1274
228
    const scalar *ar;
1275
228
    int i, j;
1276
1277
900
    for (i = rank; i-- > 0; ++out) {
1278
672
        scalar_mult(out, m++, ar = a);
1279
2.13k
        for (j = rank - 1; j > 0; --j)
1280
1.46k
            scalar_mult_add(out, m++, ++ar);
1281
672
        scalar_inverse_ntt(out);
1282
672
    }
1283
228
}
1284
1285
/* Here, the output vector must not overlap with the inputs */
1286
static void
1287
matrix_mult_transpose_add(scalar *out, const scalar *m, const scalar *a, int rank)
1288
49.8k
{
1289
49.8k
    const scalar *mc = m, *mr, *ar;
1290
49.8k
    int i, j;
1291
1292
199k
    for (i = rank; i-- > 0; ++out) {
1293
149k
        scalar_mult_add(out, mr = mc++, ar = a);
1294
447k
        for (j = rank; --j > 0;)
1295
298k
            scalar_mult_add(out, (mr += rank), ++ar);
1296
149k
    }
1297
49.8k
}
1298
1299
/*-
1300
 * Expands the matrix from a seed for key generation and for encaps-CPA.
1301
 * NOTE: FIPS 203 matrix "A" is the transpose of this matrix, computed
1302
 * by appending the (i,j) indices to the seed in the opposite order!
1303
 *
1304
 * Where FIPS 203 computes t = A * s + e, we use the transpose of "m".
1305
 */
1306
static __owur int matrix_expand(EVP_MD_CTX *mdctx, ML_KEM_KEY *key)
1307
50.0k
{
1308
50.0k
    scalar *out = key->m;
1309
50.0k
    uint8_t input[ML_KEM_RANDOM_BYTES + 2];
1310
50.0k
    int rank = key->vinfo->rank;
1311
50.0k
    int i, j;
1312
1313
50.0k
    memcpy(input, key->rho, ML_KEM_RANDOM_BYTES);
1314
200k
    for (i = 0; i < rank; i++) {
1315
600k
        for (j = 0; j < rank; j++) {
1316
450k
            input[ML_KEM_RANDOM_BYTES] = i;
1317
450k
            input[ML_KEM_RANDOM_BYTES + 1] = j;
1318
450k
            if (!EVP_DigestInit_ex(mdctx, key->shake128_md, NULL)
1319
450k
                || !EVP_DigestUpdate(mdctx, input, sizeof(input))
1320
450k
                || !sample_scalar(out++, mdctx))
1321
0
                return 0;
1322
450k
        }
1323
150k
    }
1324
50.0k
    return 1;
1325
50.0k
}
1326
1327
/*
1328
 * Algorithm 7 from the spec, with eta fixed to two and the PRF call
1329
 * included. Creates binominally distributed elements by sampling 2*|eta| bits,
1330
 * and setting the coefficient to the count of the first bits minus the count of
1331
 * the second bits, resulting in a centered binomial distribution. Since eta is
1332
 * two this gives -2/2 with a probability of 1/16, -1/1 with probability 1/4,
1333
 * and 0 with probability 3/8.
1334
 */
1335
static __owur int cbd_2(scalar *out, uint8_t in[ML_KEM_RANDOM_BYTES + 1],
1336
    EVP_MD_CTX *mdctx, const ML_KEM_KEY *key)
1337
299k
{
1338
299k
    uint16_t *curr = out->c, *end = curr + DEGREE;
1339
299k
    uint8_t randbuf[4 * DEGREE / 8], *r = randbuf; /* 64 * eta slots */
1340
299k
    uint16_t value, mask;
1341
299k
    uint8_t b;
1342
1343
299k
    if (!prf(randbuf, sizeof(randbuf), in, mdctx, key))
1344
0
        return 0;
1345
1346
38.2M
    do {
1347
38.2M
        b = *r++;
1348
1349
        /*
1350
         * Add |kPrime| if |value| underflowed.  See |constish_time_non_zero|
1351
         * for a discussion on why the value barrier is by default omitted.
1352
         * While this could have been written reduce_once(value + kPrime), this
1353
         * is one extra addition and small range of |value| tempts some
1354
         * versions of Clang to emit a branch.
1355
         */
1356
38.2M
        value = bit0(b) + bitn(1, b);
1357
38.2M
        value -= bitn(2, b) + bitn(3, b);
1358
38.2M
        mask = constish_time_non_zero(value >> 15);
1359
38.2M
        *curr++ = value + (kPrime & mask);
1360
1361
38.2M
        value = bitn(4, b) + bitn(5, b);
1362
38.2M
        value -= bitn(6, b) + bitn(7, b);
1363
38.2M
        mask = constish_time_non_zero(value >> 15);
1364
38.2M
        *curr++ = value + (kPrime & mask);
1365
38.2M
    } while (curr < end);
1366
299k
    return 1;
1367
299k
}
1368
1369
/*
1370
 * Algorithm 7 from the spec, with eta fixed to three and the PRF call
1371
 * included. Creates binominally distributed elements by sampling 3*|eta| bits,
1372
 * and setting the coefficient to the count of the first bits minus the count of
1373
 * the second bits, resulting in a centered binomial distribution.
1374
 */
1375
static __owur int cbd_3(scalar *out, uint8_t in[ML_KEM_RANDOM_BYTES + 1],
1376
    EVP_MD_CTX *mdctx, const ML_KEM_KEY *key)
1377
1.06k
{
1378
1.06k
    uint16_t *curr = out->c, *end = curr + DEGREE;
1379
1.06k
    uint8_t randbuf[6 * DEGREE / 8], *r = randbuf; /* 64 * eta slots */
1380
1.06k
    uint8_t b1, b2, b3;
1381
1.06k
    uint16_t value, mask;
1382
1383
1.06k
    if (!prf(randbuf, sizeof(randbuf), in, mdctx, key))
1384
0
        return 0;
1385
1386
68.3k
    do {
1387
68.3k
        b1 = *r++;
1388
68.3k
        b2 = *r++;
1389
68.3k
        b3 = *r++;
1390
1391
        /*
1392
         * Add |kPrime| if |value| underflowed.  See |constish_time_non_zero|
1393
         * for a discussion on why the value barrier is by default omitted.
1394
         * While this could have been written reduce_once(value + kPrime), this
1395
         * is one extra addition and small range of |value| tempts some
1396
         * versions of Clang to emit a branch.
1397
         */
1398
68.3k
        value = bit0(b1) + bitn(1, b1) + bitn(2, b1);
1399
68.3k
        value -= bitn(3, b1) + bitn(4, b1) + bitn(5, b1);
1400
68.3k
        mask = constish_time_non_zero(value >> 15);
1401
68.3k
        *curr++ = value + (kPrime & mask);
1402
1403
68.3k
        value = bitn(6, b1) + bitn(7, b1) + bit0(b2);
1404
68.3k
        value -= bitn(1, b2) + bitn(2, b2) + bitn(3, b2);
1405
68.3k
        mask = constish_time_non_zero(value >> 15);
1406
68.3k
        *curr++ = value + (kPrime & mask);
1407
1408
68.3k
        value = bitn(4, b2) + bitn(5, b2) + bitn(6, b2);
1409
68.3k
        value -= bitn(7, b2) + bit0(b3) + bitn(1, b3);
1410
68.3k
        mask = constish_time_non_zero(value >> 15);
1411
68.3k
        *curr++ = value + (kPrime & mask);
1412
1413
68.3k
        value = bitn(2, b3) + bitn(3, b3) + bitn(4, b3);
1414
68.3k
        value -= bitn(5, b3) + bitn(6, b3) + bitn(7, b3);
1415
68.3k
        mask = constish_time_non_zero(value >> 15);
1416
68.3k
        *curr++ = value + (kPrime & mask);
1417
68.3k
    } while (curr < end);
1418
1.06k
    return 1;
1419
1.06k
}
1420
1421
/*
1422
 * Generates a secret vector by using |cbd| with the given seed to generate
1423
 * scalar elements and incrementing |counter| for each slot of the vector.
1424
 */
1425
static __owur int gencbd_vector(scalar *out, CBD_FUNC cbd, uint8_t *counter,
1426
    const uint8_t seed[ML_KEM_RANDOM_BYTES], int rank,
1427
    EVP_MD_CTX *mdctx, const ML_KEM_KEY *key)
1428
228
{
1429
228
    uint8_t input[ML_KEM_RANDOM_BYTES + 1];
1430
1431
228
    memcpy(input, seed, ML_KEM_RANDOM_BYTES);
1432
672
    do {
1433
672
        input[ML_KEM_RANDOM_BYTES] = (*counter)++;
1434
672
        if (!cbd(out++, input, mdctx, key))
1435
0
            return 0;
1436
672
    } while (--rank > 0);
1437
228
    return 1;
1438
228
}
1439
1440
/*
1441
 * As above plus NTT transform.
1442
 */
1443
static __owur int gencbd_vector_ntt(scalar *out, CBD_FUNC cbd, uint8_t *counter,
1444
    const uint8_t seed[ML_KEM_RANDOM_BYTES], int rank,
1445
    EVP_MD_CTX *mdctx, const ML_KEM_KEY *key)
1446
99.8k
{
1447
99.8k
    uint8_t input[ML_KEM_RANDOM_BYTES + 1];
1448
1449
99.8k
    memcpy(input, seed, ML_KEM_RANDOM_BYTES);
1450
299k
    do {
1451
299k
        input[ML_KEM_RANDOM_BYTES] = (*counter)++;
1452
299k
        if (!cbd(out, input, mdctx, key))
1453
0
            return 0;
1454
299k
        scalar_ntt(out++);
1455
299k
    } while (--rank > 0);
1456
99.8k
    return 1;
1457
99.8k
}
1458
1459
/* The |ETA1| value for ML-KEM-512 is 3, the rest and all ETA2 values are 2. */
1460
33.5k
#define CBD1(evp_type) ((evp_type) == EVP_PKEY_ML_KEM_512 ? cbd_3 : cbd_2)
1461
1462
/*
1463
 * FIPS 203, Section 5.2, Algorithm 14: K-PKE.Encrypt.
1464
 *
1465
 * Encrypts a message with given randomness to the ciphertext in |out|. Without
1466
 * applying the Fujisaki-Okamoto transform this would not result in a CCA
1467
 * secure scheme, since lattice schemes are vulnerable to decryption failure
1468
 * oracles.
1469
 *
1470
 * The steps are re-ordered to make more efficient/localised use of storage.
1471
 *
1472
 * Note also that the input public key is assumed to hold a precomputed matrix
1473
 * |A| (our key->m, with the public key holding an expanded (16-bit per scalar
1474
 * coefficient) key->t vector).
1475
 *
1476
 * Caller passes storage in |tmp| for for two temporary vectors.
1477
 */
1478
static __owur int encrypt_cpa(uint8_t out[ML_KEM_SHARED_SECRET_BYTES],
1479
    const uint8_t message[DEGREE / 8],
1480
    const uint8_t r[ML_KEM_RANDOM_BYTES], scalar *tmp,
1481
    EVP_MD_CTX *mdctx, const ML_KEM_KEY *key)
1482
228
{
1483
228
    const ML_KEM_VINFO *vinfo = key->vinfo;
1484
228
    CBD_FUNC cbd_1 = CBD1(vinfo->evp_type);
1485
228
    int rank = vinfo->rank;
1486
    /* We can use tmp[0..rank-1] as storage for |y|, then |e1|, ... */
1487
228
    scalar *y = &tmp[0], *e1 = y, *e2 = y;
1488
    /* We can use tmp[rank]..tmp[2*rank - 1] for |u| */
1489
228
    scalar *u = &tmp[rank];
1490
228
    scalar v;
1491
228
    uint8_t input[ML_KEM_RANDOM_BYTES + 1];
1492
228
    uint8_t counter = 0;
1493
228
    int du = vinfo->du;
1494
228
    int dv = vinfo->dv;
1495
1496
    /* FIPS 203 "y" vector */
1497
228
    if (!gencbd_vector_ntt(y, cbd_1, &counter, r, rank, mdctx, key))
1498
0
        return 0;
1499
    /* FIPS 203 "v" scalar */
1500
228
    inner_product(&v, key->t, y, rank);
1501
228
    scalar_inverse_ntt(&v);
1502
    /* FIPS 203 "u" vector */
1503
228
    matrix_mult_intt(u, key->m, y, rank);
1504
1505
    /* All done with |y|, now free to reuse tmp[0] for FIPS 203 |e1| */
1506
228
    if (!gencbd_vector(e1, cbd_2, &counter, r, rank, mdctx, key))
1507
0
        return 0;
1508
228
    vector_add(u, e1, rank);
1509
228
    vector_compress(u, du, rank);
1510
228
    vector_encode(out, u, du, rank);
1511
1512
    /* All done with |e1|, now free to reuse tmp[0] for FIPS 203 |e2| */
1513
228
    memcpy(input, r, ML_KEM_RANDOM_BYTES);
1514
228
    input[ML_KEM_RANDOM_BYTES] = counter;
1515
228
    if (!cbd_2(e2, input, mdctx, key))
1516
0
        return 0;
1517
228
    scalar_add(&v, e2);
1518
1519
    /* Combine message with |v| */
1520
228
    scalar_decode_decompress_add(&v, message);
1521
228
    scalar_compress(&v, dv);
1522
228
    scalar_encode(out + vinfo->u_vector_bytes, &v, dv);
1523
228
    return 1;
1524
228
}
1525
1526
/*
1527
 * FIPS 203, Section 5.3, Algorithm 15: K-PKE.Decrypt.
1528
 */
1529
static void
1530
decrypt_cpa(uint8_t out[ML_KEM_SHARED_SECRET_BYTES],
1531
    const uint8_t *ctext, scalar *u, const ML_KEM_KEY *key)
1532
100
{
1533
100
    const ML_KEM_VINFO *vinfo = key->vinfo;
1534
100
    scalar v, mask;
1535
100
    int rank = vinfo->rank;
1536
100
    int du = vinfo->du;
1537
100
    int dv = vinfo->dv;
1538
1539
100
    vector_decode_decompress_ntt(u, ctext, du, rank);
1540
100
    scalar_decode(&v, ctext + vinfo->u_vector_bytes, dv);
1541
100
    scalar_decompress(&v, dv);
1542
100
    inner_product(&mask, key->s, u, rank);
1543
100
    scalar_inverse_ntt(&mask);
1544
100
    scalar_sub(&v, &mask);
1545
100
    scalar_compress(&v, 1);
1546
100
    scalar_encode_1(out, &v);
1547
100
}
1548
1549
/*-
1550
 * FIPS 203, Section 7.1, Algorithm 19: "ML-KEM.KeyGen".
1551
 * FIPS 203, Section 7.2, Algorithm 20: "ML-KEM.Encaps".
1552
 *
1553
 * Fills the |out| buffer with the |ek| output of "ML-KEM.KeyGen", or,
1554
 * equivalently, the |ek| input of "ML-KEM.Encaps", i.e. returns the
1555
 * wire-format of an ML-KEM public key.
1556
 */
1557
static void encode_pubkey(uint8_t *out, const ML_KEM_KEY *key)
1558
49.6k
{
1559
49.6k
    const uint8_t *rho = key->rho;
1560
49.6k
    const ML_KEM_VINFO *vinfo = key->vinfo;
1561
1562
49.6k
    vector_encode(out, key->t, 12, vinfo->rank);
1563
49.6k
    memcpy(out + vinfo->vector_bytes, rho, ML_KEM_RANDOM_BYTES);
1564
49.6k
}
1565
1566
/*-
1567
 * FIPS 203, Section 7.1, Algorithm 19: "ML-KEM.KeyGen".
1568
 *
1569
 * Fills the |out| buffer with the |dk| output of "ML-KEM.KeyGen".
1570
 * This matches the input format of parse_prvkey() below.
1571
 */
1572
static void encode_prvkey(uint8_t *out, const ML_KEM_KEY *key)
1573
74
{
1574
74
    const ML_KEM_VINFO *vinfo = key->vinfo;
1575
1576
74
    vector_encode(out, key->s, 12, vinfo->rank);
1577
74
    out += vinfo->vector_bytes;
1578
74
    encode_pubkey(out, key);
1579
74
    out += vinfo->pubkey_bytes;
1580
74
    memcpy(out, key->pkhash, ML_KEM_PKHASH_BYTES);
1581
74
    out += ML_KEM_PKHASH_BYTES;
1582
74
    memcpy(out, key->z, ML_KEM_RANDOM_BYTES);
1583
74
}
1584
1585
/*-
1586
 * FIPS 203, Section 7.1, Algorithm 19: "ML-KEM.KeyGen".
1587
 * FIPS 203, Section 7.2, Algorithm 20: "ML-KEM.Encaps".
1588
 *
1589
 * This function parses the |in| buffer as the |ek| output of "ML-KEM.KeyGen",
1590
 * or, equivalently, the |ek| input of "ML-KEM.Encaps", i.e. decodes the
1591
 * wire-format of the ML-KEM public key.
1592
 */
1593
static int parse_pubkey(const uint8_t *in, EVP_MD_CTX *mdctx, ML_KEM_KEY *key)
1594
497
{
1595
497
    const ML_KEM_VINFO *vinfo = key->vinfo;
1596
1597
    /* Decode and check |t| */
1598
497
    if (!vector_decode_12(key->t, in, vinfo->rank)) {
1599
204
        ERR_raise_data(ERR_LIB_PROV, PROV_R_INVALID_KEY,
1600
204
            "%s invalid public 't' vector",
1601
204
            vinfo->algorithm_name);
1602
204
        return 0;
1603
204
    }
1604
    /* Save the matrix |m| recovery seed |rho| */
1605
293
    memcpy(key->rho, in + vinfo->vector_bytes, ML_KEM_RANDOM_BYTES);
1606
    /*
1607
     * Pre-compute the public key hash, needed for both encap and decap.
1608
     * Also pre-compute the matrix expansion, stored with the public key.
1609
     */
1610
293
    if (!hash_h(key->pkhash, in, vinfo->pubkey_bytes, mdctx, key)
1611
293
        || !matrix_expand(mdctx, key)) {
1612
0
        ERR_raise_data(ERR_LIB_CRYPTO, ERR_R_INTERNAL_ERROR,
1613
0
            "internal error while parsing %s public key",
1614
0
            vinfo->algorithm_name);
1615
0
        return 0;
1616
0
    }
1617
293
    return 1;
1618
293
}
1619
1620
/*
1621
 * FIPS 203, Section 7.1, Algorithm 19: "ML-KEM.KeyGen".
1622
 *
1623
 * Parses the |in| buffer as a |dk| output of "ML-KEM.KeyGen".
1624
 * This matches the output format of encode_prvkey() above.
1625
 */
1626
static int parse_prvkey(const uint8_t *in, EVP_MD_CTX *mdctx, ML_KEM_KEY *key)
1627
55
{
1628
55
    const ML_KEM_VINFO *vinfo = key->vinfo;
1629
1630
    /* Decode and check |s|. */
1631
55
    if (!vector_decode_12(key->s, in, vinfo->rank)) {
1632
55
        ERR_raise_data(ERR_LIB_PROV, PROV_R_INVALID_KEY,
1633
55
            "%s invalid private 's' vector",
1634
55
            vinfo->algorithm_name);
1635
55
        return 0;
1636
55
    }
1637
0
    in += vinfo->vector_bytes;
1638
1639
0
    if (!parse_pubkey(in, mdctx, key))
1640
0
        return 0;
1641
0
    in += vinfo->pubkey_bytes;
1642
1643
    /* Check public key hash. */
1644
0
    if (memcmp(key->pkhash, in, ML_KEM_PKHASH_BYTES) != 0) {
1645
0
        ERR_raise_data(ERR_LIB_PROV, PROV_R_INVALID_KEY,
1646
0
            "%s public key hash mismatch",
1647
0
            vinfo->algorithm_name);
1648
0
        return 0;
1649
0
    }
1650
0
    in += ML_KEM_PKHASH_BYTES;
1651
1652
0
    memcpy(key->z, in, ML_KEM_RANDOM_BYTES);
1653
0
    return 1;
1654
0
}
1655
1656
/*
1657
 * FIPS 203, Section 6.1, Algorithm 16: "ML-KEM.KeyGen_internal".
1658
 *
1659
 * The implementation of Section 5.1, Algorithm 13, "K-PKE.KeyGen(d)" is
1660
 * inlined.
1661
 *
1662
 * The caller MUST pass a pre-allocated digest context that is not shared with
1663
 * any concurrent computation.
1664
 *
1665
 * This function optionally outputs the serialised wire-form |ek| public key
1666
 * into the provided |pubenc| buffer, and generates the content of the |rho|,
1667
 * |pkhash|, |t|, |m|, |s| and |z| components of the private |key| (which must
1668
 * have preallocated space for these).
1669
 *
1670
 * Keys are computed from a 32-byte random |d| plus the 1 byte rank for
1671
 * domain separation.  These are concatenated and hashed to produce a pair of
1672
 * 32-byte seeds public "rho", used to generate the matrix, and private "sigma",
1673
 * used to generate the secret vector |s|.
1674
 *
1675
 * The second random input |z| is copied verbatim into the Fujisaki-Okamoto
1676
 * (FO) transform "implicit-rejection" secret (the |z| component of the private
1677
 * key), which thwarts chosen-ciphertext attacks, provided decap() runs in
1678
 * constant time, with no side channel leaks, on all well-formed (valid length,
1679
 * and correctly encoded) ciphertext inputs.
1680
 */
1681
static __owur int genkey(const uint8_t seed[ML_KEM_SEED_BYTES],
1682
    EVP_MD_CTX *mdctx, uint8_t *pubenc, ML_KEM_KEY *key)
1683
33.2k
{
1684
33.2k
    uint8_t hashed[2 * ML_KEM_RANDOM_BYTES];
1685
33.2k
    const uint8_t *const sigma = hashed + ML_KEM_RANDOM_BYTES;
1686
33.2k
    uint8_t augmented_seed[ML_KEM_RANDOM_BYTES + 1];
1687
33.2k
    const ML_KEM_VINFO *vinfo = key->vinfo;
1688
33.2k
    CBD_FUNC cbd_1 = CBD1(vinfo->evp_type);
1689
33.2k
    int rank = vinfo->rank;
1690
33.2k
    uint8_t counter = 0;
1691
33.2k
    int ret = 0;
1692
1693
    /*
1694
     * Use the "d" seed salted with the rank to derive the public and private
1695
     * seeds rho and sigma.
1696
     */
1697
33.2k
    memcpy(augmented_seed, seed, ML_KEM_RANDOM_BYTES);
1698
33.2k
    augmented_seed[ML_KEM_RANDOM_BYTES] = (uint8_t)rank;
1699
33.2k
    if (!hash_g(hashed, augmented_seed, sizeof(augmented_seed), mdctx, key))
1700
0
        goto end;
1701
33.2k
    memcpy(key->rho, hashed, ML_KEM_RANDOM_BYTES);
1702
    /* The |rho| matrix seed is public */
1703
33.2k
    CONSTTIME_DECLASSIFY(key->rho, ML_KEM_RANDOM_BYTES);
1704
1705
    /* FIPS 203 |e| vector is initial value of key->t */
1706
33.2k
    if (!matrix_expand(mdctx, key)
1707
33.2k
        || !gencbd_vector_ntt(key->s, cbd_1, &counter, sigma, rank, mdctx, key)
1708
33.2k
        || !gencbd_vector_ntt(key->t, cbd_1, &counter, sigma, rank, mdctx, key))
1709
0
        goto end;
1710
1711
    /* To |e| we now add the product of transpose |m| and |s|, giving |t|. */
1712
33.2k
    matrix_mult_transpose_add(key->t, key->m, key->s, rank);
1713
    /* The |t| vector is public */
1714
33.2k
    CONSTTIME_DECLASSIFY(key->t, vinfo->rank * sizeof(scalar));
1715
1716
33.2k
    if (pubenc == NULL) {
1717
        /* Incremental digest of public key without in-full serialisation. */
1718
33.2k
        if (!hash_h_pubkey(key->pkhash, mdctx, key))
1719
0
            goto end;
1720
33.2k
    } else {
1721
0
        encode_pubkey(pubenc, key);
1722
0
        if (!hash_h(key->pkhash, pubenc, vinfo->pubkey_bytes, mdctx, key))
1723
0
            goto end;
1724
0
    }
1725
1726
    /* Save |z| portion of seed for "implicit rejection" on failure. */
1727
33.2k
    memcpy(key->z, seed + ML_KEM_RANDOM_BYTES, ML_KEM_RANDOM_BYTES);
1728
1729
    /* Optionally save the |d| portion of the seed */
1730
33.2k
    key->d = key->z + ML_KEM_RANDOM_BYTES;
1731
33.2k
    if (key->prov_flags & ML_KEM_KEY_RETAIN_SEED) {
1732
33.2k
        memcpy(key->d, seed, ML_KEM_RANDOM_BYTES);
1733
33.2k
    } else {
1734
0
        OPENSSL_cleanse(key->d, ML_KEM_RANDOM_BYTES);
1735
0
        key->d = NULL;
1736
0
    }
1737
1738
33.2k
    ret = 1;
1739
33.2k
end:
1740
33.2k
    OPENSSL_cleanse((void *)augmented_seed, ML_KEM_RANDOM_BYTES);
1741
33.2k
    OPENSSL_cleanse((void *)sigma, ML_KEM_RANDOM_BYTES);
1742
33.2k
    if (ret == 0) {
1743
0
        ERR_raise_data(ERR_LIB_CRYPTO, ERR_R_INTERNAL_ERROR,
1744
0
            "internal error while generating %s private key",
1745
0
            vinfo->algorithm_name);
1746
0
    }
1747
33.2k
    return ret;
1748
33.2k
}
1749
1750
/*-
1751
 * FIPS 203, Section 6.2, Algorithm 17: "ML-KEM.Encaps_internal".
1752
 * This is the deterministic version with randomness supplied externally.
1753
 *
1754
 * The caller must pass space for two vectors in |tmp|.
1755
 * The |ctext| buffer have space for the ciphertext of the ML-KEM variant
1756
 * of the provided key.
1757
 */
1758
static int encap(uint8_t *ctext, uint8_t secret[ML_KEM_SHARED_SECRET_BYTES],
1759
    const uint8_t entropy[ML_KEM_RANDOM_BYTES],
1760
    scalar *tmp, EVP_MD_CTX *mdctx, const ML_KEM_KEY *key)
1761
128
{
1762
128
    uint8_t input[ML_KEM_RANDOM_BYTES + ML_KEM_PKHASH_BYTES];
1763
128
    uint8_t Kr[ML_KEM_SHARED_SECRET_BYTES + ML_KEM_RANDOM_BYTES];
1764
128
    uint8_t *r = Kr + ML_KEM_SHARED_SECRET_BYTES;
1765
128
    int ret;
1766
1767
128
    memcpy(input, entropy, ML_KEM_RANDOM_BYTES);
1768
128
    memcpy(input + ML_KEM_RANDOM_BYTES, key->pkhash, ML_KEM_PKHASH_BYTES);
1769
128
    ret = hash_g(Kr, input, sizeof(input), mdctx, key)
1770
128
        && encrypt_cpa(ctext, entropy, r, tmp, mdctx, key);
1771
128
    OPENSSL_cleanse((void *)input, sizeof(input));
1772
1773
128
    if (ret)
1774
128
        memcpy(secret, Kr, ML_KEM_SHARED_SECRET_BYTES);
1775
0
    else
1776
0
        ERR_raise_data(ERR_LIB_CRYPTO, ERR_R_INTERNAL_ERROR,
1777
0
            "internal error while performing %s encapsulation",
1778
0
            key->vinfo->algorithm_name);
1779
128
    return ret;
1780
128
}
1781
1782
/*
1783
 * FIPS 203, Section 6.3, Algorithm 18: ML-KEM.Decaps_internal
1784
 *
1785
 * Barring failure of the supporting SHA3/SHAKE primitives, this is fully
1786
 * deterministic, the randomness for the FO transform is extracted during
1787
 * private key generation.
1788
 *
1789
 * The caller must pass space for two vectors in |tmp|.
1790
 * The |ctext| and |tmp_ctext| buffers must each have space for the ciphertext
1791
 * of the key's ML-KEM variant.
1792
 */
1793
static int decap(uint8_t secret[ML_KEM_SHARED_SECRET_BYTES],
1794
    const uint8_t *ctext, uint8_t *tmp_ctext, scalar *tmp,
1795
    EVP_MD_CTX *mdctx, const ML_KEM_KEY *key)
1796
100
{
1797
100
    uint8_t decrypted[ML_KEM_SHARED_SECRET_BYTES + ML_KEM_PKHASH_BYTES];
1798
100
    uint8_t failure_key[ML_KEM_RANDOM_BYTES];
1799
100
    uint8_t Kr[ML_KEM_SHARED_SECRET_BYTES + ML_KEM_RANDOM_BYTES];
1800
100
    uint8_t *r = Kr + ML_KEM_SHARED_SECRET_BYTES;
1801
100
    const uint8_t *pkhash = key->pkhash;
1802
100
    const ML_KEM_VINFO *vinfo = key->vinfo;
1803
100
    int i;
1804
100
    uint8_t mask;
1805
1806
    /*
1807
     * If our KDF is unavailable, fail early! Otherwise, keep going ignoring
1808
     * any further errors, returning success, and whatever we got for a shared
1809
     * secret.  The decrypt_cpa() function is just arithmetic on secret data,
1810
     * so should not be subject to failure that makes its output predictable.
1811
     *
1812
     * We guard against "should never happen" catastrophic failure of the
1813
     * "pure" function |hash_g| by overwriting the shared secret with the
1814
     * content of the failure key and returning early, if nevertheless hash_g
1815
     * fails.  This is not constant-time, but a failure of |hash_g| already
1816
     * implies loss of side-channel resistance.
1817
     *
1818
     * The same action is taken, if also |encrypt_cpa| should catastrophically
1819
     * fail, due to failure of the |PRF| underlying the CBD functions.
1820
     */
1821
100
    if (!kdf(failure_key, key->z, ctext, vinfo->ctext_bytes, mdctx, key)) {
1822
0
        ERR_raise_data(ERR_LIB_CRYPTO, ERR_R_INTERNAL_ERROR,
1823
0
            "internal error while performing %s decapsulation",
1824
0
            vinfo->algorithm_name);
1825
0
        return 0;
1826
0
    }
1827
100
    decrypt_cpa(decrypted, ctext, tmp, key);
1828
100
    memcpy(decrypted + ML_KEM_SHARED_SECRET_BYTES, pkhash, ML_KEM_PKHASH_BYTES);
1829
100
    if (!hash_g(Kr, decrypted, sizeof(decrypted), mdctx, key)
1830
100
        || !encrypt_cpa(tmp_ctext, decrypted, r, tmp, mdctx, key)) {
1831
0
        memcpy(secret, failure_key, ML_KEM_SHARED_SECRET_BYTES);
1832
0
        OPENSSL_cleanse(decrypted, ML_KEM_SHARED_SECRET_BYTES);
1833
0
        return 1;
1834
0
    }
1835
100
    mask = constant_time_eq_int_8(0,
1836
100
        CRYPTO_memcmp(ctext, tmp_ctext, vinfo->ctext_bytes));
1837
3.30k
    for (i = 0; i < ML_KEM_SHARED_SECRET_BYTES; i++)
1838
3.20k
        secret[i] = constant_time_select_8(mask, Kr[i], failure_key[i]);
1839
100
    OPENSSL_cleanse(decrypted, ML_KEM_SHARED_SECRET_BYTES);
1840
100
    OPENSSL_cleanse(Kr, sizeof(Kr));
1841
100
    return 1;
1842
100
}
1843
1844
/*
1845
 * After allocating storage for public or private key data, update the key
1846
 * component pointers to reference that storage.
1847
 */
1848
static __owur int add_storage(scalar *p, int private, ML_KEM_KEY *key)
1849
16.6k
{
1850
16.6k
    int rank = key->vinfo->rank;
1851
1852
16.6k
    if (p == NULL)
1853
0
        return 0;
1854
1855
    /*
1856
     * We're adding key material, the seed buffer will now hold |rho| and
1857
     * |pkhash|.
1858
     */
1859
16.6k
    memset(key->seedbuf, 0, sizeof(key->seedbuf));
1860
16.6k
    key->rho = key->seedbuf;
1861
16.6k
    key->pkhash = key->seedbuf + ML_KEM_RANDOM_BYTES;
1862
16.6k
    key->d = key->z = NULL;
1863
1864
    /* A public key needs space for |t| and |m| */
1865
16.6k
    key->m = (key->t = p) + rank;
1866
1867
    /*
1868
     * A private key also needs space for |s| and |z|.
1869
     * The |z| buffer always includes additional space for |d|, but a key's |d|
1870
     * pointer is left NULL when parsed from the NIST format, which omits that
1871
     * information.  Only keys generated from a (d, z) seed pair will have a
1872
     * non-NULL |d| pointer.
1873
     */
1874
16.6k
    if (private)
1875
16.5k
        key->z = (uint8_t *)(rank + (key->s = key->m + rank * rank));
1876
16.6k
    return 1;
1877
16.6k
}
1878
1879
/*
1880
 * After freeing the storage associated with a key that failed to be
1881
 * constructed, reset the internal pointers back to NULL.
1882
 */
1883
void ossl_ml_kem_key_reset(ML_KEM_KEY *key)
1884
16.7k
{
1885
16.7k
    if (key->t == NULL)
1886
101
        return;
1887
    /*-
1888
     * Cleanse any sensitive data:
1889
     * - The private vector |s| is immediately followed by the FO failure
1890
     *   secret |z|, and seed |d|, we can cleanse all three in one call.
1891
     *
1892
     * - Otherwise, when key->d is set, cleanse the stashed seed.
1893
     */
1894
16.6k
    if (ossl_ml_kem_have_prvkey(key))
1895
16.5k
        OPENSSL_cleanse(key->s,
1896
16.5k
            key->vinfo->rank * sizeof(scalar) + 2 * ML_KEM_RANDOM_BYTES);
1897
16.6k
    OPENSSL_free(key->t);
1898
16.6k
    key->d = key->z = (uint8_t *)(key->s = key->m = key->t = NULL);
1899
16.6k
}
1900
1901
/*
1902
 * ----- API exported to the provider
1903
 *
1904
 * Parameters with an implicit fixed length in the internal static API of each
1905
 * variant have an explicit checked length argument at this layer.
1906
 */
1907
1908
/* Retrieve the parameters of one of the ML-KEM variants */
1909
const ML_KEM_VINFO *ossl_ml_kem_get_vinfo(int evp_type)
1910
275k
{
1911
275k
    switch (evp_type) {
1912
57.3k
    case EVP_PKEY_ML_KEM_512:
1913
57.3k
        return &vinfo_map[ML_KEM_512_VINFO];
1914
158k
    case EVP_PKEY_ML_KEM_768:
1915
158k
        return &vinfo_map[ML_KEM_768_VINFO];
1916
59.7k
    case EVP_PKEY_ML_KEM_1024:
1917
59.7k
        return &vinfo_map[ML_KEM_1024_VINFO];
1918
275k
    }
1919
0
    return NULL;
1920
275k
}
1921
1922
ML_KEM_KEY *ossl_ml_kem_key_new(OSSL_LIB_CTX *libctx, const char *properties,
1923
    int evp_type)
1924
50.4k
{
1925
50.4k
    const ML_KEM_VINFO *vinfo = ossl_ml_kem_get_vinfo(evp_type);
1926
50.4k
    ML_KEM_KEY *key;
1927
1928
50.4k
    if (vinfo == NULL) {
1929
0
        ERR_raise_data(ERR_LIB_CRYPTO, ERR_R_PASSED_INVALID_ARGUMENT,
1930
0
            "unsupported ML-KEM key type: %d", evp_type);
1931
0
        return NULL;
1932
0
    }
1933
1934
50.4k
    if ((key = OPENSSL_malloc(sizeof(*key))) == NULL)
1935
0
        return NULL;
1936
1937
50.4k
    key->vinfo = vinfo;
1938
50.4k
    key->libctx = libctx;
1939
50.4k
    key->prov_flags = ML_KEM_KEY_PROV_FLAGS_DEFAULT;
1940
50.4k
    key->shake128_md = EVP_MD_fetch(libctx, "SHAKE128", properties);
1941
50.4k
    key->shake256_md = EVP_MD_fetch(libctx, "SHAKE256", properties);
1942
50.4k
    key->sha3_256_md = EVP_MD_fetch(libctx, "SHA3-256", properties);
1943
50.4k
    key->sha3_512_md = EVP_MD_fetch(libctx, "SHA3-512", properties);
1944
50.4k
    key->d = key->z = key->rho = key->pkhash = key->encoded_dk = NULL;
1945
50.4k
    key->s = key->m = key->t = NULL;
1946
1947
50.4k
    if (key->shake128_md != NULL
1948
50.4k
        && key->shake256_md != NULL
1949
50.4k
        && key->sha3_256_md != NULL
1950
50.4k
        && key->sha3_512_md != NULL)
1951
50.4k
        return key;
1952
1953
0
    ossl_ml_kem_key_free(key);
1954
0
    ERR_raise_data(ERR_LIB_CRYPTO, ERR_R_INTERNAL_ERROR,
1955
0
        "missing SHA3 digest algorithms while creating %s key",
1956
0
        vinfo->algorithm_name);
1957
0
    return NULL;
1958
50.4k
}
1959
1960
ML_KEM_KEY *ossl_ml_kem_key_dup(const ML_KEM_KEY *key, int selection)
1961
28
{
1962
28
    int ok = 0;
1963
28
    ML_KEM_KEY *ret;
1964
1965
    /*
1966
     * Partially decoded keys, not yet imported or loaded, should never be
1967
     * duplicated.
1968
     */
1969
28
    if (ossl_ml_kem_decoded_key(key))
1970
0
        return NULL;
1971
1972
28
    if (key == NULL
1973
28
        || (ret = OPENSSL_memdup(key, sizeof(*key))) == NULL)
1974
0
        return NULL;
1975
28
    ret->d = ret->z = ret->rho = ret->pkhash = NULL;
1976
28
    ret->s = ret->m = ret->t = NULL;
1977
1978
    /* Clear selection bits we can't fulfill */
1979
28
    if (!ossl_ml_kem_have_pubkey(key))
1980
0
        selection = 0;
1981
28
    else if (!ossl_ml_kem_have_prvkey(key))
1982
28
        selection &= ~OSSL_KEYMGMT_SELECT_PRIVATE_KEY;
1983
1984
28
    switch (selection & OSSL_KEYMGMT_SELECT_KEYPAIR) {
1985
0
    case 0:
1986
0
        ok = 1;
1987
0
        break;
1988
28
    case OSSL_KEYMGMT_SELECT_PUBLIC_KEY:
1989
28
        ok = add_storage(OPENSSL_memdup(key->t, key->vinfo->puballoc), 0, ret);
1990
28
        ret->rho = ret->seedbuf;
1991
28
        ret->pkhash = ret->rho + ML_KEM_RANDOM_BYTES;
1992
28
        break;
1993
0
    case OSSL_KEYMGMT_SELECT_PRIVATE_KEY:
1994
0
        ok = add_storage(OPENSSL_memdup(key->t, key->vinfo->prvalloc), 1, ret);
1995
        /* Duplicated keys retain |d|, if available */
1996
0
        if (key->d != NULL)
1997
0
            ret->d = ret->z + ML_KEM_RANDOM_BYTES;
1998
0
        break;
1999
28
    }
2000
2001
28
    if (!ok) {
2002
0
        OPENSSL_free(ret);
2003
0
        return NULL;
2004
0
    }
2005
2006
28
    EVP_MD_up_ref(ret->shake128_md);
2007
28
    EVP_MD_up_ref(ret->shake256_md);
2008
28
    EVP_MD_up_ref(ret->sha3_256_md);
2009
28
    EVP_MD_up_ref(ret->sha3_512_md);
2010
2011
28
    return ret;
2012
28
}
2013
2014
void ossl_ml_kem_key_free(ML_KEM_KEY *key)
2015
91.7k
{
2016
91.7k
    if (key == NULL)
2017
75.0k
        return;
2018
2019
16.7k
    EVP_MD_free(key->shake128_md);
2020
16.7k
    EVP_MD_free(key->shake256_md);
2021
16.7k
    EVP_MD_free(key->sha3_256_md);
2022
16.7k
    EVP_MD_free(key->sha3_512_md);
2023
2024
16.7k
    if (ossl_ml_kem_decoded_key(key)) {
2025
0
        OPENSSL_cleanse(key->seedbuf, sizeof(key->seedbuf));
2026
0
        if (ossl_ml_kem_have_dkenc(key)) {
2027
0
            OPENSSL_cleanse(key->encoded_dk, key->vinfo->prvkey_bytes);
2028
0
            OPENSSL_free(key->encoded_dk);
2029
0
        }
2030
0
    }
2031
16.7k
    ossl_ml_kem_key_reset(key);
2032
16.7k
    OPENSSL_free(key);
2033
16.7k
}
2034
2035
/* Serialise the public component of an ML-KEM key */
2036
int ossl_ml_kem_encode_public_key(uint8_t *out, size_t len,
2037
    const ML_KEM_KEY *key)
2038
49.6k
{
2039
49.6k
    if (!ossl_ml_kem_have_pubkey(key)
2040
49.6k
        || len != key->vinfo->pubkey_bytes)
2041
0
        return 0;
2042
49.6k
    encode_pubkey(out, key);
2043
49.6k
    return 1;
2044
49.6k
}
2045
2046
/* Serialise an ML-KEM private key */
2047
int ossl_ml_kem_encode_private_key(uint8_t *out, size_t len,
2048
    const ML_KEM_KEY *key)
2049
74
{
2050
74
    if (!ossl_ml_kem_have_prvkey(key)
2051
74
        || len != key->vinfo->prvkey_bytes)
2052
0
        return 0;
2053
74
    encode_prvkey(out, key);
2054
74
    return 1;
2055
74
}
2056
2057
int ossl_ml_kem_encode_seed(uint8_t *out, size_t len,
2058
    const ML_KEM_KEY *key)
2059
103
{
2060
103
    if (key == NULL || key->d == NULL || len != ML_KEM_SEED_BYTES)
2061
0
        return 0;
2062
    /*
2063
     * Both in the seed buffer, and in the allocated storage, the |d| component
2064
     * of the seed is stored last, so we must copy each separately.
2065
     */
2066
103
    memcpy(out, key->d, ML_KEM_RANDOM_BYTES);
2067
103
    out += ML_KEM_RANDOM_BYTES;
2068
103
    memcpy(out, key->z, ML_KEM_RANDOM_BYTES);
2069
103
    return 1;
2070
103
}
2071
2072
/*
2073
 * Stash the seed without (yet) performing a keygen, used during decoding, to
2074
 * avoid an extra keygen if we're only going to export the key again to load
2075
 * into another provider.
2076
 */
2077
ML_KEM_KEY *ossl_ml_kem_set_seed(const uint8_t *seed, size_t seedlen, ML_KEM_KEY *key)
2078
29
{
2079
29
    if (key == NULL
2080
29
        || ossl_ml_kem_have_pubkey(key)
2081
29
        || ossl_ml_kem_have_seed(key)
2082
29
        || seedlen != ML_KEM_SEED_BYTES)
2083
0
        return NULL;
2084
    /*
2085
     * With no public or private key material on hand, we can use the seed
2086
     * buffer for |z| and |d|, in that order.
2087
     */
2088
29
    key->z = key->seedbuf;
2089
29
    key->d = key->z + ML_KEM_RANDOM_BYTES;
2090
29
    memcpy(key->d, seed, ML_KEM_RANDOM_BYTES);
2091
29
    seed += ML_KEM_RANDOM_BYTES;
2092
29
    memcpy(key->z, seed, ML_KEM_RANDOM_BYTES);
2093
29
    return key;
2094
29
}
2095
2096
/* Parse input as a public key */
2097
int ossl_ml_kem_parse_public_key(const uint8_t *in, size_t len, ML_KEM_KEY *key)
2098
497
{
2099
497
    EVP_MD_CTX *mdctx = NULL;
2100
497
    const ML_KEM_VINFO *vinfo;
2101
497
    int ret = 0;
2102
2103
    /* Keys with key material are immutable */
2104
497
    if (key == NULL
2105
497
        || ossl_ml_kem_have_pubkey(key)
2106
497
        || ossl_ml_kem_have_dkenc(key))
2107
0
        return 0;
2108
497
    vinfo = key->vinfo;
2109
2110
497
    if (len != vinfo->pubkey_bytes
2111
497
        || (mdctx = EVP_MD_CTX_new()) == NULL)
2112
0
        return 0;
2113
2114
497
    if (add_storage(OPENSSL_malloc(vinfo->puballoc), 0, key))
2115
497
        ret = parse_pubkey(in, mdctx, key);
2116
2117
497
    if (!ret)
2118
204
        ossl_ml_kem_key_reset(key);
2119
497
    EVP_MD_CTX_free(mdctx);
2120
497
    return ret;
2121
497
}
2122
2123
/* Parse input as a new private key */
2124
int ossl_ml_kem_parse_private_key(const uint8_t *in, size_t len,
2125
    ML_KEM_KEY *key)
2126
55
{
2127
55
    EVP_MD_CTX *mdctx = NULL;
2128
55
    const ML_KEM_VINFO *vinfo;
2129
55
    int ret = 0;
2130
2131
    /* Keys with key material are immutable */
2132
55
    if (key == NULL
2133
55
        || ossl_ml_kem_have_pubkey(key)
2134
55
        || ossl_ml_kem_have_dkenc(key))
2135
0
        return 0;
2136
55
    vinfo = key->vinfo;
2137
2138
55
    if (len != vinfo->prvkey_bytes
2139
55
        || (mdctx = EVP_MD_CTX_new()) == NULL)
2140
0
        return 0;
2141
2142
55
    if (add_storage(OPENSSL_malloc(vinfo->prvalloc), 1, key))
2143
55
        ret = parse_prvkey(in, mdctx, key);
2144
2145
55
    if (!ret)
2146
55
        ossl_ml_kem_key_reset(key);
2147
55
    EVP_MD_CTX_free(mdctx);
2148
55
    return ret;
2149
55
}
2150
2151
/*
2152
 * Generate a new keypair, either from the saved seed (when non-null), or from
2153
 * the RNG.
2154
 */
2155
int ossl_ml_kem_genkey(uint8_t *pubenc, size_t publen, ML_KEM_KEY *key)
2156
49.8k
{
2157
49.8k
    uint8_t seed[ML_KEM_SEED_BYTES];
2158
49.8k
    EVP_MD_CTX *mdctx = NULL;
2159
49.8k
    const ML_KEM_VINFO *vinfo;
2160
49.8k
    int ret = 0;
2161
2162
49.8k
    if (key == NULL
2163
49.8k
        || ossl_ml_kem_have_pubkey(key)
2164
49.8k
        || ossl_ml_kem_have_dkenc(key))
2165
0
        return 0;
2166
49.8k
    vinfo = key->vinfo;
2167
2168
49.8k
    if (pubenc != NULL && publen != vinfo->pubkey_bytes)
2169
0
        return 0;
2170
2171
49.8k
    if (ossl_ml_kem_have_seed(key)) {
2172
29
        if (!ossl_ml_kem_encode_seed(seed, sizeof(seed), key))
2173
0
            return 0;
2174
29
        key->d = key->z = NULL;
2175
49.7k
    } else if (RAND_priv_bytes_ex(key->libctx, seed, sizeof(seed),
2176
49.7k
                   key->vinfo->secbits)
2177
49.7k
        <= 0) {
2178
0
        return 0;
2179
0
    }
2180
2181
49.8k
    if ((mdctx = EVP_MD_CTX_new()) == NULL)
2182
0
        return 0;
2183
2184
    /*
2185
     * Data derived from (d, z) defaults secret, and to avoid side-channel
2186
     * leaks should not influence control flow.
2187
     */
2188
49.8k
    CONSTTIME_SECRET(seed, ML_KEM_SEED_BYTES);
2189
2190
49.8k
    if (add_storage(OPENSSL_malloc(vinfo->prvalloc), 1, key))
2191
49.8k
        ret = genkey(seed, mdctx, pubenc, key);
2192
49.8k
    OPENSSL_cleanse(seed, sizeof(seed));
2193
2194
    /* Declassify secret inputs and derived outputs before returning control */
2195
49.8k
    CONSTTIME_DECLASSIFY(seed, ML_KEM_SEED_BYTES);
2196
2197
49.8k
    EVP_MD_CTX_free(mdctx);
2198
49.8k
    if (!ret) {
2199
0
        ossl_ml_kem_key_reset(key);
2200
0
        return 0;
2201
0
    }
2202
2203
    /* The public components are already declassified */
2204
49.8k
    CONSTTIME_DECLASSIFY(key->s, vinfo->rank * sizeof(scalar));
2205
49.8k
    CONSTTIME_DECLASSIFY(key->z, 2 * ML_KEM_RANDOM_BYTES);
2206
49.8k
    return 1;
2207
49.8k
}
2208
2209
/*
2210
 * FIPS 203, Section 6.2, Algorithm 17: ML-KEM.Encaps_internal
2211
 * This is the deterministic version with randomness supplied externally.
2212
 */
2213
int ossl_ml_kem_encap_seed(uint8_t *ctext, size_t clen,
2214
    uint8_t *shared_secret, size_t slen,
2215
    const uint8_t *entropy, size_t elen,
2216
    const ML_KEM_KEY *key)
2217
128
{
2218
128
    const ML_KEM_VINFO *vinfo;
2219
128
    EVP_MD_CTX *mdctx;
2220
128
    int ret = 0;
2221
2222
128
    if (key == NULL || !ossl_ml_kem_have_pubkey(key))
2223
0
        return 0;
2224
128
    vinfo = key->vinfo;
2225
2226
128
    if (ctext == NULL || clen != vinfo->ctext_bytes
2227
128
        || shared_secret == NULL || slen != ML_KEM_SHARED_SECRET_BYTES
2228
128
        || entropy == NULL || elen != ML_KEM_RANDOM_BYTES
2229
128
        || (mdctx = EVP_MD_CTX_new()) == NULL)
2230
0
        return 0;
2231
    /*
2232
     * Data derived from the encap entropy defaults secret, and to avoid
2233
     * side-channel leaks should not influence control flow.
2234
     */
2235
128
    CONSTTIME_SECRET(entropy, elen);
2236
2237
    /*-
2238
     * This avoids the need to handle allocation failures for two (max 2KB
2239
     * each) vectors, that are never retained on return from this function.
2240
     * We stack-allocate these.
2241
     */
2242
128
#define case_encap_seed(bits)                                        \
2243
128
    case EVP_PKEY_ML_KEM_##bits: {                                   \
2244
128
        scalar tmp[2 * ML_KEM_##bits##_RANK];                        \
2245
128
                                                                     \
2246
128
        ret = encap(ctext, shared_secret, entropy, tmp, mdctx, key); \
2247
128
        OPENSSL_cleanse((void *)tmp, sizeof(tmp));                   \
2248
128
        break;                                                       \
2249
128
    }
2250
128
    switch (vinfo->evp_type) {
2251
41
        case_encap_seed(512);
2252
52
        case_encap_seed(768);
2253
35
        case_encap_seed(1024);
2254
128
    }
2255
128
#undef case_encap_seed
2256
2257
    /* Declassify secret inputs and derived outputs before returning control */
2258
128
    CONSTTIME_DECLASSIFY(entropy, elen);
2259
128
    CONSTTIME_DECLASSIFY(ctext, clen);
2260
128
    CONSTTIME_DECLASSIFY(shared_secret, slen);
2261
2262
128
    EVP_MD_CTX_free(mdctx);
2263
128
    return ret;
2264
128
}
2265
2266
int ossl_ml_kem_encap_rand(uint8_t *ctext, size_t clen,
2267
    uint8_t *shared_secret, size_t slen,
2268
    const ML_KEM_KEY *key)
2269
128
{
2270
128
    uint8_t r[ML_KEM_RANDOM_BYTES];
2271
2272
128
    if (key == NULL)
2273
0
        return 0;
2274
2275
128
    if (RAND_bytes_ex(key->libctx, r, ML_KEM_RANDOM_BYTES,
2276
128
            key->vinfo->secbits)
2277
128
        < 1)
2278
0
        return 0;
2279
2280
128
    return ossl_ml_kem_encap_seed(ctext, clen, shared_secret, slen,
2281
128
        r, sizeof(r), key);
2282
128
}
2283
2284
int ossl_ml_kem_decap(uint8_t *shared_secret, size_t slen,
2285
    const uint8_t *ctext, size_t clen,
2286
    const ML_KEM_KEY *key)
2287
100
{
2288
100
    const ML_KEM_VINFO *vinfo;
2289
100
    EVP_MD_CTX *mdctx;
2290
100
    int ret = 0;
2291
#if defined(OPENSSL_CONSTANT_TIME_VALIDATION)
2292
    int classify_bytes;
2293
#endif
2294
2295
    /* Need a private key here */
2296
100
    if (!ossl_ml_kem_have_prvkey(key))
2297
0
        return 0;
2298
100
    vinfo = key->vinfo;
2299
2300
100
    if (shared_secret == NULL || slen != ML_KEM_SHARED_SECRET_BYTES
2301
100
        || ctext == NULL || clen != vinfo->ctext_bytes
2302
100
        || (mdctx = EVP_MD_CTX_new()) == NULL) {
2303
0
        (void)RAND_bytes_ex(key->libctx, shared_secret,
2304
0
            ML_KEM_SHARED_SECRET_BYTES, vinfo->secbits);
2305
0
        return 0;
2306
0
    }
2307
#if defined(OPENSSL_CONSTANT_TIME_VALIDATION)
2308
    /*
2309
     * Data derived from |s| and |z| defaults secret, and to avoid side-channel
2310
     * leaks should not influence control flow.
2311
     */
2312
    classify_bytes = 2 * sizeof(scalar) + ML_KEM_RANDOM_BYTES;
2313
    CONSTTIME_SECRET(key->s, classify_bytes);
2314
#endif
2315
2316
    /*-
2317
     * This avoids the need to handle allocation failures for two (max 2KB
2318
     * each) vectors and an encoded ciphertext (max 1568 bytes), that are never
2319
     * retained on return from this function.
2320
     * We stack-allocate these.
2321
     */
2322
100
#define case_decap(bits)                                          \
2323
100
    case EVP_PKEY_ML_KEM_##bits: {                                \
2324
100
        uint8_t cbuf[CTEXT_BYTES(bits)];                          \
2325
100
        scalar tmp[2 * ML_KEM_##bits##_RANK];                     \
2326
100
                                                                  \
2327
100
        ret = decap(shared_secret, ctext, cbuf, tmp, mdctx, key); \
2328
100
        OPENSSL_cleanse((void *)tmp, sizeof(tmp));                \
2329
100
        break;                                                    \
2330
100
    }
2331
100
    switch (vinfo->evp_type) {
2332
41
        case_decap(512);
2333
24
        case_decap(768);
2334
35
        case_decap(1024);
2335
100
    }
2336
2337
    /* Declassify secret inputs and derived outputs before returning control */
2338
100
    CONSTTIME_DECLASSIFY(key->s, classify_bytes);
2339
100
    CONSTTIME_DECLASSIFY(shared_secret, slen);
2340
100
    EVP_MD_CTX_free(mdctx);
2341
2342
100
    return ret;
2343
100
#undef case_decap
2344
100
}
2345
2346
int ossl_ml_kem_pubkey_cmp(const ML_KEM_KEY *key1, const ML_KEM_KEY *key2)
2347
138
{
2348
    /*
2349
     * This handles any unexpected differences in the ML-KEM variant rank,
2350
     * giving different key component structures, barring SHA3-256 hash
2351
     * collisions, the keys are the same size.
2352
     */
2353
138
    if (ossl_ml_kem_have_pubkey(key1) && ossl_ml_kem_have_pubkey(key2))
2354
138
        return memcmp(key1->pkhash, key2->pkhash, ML_KEM_PKHASH_BYTES) == 0;
2355
2356
    /*
2357
     * No match if just one of the public keys is not available, otherwise both
2358
     * are unavailable, and for now such keys are considered equal.
2359
     */
2360
0
    return (!(ossl_ml_kem_have_pubkey(key1) ^ ossl_ml_kem_have_pubkey(key2)));
2361
138
}