Coverage Report

Created: 2026-04-09 06:50

next uncovered line (L), next uncovered region (R), next uncovered branch (B)
/src/openssl36/crypto/ml_kem/ml_kem.c
Line
Count
Source
1
/*
2
 * Copyright 2024-2025 The OpenSSL Project Authors. All Rights Reserved.
3
 *
4
 * Licensed under the Apache License 2.0 (the "License").  You may not use
5
 * this file except in compliance with the License.  You can obtain a copy
6
 * in the file LICENSE in the source distribution or at
7
 * https://www.openssl.org/source/license.html
8
 */
9
10
#include <openssl/byteorder.h>
11
#include <openssl/rand.h>
12
#include <openssl/proverr.h>
13
#include "crypto/ml_kem.h"
14
#include "internal/common.h"
15
#include "internal/constant_time.h"
16
#include "internal/sha3.h"
17
18
#if defined(OPENSSL_CONSTANT_TIME_VALIDATION)
19
#include <valgrind/memcheck.h>
20
#endif
21
22
#if ML_KEM_SEED_BYTES != ML_KEM_SHARED_SECRET_BYTES + ML_KEM_RANDOM_BYTES
23
#error "ML-KEM keygen seed length != shared secret + random bytes length"
24
#endif
25
#if ML_KEM_SHARED_SECRET_BYTES != ML_KEM_RANDOM_BYTES
26
#error "Invalid unequal lengths of ML-KEM shared secret and random inputs"
27
#endif
28
29
#if UINT_MAX < UINT32_MAX
30
#error "Unsupported compiler: sizeof(unsigned int) < sizeof(uint32_t)"
31
#endif
32
33
/* Handy function-like bit-extraction macros */
34
39.8M
#define bit0(b) ((b) & 1)
35
278M
#define bitn(n, b) (((b) >> n) & 1)
36
37
/*
38
 * 12 bits are sufficient to losslessly represent values in [0, q-1].
39
 * INVERSE_DEGREE is (n/2)^-1 mod q; used in inverse NTT.
40
 */
41
3.37M
#define DEGREE ML_KEM_DEGREE
42
#define INVERSE_DEGREE (ML_KEM_PRIME - 2 * 13)
43
#define LOG2PRIME 12
44
#define BARRETT_SHIFT (2 * LOG2PRIME)
45
46
#ifdef SHA3_BLOCKSIZE
47
#define SHAKE128_BLOCKSIZE SHA3_BLOCKSIZE(128)
48
#endif
49
50
/*
51
 * Return whether a value that can only be 0 or 1 is non-zero, in constant time
52
 * in practice!  The return value is a mask that is all ones if true, and all
53
 * zeros otherwise (twos-complement arithmetic assumed for unsigned values).
54
 *
55
 * Although this is used in constant-time selects, we omit a value barrier
56
 * here.  Value barriers impede auto-vectorization (likely because it forces
57
 * the value to transit through a general-purpose register). On AArch64, this
58
 * is a difference of 2x.
59
 *
60
 * We usually add value barriers to selects because Clang turns consecutive
61
 * selects with the same condition into a branch instead of CMOV/CSEL. This
62
 * condition does not occur in Kyber, so omitting it seems to be safe so far,
63
 * but see |cbd_2|, |cbd_3|, where reduction needs to be specialised to the
64
 * sign of the input, rather than adding |q| in advance, and using the generic
65
 * |reduce_once|.  (David Benjamin, Chromium)
66
 */
67
#if 0
68
#define constish_time_non_zero(b) (~constant_time_is_zero(b));
69
#else
70
1.09G
#define constish_time_non_zero(b) (0u - (b))
71
#endif
72
73
/*
74
 * The scalar rejection-sampling buffer size needs to be a multiple of 12, but
75
 * is otherwise arbitrary, the preferred block size matches the internal buffer
76
 * size of SHAKE128, avoiding internal buffering and copying in SHAKE128. That
77
 * block size of (1600 - 256)/8 bytes, or 168, just happens to divide by 12!
78
 *
79
 * If the blocksize is unknown, or is not divisible by 12, 168 is used as a
80
 * fallback.
81
 */
82
#if defined(SHAKE128_BLOCKSIZE) && (SHAKE128_BLOCKSIZE) % 12 == 0
83
#define SCALAR_SAMPLING_BUFSIZE (SHAKE128_BLOCKSIZE)
84
#else
85
#define SCALAR_SAMPLING_BUFSIZE 168
86
#endif
87
88
/*
89
 * Structure of keys
90
 */
91
typedef struct ossl_ml_kem_scalar_st {
92
    /* On every function entry and exit, 0 <= c[i] < ML_KEM_PRIME. */
93
    uint16_t c[ML_KEM_DEGREE];
94
} scalar;
95
96
/* Key material allocation layout */
97
#define DECLARE_ML_KEM_PUBKEYDATA(name, rank)                  \
98
    struct name##_alloc {                                      \
99
        /* Public vector |t| */                                \
100
        scalar tbuf[(rank)];                                   \
101
        /* Pre-computed matrix |m| (FIPS 203 |A| transpose) */ \
102
        scalar mbuf[(rank) * (rank)];                          \
103
    }
104
105
#define DECLARE_ML_KEM_PRVKEYDATA(name, rank)  \
106
    struct name##_alloc {                      \
107
        scalar sbuf[rank];                     \
108
        uint8_t zbuf[2 * ML_KEM_RANDOM_BYTES]; \
109
    }
110
111
/* Declare variant-specific public and private storage */
112
#define DECLARE_ML_KEM_VARIANT_KEYDATA(bits)                        \
113
    DECLARE_ML_KEM_PUBKEYDATA(pubkey_##bits, ML_KEM_##bits##_RANK); \
114
    DECLARE_ML_KEM_PRVKEYDATA(prvkey_##bits, ML_KEM_##bits##_RANK)
115
116
DECLARE_ML_KEM_VARIANT_KEYDATA(512);
117
DECLARE_ML_KEM_VARIANT_KEYDATA(768);
118
DECLARE_ML_KEM_VARIANT_KEYDATA(1024);
119
#undef DECLARE_ML_KEM_VARIANT_KEYDATA
120
#undef DECLARE_ML_KEM_PUBKEYDATA
121
#undef DECLARE_ML_KEM_PRVKEYDATA
122
123
typedef __owur int (*CBD_FUNC)(scalar *out, uint8_t in[ML_KEM_RANDOM_BYTES + 1],
124
    EVP_MD_CTX *mdctx, const ML_KEM_KEY *key);
125
static void scalar_encode(uint8_t *out, const scalar *s, int bits);
126
127
/*
128
 * The wire-form of a losslessly encoded vector uses 12-bits per element.
129
 *
130
 * The wire-form public key consists of the lossless encoding of the public
131
 * vector |t|, followed by the public seed |rho|.
132
 *
133
 * Our serialised private key concatenates serialisations of the private vector
134
 * |s|, the public key, the public key hash, and the failure secret |z|.
135
 */
136
#define VECTOR_BYTES(b) ((3 * DEGREE / 2) * ML_KEM_##b##_RANK)
137
#define PUBKEY_BYTES(b) (VECTOR_BYTES(b) + ML_KEM_RANDOM_BYTES)
138
#define PRVKEY_BYTES(b) (2 * PUBKEY_BYTES(b) + ML_KEM_PKHASH_BYTES)
139
140
/*
141
 * Encapsulation produces a vector "u" and a scalar "v", whose coordinates
142
 * (numbers modulo the ML-KEM prime "q") are lossily encoded using as "du" and
143
 * "dv" bits, respectively.  This encoding is the ciphertext input for
144
 * decapsulation.
145
 */
146
#define U_VECTOR_BYTES(b) ((DEGREE / 8) * ML_KEM_##b##_DU * ML_KEM_##b##_RANK)
147
#define V_SCALAR_BYTES(b) ((DEGREE / 8) * ML_KEM_##b##_DV)
148
#define CTEXT_BYTES(b) (U_VECTOR_BYTES(b) + V_SCALAR_BYTES(b))
149
150
#if defined(OPENSSL_CONSTANT_TIME_VALIDATION)
151
152
/*
153
 * CONSTTIME_SECRET takes a pointer and a number of bytes and marks that region
154
 * of memory as secret. Secret data is tracked as it flows to registers and
155
 * other parts of a memory. If secret data is used as a condition for a branch,
156
 * or as a memory index, it will trigger warnings in valgrind.
157
 */
158
#define CONSTTIME_SECRET(ptr, len) VALGRIND_MAKE_MEM_UNDEFINED(ptr, len)
159
160
/*
161
 * CONSTTIME_DECLASSIFY takes a pointer and a number of bytes and marks that
162
 * region of memory as public. Public data is not subject to constant-time
163
 * rules.
164
 */
165
#define CONSTTIME_DECLASSIFY(ptr, len) VALGRIND_MAKE_MEM_DEFINED(ptr, len)
166
167
#else
168
169
#define CONSTTIME_SECRET(ptr, len)
170
#define CONSTTIME_DECLASSIFY(ptr, len)
171
172
#endif
173
174
/*
175
 * Indices of slots in the vinfo tables below
176
 */
177
60.0k
#define ML_KEM_512_VINFO 0
178
161k
#define ML_KEM_768_VINFO 1
179
56.4k
#define ML_KEM_1024_VINFO 2
180
181
/*
182
 * Per-variant fixed parameters
183
 */
184
static const ML_KEM_VINFO vinfo_map[3] = {
185
    { "ML-KEM-512",
186
        PRVKEY_BYTES(512),
187
        sizeof(struct prvkey_512_alloc),
188
        PUBKEY_BYTES(512),
189
        sizeof(struct pubkey_512_alloc),
190
        CTEXT_BYTES(512),
191
        VECTOR_BYTES(512),
192
        U_VECTOR_BYTES(512),
193
        EVP_PKEY_ML_KEM_512,
194
        ML_KEM_512_BITS,
195
        ML_KEM_512_RANK,
196
        ML_KEM_512_DU,
197
        ML_KEM_512_DV,
198
        ML_KEM_512_SECBITS,
199
        ML_KEM_512_SECURITY_CATEGORY },
200
    { "ML-KEM-768",
201
        PRVKEY_BYTES(768),
202
        sizeof(struct prvkey_768_alloc),
203
        PUBKEY_BYTES(768),
204
        sizeof(struct pubkey_768_alloc),
205
        CTEXT_BYTES(768),
206
        VECTOR_BYTES(768),
207
        U_VECTOR_BYTES(768),
208
        EVP_PKEY_ML_KEM_768,
209
        ML_KEM_768_BITS,
210
        ML_KEM_768_RANK,
211
        ML_KEM_768_DU,
212
        ML_KEM_768_DV,
213
        ML_KEM_768_SECBITS,
214
        ML_KEM_768_SECURITY_CATEGORY },
215
    { "ML-KEM-1024",
216
        PRVKEY_BYTES(1024),
217
        sizeof(struct prvkey_1024_alloc),
218
        PUBKEY_BYTES(1024),
219
        sizeof(struct pubkey_1024_alloc),
220
        CTEXT_BYTES(1024),
221
        VECTOR_BYTES(1024),
222
        U_VECTOR_BYTES(1024),
223
        EVP_PKEY_ML_KEM_1024,
224
        ML_KEM_1024_BITS,
225
        ML_KEM_1024_RANK,
226
        ML_KEM_1024_DU,
227
        ML_KEM_1024_DV,
228
        ML_KEM_1024_SECBITS,
229
        ML_KEM_1024_SECURITY_CATEGORY }
230
};
231
232
/*
233
 * Remainders modulo `kPrime`, for sufficiently small inputs, are computed in
234
 * constant time via Barrett reduction, and a final call to reduce_once(),
235
 * which reduces inputs that are at most 2*kPrime and is also constant-time.
236
 */
237
static const int kPrime = ML_KEM_PRIME;
238
static const unsigned int kBarrettShift = BARRETT_SHIFT;
239
static const size_t kBarrettMultiplier = (1 << BARRETT_SHIFT) / ML_KEM_PRIME;
240
static const uint16_t kHalfPrime = (ML_KEM_PRIME - 1) / 2;
241
static const uint16_t kInverseDegree = INVERSE_DEGREE;
242
243
/*
244
 * Python helper:
245
 *
246
 * p = 3329
247
 * def bitreverse(i):
248
 *     ret = 0
249
 *     for n in range(7):
250
 *         bit = i & 1
251
 *         ret <<= 1
252
 *         ret |= bit
253
 *         i >>= 1
254
 *     return ret
255
 */
256
257
/*-
258
 * First precomputed array from Appendix A of FIPS 203, or else Python:
259
 * kNTTRoots = [pow(17, bitreverse(i), p) for i in range(128)]
260
 */
261
static const uint16_t kNTTRoots[128] = {
262
    1,
263
    1729,
264
    2580,
265
    3289,
266
    2642,
267
    630,
268
    1897,
269
    848,
270
    1062,
271
    1919,
272
    193,
273
    797,
274
    2786,
275
    3260,
276
    569,
277
    1746,
278
    296,
279
    2447,
280
    1339,
281
    1476,
282
    3046,
283
    56,
284
    2240,
285
    1333,
286
    1426,
287
    2094,
288
    535,
289
    2882,
290
    2393,
291
    2879,
292
    1974,
293
    821,
294
    289,
295
    331,
296
    3253,
297
    1756,
298
    1197,
299
    2304,
300
    2277,
301
    2055,
302
    650,
303
    1977,
304
    2513,
305
    632,
306
    2865,
307
    33,
308
    1320,
309
    1915,
310
    2319,
311
    1435,
312
    807,
313
    452,
314
    1438,
315
    2868,
316
    1534,
317
    2402,
318
    2647,
319
    2617,
320
    1481,
321
    648,
322
    2474,
323
    3110,
324
    1227,
325
    910,
326
    17,
327
    2761,
328
    583,
329
    2649,
330
    1637,
331
    723,
332
    2288,
333
    1100,
334
    1409,
335
    2662,
336
    3281,
337
    233,
338
    756,
339
    2156,
340
    3015,
341
    3050,
342
    1703,
343
    1651,
344
    2789,
345
    1789,
346
    1847,
347
    952,
348
    1461,
349
    2687,
350
    939,
351
    2308,
352
    2437,
353
    2388,
354
    733,
355
    2337,
356
    268,
357
    641,
358
    1584,
359
    2298,
360
    2037,
361
    3220,
362
    375,
363
    2549,
364
    2090,
365
    1645,
366
    1063,
367
    319,
368
    2773,
369
    757,
370
    2099,
371
    561,
372
    2466,
373
    2594,
374
    2804,
375
    1092,
376
    403,
377
    1026,
378
    1143,
379
    2150,
380
    2775,
381
    886,
382
    1722,
383
    1212,
384
    1874,
385
    1029,
386
    2110,
387
    2935,
388
    885,
389
    2154,
390
};
391
392
/*
393
 * InverseNTTRoots = [pow(17, -bitreverse(i), p) for i in range(128)]
394
 * Listed in order of use in the inverse NTT loop (index 0 is skipped):
395
 *
396
 *  0, 64, 65, ..., 127, 32, 33, ..., 63, 16, 17, ..., 31, 8, 9, ...
397
 */
398
static const uint16_t kInverseNTTRoots[128] = {
399
    1,
400
    1175,
401
    2444,
402
    394,
403
    1219,
404
    2300,
405
    1455,
406
    2117,
407
    1607,
408
    2443,
409
    554,
410
    1179,
411
    2186,
412
    2303,
413
    2926,
414
    2237,
415
    525,
416
    735,
417
    863,
418
    2768,
419
    1230,
420
    2572,
421
    556,
422
    3010,
423
    2266,
424
    1684,
425
    1239,
426
    780,
427
    2954,
428
    109,
429
    1292,
430
    1031,
431
    1745,
432
    2688,
433
    3061,
434
    992,
435
    2596,
436
    941,
437
    892,
438
    1021,
439
    2390,
440
    642,
441
    1868,
442
    2377,
443
    1482,
444
    1540,
445
    540,
446
    1678,
447
    1626,
448
    279,
449
    314,
450
    1173,
451
    2573,
452
    3096,
453
    48,
454
    667,
455
    1920,
456
    2229,
457
    1041,
458
    2606,
459
    1692,
460
    680,
461
    2746,
462
    568,
463
    3312,
464
    2419,
465
    2102,
466
    219,
467
    855,
468
    2681,
469
    1848,
470
    712,
471
    682,
472
    927,
473
    1795,
474
    461,
475
    1891,
476
    2877,
477
    2522,
478
    1894,
479
    1010,
480
    1414,
481
    2009,
482
    3296,
483
    464,
484
    2697,
485
    816,
486
    1352,
487
    2679,
488
    1274,
489
    1052,
490
    1025,
491
    2132,
492
    1573,
493
    76,
494
    2998,
495
    3040,
496
    2508,
497
    1355,
498
    450,
499
    936,
500
    447,
501
    2794,
502
    1235,
503
    1903,
504
    1996,
505
    1089,
506
    3273,
507
    283,
508
    1853,
509
    1990,
510
    882,
511
    3033,
512
    1583,
513
    2760,
514
    69,
515
    543,
516
    2532,
517
    3136,
518
    1410,
519
    2267,
520
    2481,
521
    1432,
522
    2699,
523
    687,
524
    40,
525
    749,
526
    1600,
527
};
528
529
/*
530
 * Second precomputed array from Appendix A of FIPS 203 (normalised positive),
531
 * or else Python:
532
 * ModRoots = [pow(17, 2*bitreverse(i) + 1, p) for i in range(128)]
533
 */
534
static const uint16_t kModRoots[128] = {
535
    17,
536
    3312,
537
    2761,
538
    568,
539
    583,
540
    2746,
541
    2649,
542
    680,
543
    1637,
544
    1692,
545
    723,
546
    2606,
547
    2288,
548
    1041,
549
    1100,
550
    2229,
551
    1409,
552
    1920,
553
    2662,
554
    667,
555
    3281,
556
    48,
557
    233,
558
    3096,
559
    756,
560
    2573,
561
    2156,
562
    1173,
563
    3015,
564
    314,
565
    3050,
566
    279,
567
    1703,
568
    1626,
569
    1651,
570
    1678,
571
    2789,
572
    540,
573
    1789,
574
    1540,
575
    1847,
576
    1482,
577
    952,
578
    2377,
579
    1461,
580
    1868,
581
    2687,
582
    642,
583
    939,
584
    2390,
585
    2308,
586
    1021,
587
    2437,
588
    892,
589
    2388,
590
    941,
591
    733,
592
    2596,
593
    2337,
594
    992,
595
    268,
596
    3061,
597
    641,
598
    2688,
599
    1584,
600
    1745,
601
    2298,
602
    1031,
603
    2037,
604
    1292,
605
    3220,
606
    109,
607
    375,
608
    2954,
609
    2549,
610
    780,
611
    2090,
612
    1239,
613
    1645,
614
    1684,
615
    1063,
616
    2266,
617
    319,
618
    3010,
619
    2773,
620
    556,
621
    757,
622
    2572,
623
    2099,
624
    1230,
625
    561,
626
    2768,
627
    2466,
628
    863,
629
    2594,
630
    735,
631
    2804,
632
    525,
633
    1092,
634
    2237,
635
    403,
636
    2926,
637
    1026,
638
    2303,
639
    1143,
640
    2186,
641
    2150,
642
    1179,
643
    2775,
644
    554,
645
    886,
646
    2443,
647
    1722,
648
    1607,
649
    1212,
650
    2117,
651
    1874,
652
    1455,
653
    1029,
654
    2300,
655
    2110,
656
    1219,
657
    2935,
658
    394,
659
    885,
660
    2444,
661
    2154,
662
    1175,
663
};
664
665
/*
666
 * single_keccak hashes |inlen| bytes from |in| and writes |outlen| bytes of
667
 * output to |out|. If the |md| specifies a fixed-output function, like
668
 * SHA3-256, then |outlen| must be the correct length for that function.
669
 */
670
static __owur int single_keccak(uint8_t *out, size_t outlen, const uint8_t *in, size_t inlen,
671
    EVP_MD_CTX *mdctx)
672
362k
{
673
362k
    unsigned int sz = (unsigned int)outlen;
674
675
362k
    if (!EVP_DigestUpdate(mdctx, in, inlen))
676
0
        return 0;
677
362k
    if (EVP_MD_xof(EVP_MD_CTX_get0_md(mdctx)))
678
310k
        return EVP_DigestFinalXOF(mdctx, out, outlen);
679
52.0k
    return EVP_DigestFinal_ex(mdctx, out, &sz)
680
52.0k
        && ossl_assert((size_t)sz == outlen);
681
362k
}
682
683
/*
684
 * FIPS 203, Section 4.1, equation (4.3): PRF. Takes 32+1 input bytes, and uses
685
 * SHAKE256 to produce the input to SamplePolyCBD_eta: FIPS 203, algorithm 8.
686
 */
687
static __owur int prf(uint8_t *out, size_t len, const uint8_t in[ML_KEM_RANDOM_BYTES + 1],
688
    EVP_MD_CTX *mdctx, const ML_KEM_KEY *key)
689
310k
{
690
310k
    return EVP_DigestInit_ex(mdctx, key->shake256_md, NULL)
691
310k
        && single_keccak(out, len, in, ML_KEM_RANDOM_BYTES + 1, mdctx);
692
310k
}
693
694
/*
695
 * FIPS 203, Section 4.1, equation (4.4): H.  SHA3-256 hash of a variable
696
 * length input, producing 32 bytes of output.
697
 */
698
static __owur int hash_h(uint8_t out[ML_KEM_PKHASH_BYTES], const uint8_t *in, size_t len,
699
    EVP_MD_CTX *mdctx, const ML_KEM_KEY *key)
700
325
{
701
325
    return EVP_DigestInit_ex(mdctx, key->sha3_256_md, NULL)
702
325
        && single_keccak(out, ML_KEM_PKHASH_BYTES, in, len, mdctx);
703
325
}
704
705
/* Incremental hash_h of expanded public key */
706
static int
707
hash_h_pubkey(uint8_t pkhash[ML_KEM_PKHASH_BYTES],
708
    EVP_MD_CTX *mdctx, ML_KEM_KEY *key)
709
51.3k
{
710
51.3k
    const ML_KEM_VINFO *vinfo = key->vinfo;
711
51.3k
    const scalar *t = key->t, *end = t + vinfo->rank;
712
51.3k
    unsigned int sz;
713
714
51.3k
    if (!EVP_DigestInit_ex(mdctx, key->sha3_256_md, NULL))
715
0
        return 0;
716
717
154k
    do {
718
154k
        uint8_t buf[3 * DEGREE / 2];
719
720
154k
        scalar_encode(buf, t++, 12);
721
154k
        if (!EVP_DigestUpdate(mdctx, buf, sizeof(buf)))
722
0
            return 0;
723
154k
    } while (t < end);
724
725
51.3k
    if (!EVP_DigestUpdate(mdctx, key->rho, ML_KEM_RANDOM_BYTES))
726
0
        return 0;
727
51.3k
    return EVP_DigestFinal_ex(mdctx, pkhash, &sz)
728
51.3k
        && ossl_assert(sz == ML_KEM_PKHASH_BYTES);
729
51.3k
}
730
731
/*
732
 * FIPS 203, Section 4.1, equation (4.5): G.  SHA3-512 hash of a variable
733
 * length input, producing 64 bytes of output, in particular the seeds
734
 * (d,z) for key generation.
735
 */
736
static __owur int hash_g(uint8_t out[ML_KEM_SEED_BYTES], const uint8_t *in, size_t len,
737
    EVP_MD_CTX *mdctx, const ML_KEM_KEY *key)
738
51.7k
{
739
51.7k
    return EVP_DigestInit_ex(mdctx, key->sha3_512_md, NULL)
740
51.7k
        && single_keccak(out, ML_KEM_SEED_BYTES, in, len, mdctx);
741
51.7k
}
742
743
/*
744
 * FIPS 203, Section 4.1, equation (4.4): J. SHAKE256 taking a variable length
745
 * input to compute a 32-byte implicit rejection shared secret, of the same
746
 * length as the expected shared secret.  (Computed even on success to avoid
747
 * side-channel leaks).
748
 */
749
static __owur int kdf(uint8_t out[ML_KEM_SHARED_SECRET_BYTES],
750
    const uint8_t z[ML_KEM_RANDOM_BYTES],
751
    const uint8_t *ctext, size_t len,
752
    EVP_MD_CTX *mdctx, const ML_KEM_KEY *key)
753
173
{
754
173
    return EVP_DigestInit_ex(mdctx, key->shake256_md, NULL)
755
173
        && EVP_DigestUpdate(mdctx, z, ML_KEM_RANDOM_BYTES)
756
173
        && EVP_DigestUpdate(mdctx, ctext, len)
757
173
        && EVP_DigestFinalXOF(mdctx, out, ML_KEM_SHARED_SECRET_BYTES);
758
173
}
759
760
/*
761
 * FIPS 203, Section 4.2.2, Algorithm 7: "SampleNTT" (steps 3-17, steps 1, 2
762
 * are performed by the caller). Rejection-samples a Keccak stream to get
763
 * uniformly distributed elements in the range [0,q). This is used for matrix
764
 * expansion and only operates on public inputs.
765
 */
766
static __owur int sample_scalar(scalar *out, EVP_MD_CTX *mdctx)
767
464k
{
768
464k
    uint16_t *curr = out->c, *endout = curr + DEGREE;
769
464k
    uint8_t buf[SCALAR_SAMPLING_BUFSIZE], *in;
770
464k
    uint8_t *endin = buf + sizeof(buf);
771
464k
    uint16_t d;
772
464k
    uint8_t b1, b2, b3;
773
774
1.39M
    do {
775
1.39M
        if (!EVP_DigestSqueeze(mdctx, in = buf, sizeof(buf)))
776
0
            return 0;
777
72.8M
        do {
778
72.8M
            b1 = *in++;
779
72.8M
            b2 = *in++;
780
72.8M
            b3 = *in++;
781
782
72.8M
            if (curr >= endout)
783
156k
                break;
784
72.6M
            if ((d = ((b2 & 0x0f) << 8) + b1) < kPrime)
785
59.9M
                *curr++ = d;
786
72.6M
            if (curr >= endout)
787
308k
                break;
788
72.3M
            if ((d = (b3 << 4) + (b2 >> 4)) < kPrime)
789
58.9M
                *curr++ = d;
790
72.3M
        } while (in < endin);
791
1.39M
    } while (curr < endout);
792
464k
    return 1;
793
464k
}
794
795
/*-
796
 * reduce_once reduces 0 <= x < 2*kPrime, mod kPrime.
797
 *
798
 * Subtract |q| if the input is larger, without exposing a side-channel,
799
 * avoiding the "clangover" attack.  See |constish_time_non_zero| for a
800
 * discussion on why the value barrier is by default omitted.
801
 */
802
static __owur uint16_t reduce_once(uint16_t x)
803
1.01G
{
804
1.01G
    const uint16_t subtracted = x - kPrime;
805
1.01G
    uint16_t mask = constish_time_non_zero(subtracted >> 15);
806
807
1.01G
    return (mask & x) | (~mask & subtracted);
808
1.01G
}
809
810
/*
811
 * Constant-time reduce x mod kPrime using Barrett reduction. x must be less
812
 * than kPrime + 2 * kPrime^2.  This is sufficient to reduce a product of
813
 * two already reduced u_int16 values, in fact it is sufficient for each
814
 * to be less than 2^12, because (kPrime * (2 * kPrime + 1)) > 2^24.
815
 */
816
static __owur uint16_t reduce(uint32_t x)
817
458M
{
818
458M
    uint64_t product = (uint64_t)x * kBarrettMultiplier;
819
458M
    uint32_t quotient = (uint32_t)(product >> kBarrettShift);
820
458M
    uint32_t remainder = x - quotient * kPrime;
821
822
458M
    return reduce_once(remainder);
823
458M
}
824
825
/* Multiply a scalar by a constant. */
826
static void scalar_mult_const(scalar *s, uint16_t a)
827
1.61k
{
828
1.61k
    uint16_t *curr = s->c, *end = curr + DEGREE, tmp;
829
830
414k
    do {
831
414k
        tmp = reduce(*curr * a);
832
414k
        *curr++ = tmp;
833
414k
    } while (curr < end);
834
1.61k
}
835
836
/*-
837
 * FIPS 203, Section 4.3, Algorithm 9: "NTT".
838
 * In-place number theoretic transform of a given scalar.  Note that ML-KEM's
839
 * kPrime 3329 does not have a 512th root of unity, so this transform leaves
840
 * off the last iteration of the usual FFT code, with the 128 relevant roots of
841
 * unity being stored in NTTRoots.  This means the output should be seen as 128
842
 * elements in GF(3329^2), with the coefficients of the elements being
843
 * consecutive entries in |s->c|.
844
 */
845
static void scalar_ntt(scalar *s)
846
309k
{
847
309k
    const uint16_t *roots = kNTTRoots;
848
309k
    uint16_t *end = s->c + DEGREE;
849
309k
    int offset = DEGREE / 2;
850
851
2.16M
    do {
852
2.16M
        uint16_t *curr = s->c, *peer;
853
854
39.3M
        do {
855
39.3M
            uint16_t *pause = curr + offset, even, odd;
856
39.3M
            uint32_t zeta = *++roots;
857
858
39.3M
            peer = pause;
859
277M
            do {
860
277M
                even = *curr;
861
277M
                odd = reduce(*peer * zeta);
862
277M
                *peer++ = reduce_once(even - odd + kPrime);
863
277M
                *curr++ = reduce_once(odd + even);
864
277M
            } while (curr < pause);
865
39.3M
        } while ((curr = peer) < end);
866
2.16M
    } while ((offset >>= 1) >= 2);
867
309k
}
868
869
/*-
870
 * FIPS 203, Section 4.3, Algorithm 10: "NTT^(-1)".
871
 * In-place inverse number theoretic transform of a given scalar, with pairs of
872
 * entries of s->v being interpreted as elements of GF(3329^2). Just as with
873
 * the number theoretic transform, this leaves off the first step of the normal
874
 * iFFT to account for the fact that 3329 does not have a 512th root of unity,
875
 * using the precomputed 128 roots of unity stored in InverseNTTRoots.
876
 */
877
static void scalar_inverse_ntt(scalar *s)
878
1.61k
{
879
1.61k
    const uint16_t *roots = kInverseNTTRoots;
880
1.61k
    uint16_t *end = s->c + DEGREE;
881
1.61k
    int offset = 2;
882
883
11.3k
    do {
884
11.3k
        uint16_t *curr = s->c, *peer;
885
886
205k
        do {
887
205k
            uint16_t *pause = curr + offset, even, odd;
888
205k
            uint32_t zeta = *++roots;
889
890
205k
            peer = pause;
891
1.45M
            do {
892
1.45M
                even = *curr;
893
1.45M
                odd = *peer;
894
1.45M
                *peer++ = reduce(zeta * (even - odd + kPrime));
895
1.45M
                *curr++ = reduce_once(odd + even);
896
1.45M
            } while (curr < pause);
897
205k
        } while ((curr = peer) < end);
898
11.3k
    } while ((offset <<= 1) < DEGREE);
899
1.61k
    scalar_mult_const(s, kInverseDegree);
900
1.61k
}
901
902
/* Addition updating the LHS scalar in-place. */
903
static void scalar_add(scalar *lhs, const scalar *rhs)
904
1.44k
{
905
1.44k
    int i;
906
907
371k
    for (i = 0; i < DEGREE; i++)
908
370k
        lhs->c[i] = reduce_once(lhs->c[i] + rhs->c[i]);
909
1.44k
}
910
911
/* Subtraction updating the LHS scalar in-place. */
912
static void scalar_sub(scalar *lhs, const scalar *rhs)
913
173
{
914
173
    int i;
915
916
44.4k
    for (i = 0; i < DEGREE; i++)
917
44.2k
        lhs->c[i] = reduce_once(lhs->c[i] - rhs->c[i] + kPrime);
918
173
}
919
920
/*
921
 * Multiplying two scalars in the number theoretically transformed state. Since
922
 * 3329 does not have a 512th root of unity, this means we have to interpret
923
 * the 2*ith and (2*i+1)th entries of the scalar as elements of
924
 * GF(3329)[X]/(X^2 - 17^(2*bitreverse(i)+1)).
925
 *
926
 * The value of 17^(2*bitreverse(i)+1) mod 3329 is stored in the precomputed
927
 * ModRoots table. Note that our Barrett transform only allows us to multiply
928
 * two reduced numbers together, so we need some intermediate reduction steps,
929
 * even if an uint64_t could hold 3 multiplied numbers.
930
 */
931
static void scalar_mult(scalar *out, const scalar *lhs,
932
    const scalar *rhs)
933
1.61k
{
934
1.61k
    uint16_t *curr = out->c, *end = curr + DEGREE;
935
1.61k
    const uint16_t *lc = lhs->c, *rc = rhs->c;
936
1.61k
    const uint16_t *roots = kModRoots;
937
938
207k
    do {
939
207k
        uint32_t l0 = *lc++, r0 = *rc++;
940
207k
        uint32_t l1 = *lc++, r1 = *rc++;
941
207k
        uint32_t zetapow = *roots++;
942
943
207k
        *curr++ = reduce(l0 * r0 + reduce(l1 * r1) * zetapow);
944
207k
        *curr++ = reduce(l0 * r1 + l1 * r0);
945
207k
    } while (curr < end);
946
1.61k
}
947
948
/* Above, but add the result to an existing scalar */
949
static ossl_inline void scalar_mult_add(scalar *out, const scalar *lhs,
950
    const scalar *rhs)
951
465k
{
952
465k
    uint16_t *curr = out->c, *end = curr + DEGREE;
953
465k
    const uint16_t *lc = lhs->c, *rc = rhs->c;
954
465k
    const uint16_t *roots = kModRoots;
955
956
59.5M
    do {
957
59.5M
        uint32_t l0 = *lc++, r0 = *rc++;
958
59.5M
        uint32_t l1 = *lc++, r1 = *rc++;
959
59.5M
        uint16_t *c0 = curr++;
960
59.5M
        uint16_t *c1 = curr++;
961
59.5M
        uint32_t zetapow = *roots++;
962
963
59.5M
        *c0 = reduce(*c0 + l0 * r0 + reduce(l1 * r1) * zetapow);
964
59.5M
        *c1 = reduce(*c1 + l0 * r1 + l1 * r0);
965
59.5M
    } while (curr < end);
966
465k
}
967
968
/*-
969
 * FIPS 203, Section 4.2.1, Algorithm 5: "ByteEncode_d", for 2<=d<=12.
970
 * Here |bits| is |d|.  For efficiency, we handle the d=1 case separately.
971
 */
972
static void scalar_encode(uint8_t *out, const scalar *s, int bits)
973
309k
{
974
309k
    const uint16_t *curr = s->c, *end = curr + DEGREE;
975
309k
    uint64_t accum = 0, element;
976
309k
    int used = 0;
977
978
79.2M
    do {
979
79.2M
        element = *curr++;
980
79.2M
        if (used + bits < 64) {
981
64.4M
            accum |= element << used;
982
64.4M
            used += bits;
983
64.4M
        } else if (used + bits > 64) {
984
9.89M
            out = OPENSSL_store_u64_le(out, accum | (element << used));
985
9.89M
            accum = element >> (64 - used);
986
9.89M
            used = (used + bits) - 64;
987
9.89M
        } else {
988
4.94M
            out = OPENSSL_store_u64_le(out, accum | (element << used));
989
4.94M
            accum = 0;
990
4.94M
            used = 0;
991
4.94M
        }
992
79.2M
    } while (curr < end);
993
309k
}
994
995
/*
996
 * scalar_encode_1 is |scalar_encode| specialised for |bits| == 1.
997
 */
998
static void scalar_encode_1(uint8_t out[DEGREE / 8], const scalar *s)
999
173
{
1000
173
    int i, j;
1001
173
    uint8_t out_byte;
1002
1003
5.70k
    for (i = 0; i < DEGREE; i += 8) {
1004
5.53k
        out_byte = 0;
1005
49.8k
        for (j = 0; j < 8; j++)
1006
44.2k
            out_byte |= bit0(s->c[i + j]) << j;
1007
5.53k
        *out = out_byte;
1008
5.53k
        out++;
1009
5.53k
    }
1010
173
}
1011
1012
/*-
1013
 * FIPS 203, Section 4.2.1, Algorithm 6: "ByteDecode_d", for 2<=d<12.
1014
 * Here |bits| is |d|.  For efficiency, we handle the d=1 and d=12 cases
1015
 * separately.
1016
 *
1017
 * scalar_decode parses |DEGREE * bits| bits from |in| into |DEGREE| values in
1018
 * |out|.
1019
 */
1020
static void scalar_decode(scalar *out, const uint8_t *in, int bits)
1021
675
{
1022
675
    uint16_t *curr = out->c, *end = curr + DEGREE;
1023
675
    uint64_t accum = 0;
1024
675
    int accum_bits = 0, todo = bits;
1025
675
    uint16_t bitmask = (((uint16_t)1) << bits) - 1, mask = bitmask;
1026
675
    uint16_t element = 0;
1027
1028
191k
    do {
1029
191k
        if (accum_bits == 0) {
1030
23.8k
            in = OPENSSL_load_u64_le(&accum, in);
1031
23.8k
            accum_bits = 64;
1032
23.8k
        }
1033
191k
        if (todo == bits && accum_bits >= bits) {
1034
            /* No partial "element", and all the required bits available */
1035
154k
            *curr++ = ((uint16_t)accum) & mask;
1036
154k
            accum >>= bits;
1037
154k
            accum_bits -= bits;
1038
154k
        } else if (accum_bits >= todo) {
1039
            /* A partial "element", and all the required bits available */
1040
18.5k
            *curr++ = element | ((((uint16_t)accum) & mask) << (bits - todo));
1041
18.5k
            accum >>= todo;
1042
18.5k
            accum_bits -= todo;
1043
18.5k
            element = 0;
1044
18.5k
            todo = bits;
1045
18.5k
            mask = bitmask;
1046
18.5k
        } else {
1047
            /*
1048
             * Only some of the requisite bits accumulated, store |accum_bits|
1049
             * of these in |element|.  The accumulated bitcount becomes 0, but
1050
             * as soon as we have more bits we'll want to merge accum_bits
1051
             * fewer of them into the final |element|.
1052
             *
1053
             * Note that with a 64-bit accumulator and |bits| always 12 or
1054
             * less, if we're here, the previous iteration had all the
1055
             * requisite bits, and so there are no kept bits in |element|.
1056
             */
1057
18.5k
            element = ((uint16_t)accum) & mask;
1058
18.5k
            todo -= accum_bits;
1059
18.5k
            mask = bitmask >> accum_bits;
1060
18.5k
            accum_bits = 0;
1061
18.5k
        }
1062
191k
    } while (curr < end);
1063
675
}
1064
1065
static __owur int scalar_decode_12(scalar *out, const uint8_t in[3 * DEGREE / 2])
1066
1.26k
{
1067
1.26k
    int i;
1068
1.26k
    uint16_t *c = out->c;
1069
1070
127k
    for (i = 0; i < DEGREE / 2; ++i) {
1071
126k
        uint8_t b1 = *in++;
1072
126k
        uint8_t b2 = *in++;
1073
126k
        uint8_t b3 = *in++;
1074
126k
        int outOfRange1 = (*c++ = b1 | ((b2 & 0x0f) << 8)) >= kPrime;
1075
126k
        int outOfRange2 = (*c++ = (b2 >> 4) | (b3 << 4)) >= kPrime;
1076
1077
126k
        if (outOfRange1 | outOfRange2)
1078
332
            return 0;
1079
126k
    }
1080
936
    return 1;
1081
1.26k
}
1082
1083
/*-
1084
 * scalar_decode_decompress_add is a combination of decoding and decompression
1085
 * both specialised for |bits| == 1, with the result added (and sum reduced) to
1086
 * the output scalar.
1087
 *
1088
 * NOTE: this function MUST not leak an input-data-depedennt timing signal.
1089
 * A timing leak in a related function in the reference Kyber implementation
1090
 * made the "clangover" attack (CVE-2024-37880) possible, giving key recovery
1091
 * for ML-KEM-512 in minutes, provided the attacker has access to precise
1092
 * timing of a CPU performing chosen-ciphertext decap.  Admittedly this is only
1093
 * a risk when private keys are reused (perhaps KEMTLS servers).
1094
 */
1095
static void
1096
scalar_decode_decompress_add(scalar *out, const uint8_t in[DEGREE / 8])
1097
370
{
1098
370
    static const uint16_t half_q_plus_1 = (ML_KEM_PRIME >> 1) + 1;
1099
370
    uint16_t *curr = out->c, *end = curr + DEGREE;
1100
370
    uint16_t mask;
1101
370
    uint8_t b;
1102
1103
    /*
1104
     * Add |half_q_plus_1| if the bit is set, without exposing a side-channel,
1105
     * avoiding the "clangover" attack.  See |constish_time_non_zero| for a
1106
     * discussion on why the value barrier is by default omitted.
1107
     */
1108
370
#define decode_decompress_add_bit                        \
1109
94.7k
    mask = constish_time_non_zero(bit0(b));              \
1110
94.7k
    *curr = reduce_once(*curr + (mask & half_q_plus_1)); \
1111
94.7k
    curr++;                                              \
1112
94.7k
    b >>= 1
1113
1114
    /* Unrolled to process each byte in one iteration */
1115
11.8k
    do {
1116
11.8k
        b = *in++;
1117
11.8k
        decode_decompress_add_bit;
1118
11.8k
        decode_decompress_add_bit;
1119
11.8k
        decode_decompress_add_bit;
1120
11.8k
        decode_decompress_add_bit;
1121
1122
11.8k
        decode_decompress_add_bit;
1123
11.8k
        decode_decompress_add_bit;
1124
11.8k
        decode_decompress_add_bit;
1125
11.8k
        decode_decompress_add_bit;
1126
11.8k
    } while (curr < end);
1127
370
#undef decode_decompress_add_bit
1128
370
}
1129
1130
/*
1131
 * FIPS 203, Section 4.2.1, Equation (4.7): Compress_d.
1132
 *
1133
 * Compresses (lossily) an input |x| mod 3329 into |bits| many bits by grouping
1134
 * numbers close to each other together. The formula used is
1135
 * round(2^|bits|/kPrime*x) mod 2^|bits|.
1136
 * Uses Barrett reduction to achieve constant time. Since we need both the
1137
 * remainder (for rounding) and the quotient (as the result), we cannot use
1138
 * |reduce| here, but need to do the Barrett reduction directly.
1139
 */
1140
static __owur uint16_t compress(uint16_t x, int bits)
1141
414k
{
1142
414k
    uint32_t shifted = (uint32_t)x << bits;
1143
414k
    uint64_t product = (uint64_t)shifted * kBarrettMultiplier;
1144
414k
    uint32_t quotient = (uint32_t)(product >> kBarrettShift);
1145
414k
    uint32_t remainder = shifted - quotient * kPrime;
1146
1147
    /*
1148
     * Adjust the quotient to round correctly:
1149
     *   0 <= remainder <= kHalfPrime round to 0
1150
     *   kHalfPrime < remainder <= kPrime + kHalfPrime round to 1
1151
     *   kPrime + kHalfPrime < remainder < 2 * kPrime round to 2
1152
     */
1153
414k
    quotient += 1 & constant_time_lt_32(kHalfPrime, remainder);
1154
414k
    quotient += 1 & constant_time_lt_32(kPrime + kHalfPrime, remainder);
1155
414k
    return quotient & ((1 << bits) - 1);
1156
414k
}
1157
1158
/*
1159
 * FIPS 203, Section 4.2.1, Equation (4.8): Decompress_d.
1160
1161
 * Decompresses |x| by using a close equi-distant representative. The formula
1162
 * is round(kPrime/2^|bits|*x). Note that 2^|bits| being the divisor allows us
1163
 * to implement this logic using only bit operations.
1164
 */
1165
static __owur uint16_t decompress(uint16_t x, int bits)
1166
172k
{
1167
172k
    uint32_t product = (uint32_t)x * kPrime;
1168
172k
    uint32_t power = 1 << bits;
1169
    /* This is |product| % power, since |power| is a power of 2. */
1170
172k
    uint32_t remainder = product & (power - 1);
1171
    /* This is |product| / power, since |power| is a power of 2. */
1172
172k
    uint32_t lower = product >> bits;
1173
1174
    /*
1175
     * The rounding logic works since the first half of numbers mod |power|
1176
     * have a 0 as first bit, and the second half has a 1 as first bit, since
1177
     * |power| is a power of 2. As a 12 bit number, |remainder| is always
1178
     * positive, so we will shift in 0s for a right shift.
1179
     */
1180
172k
    return lower + (remainder >> (bits - 1));
1181
172k
}
1182
1183
/*-
1184
 * FIPS 203, Section 4.2.1, Equation (4.7): "Compress_d".
1185
 * In-place lossy rounding of scalars to 2^d bits.
1186
 */
1187
static void scalar_compress(scalar *s, int bits)
1188
1.61k
{
1189
1.61k
    int i;
1190
1191
416k
    for (i = 0; i < DEGREE; i++)
1192
414k
        s->c[i] = compress(s->c[i], bits);
1193
1.61k
}
1194
1195
/*
1196
 * FIPS 203, Section 4.2.1, Equation (4.8): "Decompress_d".
1197
 * In-place approximate recovery of scalars from 2^d bit compression.
1198
 */
1199
static void scalar_decompress(scalar *s, int bits)
1200
675
{
1201
675
    int i;
1202
1203
173k
    for (i = 0; i < DEGREE; i++)
1204
172k
        s->c[i] = decompress(s->c[i], bits);
1205
675
}
1206
1207
/* Addition updating the LHS vector in-place. */
1208
static void vector_add(scalar *lhs, const scalar *rhs, int rank)
1209
370
{
1210
1.07k
    do {
1211
1.07k
        scalar_add(lhs++, rhs++);
1212
1.07k
    } while (--rank > 0);
1213
370
}
1214
1215
/*
1216
 * Encodes an entire vector into 32*|rank|*|bits| bytes. Note that since 256
1217
 * (DEGREE) is divisible by 8, the individual vector entries will always fill a
1218
 * whole number of bytes, so we do not need to worry about bit packing here.
1219
 */
1220
static void vector_encode(uint8_t *out, const scalar *a, int bits, int rank)
1221
51.7k
{
1222
51.7k
    int stride = bits * DEGREE / 8;
1223
1224
206k
    for (; rank-- > 0; out += stride)
1225
155k
        scalar_encode(out, a++, bits);
1226
51.7k
}
1227
1228
/*
1229
 * Decodes 32*|rank|*|bits| bytes from |in| into |out|. It returns early
1230
 * if any parsed value is >= |ML_KEM_PRIME|.  The resulting scalars are
1231
 * then decompressed and transformed via the NTT.
1232
 *
1233
 * Note: Used only in decrypt_cpa(), which returns void and so does not check
1234
 * the return value of this function.  Side-channels are fine when the input
1235
 * ciphertext to decap() is simply syntactically invalid.
1236
 */
1237
static void
1238
vector_decode_decompress_ntt(scalar *out, const uint8_t *in, int bits, int rank)
1239
173
{
1240
173
    int stride = bits * DEGREE / 8;
1241
1242
675
    for (; rank-- > 0; in += stride, ++out) {
1243
502
        scalar_decode(out, in, bits);
1244
502
        scalar_decompress(out, bits);
1245
502
        scalar_ntt(out);
1246
502
    }
1247
173
}
1248
1249
/* vector_decode(), specialised to bits == 12. */
1250
static __owur int vector_decode_12(scalar *out, const uint8_t in[3 * DEGREE / 2], int rank)
1251
670
{
1252
670
    int stride = 3 * DEGREE / 2;
1253
1254
1.60k
    for (; rank-- > 0; in += stride)
1255
1.26k
        if (!scalar_decode_12(out++, in))
1256
332
            return 0;
1257
338
    return 1;
1258
670
}
1259
1260
/* In-place compression of each scalar component */
1261
static void vector_compress(scalar *a, int bits, int rank)
1262
370
{
1263
1.07k
    do {
1264
1.07k
        scalar_compress(a++, bits);
1265
1.07k
    } while (--rank > 0);
1266
370
}
1267
1268
/* The output scalar must not overlap with the inputs */
1269
static void inner_product(scalar *out, const scalar *lhs, const scalar *rhs,
1270
    int rank)
1271
543
{
1272
543
    scalar_mult(out, lhs, rhs);
1273
1.57k
    while (--rank > 0)
1274
1.03k
        scalar_mult_add(out, ++lhs, ++rhs);
1275
543
}
1276
1277
/*
1278
 * Here, the output vector must not overlap with the inputs, the result is
1279
 * directly subjected to inverse NTT.
1280
 */
1281
static void
1282
matrix_mult_intt(scalar *out, const scalar *m, const scalar *a, int rank)
1283
370
{
1284
370
    const scalar *ar;
1285
370
    int i, j;
1286
1287
1.44k
    for (i = rank; i-- > 0; ++out) {
1288
1.07k
        scalar_mult(out, m++, ar = a);
1289
3.36k
        for (j = rank - 1; j > 0; --j)
1290
2.29k
            scalar_mult_add(out, m++, ++ar);
1291
1.07k
        scalar_inverse_ntt(out);
1292
1.07k
    }
1293
370
}
1294
1295
/* Here, the output vector must not overlap with the inputs */
1296
static void
1297
matrix_mult_transpose_add(scalar *out, const scalar *m, const scalar *a, int rank)
1298
51.3k
{
1299
51.3k
    const scalar *mc = m, *mr, *ar;
1300
51.3k
    int i, j;
1301
1302
205k
    for (i = rank; i-- > 0; ++out) {
1303
154k
        scalar_mult_add(out, mr = mc++, ar = a);
1304
462k
        for (j = rank; --j > 0;)
1305
308k
            scalar_mult_add(out, (mr += rank), ++ar);
1306
154k
    }
1307
51.3k
}
1308
1309
/*-
1310
 * Expands the matrix from a seed for key generation and for encaps-CPA.
1311
 * NOTE: FIPS 203 matrix "A" is the transpose of this matrix, computed
1312
 * by appending the (i,j) indices to the seed in the opposite order!
1313
 *
1314
 * Where FIPS 203 computes t = A * s + e, we use the transpose of "m".
1315
 */
1316
static __owur int matrix_expand(EVP_MD_CTX *mdctx, ML_KEM_KEY *key)
1317
51.7k
{
1318
51.7k
    scalar *out = key->m;
1319
51.7k
    uint8_t input[ML_KEM_RANDOM_BYTES + 2];
1320
51.7k
    int rank = key->vinfo->rank;
1321
51.7k
    int i, j;
1322
1323
51.7k
    memcpy(input, key->rho, ML_KEM_RANDOM_BYTES);
1324
206k
    for (i = 0; i < rank; i++) {
1325
619k
        for (j = 0; j < rank; j++) {
1326
464k
            input[ML_KEM_RANDOM_BYTES] = i;
1327
464k
            input[ML_KEM_RANDOM_BYTES + 1] = j;
1328
464k
            if (!EVP_DigestInit_ex(mdctx, key->shake128_md, NULL)
1329
464k
                || !EVP_DigestUpdate(mdctx, input, sizeof(input))
1330
464k
                || !sample_scalar(out++, mdctx))
1331
0
                return 0;
1332
464k
        }
1333
154k
    }
1334
51.7k
    return 1;
1335
51.7k
}
1336
1337
/*
1338
 * Algorithm 7 from the spec, with eta fixed to two and the PRF call
1339
 * included. Creates binominally distributed elements by sampling 2*|eta| bits,
1340
 * and setting the coefficient to the count of the first bits minus the count of
1341
 * the second bits, resulting in a centered binomial distribution. Since eta is
1342
 * two this gives -2/2 with a probability of 1/16, -1/1 with probability 1/4,
1343
 * and 0 with probability 3/8.
1344
 */
1345
static __owur int cbd_2(scalar *out, uint8_t in[ML_KEM_RANDOM_BYTES + 1],
1346
    EVP_MD_CTX *mdctx, const ML_KEM_KEY *key)
1347
308k
{
1348
308k
    uint16_t *curr = out->c, *end = curr + DEGREE;
1349
308k
    uint8_t randbuf[4 * DEGREE / 8], *r = randbuf; /* 64 * eta slots */
1350
308k
    uint16_t value, mask;
1351
308k
    uint8_t b;
1352
1353
308k
    if (!prf(randbuf, sizeof(randbuf), in, mdctx, key))
1354
0
        return 0;
1355
1356
39.5M
    do {
1357
39.5M
        b = *r++;
1358
1359
        /*
1360
         * Add |kPrime| if |value| underflowed.  See |constish_time_non_zero|
1361
         * for a discussion on why the value barrier is by default omitted.
1362
         * While this could have been written reduce_once(value + kPrime), this
1363
         * is one extra addition and small range of |value| tempts some
1364
         * versions of Clang to emit a branch.
1365
         */
1366
39.5M
        value = bit0(b) + bitn(1, b);
1367
39.5M
        value -= bitn(2, b) + bitn(3, b);
1368
39.5M
        mask = constish_time_non_zero(value >> 15);
1369
39.5M
        *curr++ = value + (kPrime & mask);
1370
1371
39.5M
        value = bitn(4, b) + bitn(5, b);
1372
39.5M
        value -= bitn(6, b) + bitn(7, b);
1373
39.5M
        mask = constish_time_non_zero(value >> 15);
1374
39.5M
        *curr++ = value + (kPrime & mask);
1375
39.5M
    } while (curr < end);
1376
308k
    return 1;
1377
308k
}
1378
1379
/*
1380
 * Algorithm 7 from the spec, with eta fixed to three and the PRF call
1381
 * included. Creates binominally distributed elements by sampling 3*|eta| bits,
1382
 * and setting the coefficient to the count of the first bits minus the count of
1383
 * the second bits, resulting in a centered binomial distribution.
1384
 */
1385
static __owur int cbd_3(scalar *out, uint8_t in[ML_KEM_RANDOM_BYTES + 1],
1386
    EVP_MD_CTX *mdctx, const ML_KEM_KEY *key)
1387
1.63k
{
1388
1.63k
    uint16_t *curr = out->c, *end = curr + DEGREE;
1389
1.63k
    uint8_t randbuf[6 * DEGREE / 8], *r = randbuf; /* 64 * eta slots */
1390
1.63k
    uint8_t b1, b2, b3;
1391
1.63k
    uint16_t value, mask;
1392
1393
1.63k
    if (!prf(randbuf, sizeof(randbuf), in, mdctx, key))
1394
0
        return 0;
1395
1396
104k
    do {
1397
104k
        b1 = *r++;
1398
104k
        b2 = *r++;
1399
104k
        b3 = *r++;
1400
1401
        /*
1402
         * Add |kPrime| if |value| underflowed.  See |constish_time_non_zero|
1403
         * for a discussion on why the value barrier is by default omitted.
1404
         * While this could have been written reduce_once(value + kPrime), this
1405
         * is one extra addition and small range of |value| tempts some
1406
         * versions of Clang to emit a branch.
1407
         */
1408
104k
        value = bit0(b1) + bitn(1, b1) + bitn(2, b1);
1409
104k
        value -= bitn(3, b1) + bitn(4, b1) + bitn(5, b1);
1410
104k
        mask = constish_time_non_zero(value >> 15);
1411
104k
        *curr++ = value + (kPrime & mask);
1412
1413
104k
        value = bitn(6, b1) + bitn(7, b1) + bit0(b2);
1414
104k
        value -= bitn(1, b2) + bitn(2, b2) + bitn(3, b2);
1415
104k
        mask = constish_time_non_zero(value >> 15);
1416
104k
        *curr++ = value + (kPrime & mask);
1417
1418
104k
        value = bitn(4, b2) + bitn(5, b2) + bitn(6, b2);
1419
104k
        value -= bitn(7, b2) + bit0(b3) + bitn(1, b3);
1420
104k
        mask = constish_time_non_zero(value >> 15);
1421
104k
        *curr++ = value + (kPrime & mask);
1422
1423
104k
        value = bitn(2, b3) + bitn(3, b3) + bitn(4, b3);
1424
104k
        value -= bitn(5, b3) + bitn(6, b3) + bitn(7, b3);
1425
104k
        mask = constish_time_non_zero(value >> 15);
1426
104k
        *curr++ = value + (kPrime & mask);
1427
104k
    } while (curr < end);
1428
1.63k
    return 1;
1429
1.63k
}
1430
1431
/*
1432
 * Generates a secret vector by using |cbd| with the given seed to generate
1433
 * scalar elements and incrementing |counter| for each slot of the vector.
1434
 */
1435
static __owur int gencbd_vector(scalar *out, CBD_FUNC cbd, uint8_t *counter,
1436
    const uint8_t seed[ML_KEM_RANDOM_BYTES], int rank,
1437
    EVP_MD_CTX *mdctx, const ML_KEM_KEY *key)
1438
370
{
1439
370
    uint8_t input[ML_KEM_RANDOM_BYTES + 1];
1440
1441
370
    memcpy(input, seed, ML_KEM_RANDOM_BYTES);
1442
1.07k
    do {
1443
1.07k
        input[ML_KEM_RANDOM_BYTES] = (*counter)++;
1444
1.07k
        if (!cbd(out++, input, mdctx, key))
1445
0
            return 0;
1446
1.07k
    } while (--rank > 0);
1447
370
    return 1;
1448
370
}
1449
1450
/*
1451
 * As above plus NTT transform.
1452
 */
1453
static __owur int gencbd_vector_ntt(scalar *out, CBD_FUNC cbd, uint8_t *counter,
1454
    const uint8_t seed[ML_KEM_RANDOM_BYTES], int rank,
1455
    EVP_MD_CTX *mdctx, const ML_KEM_KEY *key)
1456
103k
{
1457
103k
    uint8_t input[ML_KEM_RANDOM_BYTES + 1];
1458
1459
103k
    memcpy(input, seed, ML_KEM_RANDOM_BYTES);
1460
309k
    do {
1461
309k
        input[ML_KEM_RANDOM_BYTES] = (*counter)++;
1462
309k
        if (!cbd(out, input, mdctx, key))
1463
0
            return 0;
1464
309k
        scalar_ntt(out++);
1465
309k
    } while (--rank > 0);
1466
103k
    return 1;
1467
103k
}
1468
1469
/* The |ETA1| value for ML-KEM-512 is 3, the rest and all ETA2 values are 2. */
1470
34.9k
#define CBD1(evp_type) ((evp_type) == EVP_PKEY_ML_KEM_512 ? cbd_3 : cbd_2)
1471
1472
/*
1473
 * FIPS 203, Section 5.2, Algorithm 14: K-PKE.Encrypt.
1474
 *
1475
 * Encrypts a message with given randomness to the ciphertext in |out|. Without
1476
 * applying the Fujisaki-Okamoto transform this would not result in a CCA
1477
 * secure scheme, since lattice schemes are vulnerable to decryption failure
1478
 * oracles.
1479
 *
1480
 * The steps are re-ordered to make more efficient/localised use of storage.
1481
 *
1482
 * Note also that the input public key is assumed to hold a precomputed matrix
1483
 * |A| (our key->m, with the public key holding an expanded (16-bit per scalar
1484
 * coefficient) key->t vector).
1485
 *
1486
 * Caller passes storage in |tmp| for for two temporary vectors.
1487
 */
1488
static __owur int encrypt_cpa(uint8_t out[ML_KEM_SHARED_SECRET_BYTES],
1489
    const uint8_t message[DEGREE / 8],
1490
    const uint8_t r[ML_KEM_RANDOM_BYTES], scalar *tmp,
1491
    EVP_MD_CTX *mdctx, const ML_KEM_KEY *key)
1492
370
{
1493
370
    const ML_KEM_VINFO *vinfo = key->vinfo;
1494
370
    CBD_FUNC cbd_1 = CBD1(vinfo->evp_type);
1495
370
    int rank = vinfo->rank;
1496
    /* We can use tmp[0..rank-1] as storage for |y|, then |e1|, ... */
1497
370
    scalar *y = &tmp[0], *e1 = y, *e2 = y;
1498
    /* We can use tmp[rank]..tmp[2*rank - 1] for |u| */
1499
370
    scalar *u = &tmp[rank];
1500
370
    scalar v;
1501
370
    uint8_t input[ML_KEM_RANDOM_BYTES + 1];
1502
370
    uint8_t counter = 0;
1503
370
    int du = vinfo->du;
1504
370
    int dv = vinfo->dv;
1505
1506
    /* FIPS 203 "y" vector */
1507
370
    if (!gencbd_vector_ntt(y, cbd_1, &counter, r, rank, mdctx, key))
1508
0
        return 0;
1509
    /* FIPS 203 "v" scalar */
1510
370
    inner_product(&v, key->t, y, rank);
1511
370
    scalar_inverse_ntt(&v);
1512
    /* FIPS 203 "u" vector */
1513
370
    matrix_mult_intt(u, key->m, y, rank);
1514
1515
    /* All done with |y|, now free to reuse tmp[0] for FIPS 203 |e1| */
1516
370
    if (!gencbd_vector(e1, cbd_2, &counter, r, rank, mdctx, key))
1517
0
        return 0;
1518
370
    vector_add(u, e1, rank);
1519
370
    vector_compress(u, du, rank);
1520
370
    vector_encode(out, u, du, rank);
1521
1522
    /* All done with |e1|, now free to reuse tmp[0] for FIPS 203 |e2| */
1523
370
    memcpy(input, r, ML_KEM_RANDOM_BYTES);
1524
370
    input[ML_KEM_RANDOM_BYTES] = counter;
1525
370
    if (!cbd_2(e2, input, mdctx, key))
1526
0
        return 0;
1527
370
    scalar_add(&v, e2);
1528
1529
    /* Combine message with |v| */
1530
370
    scalar_decode_decompress_add(&v, message);
1531
370
    scalar_compress(&v, dv);
1532
370
    scalar_encode(out + vinfo->u_vector_bytes, &v, dv);
1533
370
    return 1;
1534
370
}
1535
1536
/*
1537
 * FIPS 203, Section 5.3, Algorithm 15: K-PKE.Decrypt.
1538
 */
1539
static void
1540
decrypt_cpa(uint8_t out[ML_KEM_SHARED_SECRET_BYTES],
1541
    const uint8_t *ctext, scalar *u, const ML_KEM_KEY *key)
1542
173
{
1543
173
    const ML_KEM_VINFO *vinfo = key->vinfo;
1544
173
    scalar v, mask;
1545
173
    int rank = vinfo->rank;
1546
173
    int du = vinfo->du;
1547
173
    int dv = vinfo->dv;
1548
1549
173
    vector_decode_decompress_ntt(u, ctext, du, rank);
1550
173
    scalar_decode(&v, ctext + vinfo->u_vector_bytes, dv);
1551
173
    scalar_decompress(&v, dv);
1552
173
    inner_product(&mask, key->s, u, rank);
1553
173
    scalar_inverse_ntt(&mask);
1554
173
    scalar_sub(&v, &mask);
1555
173
    scalar_compress(&v, 1);
1556
173
    scalar_encode_1(out, &v);
1557
173
}
1558
1559
/*-
1560
 * FIPS 203, Section 7.1, Algorithm 19: "ML-KEM.KeyGen".
1561
 * FIPS 203, Section 7.2, Algorithm 20: "ML-KEM.Encaps".
1562
 *
1563
 * Fills the |out| buffer with the |ek| output of "ML-KEM.KeyGen", or,
1564
 * equivalently, the |ek| input of "ML-KEM.Encaps", i.e. returns the
1565
 * wire-format of an ML-KEM public key.
1566
 */
1567
static void encode_pubkey(uint8_t *out, const ML_KEM_KEY *key)
1568
51.2k
{
1569
51.2k
    const uint8_t *rho = key->rho;
1570
51.2k
    const ML_KEM_VINFO *vinfo = key->vinfo;
1571
1572
51.2k
    vector_encode(out, key->t, 12, vinfo->rank);
1573
51.2k
    memcpy(out + vinfo->vector_bytes, rho, ML_KEM_RANDOM_BYTES);
1574
51.2k
}
1575
1576
/*-
1577
 * FIPS 203, Section 7.1, Algorithm 19: "ML-KEM.KeyGen".
1578
 *
1579
 * Fills the |out| buffer with the |dk| output of "ML-KEM.KeyGen".
1580
 * This matches the input format of parse_prvkey() below.
1581
 */
1582
static void encode_prvkey(uint8_t *out, const ML_KEM_KEY *key)
1583
138
{
1584
138
    const ML_KEM_VINFO *vinfo = key->vinfo;
1585
1586
138
    vector_encode(out, key->s, 12, vinfo->rank);
1587
138
    out += vinfo->vector_bytes;
1588
138
    encode_pubkey(out, key);
1589
138
    out += vinfo->pubkey_bytes;
1590
138
    memcpy(out, key->pkhash, ML_KEM_PKHASH_BYTES);
1591
138
    out += ML_KEM_PKHASH_BYTES;
1592
138
    memcpy(out, key->z, ML_KEM_RANDOM_BYTES);
1593
138
}
1594
1595
/*-
1596
 * FIPS 203, Section 7.1, Algorithm 19: "ML-KEM.KeyGen".
1597
 * FIPS 203, Section 7.2, Algorithm 20: "ML-KEM.Encaps".
1598
 *
1599
 * This function parses the |in| buffer as the |ek| output of "ML-KEM.KeyGen",
1600
 * or, equivalently, the |ek| input of "ML-KEM.Encaps", i.e. decodes the
1601
 * wire-format of the ML-KEM public key.
1602
 */
1603
static int parse_pubkey(const uint8_t *in, EVP_MD_CTX *mdctx, ML_KEM_KEY *key)
1604
563
{
1605
563
    const ML_KEM_VINFO *vinfo = key->vinfo;
1606
1607
    /* Decode and check |t| */
1608
563
    if (!vector_decode_12(key->t, in, vinfo->rank)) {
1609
238
        ERR_raise_data(ERR_LIB_PROV, PROV_R_INVALID_KEY,
1610
238
            "%s invalid public 't' vector",
1611
238
            vinfo->algorithm_name);
1612
238
        return 0;
1613
238
    }
1614
    /* Save the matrix |m| recovery seed |rho| */
1615
325
    memcpy(key->rho, in + vinfo->vector_bytes, ML_KEM_RANDOM_BYTES);
1616
    /*
1617
     * Pre-compute the public key hash, needed for both encap and decap.
1618
     * Also pre-compute the matrix expansion, stored with the public key.
1619
     */
1620
325
    if (!hash_h(key->pkhash, in, vinfo->pubkey_bytes, mdctx, key)
1621
325
        || !matrix_expand(mdctx, key)) {
1622
0
        ERR_raise_data(ERR_LIB_CRYPTO, ERR_R_INTERNAL_ERROR,
1623
0
            "internal error while parsing %s public key",
1624
0
            vinfo->algorithm_name);
1625
0
        return 0;
1626
0
    }
1627
325
    return 1;
1628
325
}
1629
1630
/*
1631
 * FIPS 203, Section 7.1, Algorithm 19: "ML-KEM.KeyGen".
1632
 *
1633
 * Parses the |in| buffer as a |dk| output of "ML-KEM.KeyGen".
1634
 * This matches the output format of encode_prvkey() above.
1635
 */
1636
static int parse_prvkey(const uint8_t *in, EVP_MD_CTX *mdctx, ML_KEM_KEY *key)
1637
107
{
1638
107
    const ML_KEM_VINFO *vinfo = key->vinfo;
1639
1640
    /* Decode and check |s|. */
1641
107
    if (!vector_decode_12(key->s, in, vinfo->rank)) {
1642
94
        ERR_raise_data(ERR_LIB_PROV, PROV_R_INVALID_KEY,
1643
94
            "%s invalid private 's' vector",
1644
94
            vinfo->algorithm_name);
1645
94
        return 0;
1646
94
    }
1647
13
    in += vinfo->vector_bytes;
1648
1649
13
    if (!parse_pubkey(in, mdctx, key))
1650
9
        return 0;
1651
4
    in += vinfo->pubkey_bytes;
1652
1653
    /* Check public key hash. */
1654
4
    if (memcmp(key->pkhash, in, ML_KEM_PKHASH_BYTES) != 0) {
1655
4
        ERR_raise_data(ERR_LIB_PROV, PROV_R_INVALID_KEY,
1656
4
            "%s public key hash mismatch",
1657
4
            vinfo->algorithm_name);
1658
4
        return 0;
1659
4
    }
1660
0
    in += ML_KEM_PKHASH_BYTES;
1661
1662
0
    memcpy(key->z, in, ML_KEM_RANDOM_BYTES);
1663
0
    return 1;
1664
4
}
1665
1666
/*
1667
 * FIPS 203, Section 6.1, Algorithm 16: "ML-KEM.KeyGen_internal".
1668
 *
1669
 * The implementation of Section 5.1, Algorithm 13, "K-PKE.KeyGen(d)" is
1670
 * inlined.
1671
 *
1672
 * The caller MUST pass a pre-allocated digest context that is not shared with
1673
 * any concurrent computation.
1674
 *
1675
 * This function optionally outputs the serialised wire-form |ek| public key
1676
 * into the provided |pubenc| buffer, and generates the content of the |rho|,
1677
 * |pkhash|, |t|, |m|, |s| and |z| components of the private |key| (which must
1678
 * have preallocated space for these).
1679
 *
1680
 * Keys are computed from a 32-byte random |d| plus the 1 byte rank for
1681
 * domain separation.  These are concatenated and hashed to produce a pair of
1682
 * 32-byte seeds public "rho", used to generate the matrix, and private "sigma",
1683
 * used to generate the secret vector |s|.
1684
 *
1685
 * The second random input |z| is copied verbatim into the Fujisaki-Okamoto
1686
 * (FO) transform "implicit-rejection" secret (the |z| component of the private
1687
 * key), which thwarts chosen-ciphertext attacks, provided decap() runs in
1688
 * constant time, with no side channel leaks, on all well-formed (valid length,
1689
 * and correctly encoded) ciphertext inputs.
1690
 */
1691
static __owur int genkey(const uint8_t seed[ML_KEM_SEED_BYTES],
1692
    EVP_MD_CTX *mdctx, uint8_t *pubenc, ML_KEM_KEY *key)
1693
34.5k
{
1694
34.5k
    uint8_t hashed[2 * ML_KEM_RANDOM_BYTES];
1695
34.5k
    const uint8_t *const sigma = hashed + ML_KEM_RANDOM_BYTES;
1696
34.5k
    uint8_t augmented_seed[ML_KEM_RANDOM_BYTES + 1];
1697
34.5k
    const ML_KEM_VINFO *vinfo = key->vinfo;
1698
34.5k
    CBD_FUNC cbd_1 = CBD1(vinfo->evp_type);
1699
34.5k
    int rank = vinfo->rank;
1700
34.5k
    uint8_t counter = 0;
1701
34.5k
    int ret = 0;
1702
1703
    /*
1704
     * Use the "d" seed salted with the rank to derive the public and private
1705
     * seeds rho and sigma.
1706
     */
1707
34.5k
    memcpy(augmented_seed, seed, ML_KEM_RANDOM_BYTES);
1708
34.5k
    augmented_seed[ML_KEM_RANDOM_BYTES] = (uint8_t)rank;
1709
34.5k
    if (!hash_g(hashed, augmented_seed, sizeof(augmented_seed), mdctx, key))
1710
0
        goto end;
1711
34.5k
    memcpy(key->rho, hashed, ML_KEM_RANDOM_BYTES);
1712
    /* The |rho| matrix seed is public */
1713
34.5k
    CONSTTIME_DECLASSIFY(key->rho, ML_KEM_RANDOM_BYTES);
1714
1715
    /* FIPS 203 |e| vector is initial value of key->t */
1716
34.5k
    if (!matrix_expand(mdctx, key)
1717
34.5k
        || !gencbd_vector_ntt(key->s, cbd_1, &counter, sigma, rank, mdctx, key)
1718
34.5k
        || !gencbd_vector_ntt(key->t, cbd_1, &counter, sigma, rank, mdctx, key))
1719
0
        goto end;
1720
1721
    /* To |e| we now add the product of transpose |m| and |s|, giving |t|. */
1722
34.5k
    matrix_mult_transpose_add(key->t, key->m, key->s, rank);
1723
    /* The |t| vector is public */
1724
34.5k
    CONSTTIME_DECLASSIFY(key->t, vinfo->rank * sizeof(scalar));
1725
1726
34.5k
    if (pubenc == NULL) {
1727
        /* Incremental digest of public key without in-full serialisation. */
1728
34.5k
        if (!hash_h_pubkey(key->pkhash, mdctx, key))
1729
0
            goto end;
1730
34.5k
    } else {
1731
0
        encode_pubkey(pubenc, key);
1732
0
        if (!hash_h(key->pkhash, pubenc, vinfo->pubkey_bytes, mdctx, key))
1733
0
            goto end;
1734
0
    }
1735
1736
    /* Save |z| portion of seed for "implicit rejection" on failure. */
1737
34.5k
    memcpy(key->z, seed + ML_KEM_RANDOM_BYTES, ML_KEM_RANDOM_BYTES);
1738
1739
    /* Optionally save the |d| portion of the seed */
1740
34.5k
    key->d = key->z + ML_KEM_RANDOM_BYTES;
1741
34.5k
    if (key->prov_flags & ML_KEM_KEY_RETAIN_SEED) {
1742
34.5k
        memcpy(key->d, seed, ML_KEM_RANDOM_BYTES);
1743
34.5k
    } else {
1744
0
        OPENSSL_cleanse(key->d, ML_KEM_RANDOM_BYTES);
1745
0
        key->d = NULL;
1746
0
    }
1747
1748
34.5k
    ret = 1;
1749
34.5k
end:
1750
34.5k
    OPENSSL_cleanse((void *)augmented_seed, ML_KEM_RANDOM_BYTES);
1751
34.5k
    OPENSSL_cleanse((void *)sigma, ML_KEM_RANDOM_BYTES);
1752
34.5k
    if (ret == 0) {
1753
0
        ERR_raise_data(ERR_LIB_CRYPTO, ERR_R_INTERNAL_ERROR,
1754
0
            "internal error while generating %s private key",
1755
0
            vinfo->algorithm_name);
1756
0
    }
1757
34.5k
    return ret;
1758
34.5k
}
1759
1760
/*-
1761
 * FIPS 203, Section 6.2, Algorithm 17: "ML-KEM.Encaps_internal".
1762
 * This is the deterministic version with randomness supplied externally.
1763
 *
1764
 * The caller must pass space for two vectors in |tmp|.
1765
 * The |ctext| buffer have space for the ciphertext of the ML-KEM variant
1766
 * of the provided key.
1767
 */
1768
static int encap(uint8_t *ctext, uint8_t secret[ML_KEM_SHARED_SECRET_BYTES],
1769
    const uint8_t entropy[ML_KEM_RANDOM_BYTES],
1770
    scalar *tmp, EVP_MD_CTX *mdctx, const ML_KEM_KEY *key)
1771
197
{
1772
197
    uint8_t input[ML_KEM_RANDOM_BYTES + ML_KEM_PKHASH_BYTES];
1773
197
    uint8_t Kr[ML_KEM_SHARED_SECRET_BYTES + ML_KEM_RANDOM_BYTES];
1774
197
    uint8_t *r = Kr + ML_KEM_SHARED_SECRET_BYTES;
1775
197
    int ret;
1776
1777
197
    memcpy(input, entropy, ML_KEM_RANDOM_BYTES);
1778
197
    memcpy(input + ML_KEM_RANDOM_BYTES, key->pkhash, ML_KEM_PKHASH_BYTES);
1779
197
    ret = hash_g(Kr, input, sizeof(input), mdctx, key)
1780
197
        && encrypt_cpa(ctext, entropy, r, tmp, mdctx, key);
1781
197
    OPENSSL_cleanse((void *)input, sizeof(input));
1782
1783
197
    if (ret)
1784
197
        memcpy(secret, Kr, ML_KEM_SHARED_SECRET_BYTES);
1785
0
    else
1786
0
        ERR_raise_data(ERR_LIB_CRYPTO, ERR_R_INTERNAL_ERROR,
1787
0
            "internal error while performing %s encapsulation",
1788
0
            key->vinfo->algorithm_name);
1789
197
    return ret;
1790
197
}
1791
1792
/*
1793
 * FIPS 203, Section 6.3, Algorithm 18: ML-KEM.Decaps_internal
1794
 *
1795
 * Barring failure of the supporting SHA3/SHAKE primitives, this is fully
1796
 * deterministic, the randomness for the FO transform is extracted during
1797
 * private key generation.
1798
 *
1799
 * The caller must pass space for two vectors in |tmp|.
1800
 * The |ctext| and |tmp_ctext| buffers must each have space for the ciphertext
1801
 * of the key's ML-KEM variant.
1802
 */
1803
static int decap(uint8_t secret[ML_KEM_SHARED_SECRET_BYTES],
1804
    const uint8_t *ctext, uint8_t *tmp_ctext, scalar *tmp,
1805
    EVP_MD_CTX *mdctx, const ML_KEM_KEY *key)
1806
110
{
1807
110
    uint8_t decrypted[ML_KEM_SHARED_SECRET_BYTES + ML_KEM_PKHASH_BYTES];
1808
110
    uint8_t failure_key[ML_KEM_RANDOM_BYTES];
1809
110
    uint8_t Kr[ML_KEM_SHARED_SECRET_BYTES + ML_KEM_RANDOM_BYTES];
1810
110
    uint8_t *r = Kr + ML_KEM_SHARED_SECRET_BYTES;
1811
110
    const uint8_t *pkhash = key->pkhash;
1812
110
    const ML_KEM_VINFO *vinfo = key->vinfo;
1813
110
    int i;
1814
110
    uint8_t mask;
1815
1816
    /*
1817
     * If our KDF is unavailable, fail early! Otherwise, keep going ignoring
1818
     * any further errors, returning success, and whatever we got for a shared
1819
     * secret.  The decrypt_cpa() function is just arithmetic on secret data,
1820
     * so should not be subject to failure that makes its output predictable.
1821
     *
1822
     * We guard against "should never happen" catastrophic failure of the
1823
     * "pure" function |hash_g| by overwriting the shared secret with the
1824
     * content of the failure key and returning early, if nevertheless hash_g
1825
     * fails.  This is not constant-time, but a failure of |hash_g| already
1826
     * implies loss of side-channel resistance.
1827
     *
1828
     * The same action is taken, if also |encrypt_cpa| should catastrophically
1829
     * fail, due to failure of the |PRF| underlying the CBD functions.
1830
     */
1831
110
    if (!kdf(failure_key, key->z, ctext, vinfo->ctext_bytes, mdctx, key)) {
1832
0
        ERR_raise_data(ERR_LIB_CRYPTO, ERR_R_INTERNAL_ERROR,
1833
0
            "internal error while performing %s decapsulation",
1834
0
            vinfo->algorithm_name);
1835
0
        return 0;
1836
0
    }
1837
110
    decrypt_cpa(decrypted, ctext, tmp, key);
1838
110
    memcpy(decrypted + ML_KEM_SHARED_SECRET_BYTES, pkhash, ML_KEM_PKHASH_BYTES);
1839
110
    if (!hash_g(Kr, decrypted, sizeof(decrypted), mdctx, key)
1840
110
        || !encrypt_cpa(tmp_ctext, decrypted, r, tmp, mdctx, key)) {
1841
0
        memcpy(secret, failure_key, ML_KEM_SHARED_SECRET_BYTES);
1842
0
        OPENSSL_cleanse(decrypted, ML_KEM_SHARED_SECRET_BYTES);
1843
0
        return 1;
1844
0
    }
1845
110
    mask = constant_time_eq_int_8(0,
1846
110
        CRYPTO_memcmp(ctext, tmp_ctext, vinfo->ctext_bytes));
1847
3.63k
    for (i = 0; i < ML_KEM_SHARED_SECRET_BYTES; i++)
1848
3.52k
        secret[i] = constant_time_select_8(mask, Kr[i], failure_key[i]);
1849
110
    OPENSSL_cleanse(decrypted, ML_KEM_SHARED_SECRET_BYTES);
1850
110
    OPENSSL_cleanse(Kr, sizeof(Kr));
1851
110
    return 1;
1852
110
}
1853
1854
/*
1855
 * After allocating storage for public or private key data, update the key
1856
 * component pointers to reference that storage.
1857
 *
1858
 * The caller should only store private data in `priv` *after* a successful
1859
 * (non-zero) return from this function.
1860
 */
1861
static __owur int add_storage(scalar *pub, scalar *priv, int private, ML_KEM_KEY *key)
1862
35.0k
{
1863
35.0k
    int rank = key->vinfo->rank;
1864
1865
35.0k
    if (pub == NULL || (private && priv == NULL)) {
1866
        /*
1867
         * One of these could be allocated correctly. It is legal to call free with a NULL
1868
         * pointer, so always attempt to free both allocations here
1869
         */
1870
0
        OPENSSL_free(pub);
1871
0
        OPENSSL_secure_free(priv);
1872
0
        return 0;
1873
0
    }
1874
1875
    /*
1876
     * We're adding key material, set up rho and pkhash to point to the rho_pkhash buffer
1877
     */
1878
35.0k
    memset(key->rho_pkhash, 0, sizeof(key->rho_pkhash));
1879
35.0k
    key->rho = key->rho_pkhash;
1880
35.0k
    key->pkhash = key->rho_pkhash + ML_KEM_RANDOM_BYTES;
1881
35.0k
    key->d = key->z = NULL;
1882
1883
    /* A public key needs space for |t| and |m| */
1884
35.0k
    key->m = (key->t = pub) + rank;
1885
1886
    /*
1887
     * A private key also needs space for |s| and |z|.
1888
     * The |z| buffer always includes additional space for |d|, but a key's |d|
1889
     * pointer is left NULL when parsed from the NIST format, which omits that
1890
     * information.  Only keys generated from a (d, z) seed pair will have a
1891
     * non-NULL |d| pointer.
1892
     */
1893
35.0k
    if (private)
1894
34.5k
        key->z = (uint8_t *)(rank + (key->s = priv));
1895
35.0k
    return 1;
1896
35.0k
}
1897
1898
/*
1899
 * After freeing the storage associated with a key that failed to be
1900
 * constructed, reset the internal pointers back to NULL.
1901
 */
1902
void ossl_ml_kem_key_reset(ML_KEM_KEY *key)
1903
35.3k
{
1904
    /*
1905
     * seedbuf can be allocated and contain |z| and |d| if the key is
1906
     * being created from a private key encoding.  Similarly a pending
1907
     * serialised (encoded) private key may be queued up to load.
1908
     * Clear and free that data now.
1909
     */
1910
35.3k
    if (key->seedbuf != NULL)
1911
28
        OPENSSL_secure_clear_free(key->seedbuf, ML_KEM_SEED_BYTES);
1912
35.3k
    if (ossl_ml_kem_have_dkenc(key))
1913
0
        OPENSSL_secure_clear_free(key->encoded_dk, key->vinfo->prvkey_bytes);
1914
1915
    /*-
1916
     * Cleanse any sensitive data:
1917
     * - The private vector |s| is immediately followed by the FO failure
1918
     *   secret |z|, and seed |d|, we can cleanse all three in one call.
1919
     */
1920
35.3k
    if (key->t != NULL) {
1921
35.0k
        if (ossl_ml_kem_have_prvkey(key))
1922
34.5k
            OPENSSL_secure_clear_free(key->s, key->vinfo->prvalloc);
1923
35.0k
        OPENSSL_free(key->t);
1924
35.0k
    }
1925
35.3k
    key->d = key->z = key->seedbuf = key->encoded_dk = (uint8_t *)(key->s = key->m = key->t = NULL);
1926
35.3k
}
1927
1928
/*
1929
 * ----- API exported to the provider
1930
 *
1931
 * Parameters with an implicit fixed length in the internal static API of each
1932
 * variant have an explicit checked length argument at this layer.
1933
 */
1934
1935
/* Retrieve the parameters of one of the ML-KEM variants */
1936
const ML_KEM_VINFO *ossl_ml_kem_get_vinfo(int evp_type)
1937
277k
{
1938
277k
    switch (evp_type) {
1939
60.0k
    case EVP_PKEY_ML_KEM_512:
1940
60.0k
        return &vinfo_map[ML_KEM_512_VINFO];
1941
161k
    case EVP_PKEY_ML_KEM_768:
1942
161k
        return &vinfo_map[ML_KEM_768_VINFO];
1943
56.4k
    case EVP_PKEY_ML_KEM_1024:
1944
56.4k
        return &vinfo_map[ML_KEM_1024_VINFO];
1945
277k
    }
1946
0
    return NULL;
1947
277k
}
1948
1949
ML_KEM_KEY *ossl_ml_kem_key_new(OSSL_LIB_CTX *libctx, const char *properties,
1950
    int evp_type)
1951
35.0k
{
1952
35.0k
    const ML_KEM_VINFO *vinfo = ossl_ml_kem_get_vinfo(evp_type);
1953
35.0k
    ML_KEM_KEY *key;
1954
1955
35.0k
    if (vinfo == NULL) {
1956
0
        ERR_raise_data(ERR_LIB_CRYPTO, ERR_R_PASSED_INVALID_ARGUMENT,
1957
0
            "unsupported ML-KEM key type: %d", evp_type);
1958
0
        return NULL;
1959
0
    }
1960
1961
35.0k
    if ((key = OPENSSL_malloc(sizeof(*key))) == NULL)
1962
0
        return NULL;
1963
1964
35.0k
    key->vinfo = vinfo;
1965
35.0k
    key->libctx = libctx;
1966
35.0k
    key->prov_flags = ML_KEM_KEY_PROV_FLAGS_DEFAULT;
1967
35.0k
    key->shake128_md = EVP_MD_fetch(libctx, "SHAKE128", properties);
1968
35.0k
    key->shake256_md = EVP_MD_fetch(libctx, "SHAKE256", properties);
1969
35.0k
    key->sha3_256_md = EVP_MD_fetch(libctx, "SHA3-256", properties);
1970
35.0k
    key->sha3_512_md = EVP_MD_fetch(libctx, "SHA3-512", properties);
1971
35.0k
    key->d = key->z = key->rho = key->pkhash = key->encoded_dk = key->seedbuf = NULL;
1972
35.0k
    key->s = key->m = key->t = NULL;
1973
1974
35.0k
    if (key->shake128_md != NULL
1975
35.0k
        && key->shake256_md != NULL
1976
35.0k
        && key->sha3_256_md != NULL
1977
35.0k
        && key->sha3_512_md != NULL)
1978
35.0k
        return key;
1979
1980
0
    ossl_ml_kem_key_free(key);
1981
0
    ERR_raise_data(ERR_LIB_CRYPTO, ERR_R_INTERNAL_ERROR,
1982
0
        "missing SHA3 digest algorithms while creating %s key",
1983
0
        vinfo->algorithm_name);
1984
0
    return NULL;
1985
35.0k
}
1986
1987
ML_KEM_KEY *ossl_ml_kem_key_dup(const ML_KEM_KEY *key, int selection)
1988
57
{
1989
57
    int ok = 0;
1990
57
    ML_KEM_KEY *ret;
1991
57
    void *tmp_pub;
1992
57
    void *tmp_priv;
1993
1994
57
    if (key == NULL)
1995
0
        return NULL;
1996
    /*
1997
     * Partially decoded keys, not yet imported or loaded, should never be
1998
     * duplicated.
1999
     */
2000
57
    if (ossl_ml_kem_decoded_key(key))
2001
0
        return NULL;
2002
2003
57
    else if ((ret = OPENSSL_memdup(key, sizeof(*key))) == NULL)
2004
0
        return NULL;
2005
2006
57
    ret->d = ret->z = ret->rho = ret->pkhash = NULL;
2007
57
    ret->s = ret->m = ret->t = NULL;
2008
2009
    /* Clear selection bits we can't fulfill */
2010
57
    if (!ossl_ml_kem_have_pubkey(key))
2011
0
        selection = 0;
2012
57
    else if (!ossl_ml_kem_have_prvkey(key))
2013
57
        selection &= ~OSSL_KEYMGMT_SELECT_PRIVATE_KEY;
2014
2015
57
    switch (selection & OSSL_KEYMGMT_SELECT_KEYPAIR) {
2016
0
    case 0:
2017
0
        ok = 1;
2018
0
        break;
2019
57
    case OSSL_KEYMGMT_SELECT_PUBLIC_KEY:
2020
57
        ok = add_storage(OPENSSL_memdup(key->t, key->vinfo->puballoc), NULL, 0, ret);
2021
57
        break;
2022
0
    case OSSL_KEYMGMT_SELECT_PRIVATE_KEY:
2023
0
        tmp_pub = OPENSSL_memdup(key->t, key->vinfo->puballoc);
2024
0
        if (tmp_pub == NULL)
2025
0
            break;
2026
0
        tmp_priv = OPENSSL_secure_malloc(key->vinfo->prvalloc);
2027
0
        if (tmp_priv == NULL) {
2028
0
            OPENSSL_free(tmp_pub);
2029
0
            break;
2030
0
        }
2031
0
        if ((ok = add_storage(tmp_pub, tmp_priv, 1, ret)) != 0)
2032
0
            memcpy(tmp_priv, key->s, key->vinfo->prvalloc);
2033
        /* Duplicated keys retain |d|, if available */
2034
0
        if (key->d != NULL)
2035
0
            ret->d = ret->z + ML_KEM_RANDOM_BYTES;
2036
0
        break;
2037
57
    }
2038
2039
57
    if (!ok) {
2040
0
        OPENSSL_free(ret);
2041
0
        return NULL;
2042
0
    }
2043
2044
57
    EVP_MD_up_ref(ret->shake128_md);
2045
57
    EVP_MD_up_ref(ret->shake256_md);
2046
57
    EVP_MD_up_ref(ret->sha3_256_md);
2047
57
    EVP_MD_up_ref(ret->sha3_512_md);
2048
2049
57
    return ret;
2050
57
}
2051
2052
void ossl_ml_kem_key_free(ML_KEM_KEY *key)
2053
170k
{
2054
170k
    if (key == NULL)
2055
135k
        return;
2056
2057
35.1k
    EVP_MD_free(key->shake128_md);
2058
35.1k
    EVP_MD_free(key->shake256_md);
2059
35.1k
    EVP_MD_free(key->sha3_256_md);
2060
35.1k
    EVP_MD_free(key->sha3_512_md);
2061
2062
35.1k
    ossl_ml_kem_key_reset(key);
2063
35.1k
    OPENSSL_free(key);
2064
35.1k
}
2065
2066
/* Serialise the public component of an ML-KEM key */
2067
int ossl_ml_kem_encode_public_key(uint8_t *out, size_t len,
2068
    const ML_KEM_KEY *key)
2069
51.1k
{
2070
51.1k
    if (!ossl_ml_kem_have_pubkey(key)
2071
51.1k
        || len != key->vinfo->pubkey_bytes)
2072
0
        return 0;
2073
51.1k
    encode_pubkey(out, key);
2074
51.1k
    return 1;
2075
51.1k
}
2076
2077
/* Serialise an ML-KEM private key */
2078
int ossl_ml_kem_encode_private_key(uint8_t *out, size_t len,
2079
    const ML_KEM_KEY *key)
2080
138
{
2081
138
    if (!ossl_ml_kem_have_prvkey(key)
2082
138
        || len != key->vinfo->prvkey_bytes)
2083
0
        return 0;
2084
138
    encode_prvkey(out, key);
2085
138
    return 1;
2086
138
}
2087
2088
int ossl_ml_kem_encode_seed(uint8_t *out, size_t len,
2089
    const ML_KEM_KEY *key)
2090
226
{
2091
226
    if (key == NULL || key->d == NULL || len != ML_KEM_SEED_BYTES)
2092
28
        return 0;
2093
    /*
2094
     * Both in the seed buffer, and in the allocated storage, the |d| component
2095
     * of the seed is stored last, so we must copy each separately.
2096
     */
2097
198
    memcpy(out, key->d, ML_KEM_RANDOM_BYTES);
2098
198
    out += ML_KEM_RANDOM_BYTES;
2099
198
    memcpy(out, key->z, ML_KEM_RANDOM_BYTES);
2100
198
    return 1;
2101
226
}
2102
2103
/*
2104
 * Stash the seed without (yet) performing a keygen, used during decoding, to
2105
 * avoid an extra keygen if we're only going to export the key again to load
2106
 * into another provider.
2107
 */
2108
ML_KEM_KEY *ossl_ml_kem_set_seed(const uint8_t *seed, size_t seedlen, ML_KEM_KEY *key)
2109
28
{
2110
28
    if (key == NULL
2111
28
        || ossl_ml_kem_have_pubkey(key)
2112
28
        || ossl_ml_kem_have_seed(key)
2113
28
        || seedlen != ML_KEM_SEED_BYTES)
2114
0
        return NULL;
2115
2116
28
    if (key->seedbuf == NULL) {
2117
28
        key->seedbuf = OPENSSL_secure_malloc(seedlen);
2118
28
        if (key->seedbuf == NULL)
2119
0
            return NULL;
2120
28
    }
2121
2122
28
    key->z = key->seedbuf;
2123
28
    key->d = key->z + ML_KEM_RANDOM_BYTES;
2124
28
    memcpy(key->d, seed, ML_KEM_RANDOM_BYTES);
2125
28
    seed += ML_KEM_RANDOM_BYTES;
2126
28
    memcpy(key->z, seed, ML_KEM_RANDOM_BYTES);
2127
28
    return key;
2128
28
}
2129
2130
/* Parse input as a public key */
2131
int ossl_ml_kem_parse_public_key(const uint8_t *in, size_t len, ML_KEM_KEY *key)
2132
550
{
2133
550
    EVP_MD_CTX *mdctx = NULL;
2134
550
    const ML_KEM_VINFO *vinfo;
2135
550
    int ret = 0;
2136
2137
    /* Keys with key material are immutable */
2138
550
    if (key == NULL
2139
550
        || ossl_ml_kem_have_pubkey(key)
2140
550
        || ossl_ml_kem_have_dkenc(key))
2141
0
        return 0;
2142
550
    vinfo = key->vinfo;
2143
2144
550
    if (len != vinfo->pubkey_bytes
2145
550
        || (mdctx = EVP_MD_CTX_new()) == NULL)
2146
0
        return 0;
2147
2148
550
    if (add_storage(OPENSSL_malloc(vinfo->puballoc), NULL, 0, key))
2149
550
        ret = parse_pubkey(in, mdctx, key);
2150
2151
550
    if (!ret)
2152
229
        ossl_ml_kem_key_reset(key);
2153
550
    EVP_MD_CTX_free(mdctx);
2154
550
    return ret;
2155
550
}
2156
2157
/* Parse input as a new private key */
2158
int ossl_ml_kem_parse_private_key(const uint8_t *in, size_t len,
2159
    ML_KEM_KEY *key)
2160
107
{
2161
107
    EVP_MD_CTX *mdctx = NULL;
2162
107
    const ML_KEM_VINFO *vinfo;
2163
107
    int ret = 0;
2164
2165
    /* Keys with key material are immutable */
2166
107
    if (key == NULL
2167
107
        || ossl_ml_kem_have_pubkey(key)
2168
107
        || ossl_ml_kem_have_dkenc(key))
2169
0
        return 0;
2170
107
    vinfo = key->vinfo;
2171
2172
107
    if (len != vinfo->prvkey_bytes
2173
107
        || (mdctx = EVP_MD_CTX_new()) == NULL)
2174
0
        return 0;
2175
2176
107
    if (add_storage(OPENSSL_malloc(vinfo->puballoc),
2177
107
            OPENSSL_secure_malloc(vinfo->prvalloc), 1, key))
2178
107
        ret = parse_prvkey(in, mdctx, key);
2179
2180
107
    if (!ret)
2181
107
        ossl_ml_kem_key_reset(key);
2182
107
    EVP_MD_CTX_free(mdctx);
2183
107
    return ret;
2184
107
}
2185
2186
/*
2187
 * Generate a new keypair, either from the saved seed (when non-null), or from
2188
 * the RNG.
2189
 */
2190
int ossl_ml_kem_genkey(uint8_t *pubenc, size_t publen, ML_KEM_KEY *key)
2191
51.3k
{
2192
51.3k
    uint8_t seed[ML_KEM_SEED_BYTES];
2193
51.3k
    EVP_MD_CTX *mdctx = NULL;
2194
51.3k
    const ML_KEM_VINFO *vinfo;
2195
51.3k
    int ret = 0;
2196
2197
51.3k
    if (key == NULL
2198
51.3k
        || ossl_ml_kem_have_pubkey(key)
2199
51.3k
        || ossl_ml_kem_have_dkenc(key))
2200
0
        return 0;
2201
51.3k
    vinfo = key->vinfo;
2202
2203
51.3k
    if (pubenc != NULL && publen != vinfo->pubkey_bytes)
2204
0
        return 0;
2205
2206
51.3k
    if (key->seedbuf != NULL) {
2207
60
        if (!ossl_ml_kem_encode_seed(seed, sizeof(seed), key))
2208
0
            return 0;
2209
60
        ossl_ml_kem_key_reset(key);
2210
51.3k
    } else if (RAND_priv_bytes_ex(key->libctx, seed, sizeof(seed),
2211
51.3k
                   key->vinfo->secbits)
2212
51.3k
        <= 0) {
2213
0
        return 0;
2214
0
    }
2215
2216
51.3k
    if ((mdctx = EVP_MD_CTX_new()) == NULL)
2217
0
        return 0;
2218
2219
    /*
2220
     * Data derived from (d, z) defaults secret, and to avoid side-channel
2221
     * leaks should not influence control flow.
2222
     */
2223
51.3k
    CONSTTIME_SECRET(seed, ML_KEM_SEED_BYTES);
2224
2225
51.3k
    if (add_storage(OPENSSL_malloc(vinfo->puballoc),
2226
51.3k
            OPENSSL_secure_malloc(vinfo->prvalloc), 1, key))
2227
51.3k
        ret = genkey(seed, mdctx, pubenc, key);
2228
51.3k
    OPENSSL_cleanse(seed, sizeof(seed));
2229
2230
    /* Declassify secret inputs and derived outputs before returning control */
2231
51.3k
    CONSTTIME_DECLASSIFY(seed, ML_KEM_SEED_BYTES);
2232
2233
51.3k
    EVP_MD_CTX_free(mdctx);
2234
51.3k
    if (!ret) {
2235
0
        ossl_ml_kem_key_reset(key);
2236
0
        return 0;
2237
0
    }
2238
2239
    /* The public components are already declassified */
2240
51.3k
    CONSTTIME_DECLASSIFY(key->s, vinfo->rank * sizeof(scalar));
2241
51.3k
    CONSTTIME_DECLASSIFY(key->z, 2 * ML_KEM_RANDOM_BYTES);
2242
51.3k
    return 1;
2243
51.3k
}
2244
2245
/*
2246
 * FIPS 203, Section 6.2, Algorithm 17: ML-KEM.Encaps_internal
2247
 * This is the deterministic version with randomness supplied externally.
2248
 */
2249
int ossl_ml_kem_encap_seed(uint8_t *ctext, size_t clen,
2250
    uint8_t *shared_secret, size_t slen,
2251
    const uint8_t *entropy, size_t elen,
2252
    const ML_KEM_KEY *key)
2253
197
{
2254
197
    const ML_KEM_VINFO *vinfo;
2255
197
    EVP_MD_CTX *mdctx;
2256
197
    int ret = 0;
2257
2258
197
    if (key == NULL || !ossl_ml_kem_have_pubkey(key))
2259
0
        return 0;
2260
197
    vinfo = key->vinfo;
2261
2262
197
    if (ctext == NULL || clen != vinfo->ctext_bytes
2263
197
        || shared_secret == NULL || slen != ML_KEM_SHARED_SECRET_BYTES
2264
197
        || entropy == NULL || elen != ML_KEM_RANDOM_BYTES
2265
197
        || (mdctx = EVP_MD_CTX_new()) == NULL)
2266
0
        return 0;
2267
    /*
2268
     * Data derived from the encap entropy defaults secret, and to avoid
2269
     * side-channel leaks should not influence control flow.
2270
     */
2271
197
    CONSTTIME_SECRET(entropy, elen);
2272
2273
    /*-
2274
     * This avoids the need to handle allocation failures for two (max 2KB
2275
     * each) vectors, that are never retained on return from this function.
2276
     * We stack-allocate these.
2277
     */
2278
197
#define case_encap_seed(bits)                                        \
2279
197
    case EVP_PKEY_ML_KEM_##bits: {                                   \
2280
197
        scalar tmp[2 * ML_KEM_##bits##_RANK];                        \
2281
197
                                                                     \
2282
197
        ret = encap(ctext, shared_secret, entropy, tmp, mdctx, key); \
2283
197
        OPENSSL_cleanse((void *)tmp, sizeof(tmp));                   \
2284
197
        break;                                                       \
2285
197
    }
2286
197
    switch (vinfo->evp_type) {
2287
69
        case_encap_seed(512);
2288
76
        case_encap_seed(768);
2289
52
        case_encap_seed(1024);
2290
197
    }
2291
197
#undef case_encap_seed
2292
2293
    /* Declassify secret inputs and derived outputs before returning control */
2294
197
    CONSTTIME_DECLASSIFY(entropy, elen);
2295
197
    CONSTTIME_DECLASSIFY(ctext, clen);
2296
197
    CONSTTIME_DECLASSIFY(shared_secret, slen);
2297
2298
197
    EVP_MD_CTX_free(mdctx);
2299
197
    return ret;
2300
197
}
2301
2302
int ossl_ml_kem_encap_rand(uint8_t *ctext, size_t clen,
2303
    uint8_t *shared_secret, size_t slen,
2304
    const ML_KEM_KEY *key)
2305
197
{
2306
197
    uint8_t r[ML_KEM_RANDOM_BYTES];
2307
2308
197
    if (key == NULL)
2309
0
        return 0;
2310
2311
197
    if (RAND_bytes_ex(key->libctx, r, ML_KEM_RANDOM_BYTES,
2312
197
            key->vinfo->secbits)
2313
197
        < 1)
2314
0
        return 0;
2315
2316
197
    return ossl_ml_kem_encap_seed(ctext, clen, shared_secret, slen,
2317
197
        r, sizeof(r), key);
2318
197
}
2319
2320
int ossl_ml_kem_decap(uint8_t *shared_secret, size_t slen,
2321
    const uint8_t *ctext, size_t clen,
2322
    const ML_KEM_KEY *key)
2323
173
{
2324
173
    const ML_KEM_VINFO *vinfo;
2325
173
    EVP_MD_CTX *mdctx;
2326
173
    int ret = 0;
2327
#if defined(OPENSSL_CONSTANT_TIME_VALIDATION)
2328
    int classify_bytes;
2329
#endif
2330
2331
    /* Need a private key here */
2332
173
    if (!ossl_ml_kem_have_prvkey(key))
2333
0
        return 0;
2334
173
    vinfo = key->vinfo;
2335
2336
173
    if (shared_secret == NULL || slen != ML_KEM_SHARED_SECRET_BYTES
2337
173
        || ctext == NULL || clen != vinfo->ctext_bytes
2338
173
        || (mdctx = EVP_MD_CTX_new()) == NULL) {
2339
0
        (void)RAND_bytes_ex(key->libctx, shared_secret,
2340
0
            ML_KEM_SHARED_SECRET_BYTES, vinfo->secbits);
2341
0
        return 0;
2342
0
    }
2343
#if defined(OPENSSL_CONSTANT_TIME_VALIDATION)
2344
    /*
2345
     * Data derived from |s| and |z| defaults secret, and to avoid side-channel
2346
     * leaks should not influence control flow.
2347
     */
2348
    classify_bytes = 2 * sizeof(scalar) + ML_KEM_RANDOM_BYTES;
2349
    CONSTTIME_SECRET(key->s, classify_bytes);
2350
#endif
2351
2352
    /*-
2353
     * This avoids the need to handle allocation failures for two (max 2KB
2354
     * each) vectors and an encoded ciphertext (max 1568 bytes), that are never
2355
     * retained on return from this function.
2356
     * We stack-allocate these.
2357
     */
2358
173
#define case_decap(bits)                                          \
2359
173
    case EVP_PKEY_ML_KEM_##bits: {                                \
2360
173
        uint8_t cbuf[CTEXT_BYTES(bits)];                          \
2361
173
        scalar tmp[2 * ML_KEM_##bits##_RANK];                     \
2362
173
                                                                  \
2363
173
        ret = decap(shared_secret, ctext, cbuf, tmp, mdctx, key); \
2364
173
        OPENSSL_cleanse((void *)tmp, sizeof(tmp));                \
2365
173
        break;                                                    \
2366
173
    }
2367
173
    switch (vinfo->evp_type) {
2368
69
        case_decap(512);
2369
52
        case_decap(768);
2370
52
        case_decap(1024);
2371
173
    }
2372
2373
    /* Declassify secret inputs and derived outputs before returning control */
2374
173
    CONSTTIME_DECLASSIFY(key->s, classify_bytes);
2375
173
    CONSTTIME_DECLASSIFY(shared_secret, slen);
2376
173
    EVP_MD_CTX_free(mdctx);
2377
2378
173
    return ret;
2379
173
#undef case_decap
2380
173
}
2381
2382
int ossl_ml_kem_pubkey_cmp(const ML_KEM_KEY *key1, const ML_KEM_KEY *key2)
2383
159
{
2384
    /*
2385
     * This handles any unexpected differences in the ML-KEM variant rank,
2386
     * giving different key component structures, barring SHA3-256 hash
2387
     * collisions, the keys are the same size.
2388
     */
2389
159
    if (ossl_ml_kem_have_pubkey(key1) && ossl_ml_kem_have_pubkey(key2))
2390
159
        return memcmp(key1->pkhash, key2->pkhash, ML_KEM_PKHASH_BYTES) == 0;
2391
2392
    /*
2393
     * No match if just one of the public keys is not available, otherwise both
2394
     * are unavailable, and for now such keys are considered equal.
2395
     */
2396
0
    return (!(ossl_ml_kem_have_pubkey(key1) ^ ossl_ml_kem_have_pubkey(key2)));
2397
159
}