Coverage Report

Created: 2024-11-21 07:03

/src/SymCrypt/lib/mlkem_primitives.c
Line
Count
Source (jump to first uncovered line)
1
//
2
// mlkem_primitives.c   ML-KEM related functionality
3
//
4
// Copyright (c) Microsoft Corporation. Licensed under the MIT license.
5
//
6
7
#include "precomp.h"
8
9
//
10
// Current approach is to represent polynomial ring elements as a 512-byte buffer (256 UINT16s).
11
//
12
13
// Coefficients are added and subtracted when polynomials are in the NTT domain and in the lattice domain.
14
//
15
// Coefficients are only multiplied in the NTT/INTT operations, and in MulAdd which only operates on
16
// polynomials in NTT form.
17
// We choose to perform modular multiplication exclusively using Montgomery multiplication, that is, we choose
18
// a Montgomery divisor R, and modular multiplication always divides by R, as this make reduction logic easy
19
// and quick.
20
// i.e. MontMul(a,b) -> ((a*b) / R) mod Q
21
//
22
// For powers of Zeta used in as multiplication twiddle factors in NTT/INTT and base polynomial multiplication,
23
// we pre-multiply the constants by R s.t.
24
//  MontMul(x, twiddleForZetaToTheK) -> x*(Zeta^K) mod Q.
25
//
26
// Most other modular multiplication can be done with a fixup deferred until the INTT. The one exception is in key
27
// generation, where A o s + e = t, we need to pre-multiply s'
28
29
// R = 2^16
30
const UINT32 SYMCRYPT_MLKEM_Rlog2 = 16;
31
const UINT32 SYMCRYPT_MLKEM_Rmask = 0xffff;
32
33
// NegQInvModR = -Q^(-1) mod R
34
const UINT32 SYMCRYPT_MLKEM_NegQInvModR = 3327;
35
36
// Rsqr = R^2 = (1<<32) mod Q
37
const UINT32 SYMCRYPT_MLKEM_Rsqr = 1353;
38
// RsqrTimesNegQInvModR = R^2 = ((1<<32) mod Q) * -Q^(-1) mod R
39
const UINT32 SYMCRYPT_MLKEM_RsqrTimesNegQInvModR = 44983;
40
41
//
42
// Zeta tables.
43
// Zeta = 17, which is a primitive 256-th root of unity modulo Q
44
//
45
// In ML-KEM we use powers of zeta to convert to and from NTT form
46
// and to perform multiplication between polynomials in NTT form
47
//
48
49
// This table is a lookup for (Zeta^(BitRev(index)) * R) mod Q
50
// Used in NTT and INTT
51
// i.e. element 1 is Zeta^(BitRev(1)) * (2^16) mod Q == (17^64)*(2^16) mod 3329 == 2571
52
//
53
// MlKemZetaBitRevTimesR = [ (pow(17, bitRev(i), 3329) << 16) % 3329 for i in range(128) ]
54
const UINT16 MlKemZetaBitRevTimesR[128] =
55
{
56
    2285, 2571, 2970, 1812, 1493, 1422,  287,  202,
57
    3158,  622, 1577,  182,  962, 2127, 1855, 1468,
58
     573, 2004,  264,  383, 2500, 1458, 1727, 3199,
59
    2648, 1017,  732,  608, 1787,  411, 3124, 1758,
60
    1223,  652, 2777, 1015, 2036, 1491, 3047, 1785,
61
     516, 3321, 3009, 2663, 1711, 2167,  126, 1469,
62
    2476, 3239, 3058,  830,  107, 1908, 3082, 2378,
63
    2931,  961, 1821, 2604,  448, 2264,  677, 2054,
64
    2226,  430,  555,  843, 2078,  871, 1550,  105,
65
     422,  587,  177, 3094, 3038, 2869, 1574, 1653,
66
    3083,  778, 1159, 3182, 2552, 1483, 2727, 1119,
67
    1739,  644, 2457,  349,  418,  329, 3173, 3254,
68
     817, 1097,  603,  610, 1322, 2044, 1864,  384,
69
    2114, 3193, 1218, 1994, 2455,  220, 2142, 1670,
70
    2144, 1799, 2051,  794, 1819, 2475, 2459,  478,
71
    3221, 3021,  996,  991,  958, 1869, 1522, 1628,
72
};
73
74
// This table is a lookup for ((Zeta^(BitRev(index)) * R) mod Q) * -Q^(-1) mod R
75
// Used in NTT and INTT
76
//
77
// MlKemZetaBitRevTimesRTimesNegQInvModR = [ (((pow(17, bitRev(i), Q) << 16) % Q) * 3327) & 0xffff for i in range(128) ]
78
const UINT16 MlKemZetaBitRevTimesRTimesNegQInvModR[128] =
79
{
80
       19, 34037, 50790, 64748, 52011, 12402, 37345, 16694,
81
    20906, 37778,  3799, 15690, 54846, 64177, 11201, 34372,
82
     5827, 48172, 26360, 29057, 59964,  1102, 44097, 26241,
83
    28072, 41223, 10532, 56736, 47109, 56677, 38860, 16162,
84
     5689,  6516, 64039, 34569, 23564, 45357, 44825, 40455,
85
    12796, 38919, 49471, 12441, 56401,   649, 25986, 37699,
86
    45652, 28249, 15886,  8898, 28309, 56460, 30198, 47286,
87
    52109, 51519, 29155, 12756, 48704, 61224, 24155, 17914,
88
      334, 54354, 11477, 52149, 32226, 14233, 45042, 21655,
89
    27738, 52405, 64591,  4586, 14882, 42443, 59354, 60043,
90
    33525, 32502, 54905, 35218, 36360, 18741, 28761, 52897,
91
    18485, 45436, 47975, 47011, 14430, 46007,  5275, 12618,
92
    31183, 45239, 40101, 63390,  7382, 50180, 41144, 32384,
93
    20926,  6279, 54590, 14902, 41321, 11044, 48546, 51066,
94
    55200, 21497,  7933, 20198, 22501, 42325, 54629, 17442,
95
    33899, 23859, 36892, 20257, 41538, 57779, 17422, 42404,
96
};
97
98
// This table is a lookup for ((Zeta^(2*BitRev(index) + 1) * R) mod Q)
99
// Used in multiplication of 2 NTT-form polynomials
100
//
101
// zetaTwoTimesBitRevPlus1TimesR =  [ (pow(17, 2*bitRev(i)+1, 3329) << 16) % 3329 for i in range(128) ]
102
const UINT16 zetaTwoTimesBitRevPlus1TimesR[128] =
103
{
104
    2226, 1103,  430, 2899,  555, 2774,  843, 2486,
105
    2078, 1251,  871, 2458, 1550, 1779,  105, 3224,
106
     422, 2907,  587, 2742,  177, 3152, 3094,  235,
107
    3038,  291, 2869,  460, 1574, 1755, 1653, 1676,
108
    3083,  246,  778, 2551, 1159, 2170, 3182,  147,
109
    2552,  777, 1483, 1846, 2727,  602, 1119, 2210,
110
    1739, 1590,  644, 2685, 2457,  872,  349, 2980,
111
     418, 2911,  329, 3000, 3173,  156, 3254,   75,
112
     817, 2512, 1097, 2232,  603, 2726,  610, 2719,
113
    1322, 2007, 2044, 1285, 1864, 1465,  384, 2945,
114
    2114, 1215, 3193,  136, 1218, 2111, 1994, 1335,
115
    2455,  874,  220, 3109, 2142, 1187, 1670, 1659,
116
    2144, 1185, 1799, 1530, 2051, 1278,  794, 2535,
117
    1819, 1510, 2475,  854, 2459,  870,  478, 2851,
118
    3221,  108, 3021,  308,  996, 2333,  991, 2338,
119
     958, 2371, 1869, 1460, 1522, 1807, 1628, 1701,
120
};
121
122
PSYMCRYPT_MLKEM_POLYELEMENT
123
SYMCRYPT_CALL
124
SymCryptMlKemPolyElementCreate(
125
    _Out_writes_bytes_( cbBuffer )  PBYTE   pbBuffer,
126
                                    UINT32  cbBuffer )
127
0
{
128
0
    PSYMCRYPT_MLKEM_POLYELEMENT pDst = (PSYMCRYPT_MLKEM_POLYELEMENT) pbBuffer;
129
130
0
    UNREFERENCED_PARAMETER( cbBuffer );
131
132
0
    SYMCRYPT_ASSERT_ASYM_ALIGNED( pbBuffer );
133
0
    SYMCRYPT_ASSERT( cbBuffer == SYMCRYPT_INTERNAL_MLKEM_SIZEOF_POLYRINGELEMENT );
134
135
0
    return pDst;
136
0
}
137
138
PSYMCRYPT_MLKEM_POLYELEMENT_ACCUMULATOR
139
SYMCRYPT_CALL
140
SymCryptMlKemPolyElementAccumulatorCreate(
141
    _Out_writes_bytes_( cbBuffer )  PBYTE   pbBuffer,
142
                                    UINT32  cbBuffer )
143
0
{
144
0
    PSYMCRYPT_MLKEM_POLYELEMENT_ACCUMULATOR pDst = (PSYMCRYPT_MLKEM_POLYELEMENT_ACCUMULATOR) pbBuffer;
145
146
0
    UNREFERENCED_PARAMETER( cbBuffer );
147
148
0
    SYMCRYPT_ASSERT_ASYM_ALIGNED( pbBuffer );
149
0
    SYMCRYPT_ASSERT( cbBuffer == SYMCRYPT_INTERNAL_MLKEM_SIZEOF_POLYRINGELEMENT_ACCUMULATOR );
150
151
0
    return pDst;
152
0
}
153
154
PSYMCRYPT_MLKEM_VECTOR
155
SYMCRYPT_CALL
156
SymCryptMlKemVectorCreate(
157
    _Out_writes_bytes_( cbBuffer )  PBYTE   pbBuffer,
158
                                    UINT32  cbBuffer,
159
                                    UINT32  nRows )
160
0
{
161
0
    PSYMCRYPT_MLKEM_VECTOR pDst = NULL;
162
0
    PSYMCRYPT_MLKEM_VECTOR pVector = (PSYMCRYPT_MLKEM_VECTOR)pbBuffer;
163
0
    PSYMCRYPT_MLKEM_POLYELEMENT peTmp = NULL;
164
0
    UINT32 i;
165
0
    PBYTE pbTmp = pbBuffer + sizeof(SYMCRYPT_MLKEM_VECTOR);
166
167
0
    SYMCRYPT_ASSERT_ASYM_ALIGNED( pbBuffer );
168
169
0
    SYMCRYPT_ASSERT( nRows > 0 );
170
0
    SYMCRYPT_ASSERT( nRows <= SYMCRYPT_MLKEM_MATRIX_MAX_NROWS );
171
172
0
    pVector->nRows = nRows;
173
0
    pVector->cbTotalSize = cbBuffer;
174
175
0
    for( i=0; i<nRows; i++ )
176
0
    {
177
0
        peTmp = SymCryptMlKemPolyElementCreate( pbTmp, SYMCRYPT_INTERNAL_MLKEM_SIZEOF_POLYRINGELEMENT );
178
0
        if( peTmp == NULL )
179
0
        {
180
0
            goto cleanup;
181
0
        }
182
183
0
        pbTmp += SYMCRYPT_INTERNAL_MLKEM_SIZEOF_POLYRINGELEMENT;
184
0
    }
185
186
0
    SYMCRYPT_ASSERT( pbTmp == (pbBuffer + cbBuffer) );
187
188
0
    pDst = pVector;
189
190
0
cleanup:
191
0
    return pDst;
192
0
}
193
194
PSYMCRYPT_MLKEM_MATRIX
195
SYMCRYPT_CALL
196
SymCryptMlKemMatrixCreate(
197
    _Out_writes_bytes_( cbBuffer )  PBYTE   pbBuffer,
198
                                    UINT32  cbBuffer,
199
                                    UINT32  nRows )
200
0
{
201
0
    PSYMCRYPT_MLKEM_MATRIX pDst = NULL;
202
0
    PSYMCRYPT_MLKEM_MATRIX pMatrix = (PSYMCRYPT_MLKEM_MATRIX)pbBuffer;
203
0
    UINT32 i;
204
0
    PBYTE pbTmp = pbBuffer + sizeof(SYMCRYPT_MLKEM_MATRIX);
205
206
0
    SYMCRYPT_ASSERT_ASYM_ALIGNED( pbBuffer );
207
208
0
    SYMCRYPT_ASSERT( nRows > 0 );
209
0
    SYMCRYPT_ASSERT( nRows <= SYMCRYPT_MLKEM_MATRIX_MAX_NROWS );
210
211
0
    pMatrix->nRows = nRows;
212
0
    pMatrix->cbTotalSize = cbBuffer;
213
214
0
    for( i=0; i<(nRows*nRows); i++ )
215
0
    {
216
0
        pMatrix->apPolyElements[i] = SymCryptMlKemPolyElementCreate( pbTmp, SYMCRYPT_INTERNAL_MLKEM_SIZEOF_POLYRINGELEMENT );
217
0
        if( pMatrix->apPolyElements[i] == NULL )
218
0
        {
219
0
            goto cleanup;
220
0
        }
221
222
0
        pbTmp += SYMCRYPT_INTERNAL_MLKEM_SIZEOF_POLYRINGELEMENT;
223
0
    }
224
225
0
    SYMCRYPT_ASSERT( pbTmp == (pbBuffer + cbBuffer) );
226
227
0
    pDst = pMatrix;
228
229
0
cleanup:
230
0
    return pDst;
231
0
}
232
233
#if SYMCRYPT_CPU_AMD64 | SYMCRYPT_CPU_X86 | SYMCRYPT_CPU_ARM64
234
235
#if SYMCRYPT_CPU_AMD64 | SYMCRYPT_CPU_X86
236
0
#define VEC128_TYPE_UINT16 __m128i
237
238
0
#define VEC128_LOAD_UINT16( addr )       _mm_loadu_si128( (__m128i*) (addr) )
239
0
#define VEC64_LOAD_UINT16( addr )        _mm_loadu_si64( (PBYTE) (addr) )
240
0
#define VEC32_LOAD_UINT16( addr )        _mm_cvtsi32_si128( SYMCRYPT_LOAD_LSBFIRST32( addr ) )
241
242
0
#define VEC128_STORE_UINT16( addr, vec ) _mm_storeu_si128( (__m128i*) (addr), (vec) )
243
0
#define VEC64_STORE_UINT16( addr, vec )  _mm_storeu_si64( (PBYTE) (addr), (vec) )
244
0
#define VEC32_STORE_UINT16( addr, vec )  SYMCRYPT_STORE_LSBFIRST32( (addr), _mm_cvtsi128_si32( vec ) )
245
246
0
#define VEC128_SET_UINT16( value )       _mm_set1_epi16( (value) )
247
248
#define VEC128_MOD_SUB_UINT16( res, a, b, Q, zero, tmp1 ) \
249
    /* res = a - b */ \
250
0
    res = _mm_sub_epi16( a, b ); \
251
0
    /* tmp1 = (a - b) < 0 ? -1 : 0 */ \
252
0
    tmp1 = _mm_cmpgt_epi16( zero, res ); \
253
0
    /* tmp1 = (a - b) < 0 ? Q : 0 */ \
254
0
    tmp1 = _mm_and_si128( tmp1, Q ); \
255
0
    /* res = (a - b) mod Q */ \
256
0
    res = _mm_add_epi16( res, tmp1 );
257
258
#define VEC128_MOD_ADD_UINT16( res, a, b, Q, tmp1 ) \
259
    /* res = a + b */ \
260
0
    res = _mm_add_epi16( a, b ); \
261
0
    /* tmp1 = (a + b) < Q ? -1 : 0 */ \
262
0
    tmp1 = _mm_cmpgt_epi16( Q, res ); \
263
0
    /* tmp1 = (a + b) < Q ? 0 : Q */ \
264
0
    tmp1 = _mm_andnot_si128( tmp1, Q ); \
265
0
    /* res = (a + b) mod Q */ \
266
0
    res = _mm_sub_epi16( res, tmp1 );
267
268
#define VEC128_MONTGOMERY_MUL_UINT16( res, a, b, bTimesNegQInvModR, Q, zero, one, tmp1, tmp2 ) \
269
    /* tmp1 = a *low  bTimesNegQInvModR */ \
270
0
    tmp1 = _mm_mullo_epi16( a, bTimesNegQInvModR ); \
271
0
    /* res  = a *high b */ \
272
0
    res = _mm_mulhi_epu16( a, b ); \
273
0
    /* tmp2 = (tmp1 == 0) ? -1 : 0 */ \
274
0
    tmp2 = _mm_cmpeq_epi16( tmp1, zero ); \
275
0
    /* tmp1 = (a *low bTimesNegQInvModR) *high Q */ \
276
0
    tmp1 = _mm_mulhi_epu16( tmp1, Q ); \
277
0
    /* res = a *high b + 1 */ \
278
0
    res = _mm_add_epi16( res, one ); \
279
0
    /* res  = a *high b (+ 1 if a != 0) */ \
280
0
    res = _mm_add_epi16( res, tmp2 ); \
281
0
    /* res  = a *high b + inv*Q (+ 1 if a != 0) */ \
282
0
    res = _mm_add_epi16( res, tmp1 ); \
283
0
    /* res  = (a*b + inv*Q >> 16) mod Q */ \
284
0
    VEC128_MOD_SUB_UINT16( res, res, Q, Q, zero, tmp1 );
285
286
#elif SYMCRYPT_CPU_ARM64
287
288
#define VEC128_TYPE_UINT16 uint16x8_t
289
290
#define VEC128_LOAD_UINT16( addr )       vld1q_u16( addr )
291
#define VEC64_LOAD_UINT16( addr )        vld1q_dup_u64( addr )
292
#define VEC32_LOAD_UINT16( addr )        vld1q_dup_u32( addr )
293
294
#define VEC128_STORE_UINT16( addr, vec ) vst1q_u16( (addr), (vec) )
295
#define VEC64_STORE_UINT16( addr, vec )  vst1_u16( (PBYTE) (addr), vget_low_u16(vec) )
296
#define VEC32_STORE_UINT16( addr, vec )  vst1_lane_u32( (PBYTE) (addr), vget_low_u32(vec), 0 )
297
298
#define VEC128_SET_UINT16( value )       vdupq_n_u16( (value) )
299
300
#define VEC128_MOD_SUB_UINT16( res, a, b, Q, zero, tmp1 ) \
301
    /* res = a - b */ \
302
    res = vsubq_u16( a, b ); \
303
    /* tmp1 = (a - b) < 0 ? -1 : 0 */ \
304
    tmp1 = vcltzq_s16( res ); \
305
    /* tmp1 = (a - b) < 0 ? Q : 0 */ \
306
    tmp1 = vandq_u16( tmp1, Q ); \
307
    /* res = (a - b) mod Q */ \
308
    res = vaddq_u16( res, tmp1 );
309
310
#define VEC128_MOD_ADD_UINT16( res, a, b, Q, tmp1 ) \
311
    /* res = a + b */ \
312
    res = vaddq_u16( a, b ); \
313
    /* tmp1 = (a + b) >= Q ? -1 : 0 */ \
314
    tmp1 = vcgeq_u16( res, Q ); \
315
    /* tmp1 = (a + b) >= Q ? Q : 0 */ \
316
    tmp1 = vandq_u16( tmp1, Q ); \
317
    /* res = (a + b) mod Q */ \
318
    res = vsubq_u16( res, tmp1 );
319
320
#define VEC128_MONTGOMERY_MUL_UINT16( res, a, b, bTimesNegQInvModR, Q, zero, one, tmp1, tmp2 ) \
321
    /* tmp1 = a *low  bTimesNegQInvModR */ \
322
    tmp1 = vmulq_u16( a, bTimesNegQInvModR ); \
323
    /* tmp2 = a*b [0-3]*/ \
324
    tmp2  = vmull_u16( vget_low_u16(a), vget_low_u16(b) ); \
325
    /* res  = a*b [4-7]*/ \
326
    res = vmull_high_u16( a, b ); \
327
    /* tmp2 = a*b + inv*Q [0-3]*/ \
328
    tmp2  = vmlal_u16( tmp2, vget_low_u16(tmp1), vget_low_u16(Q) ); \
329
    /* res  = a*b + inv*Q [4-7]*/ \
330
    res = vmlal_high_u16( res, tmp1, Q ); \
331
    /* res  = a*b + inv*Q >> 16 */ \
332
    res  = vuzp2q_u16( tmp2, res ); \
333
    /* res  = (a*b + inv*Q >> 16) mod Q */ \
334
    VEC128_MOD_SUB_UINT16( res, res, Q, Q, zero, tmp1 );
335
336
#endif
337
338
FORCEINLINE
339
VOID
340
SYMCRYPT_CALL
341
SymCryptMlKemPolyElementNTTLayerVec128(
342
    _Inout_ PSYMCRYPT_MLKEM_POLYELEMENT  peSrc,
343
            UINT32                       k,
344
            UINT32                       len )
345
0
{
346
0
    UINT32 start, j;
347
0
    VEC128_TYPE_UINT16 vc0, vc1, vTmp0, vTmp1, vc1Twiddle, vTwiddleFactor, vTwiddleFactorMont, vQ, vZero, vOne;
348
349
0
    SYMCRYPT_ASSERT( len >= 2 );
350
351
0
    vQ = VEC128_SET_UINT16( SYMCRYPT_MLKEM_Q );
352
0
    vZero = VEC128_SET_UINT16( 0 );
353
0
    vOne = VEC128_SET_UINT16( 1 );
354
355
0
    for( start=0; start<256; start+=(2*len) )
356
0
    {
357
0
        vTwiddleFactor     = VEC128_SET_UINT16( MlKemZetaBitRevTimesR[k] );
358
0
        vTwiddleFactorMont = VEC128_SET_UINT16( MlKemZetaBitRevTimesRTimesNegQInvModR[k] );
359
0
        k++;
360
0
        for( j=0; j<len; j+=8 )
361
0
        {
362
0
            if( len >= 8 )
363
0
            {
364
0
                vc0 = VEC128_LOAD_UINT16( &(peSrc->coeffs[start+j]    ) );
365
0
                vc1 = VEC128_LOAD_UINT16( &(peSrc->coeffs[start+j+len]) );
366
0
            }
367
0
            else if ( len == 4 )
368
0
            {
369
0
                vc0 = VEC64_LOAD_UINT16( &(peSrc->coeffs[start+j]    ) );
370
0
                vc1 = VEC64_LOAD_UINT16( &(peSrc->coeffs[start+j+len]) );
371
0
            }
372
0
            else /*if ( len == 2 )*/
373
0
            {
374
0
                vc0 = VEC32_LOAD_UINT16( &(peSrc->coeffs[start+j]    ) );
375
0
                vc1 = VEC32_LOAD_UINT16( &(peSrc->coeffs[start+j+len]) );
376
0
            }
377
378
            // c1TimesTwiddle = twiddleFactor * c1 mod Q;
379
0
            VEC128_MONTGOMERY_MUL_UINT16( vc1Twiddle, vc1, vTwiddleFactor, vTwiddleFactorMont, vQ, vZero, vOne, vTmp0, vTmp1 );
380
            // c1 = c0 - c1TimesTwiddle mod Q
381
0
            VEC128_MOD_SUB_UINT16( vc1, vc0, vc1Twiddle, vQ, vZero, vTmp0 );
382
            // c0 = c0 + c1TimesTwiddle mod Q
383
0
            VEC128_MOD_ADD_UINT16( vc0, vc0, vc1Twiddle, vQ, vTmp1 );
384
385
0
            if( len >= 8 )
386
0
            {
387
0
                VEC128_STORE_UINT16( &(peSrc->coeffs[start+j]    ), vc0 );
388
0
                VEC128_STORE_UINT16( &(peSrc->coeffs[start+j+len]), vc1 );
389
0
            }
390
0
            else if ( len == 4 )
391
0
            {
392
0
                VEC64_STORE_UINT16( &(peSrc->coeffs[start+j]    ), vc0 );
393
0
                VEC64_STORE_UINT16( &(peSrc->coeffs[start+j+len]), vc1 );
394
0
            }
395
0
            else /*if ( len == 2 )*/
396
0
            {
397
0
                VEC32_STORE_UINT16( &(peSrc->coeffs[start+j]    ), vc0 );
398
0
                VEC32_STORE_UINT16( &(peSrc->coeffs[start+j+len]), vc1 );
399
0
            }
400
0
        }
401
0
    }
402
0
}
403
404
FORCEINLINE
405
VOID
406
SYMCRYPT_CALL
407
SymCryptMlKemPolyElementINTTLayerVec128(
408
    _Inout_ PSYMCRYPT_MLKEM_POLYELEMENT  peSrc,
409
            UINT32                       k,
410
            UINT32                       len )
411
0
{
412
0
    UINT32 start, j;
413
0
    VEC128_TYPE_UINT16 vc0, vc1, vTmp0, vTmp1, vTmp2, vTwiddleFactor, vTwiddleFactorMont, vQ, vZero, vOne;
414
415
0
    SYMCRYPT_ASSERT( len >= 2 );
416
417
0
    vQ = VEC128_SET_UINT16( SYMCRYPT_MLKEM_Q );
418
0
    vZero = VEC128_SET_UINT16( 0 );
419
0
    vOne = VEC128_SET_UINT16( 1 );
420
421
0
    for( start=0; start<256; start+=(2*len) )
422
0
    {
423
0
        vTwiddleFactor     = VEC128_SET_UINT16( MlKemZetaBitRevTimesR[k] );
424
0
        vTwiddleFactorMont = VEC128_SET_UINT16( MlKemZetaBitRevTimesRTimesNegQInvModR[k] );
425
0
        k--;
426
0
        for( j=0; j<len; j+=8 )
427
0
        {
428
0
            if( len >= 8 )
429
0
            {
430
0
                vc0 = VEC128_LOAD_UINT16( &(peSrc->coeffs[start+j]    ) );
431
0
                vc1 = VEC128_LOAD_UINT16( &(peSrc->coeffs[start+j+len]) );
432
0
            }
433
0
            else if ( len == 4 )
434
0
            {
435
0
                vc0 = VEC64_LOAD_UINT16( &(peSrc->coeffs[start+j]    ) );
436
0
                vc1 = VEC64_LOAD_UINT16( &(peSrc->coeffs[start+j+len]) );
437
0
            }
438
0
            else /*if ( len == 2 )*/
439
0
            {
440
0
                vc0 = VEC32_LOAD_UINT16( &(peSrc->coeffs[start+j]    ) );
441
0
                vc1 = VEC32_LOAD_UINT16( &(peSrc->coeffs[start+j+len]) );
442
0
            }
443
444
            // tmp = c0 + c1 mod Q
445
0
            VEC128_MOD_ADD_UINT16( vTmp2, vc0, vc1, vQ, vTmp0 );
446
            // c1 = c1 - c0 mod Q
447
0
            VEC128_MOD_SUB_UINT16( vc1, vc1, vc0, vQ, vZero, vTmp1 );
448
            // c1 = twiddleFactor * c1;
449
0
            VEC128_MONTGOMERY_MUL_UINT16( vc1, vc1, vTwiddleFactor, vTwiddleFactorMont, vQ, vZero, vOne, vTmp0, vTmp1 );
450
451
0
            if( len >= 8 )
452
0
            {
453
0
                VEC128_STORE_UINT16( &(peSrc->coeffs[start+j]    ), vTmp2 );
454
0
                VEC128_STORE_UINT16( &(peSrc->coeffs[start+j+len]), vc1 );
455
0
            }
456
0
            else if ( len == 4 )
457
0
            {
458
0
                VEC64_STORE_UINT16( &(peSrc->coeffs[start+j]    ), vTmp2 );
459
0
                VEC64_STORE_UINT16( &(peSrc->coeffs[start+j+len]), vc1 );
460
0
            }
461
0
            else /*if ( len == 2 )*/
462
0
            {
463
0
                VEC32_STORE_UINT16( &(peSrc->coeffs[start+j]    ), vTmp2 );
464
0
                VEC32_STORE_UINT16( &(peSrc->coeffs[start+j+len]), vc1 );
465
0
            }
466
0
        }
467
0
    }
468
0
}
469
470
#endif
471
472
FORCEINLINE
473
UINT32
474
SYMCRYPT_CALL
475
SymCryptMlKemModAdd(
476
    UINT32 a,
477
    UINT32 b )
478
0
{
479
0
    UINT32 res;
480
481
0
    SYMCRYPT_ASSERT( a < SYMCRYPT_MLKEM_Q );
482
0
    SYMCRYPT_ASSERT( b < SYMCRYPT_MLKEM_Q );
483
484
0
    res = a + b - SYMCRYPT_MLKEM_Q;
485
0
    SYMCRYPT_ASSERT( ((res >> 16) == 0) || ((res >> 16) == 0xffff) );
486
0
    res = res + (SYMCRYPT_MLKEM_Q & (res >> 16));
487
0
    SYMCRYPT_ASSERT( res < SYMCRYPT_MLKEM_Q );
488
489
0
    return res;
490
0
}
491
492
FORCEINLINE
493
UINT32
494
SYMCRYPT_CALL
495
SymCryptMlKemModSub(
496
    UINT32 a,
497
    UINT32 b )
498
0
{
499
0
    UINT32 res;
500
501
0
    SYMCRYPT_ASSERT( a < 2*SYMCRYPT_MLKEM_Q );
502
0
    SYMCRYPT_ASSERT( b <= SYMCRYPT_MLKEM_Q );
503
504
0
    res = a - b;
505
0
    SYMCRYPT_ASSERT( ((res >> 16) == 0) || ((res >> 16) == 0xffff) );
506
0
    res = res + (SYMCRYPT_MLKEM_Q & (res >> 16));
507
0
    SYMCRYPT_ASSERT( res < SYMCRYPT_MLKEM_Q );
508
509
0
    return res;
510
0
}
511
512
FORCEINLINE
513
UINT32
514
SYMCRYPT_CALL
515
SymCryptMlKemMontMul(
516
    UINT32 a,
517
    UINT32 b,
518
    UINT32 bMont )
519
0
{
520
0
    UINT32 res, inv;
521
522
0
    SYMCRYPT_ASSERT( a < SYMCRYPT_MLKEM_Q );
523
0
    SYMCRYPT_ASSERT( b < SYMCRYPT_MLKEM_Q );
524
0
    SYMCRYPT_ASSERT( bMont <= SYMCRYPT_MLKEM_Rmask );
525
0
    SYMCRYPT_ASSERT( bMont == ((b * SYMCRYPT_MLKEM_NegQInvModR) & SYMCRYPT_MLKEM_Rmask) );
526
527
0
    res = a * b;
528
0
    inv = (a * bMont) & SYMCRYPT_MLKEM_Rmask;
529
0
    res += inv * SYMCRYPT_MLKEM_Q;
530
0
    SYMCRYPT_ASSERT( (res & SYMCRYPT_MLKEM_Rmask) == 0 );
531
0
    res = res >> SYMCRYPT_MLKEM_Rlog2;
532
533
0
    return SymCryptMlKemModSub( res, SYMCRYPT_MLKEM_Q );
534
0
}
535
536
VOID
537
SYMCRYPT_CALL
538
SymCryptMlKemPolyElementNTTLayerC(
539
    _Inout_ PSYMCRYPT_MLKEM_POLYELEMENT  peSrc,
540
            UINT32                       k,
541
            UINT32                       len )
542
0
{
543
0
    UINT32 start, j;
544
0
    UINT32 twiddleFactor, twiddleFactorMont, c0, c1, c1TimesTwiddle;
545
546
0
    for( start=0; start<256; start+=(2*len) )
547
0
    {
548
0
        twiddleFactor = MlKemZetaBitRevTimesR[k];
549
0
        twiddleFactorMont = MlKemZetaBitRevTimesRTimesNegQInvModR[k];
550
0
        k++;
551
0
        for( j=0; j<len; j++ )
552
0
        {
553
0
            c0 = peSrc->coeffs[start+j];
554
0
            SYMCRYPT_ASSERT( c0 < SYMCRYPT_MLKEM_Q );
555
0
            c1 = peSrc->coeffs[start+j+len];
556
0
            SYMCRYPT_ASSERT( c1 < SYMCRYPT_MLKEM_Q );
557
558
0
            c1TimesTwiddle = SymCryptMlKemMontMul( c1, twiddleFactor, twiddleFactorMont );
559
0
            c1 = SymCryptMlKemModSub( c0, c1TimesTwiddle );
560
0
            c0 = SymCryptMlKemModAdd( c0, c1TimesTwiddle );
561
562
0
            peSrc->coeffs[start+j]      = (UINT16) c0;
563
0
            peSrc->coeffs[start+j+len]  = (UINT16) c1;
564
0
        }
565
0
    }
566
0
}
567
568
VOID
569
SYMCRYPT_CALL
570
SymCryptMlKemPolyElementINTTLayerC(
571
    _Inout_ PSYMCRYPT_MLKEM_POLYELEMENT  peSrc,
572
            UINT32                       k,
573
            UINT32                       len )
574
0
{
575
0
    UINT32 start, j;
576
0
    UINT32 twiddleFactor, twiddleFactorMont, c0, c1, tmp;
577
578
0
    for( start=0; start<256; start+=(2*len) )
579
0
    {
580
0
        twiddleFactor = MlKemZetaBitRevTimesR[k];
581
0
        twiddleFactorMont = MlKemZetaBitRevTimesRTimesNegQInvModR[k];
582
0
        k--;
583
0
        for( j=0; j<len; j++ )
584
0
        {
585
0
            c0 = peSrc->coeffs[start+j];
586
0
            SYMCRYPT_ASSERT( c0 < SYMCRYPT_MLKEM_Q );
587
0
            c1 = peSrc->coeffs[start+j+len];
588
0
            SYMCRYPT_ASSERT( c1 < SYMCRYPT_MLKEM_Q );
589
590
0
            tmp = SymCryptMlKemModAdd( c0, c1 );
591
0
            c1 = SymCryptMlKemModSub( c1, c0 );
592
0
            c1 = SymCryptMlKemMontMul( c1, twiddleFactor, twiddleFactorMont );
593
594
0
            peSrc->coeffs[start+j]      = (UINT16) tmp;
595
0
            peSrc->coeffs[start+j+len]  = (UINT16) c1;
596
0
        }
597
0
    }
598
0
}
599
600
FORCEINLINE
601
VOID
602
SYMCRYPT_CALL
603
SymCryptMlKemPolyElementNTTLayer(
604
    _Inout_ PSYMCRYPT_MLKEM_POLYELEMENT  peSrc,
605
            UINT32                       k,
606
            UINT32                       len )
607
0
{
608
#if SYMCRYPT_CPU_X86
609
    SYMCRYPT_EXTENDED_SAVE_DATA  SaveData;
610
    if( SYMCRYPT_CPU_FEATURES_PRESENT( SYMCRYPT_CPU_FEATURE_SSE2 ) && SymCryptSaveXmm( &SaveData ) == SYMCRYPT_NO_ERROR )
611
    {
612
        SymCryptMlKemPolyElementNTTLayerVec128( peSrc, k, len );
613
        SymCryptRestoreXmm( &SaveData );
614
    } else {
615
        SymCryptMlKemPolyElementNTTLayerC( peSrc, k, len );
616
    }
617
#elif SYMCRYPT_CPU_AMD64
618
0
    if( SYMCRYPT_CPU_FEATURES_PRESENT( SYMCRYPT_CPU_FEATURE_SSE2 ) )
619
0
    {
620
0
        SymCryptMlKemPolyElementNTTLayerVec128( peSrc, k, len );
621
0
    } else {
622
0
        SymCryptMlKemPolyElementNTTLayerC( peSrc, k, len );
623
0
    }
624
#elif SYMCRYPT_CPU_ARM64
625
    if( SYMCRYPT_CPU_FEATURES_PRESENT( SYMCRYPT_CPU_FEATURE_NEON ) )
626
    {
627
        SymCryptMlKemPolyElementNTTLayerVec128( peSrc, k, len );
628
    } else {
629
        SymCryptMlKemPolyElementNTTLayerC( peSrc, k, len );
630
    }
631
#else
632
    SymCryptMlKemPolyElementNTTLayerC( peSrc, k, len );
633
#endif
634
0
}
635
636
FORCEINLINE
637
VOID
638
SYMCRYPT_CALL
639
SymCryptMlKemPolyElementINTTLayer(
640
    _Inout_ PSYMCRYPT_MLKEM_POLYELEMENT  peSrc,
641
            UINT32                       k,
642
            UINT32                       len )
643
0
{
644
#if SYMCRYPT_CPU_X86
645
    SYMCRYPT_EXTENDED_SAVE_DATA  SaveData;
646
    if( SYMCRYPT_CPU_FEATURES_PRESENT( SYMCRYPT_CPU_FEATURE_SSE2 ) && SymCryptSaveXmm( &SaveData ) == SYMCRYPT_NO_ERROR )
647
    {
648
        SymCryptMlKemPolyElementINTTLayerVec128( peSrc, k, len );
649
        SymCryptRestoreXmm( &SaveData );
650
    } else {
651
        SymCryptMlKemPolyElementINTTLayerC( peSrc, k, len );
652
    }
653
#elif SYMCRYPT_CPU_AMD64
654
0
    if( SYMCRYPT_CPU_FEATURES_PRESENT( SYMCRYPT_CPU_FEATURE_SSE2 ) )
655
0
    {
656
0
        SymCryptMlKemPolyElementINTTLayerVec128( peSrc, k, len );
657
0
    } else {
658
0
        SymCryptMlKemPolyElementINTTLayerC( peSrc, k, len );
659
0
    }
660
#elif SYMCRYPT_CPU_ARM64
661
    if( SYMCRYPT_CPU_FEATURES_PRESENT( SYMCRYPT_CPU_FEATURE_NEON ) )
662
    {
663
        SymCryptMlKemPolyElementINTTLayerVec128( peSrc, k, len );
664
    } else {
665
        SymCryptMlKemPolyElementINTTLayerC( peSrc, k, len );
666
    }
667
#else
668
    SymCryptMlKemPolyElementINTTLayerC( peSrc, k, len );
669
#endif
670
0
}
671
672
VOID
673
SYMCRYPT_CALL
674
SymCryptMlKemPolyElementMulAndAccumulate(
675
    _In_    PCSYMCRYPT_MLKEM_POLYELEMENT            peSrc1,
676
    _In_    PCSYMCRYPT_MLKEM_POLYELEMENT            peSrc2,
677
    _Inout_ PSYMCRYPT_MLKEM_POLYELEMENT_ACCUMULATOR paDst )
678
0
{
679
0
    UINT32 i;
680
0
    UINT32 a0, a1, b0, b1, c0, c1;
681
0
    UINT32 a0b0, a1b1, a0b1, a1b0, a1b1zetapow, inv;
682
683
0
    for( i=0; i<(SYMCRYPT_MLWE_POLYNOMIAL_COEFFICIENTS / 2); i++ )
684
0
    {
685
0
        a0 = peSrc1->coeffs[(2*i)  ];
686
0
        SYMCRYPT_ASSERT( a0 < SYMCRYPT_MLKEM_Q );
687
0
        a1 = peSrc1->coeffs[(2*i)+1];
688
0
        SYMCRYPT_ASSERT( a1 < SYMCRYPT_MLKEM_Q );
689
690
0
        b0 = peSrc2->coeffs[(2*i)  ];
691
0
        SYMCRYPT_ASSERT( b0 < SYMCRYPT_MLKEM_Q );
692
0
        b1 = peSrc2->coeffs[(2*i)+1];
693
0
        SYMCRYPT_ASSERT( b1 < SYMCRYPT_MLKEM_Q );
694
695
0
        c0 = paDst->coeffs[(2*i)  ];
696
0
        SYMCRYPT_ASSERT( c0 <= 3*((3328*3328) + (3494*3312)) );
697
0
        c1 = paDst->coeffs[(2*i)+1];
698
0
        SYMCRYPT_ASSERT( c1 <= 3*((3328*3328) + (3494*3312)) );
699
700
        // multiplication results in range [0, 3328*3328]
701
0
        a0b0 = a0 * b0;
702
0
        a1b1 = a1 * b1;
703
0
        a0b1 = a0 * b1;
704
0
        a1b0 = a1 * b0;
705
706
        // we need a1*b1*zetaTwoTimesBitRevPlus1TimesR[i]
707
        // eagerly reduce a1*b1 with montgomery reduction
708
        // a1b1 = red(a1*b1) -> range [0,3494]
709
        //   (3494 is maximum result of first step of montgomery reduction of x*y for x,y in [0,3328])
710
        // we do not need to do final reduction yet
711
0
        inv = (a1b1 * SYMCRYPT_MLKEM_NegQInvModR) & SYMCRYPT_MLKEM_Rmask;
712
0
        a1b1 = (a1b1 + (inv * SYMCRYPT_MLKEM_Q)) >> SYMCRYPT_MLKEM_Rlog2; // in range [0, 3494]
713
0
        SYMCRYPT_ASSERT( a1b1 <= 3494 );
714
715
        // now multiply a1b1 by power of zeta
716
0
        a1b1zetapow = a1b1 * zetaTwoTimesBitRevPlus1TimesR[i];
717
718
        // sum pairs of products
719
0
        a0b0 += a1b1zetapow;    // a0*b0 + red(a1*b1)*zetapower in range [0, 3328*3328 + 3494*3312]
720
0
        SYMCRYPT_ASSERT( a0b0 <= (3328*3328) + (3494*3312) );
721
0
        a0b1 += a1b0;           // a0*b1 + a1*b0                in range [0, 2*3328*3328]
722
0
        SYMCRYPT_ASSERT( a0b1 <= 2*3328*3328 );
723
724
        // We sum at most 4 pairs of products into an accumulator in ML-KEM
725
0
        C_ASSERT( SYMCRYPT_MLKEM_MATRIX_MAX_NROWS <= 4 );
726
0
        c0 += a0b0; // in range [0,4*3328*3328 + 4*3494*3312]
727
0
        SYMCRYPT_ASSERT( c0 < (4*3328*3328) + (4*3494*3312) );
728
0
        c1 += a0b1; // in range [0,5*3328*3328 + 3*3494*3312]
729
0
        SYMCRYPT_ASSERT( c1 < (5*3328*3328) + (3*3494*3312) );
730
731
0
        paDst->coeffs[(2*i)  ] = c0;
732
0
        paDst->coeffs[(2*i)+1] = c1;
733
0
    }
734
0
}
735
736
VOID
737
SYMCRYPT_CALL
738
SymCryptMlKemMontgomeryReduceAndAddPolyElementAccumulatorToPolyElement(
739
    _Inout_ PSYMCRYPT_MLKEM_POLYELEMENT_ACCUMULATOR paSrc,
740
    _Inout_ PSYMCRYPT_MLKEM_POLYELEMENT             peDst )
741
0
{
742
0
    UINT32 i;
743
0
    UINT32 a, c, inv;
744
745
0
    for( i=0; i<SYMCRYPT_MLWE_POLYNOMIAL_COEFFICIENTS; i++ )
746
0
    {
747
0
        a = paSrc->coeffs[i];
748
0
        SYMCRYPT_ASSERT( a <= 4*((3328*3328) + (3494*3312)) );
749
0
        paSrc->coeffs[i] = 0;
750
751
0
        c = peDst->coeffs[i];
752
0
        SYMCRYPT_ASSERT( c < SYMCRYPT_MLKEM_Q );
753
754
        // montgomery reduce sum of products
755
0
        inv = (a * SYMCRYPT_MLKEM_NegQInvModR) & SYMCRYPT_MLKEM_Rmask;
756
0
        a = (a + (inv * SYMCRYPT_MLKEM_Q)) >> SYMCRYPT_MLKEM_Rlog2; // in range [0, 4711]
757
0
        SYMCRYPT_ASSERT( a <= 4711 );
758
759
        // add destination
760
0
        c += a;
761
0
        SYMCRYPT_ASSERT( c <= 8039 );
762
763
        // subtraction and conditional additions for constant time range reduction
764
0
        c -= 2*SYMCRYPT_MLKEM_Q;           // in range [-2Q, 1381]
765
0
        SYMCRYPT_ASSERT( (c >= ((UINT32)(-2*SYMCRYPT_MLKEM_Q))) || (c < 1381) );
766
0
        c += SYMCRYPT_MLKEM_Q & (c >> 16); // in range [-Q, Q-1]
767
0
        SYMCRYPT_ASSERT( (c >= ((UINT32)-SYMCRYPT_MLKEM_Q)) || (c < SYMCRYPT_MLKEM_Q) );
768
0
        c += SYMCRYPT_MLKEM_Q & (c >> 16); // in range [0, Q-1]
769
0
        SYMCRYPT_ASSERT( c < SYMCRYPT_MLKEM_Q );
770
771
0
        peDst->coeffs[i] = (UINT16) c;
772
0
    }
773
0
}
774
775
VOID
776
SYMCRYPT_CALL
777
SymCryptMlKemPolyElementMulR(
778
    _In_    PCSYMCRYPT_MLKEM_POLYELEMENT    peSrc,
779
    _Out_   PSYMCRYPT_MLKEM_POLYELEMENT     peDst )
780
0
{
781
0
    UINT32 i;
782
0
    for( i=0; i<SYMCRYPT_MLWE_POLYNOMIAL_COEFFICIENTS; i++ )
783
0
    {
784
0
        peDst->coeffs[i] = (UINT16) SymCryptMlKemMontMul(
785
0
            peSrc->coeffs[i], SYMCRYPT_MLKEM_Rsqr, SYMCRYPT_MLKEM_RsqrTimesNegQInvModR );
786
0
    }
787
0
}
788
789
VOID
790
SYMCRYPT_CALL
791
SymCryptMlKemPolyElementAdd(
792
    _In_    PCSYMCRYPT_MLKEM_POLYELEMENT peSrc1,
793
    _In_    PCSYMCRYPT_MLKEM_POLYELEMENT peSrc2,
794
    _Out_   PSYMCRYPT_MLKEM_POLYELEMENT  peDst )
795
0
{
796
0
    UINT32 i;
797
0
    for( i=0; i<SYMCRYPT_MLWE_POLYNOMIAL_COEFFICIENTS; i++ )
798
0
    {
799
0
        peDst->coeffs[i] = (UINT16) SymCryptMlKemModAdd( peSrc1->coeffs[i], peSrc2->coeffs[i] );
800
0
    }
801
0
}
802
803
VOID
804
SYMCRYPT_CALL
805
SymCryptMlKemPolyElementSub(
806
    _In_    PCSYMCRYPT_MLKEM_POLYELEMENT peSrc1,
807
    _In_    PCSYMCRYPT_MLKEM_POLYELEMENT peSrc2,
808
    _Out_   PSYMCRYPT_MLKEM_POLYELEMENT  peDst )
809
0
{
810
0
    UINT32 i;
811
0
    for( i=0; i<SYMCRYPT_MLWE_POLYNOMIAL_COEFFICIENTS; i++ )
812
0
    {
813
0
        peDst->coeffs[i] = (UINT16) SymCryptMlKemModSub( peSrc1->coeffs[i], peSrc2->coeffs[i] );
814
0
    }
815
0
}
816
817
VOID
818
SYMCRYPT_CALL
819
SymCryptMlKemPolyElementNTT(
820
    _Inout_ PSYMCRYPT_MLKEM_POLYELEMENT  peSrc )
821
0
{
822
0
    SymCryptMlKemPolyElementNTTLayer( peSrc,  1, 128 );
823
0
    SymCryptMlKemPolyElementNTTLayer( peSrc,  2,  64 );
824
0
    SymCryptMlKemPolyElementNTTLayer( peSrc,  4,  32 );
825
0
    SymCryptMlKemPolyElementNTTLayer( peSrc,  8,  16 );
826
0
    SymCryptMlKemPolyElementNTTLayer( peSrc, 16,   8 );
827
0
    SymCryptMlKemPolyElementNTTLayer( peSrc, 32,   4 );
828
0
    SymCryptMlKemPolyElementNTTLayer( peSrc, 64,   2 );
829
0
}
830
831
// INTTFixupTimesRsqr = R^2 * 3303 = (3303<<32) mod Q
832
// 3303 constant is fixup from FIPS 203
833
// Multiplied by R^2 to additionally multiply coefficients by R after montgomery reduction
834
const UINT32 SYMCRYPT_MLKEM_INTTFixupTimesRsqr = 1441;
835
const UINT32 SYMCRYPT_MLKEM_INTTFixupTimesRsqrTimesNegQInvModR = 10079;
836
837
VOID
838
SYMCRYPT_CALL
839
SymCryptMlKemPolyElementINTTAndMulR(
840
    _Inout_ PSYMCRYPT_MLKEM_POLYELEMENT  peSrc )
841
0
{
842
0
    UINT32 i;
843
844
0
    SymCryptMlKemPolyElementINTTLayer( peSrc, 127,   2 );
845
0
    SymCryptMlKemPolyElementINTTLayer( peSrc,  63,   4 );
846
0
    SymCryptMlKemPolyElementINTTLayer( peSrc,  31,   8 );
847
0
    SymCryptMlKemPolyElementINTTLayer( peSrc,  15,  16 );
848
0
    SymCryptMlKemPolyElementINTTLayer( peSrc,   7,  32 );
849
0
    SymCryptMlKemPolyElementINTTLayer( peSrc,   3,  64 );
850
0
    SymCryptMlKemPolyElementINTTLayer( peSrc,   1, 128 );
851
852
0
    for( i=0; i<SYMCRYPT_MLWE_POLYNOMIAL_COEFFICIENTS; i++)
853
0
    {
854
0
        peSrc->coeffs[i] = (UINT16) SymCryptMlKemMontMul(
855
0
            peSrc->coeffs[i], SYMCRYPT_MLKEM_INTTFixupTimesRsqr, SYMCRYPT_MLKEM_INTTFixupTimesRsqrTimesNegQInvModR );
856
0
    }
857
0
}
858
859
// ((1<<35) / SYMCRYPT_MLKEM_Q)
860
//
861
// 1<<35 is the smallest power of 2 s.t. the constant has sufficient precision to round
862
// all inputs correctly in compression for all nBitsPerCoefficient < 12. A smaller
863
// constant could be used for smaller nBitsPerCoefficient for a small performance gain
864
//
865
const UINT32 SYMCRYPT_MLKEM_COMPRESS_MULCONSTANT = 0x9d7dbb;
866
const UINT32 SYMCRYPT_MLKEM_COMPRESS_SHIFTCONSTANT = 35;
867
868
VOID
869
SYMCRYPT_CALL
870
SymCryptMlKemPolyElementCompressAndEncode(
871
    _In_    PCSYMCRYPT_MLKEM_POLYELEMENT    peSrc,
872
            UINT32                          nBitsPerCoefficient,
873
    _Out_writes_bytes_(nBitsPerCoefficient*(SYMCRYPT_MLWE_POLYNOMIAL_COEFFICIENTS / 8))
874
            PBYTE                           pbDst )
875
0
{
876
0
    UINT32 i;
877
0
    UINT64 multiplication;
878
0
    UINT32 coefficient;
879
0
    UINT32 nBitsInCoefficient;
880
0
    UINT32 bitsToEncode;
881
0
    UINT32 nBitsToEncode;
882
0
    UINT32 cbDstWritten = 0;
883
0
    UINT32 accumulator = 0;
884
0
    UINT32 nBitsInAccumulator = 0;
885
886
0
    SYMCRYPT_ASSERT( nBitsPerCoefficient >  0  );
887
0
    SYMCRYPT_ASSERT( nBitsPerCoefficient <= 12 );
888
889
0
    for( i=0; i<SYMCRYPT_MLWE_POLYNOMIAL_COEFFICIENTS; i++ )
890
0
    {
891
0
        nBitsInCoefficient = nBitsPerCoefficient;
892
0
        coefficient = peSrc->coeffs[i]; // in range [0, Q-1]
893
0
        SYMCRYPT_ASSERT( coefficient < SYMCRYPT_MLKEM_Q );
894
895
        // first compress the coefficient
896
        // when nBitsPerCoefficient < 12 we compress per Compress_d in FIPS 203;
897
0
        if(nBitsPerCoefficient < 12)
898
0
        {
899
            // Multiply by 2^(nBitsPerCoefficient+1) / Q by multiplying by constant and shifting right
900
0
            multiplication = SYMCRYPT_MUL32x32TO64(coefficient, SYMCRYPT_MLKEM_COMPRESS_MULCONSTANT);
901
0
            coefficient = (UINT32) (multiplication >> (SYMCRYPT_MLKEM_COMPRESS_SHIFTCONSTANT-(nBitsPerCoefficient+1)));
902
903
            // add "half" to round to nearest integer
904
0
            coefficient++;
905
906
            // final divide by two to get multiplication by 2^nBitsPerCoefficient / Q
907
0
            coefficient >>= 1;                              // in range [0, 2^nBitsPerCoefficient]
908
0
            SYMCRYPT_ASSERT(coefficient <= (1UL<<nBitsPerCoefficient));
909
910
            // modular reduction by masking
911
0
            coefficient &= (1UL<<nBitsPerCoefficient)-1;    // in range [0, 2^nBitsPerCoefficient - 1]
912
0
            SYMCRYPT_ASSERT(coefficient <  (1UL<<nBitsPerCoefficient));
913
0
        }
914
915
        // encode the coefficient
916
        // simple loop to add bits to accumulator and write accumulator to output
917
0
        do
918
0
        {
919
0
            nBitsToEncode = SYMCRYPT_MIN(nBitsInCoefficient, 32-nBitsInAccumulator);
920
921
0
            bitsToEncode = coefficient & ((1UL<<nBitsToEncode)-1);
922
0
            coefficient >>= nBitsToEncode;
923
0
            nBitsInCoefficient -= nBitsToEncode;
924
925
0
            accumulator |= (bitsToEncode << nBitsInAccumulator);
926
0
            nBitsInAccumulator += nBitsToEncode;
927
0
            if(nBitsInAccumulator == 32)
928
0
            {
929
0
                SYMCRYPT_STORE_LSBFIRST32( pbDst+cbDstWritten, accumulator );
930
0
                cbDstWritten += 4;
931
0
                accumulator = 0;
932
0
                nBitsInAccumulator = 0;
933
0
            }
934
0
        } while( nBitsInCoefficient > 0 );
935
0
    }
936
937
0
    SYMCRYPT_ASSERT(nBitsInAccumulator == 0);
938
0
    SYMCRYPT_ASSERT(cbDstWritten == (nBitsPerCoefficient*(SYMCRYPT_MLWE_POLYNOMIAL_COEFFICIENTS / 8)));
939
0
}
940
941
SYMCRYPT_ERROR
942
SYMCRYPT_CALL
943
SymCryptMlKemPolyElementDecodeAndDecompress(
944
    _In_reads_bytes_(nBitsPerCoefficient*(SYMCRYPT_MLWE_POLYNOMIAL_COEFFICIENTS / 8))
945
            PCBYTE                      pbSrc,
946
            UINT32                      nBitsPerCoefficient,
947
    _Out_   PSYMCRYPT_MLKEM_POLYELEMENT peDst )
948
0
{
949
0
    UINT32 i;
950
0
    UINT32 coefficient;
951
0
    UINT32 nBitsInCoefficient;
952
0
    UINT32 bitsToDecode;
953
0
    UINT32 nBitsToDecode;
954
0
    UINT32 cbSrcRead = 0;
955
0
    UINT32 accumulator = 0;
956
0
    UINT32 nBitsInAccumulator = 0;
957
958
0
    SYMCRYPT_ASSERT( nBitsPerCoefficient >  0  );
959
0
    SYMCRYPT_ASSERT( nBitsPerCoefficient <= 12 );
960
961
0
    for( i=0; i<SYMCRYPT_MLWE_POLYNOMIAL_COEFFICIENTS; i++ )
962
0
    {
963
0
        coefficient = 0;
964
0
        nBitsInCoefficient = 0;
965
966
        // first gather and decode bits from pbSrc
967
0
        do
968
0
        {
969
0
            if(nBitsInAccumulator == 0)
970
0
            {
971
0
                accumulator = SYMCRYPT_LOAD_LSBFIRST32( pbSrc+cbSrcRead );
972
0
                cbSrcRead += 4;
973
0
                nBitsInAccumulator = 32;
974
0
            }
975
976
0
            nBitsToDecode = SYMCRYPT_MIN(nBitsPerCoefficient-nBitsInCoefficient, nBitsInAccumulator);
977
0
            SYMCRYPT_ASSERT(nBitsToDecode <= nBitsInAccumulator);
978
979
0
            bitsToDecode = accumulator & ((1UL<<nBitsToDecode)-1);
980
0
            accumulator >>= nBitsToDecode;
981
0
            nBitsInAccumulator -= nBitsToDecode;
982
983
0
            coefficient |= (bitsToDecode << nBitsInCoefficient);
984
0
            nBitsInCoefficient += nBitsToDecode;
985
0
        } while( nBitsPerCoefficient > nBitsInCoefficient );
986
0
        SYMCRYPT_ASSERT(nBitsInCoefficient == nBitsPerCoefficient);
987
988
        // decompress the coefficient
989
        // when nBitsPerCoefficient < 12 we decompress per Decompress_d in FIPS 203
990
        // otherwise we perform input validation per 203 6.2 Input validation 2 (Modulus check)
991
0
        if(nBitsPerCoefficient < 12)
992
0
        {
993
            // Multiply by Q / 2^(nBitsPerCoefficient-1) by multiplying by constant and shifting right
994
0
            coefficient *= SYMCRYPT_MLKEM_Q;
995
0
            coefficient >>= (nBitsPerCoefficient-1);
996
997
            // add "half" to round to nearest integer
998
0
            coefficient++;
999
1000
            // final divide by two to get multiplication by Q / 2^nBitsPerCoefficient
1001
0
            coefficient >>= 1;  // in range [0, Q]
1002
1003
            // modular reduction by conditional subtraction
1004
0
            coefficient = SymCryptMlKemModSub( coefficient, SYMCRYPT_MLKEM_Q );
1005
0
            SYMCRYPT_ASSERT( coefficient < SYMCRYPT_MLKEM_Q );
1006
0
        }
1007
0
        else if( coefficient > SYMCRYPT_MLKEM_Q )
1008
0
        {
1009
            // input validation failure - this can happen with a malformed or corrupt encapsulation
1010
            // or decapsulation key, but this validation failure only triggers on public data; we
1011
            // do not need to be constant time
1012
0
            return SYMCRYPT_INVALID_BLOB;
1013
0
        }
1014
1015
0
        peDst->coeffs[i] = (UINT16) coefficient;
1016
0
    }
1017
1018
0
    SYMCRYPT_ASSERT(nBitsInAccumulator == 0);
1019
0
    SYMCRYPT_ASSERT(cbSrcRead == (nBitsPerCoefficient*(SYMCRYPT_MLWE_POLYNOMIAL_COEFFICIENTS / 8)));
1020
1021
0
    return SYMCRYPT_NO_ERROR;
1022
0
}
1023
1024
VOID
1025
SYMCRYPT_CALL
1026
SymCryptMlKemPolyElementSampleNTTFromShake128(
1027
    _Inout_ PSYMCRYPT_SHAKE128_STATE    pState,
1028
    _Out_   PSYMCRYPT_MLKEM_POLYELEMENT  peDst )
1029
0
{
1030
0
    UINT32 i=0;
1031
0
    BYTE shakeOutputBuf[3*8]; // Keccak likes extracting multiples of 8-bytes
1032
0
    UINT32 currBufIndex = sizeof(shakeOutputBuf);
1033
0
    UINT16 sample0, sample1;
1034
1035
0
    while( i<SYMCRYPT_MLWE_POLYNOMIAL_COEFFICIENTS )
1036
0
    {
1037
0
        SYMCRYPT_ASSERT(currBufIndex <= sizeof(shakeOutputBuf));
1038
0
        if( currBufIndex == sizeof(shakeOutputBuf) )
1039
0
        {
1040
0
            SymCryptShake128Extract(pState, shakeOutputBuf, sizeof(shakeOutputBuf), FALSE);
1041
0
            currBufIndex = 0;
1042
0
        }
1043
1044
0
        sample0 = SYMCRYPT_LOAD_LSBFIRST16( shakeOutputBuf+currBufIndex ) & 0xfff;
1045
0
        sample1 = SYMCRYPT_LOAD_LSBFIRST16( shakeOutputBuf+currBufIndex+1 ) >> 4;
1046
0
        currBufIndex += 3;
1047
1048
0
        peDst->coeffs[i] = sample0;
1049
0
        i += sample0 < SYMCRYPT_MLKEM_Q;
1050
1051
0
        if( i<SYMCRYPT_MLWE_POLYNOMIAL_COEFFICIENTS )
1052
0
        {
1053
0
            peDst->coeffs[i] = sample1;
1054
0
            i += sample1 < SYMCRYPT_MLKEM_Q;
1055
0
        }
1056
0
    }
1057
0
}
1058
1059
VOID
1060
SYMCRYPT_CALL
1061
SymCryptMlKemPolyElementSampleCBDFromBytes(
1062
    _In_reads_bytes_(eta*2*(SYMCRYPT_MLWE_POLYNOMIAL_COEFFICIENTS / 8) + 1)
1063
                    PCBYTE                      pbSrc,
1064
    _In_range_(2,3) UINT32                      eta,
1065
    _Out_           PSYMCRYPT_MLKEM_POLYELEMENT peDst )
1066
0
{
1067
0
    UINT32 i, j;
1068
0
    UINT32 sampleBits;
1069
0
    UINT32 coefficient;
1070
1071
0
    SYMCRYPT_ASSERT((eta == 2) || (eta == 3));
1072
0
    if( eta == 3 )
1073
0
    {
1074
0
        for( i=0; i<SYMCRYPT_MLWE_POLYNOMIAL_COEFFICIENTS; i+=4 )
1075
0
        {
1076
            // unconditionally load 4 bytes into sampleBits, but only treat the load
1077
            // as being 3 bytes (24-bits -> 4 coefficients) for eta==3 to align to
1078
            // byte boundaries. Source buffer must be 1 byte larger than shake output
1079
0
            sampleBits = SYMCRYPT_LOAD_LSBFIRST32( pbSrc );
1080
0
            pbSrc += 3;
1081
1082
            // sum bit samples - each consecutive slice of eta bits is summed together
1083
0
            sampleBits = (sampleBits&0x249249) + ((sampleBits>>1)&0x249249) + ((sampleBits>>2)&0x249249);
1084
1085
0
            for( j=0; j<4; j++ )
1086
0
            {
1087
                // each coefficient is formed by taking the difference of two consecutive slices of eta bits
1088
                // the first eta bits are positive, the second eta bits are negative
1089
0
                coefficient = sampleBits & 0x3f;
1090
0
                sampleBits >>= 6;
1091
0
                coefficient = (coefficient&3) - (coefficient>>3);
1092
0
                SYMCRYPT_ASSERT((coefficient >= ((UINT32)-3)) || (coefficient <= 3));
1093
1094
0
                coefficient = coefficient + (SYMCRYPT_MLKEM_Q & (coefficient >> 16));     // in range [0, Q-1]
1095
0
                SYMCRYPT_ASSERT( coefficient < SYMCRYPT_MLKEM_Q );
1096
1097
0
                peDst->coeffs[i+j] = (UINT16) coefficient;
1098
0
            }
1099
0
        }
1100
0
    }
1101
0
    else
1102
0
    {
1103
0
        for( i=0; i<SYMCRYPT_MLWE_POLYNOMIAL_COEFFICIENTS; i+=8 )
1104
0
        {
1105
            // unconditionally load 4 bytes (32-bits -> 8 coefficients) into sampleBits
1106
0
            sampleBits = SYMCRYPT_LOAD_LSBFIRST32( pbSrc );
1107
0
            pbSrc += 4;
1108
1109
            // sum bit samples - each consecutive slice of eta bits is summed together
1110
0
            sampleBits = (sampleBits&0x55555555) + ((sampleBits>>1)&0x55555555);
1111
1112
0
            for( j=0; j<8; j++ )
1113
0
            {
1114
                // each coefficient is formed by taking the difference of two consecutive slices of eta bits
1115
                // the first eta bits are positive, the second eta bits are negative
1116
0
                coefficient = sampleBits & 0xf;
1117
0
                sampleBits >>= 4;
1118
0
                coefficient = (coefficient&3) - (coefficient>>2);
1119
0
                SYMCRYPT_ASSERT((coefficient >= ((UINT32)-2)) || (coefficient <= 2));
1120
1121
0
                coefficient = coefficient + (SYMCRYPT_MLKEM_Q & (coefficient >> 16));     // in range [0, Q-1]
1122
0
                SYMCRYPT_ASSERT( coefficient < SYMCRYPT_MLKEM_Q );
1123
1124
0
                peDst->coeffs[i+j] = (UINT16) coefficient;
1125
0
            }
1126
0
        }
1127
0
    }
1128
0
}
1129
1130
VOID
1131
SYMCRYPT_CALL
1132
SymCryptMlKemMatrixTranspose(
1133
    _Inout_ PSYMCRYPT_MLKEM_MATRIX  pmSrc )
1134
0
{
1135
0
    UINT32 i, j;
1136
0
    PSYMCRYPT_MLKEM_POLYELEMENT swap;
1137
0
    const UINT32 nRows = pmSrc->nRows;
1138
1139
0
    SYMCRYPT_ASSERT( nRows >  0 );
1140
0
    SYMCRYPT_ASSERT( nRows <= SYMCRYPT_MLKEM_MATRIX_MAX_NROWS );
1141
1142
0
    for( i=0; i<nRows; i++ )
1143
0
    {
1144
0
        for( j=i+1; j<nRows; j++ )
1145
0
        {
1146
0
            swap = pmSrc->apPolyElements[(i*nRows) + j];
1147
0
            pmSrc->apPolyElements[(i*nRows) + j] = pmSrc->apPolyElements[(j*nRows) + i];
1148
0
            pmSrc->apPolyElements[(j*nRows) + i] = swap;
1149
0
        }
1150
0
    }
1151
0
}
1152
1153
VOID
1154
SYMCRYPT_CALL
1155
SymCryptMlKemMatrixVectorMontMulAndAdd(
1156
    _In_    PCSYMCRYPT_MLKEM_MATRIX                 pmSrc1,
1157
    _In_    PCSYMCRYPT_MLKEM_VECTOR                 pvSrc2,
1158
    _Inout_ PSYMCRYPT_MLKEM_VECTOR                  pvDst,
1159
    _Inout_ PSYMCRYPT_MLKEM_POLYELEMENT_ACCUMULATOR paTmp )
1160
0
{
1161
0
    UINT32 i, j;
1162
0
    const UINT32 nRows = pmSrc1->nRows;
1163
0
    PCSYMCRYPT_MLKEM_POLYELEMENT peSrc1, peSrc2;
1164
0
    PSYMCRYPT_MLKEM_POLYELEMENT  peDst;
1165
1166
0
    SYMCRYPT_ASSERT( nRows >  0 );
1167
0
    SYMCRYPT_ASSERT( nRows <= SYMCRYPT_MLKEM_MATRIX_MAX_NROWS );
1168
0
    SYMCRYPT_ASSERT( pvSrc2->nRows == nRows );
1169
0
    SYMCRYPT_ASSERT( pvDst->nRows == nRows );
1170
1171
    // Zero paTmp
1172
0
    SymCryptWipeKnownSize( paTmp, SYMCRYPT_INTERNAL_MLKEM_SIZEOF_POLYRINGELEMENT_ACCUMULATOR );
1173
1174
0
    for( i=0; i<nRows; i++ )
1175
0
    {
1176
0
        for( j=0; j<nRows; j++ )
1177
0
        {
1178
0
            peSrc1 = pmSrc1->apPolyElements[(i*nRows) + j];
1179
0
            peSrc2 = SYMCRYPT_INTERNAL_MLKEM_VECTOR_ELEMENT( j, pvSrc2 );
1180
0
            SymCryptMlKemPolyElementMulAndAccumulate( peSrc1, peSrc2, paTmp );
1181
0
        }
1182
1183
        // write accumulator to dest and zero accumulator
1184
0
        peDst  = SYMCRYPT_INTERNAL_MLKEM_VECTOR_ELEMENT( i, pvDst );
1185
0
        SymCryptMlKemMontgomeryReduceAndAddPolyElementAccumulatorToPolyElement( paTmp, peDst );
1186
0
    }
1187
0
}
1188
1189
VOID
1190
SYMCRYPT_CALL
1191
SymCryptMlKemVectorMontDotProduct(
1192
    _In_    PCSYMCRYPT_MLKEM_VECTOR                 pvSrc1,
1193
    _In_    PCSYMCRYPT_MLKEM_VECTOR                 pvSrc2,
1194
    _Inout_ PSYMCRYPT_MLKEM_POLYELEMENT             peDst,
1195
    _Inout_ PSYMCRYPT_MLKEM_POLYELEMENT_ACCUMULATOR paTmp )
1196
0
{
1197
0
    UINT32 i;
1198
0
    const UINT32 nRows = pvSrc1->nRows;
1199
0
    PCSYMCRYPT_MLKEM_POLYELEMENT peSrc1, peSrc2;
1200
1201
0
    SYMCRYPT_ASSERT( nRows >  0 );
1202
0
    SYMCRYPT_ASSERT( nRows <= SYMCRYPT_MLKEM_MATRIX_MAX_NROWS );
1203
0
    SYMCRYPT_ASSERT( pvSrc2->nRows == nRows );
1204
1205
    // Zero paTmp and peDst
1206
0
    SymCryptWipeKnownSize( paTmp, SYMCRYPT_INTERNAL_MLKEM_SIZEOF_POLYRINGELEMENT_ACCUMULATOR );
1207
0
    SymCryptWipeKnownSize( peDst, SYMCRYPT_INTERNAL_MLKEM_SIZEOF_POLYRINGELEMENT );
1208
1209
0
    for( i=0; i<nRows; i++ )
1210
0
    {
1211
0
        peSrc1 = SYMCRYPT_INTERNAL_MLKEM_VECTOR_ELEMENT( i, pvSrc1 );
1212
0
        peSrc2 = SYMCRYPT_INTERNAL_MLKEM_VECTOR_ELEMENT( i, pvSrc2 );
1213
0
        SymCryptMlKemPolyElementMulAndAccumulate( peSrc1, peSrc2, paTmp );
1214
0
    }
1215
1216
    // write accumulator to dest and zero accumulator
1217
0
    SymCryptMlKemMontgomeryReduceAndAddPolyElementAccumulatorToPolyElement( paTmp, peDst );
1218
0
}
1219
1220
VOID
1221
SYMCRYPT_CALL
1222
SymCryptMlKemVectorSetZero(
1223
    _Inout_ PSYMCRYPT_MLKEM_VECTOR  pvSrc )
1224
0
{
1225
0
    const UINT32 nRows = pvSrc->nRows;
1226
1227
0
    SYMCRYPT_ASSERT( nRows >  0 );
1228
0
    SYMCRYPT_ASSERT( nRows <= SYMCRYPT_MLKEM_MATRIX_MAX_NROWS );
1229
1230
0
    SymCryptWipe( (PBYTE) SYMCRYPT_INTERNAL_MLKEM_VECTOR_ELEMENT( 0, pvSrc ), nRows*SYMCRYPT_INTERNAL_MLKEM_SIZEOF_POLYRINGELEMENT );
1231
0
}
1232
1233
VOID
1234
SYMCRYPT_CALL
1235
SymCryptMlKemVectorMulR(
1236
    _In_    PCSYMCRYPT_MLKEM_VECTOR pvSrc,
1237
    _Out_   PSYMCRYPT_MLKEM_VECTOR  pvDst )
1238
0
{
1239
0
    UINT32 i;
1240
0
    const UINT32 nRows = pvSrc->nRows;
1241
1242
0
    SYMCRYPT_ASSERT( nRows >  0 );
1243
0
    SYMCRYPT_ASSERT( nRows <= SYMCRYPT_MLKEM_MATRIX_MAX_NROWS );
1244
0
    SYMCRYPT_ASSERT( pvDst->nRows == nRows );
1245
1246
0
    for( i=0; i<nRows; i++ )
1247
0
    {
1248
0
        SymCryptMlKemPolyElementMulR(
1249
0
            SYMCRYPT_INTERNAL_MLKEM_VECTOR_ELEMENT( i, pvSrc ),
1250
0
            SYMCRYPT_INTERNAL_MLKEM_VECTOR_ELEMENT( i, pvDst ) );
1251
0
    }
1252
0
}
1253
1254
VOID
1255
SYMCRYPT_CALL
1256
SymCryptMlKemVectorAdd(
1257
    _In_    PCSYMCRYPT_MLKEM_VECTOR pvSrc1,
1258
    _In_    PCSYMCRYPT_MLKEM_VECTOR pvSrc2,
1259
    _Out_   PSYMCRYPT_MLKEM_VECTOR  pvDst )
1260
0
{
1261
0
    UINT32 i;
1262
0
    const UINT32 nRows = pvSrc1->nRows;
1263
0
    PCSYMCRYPT_MLKEM_POLYELEMENT peSrc1, peSrc2;
1264
0
    PSYMCRYPT_MLKEM_POLYELEMENT  peDst;
1265
1266
0
    SYMCRYPT_ASSERT( nRows >  0 );
1267
0
    SYMCRYPT_ASSERT( nRows <= SYMCRYPT_MLKEM_MATRIX_MAX_NROWS );
1268
0
    SYMCRYPT_ASSERT( pvSrc2->nRows == nRows );
1269
0
    SYMCRYPT_ASSERT( pvDst->nRows == nRows );
1270
1271
0
    for( i=0; i<nRows; i++ )
1272
0
    {
1273
0
        peSrc1 = SYMCRYPT_INTERNAL_MLKEM_VECTOR_ELEMENT( i, pvSrc1 );
1274
0
        peSrc2 = SYMCRYPT_INTERNAL_MLKEM_VECTOR_ELEMENT( i, pvSrc2 );
1275
0
        peDst  = SYMCRYPT_INTERNAL_MLKEM_VECTOR_ELEMENT( i, pvDst );
1276
0
        SymCryptMlKemPolyElementAdd( peSrc1, peSrc2, peDst );
1277
0
    }
1278
0
}
1279
1280
VOID
1281
SYMCRYPT_CALL
1282
SymCryptMlKemVectorSub(
1283
    _In_    PCSYMCRYPT_MLKEM_VECTOR pvSrc1,
1284
    _In_    PCSYMCRYPT_MLKEM_VECTOR pvSrc2,
1285
    _Out_   PSYMCRYPT_MLKEM_VECTOR  pvDst )
1286
0
{
1287
0
    UINT32 i;
1288
0
    const UINT32 nRows = pvSrc1->nRows;
1289
0
    PCSYMCRYPT_MLKEM_POLYELEMENT peSrc1, peSrc2;
1290
0
    PSYMCRYPT_MLKEM_POLYELEMENT  peDst;
1291
1292
0
    SYMCRYPT_ASSERT( nRows >  0 );
1293
0
    SYMCRYPT_ASSERT( nRows <= SYMCRYPT_MLKEM_MATRIX_MAX_NROWS );
1294
0
    SYMCRYPT_ASSERT( pvSrc2->nRows == nRows );
1295
0
    SYMCRYPT_ASSERT( pvDst->nRows == nRows );
1296
1297
0
    for( i=0; i<nRows; i++ )
1298
0
    {
1299
0
        peSrc1 = SYMCRYPT_INTERNAL_MLKEM_VECTOR_ELEMENT( i, pvSrc1 );
1300
0
        peSrc2 = SYMCRYPT_INTERNAL_MLKEM_VECTOR_ELEMENT( i, pvSrc2 );
1301
0
        peDst  = SYMCRYPT_INTERNAL_MLKEM_VECTOR_ELEMENT( i, pvDst );
1302
0
        SymCryptMlKemPolyElementSub( peSrc1, peSrc2, peDst );
1303
0
    }
1304
0
}
1305
1306
VOID
1307
SYMCRYPT_CALL
1308
SymCryptMlKemVectorNTT(
1309
    _Inout_ PSYMCRYPT_MLKEM_VECTOR  pvSrc )
1310
0
{
1311
0
    UINT32 i;
1312
0
    const UINT32 nRows = pvSrc->nRows;
1313
1314
0
    SYMCRYPT_ASSERT( nRows >  0 );
1315
0
    SYMCRYPT_ASSERT( nRows <= SYMCRYPT_MLKEM_MATRIX_MAX_NROWS );
1316
1317
0
    for( i=0; i<nRows; i++ )
1318
0
    {
1319
0
        SymCryptMlKemPolyElementNTT( SYMCRYPT_INTERNAL_MLKEM_VECTOR_ELEMENT( i, pvSrc ) );
1320
0
    }
1321
0
}
1322
1323
VOID
1324
SYMCRYPT_CALL
1325
SymCryptMlKemVectorINTTAndMulR(
1326
    _Inout_ PSYMCRYPT_MLKEM_VECTOR  pvSrc )
1327
0
{
1328
0
    UINT32 i;
1329
0
    const UINT32 nRows = pvSrc->nRows;
1330
1331
0
    SYMCRYPT_ASSERT( nRows >  0 );
1332
0
    SYMCRYPT_ASSERT( nRows <= SYMCRYPT_MLKEM_MATRIX_MAX_NROWS );
1333
1334
0
    for( i=0; i<nRows; i++ )
1335
0
    {
1336
0
        SymCryptMlKemPolyElementINTTAndMulR( SYMCRYPT_INTERNAL_MLKEM_VECTOR_ELEMENT( i, pvSrc ) );
1337
0
    }
1338
0
}
1339
1340
VOID
1341
SYMCRYPT_CALL
1342
SymCryptMlKemVectorCompressAndEncode(
1343
    _In_                        PCSYMCRYPT_MLKEM_VECTOR pvSrc,
1344
                                UINT32                  nBitsPerCoefficient,
1345
    _Out_writes_bytes_(cbDst)   PBYTE                   pbDst,
1346
                                SIZE_T                  cbDst )
1347
0
{
1348
0
    UINT32 i;
1349
0
    const UINT32 nRows = pvSrc->nRows;
1350
0
    PCSYMCRYPT_MLKEM_POLYELEMENT peSrc;
1351
1352
0
    SYMCRYPT_ASSERT( nRows >  0 );
1353
0
    SYMCRYPT_ASSERT( nRows <= SYMCRYPT_MLKEM_MATRIX_MAX_NROWS );
1354
0
    SYMCRYPT_ASSERT( nBitsPerCoefficient >  0  );
1355
0
    SYMCRYPT_ASSERT( nBitsPerCoefficient <= 12 );
1356
0
    SYMCRYPT_ASSERT( cbDst == nRows*nBitsPerCoefficient*(SYMCRYPT_MLWE_POLYNOMIAL_COEFFICIENTS / 8) );
1357
1358
0
    UNREFERENCED_PARAMETER( cbDst );
1359
1360
0
    for( i=0; i<nRows; i++ )
1361
0
    {
1362
0
        peSrc  = SYMCRYPT_INTERNAL_MLKEM_VECTOR_ELEMENT( i, pvSrc );
1363
0
        SymCryptMlKemPolyElementCompressAndEncode( peSrc, nBitsPerCoefficient, pbDst );
1364
0
        pbDst += nBitsPerCoefficient*(SYMCRYPT_MLWE_POLYNOMIAL_COEFFICIENTS / 8);
1365
0
    }
1366
0
}
1367
1368
SYMCRYPT_ERROR
1369
SYMCRYPT_CALL
1370
SymCryptMlKemVectorDecodeAndDecompress(
1371
    _In_reads_bytes_(cbSrc) PCBYTE                  pbSrc,
1372
                            SIZE_T                  cbSrc,
1373
                            UINT32                  nBitsPerCoefficient,
1374
    _Out_                   PSYMCRYPT_MLKEM_VECTOR  pvDst )
1375
0
{
1376
0
    SYMCRYPT_ERROR scError = SYMCRYPT_NO_ERROR;
1377
0
    UINT32 i;
1378
0
    const UINT32 nRows = pvDst->nRows;
1379
0
    PSYMCRYPT_MLKEM_POLYELEMENT peDst;
1380
1381
0
    SYMCRYPT_ASSERT( nRows >  0 );
1382
0
    SYMCRYPT_ASSERT( nRows <= SYMCRYPT_MLKEM_MATRIX_MAX_NROWS );
1383
0
    SYMCRYPT_ASSERT( nBitsPerCoefficient >  0  );
1384
0
    SYMCRYPT_ASSERT( nBitsPerCoefficient <= 12 );
1385
0
    SYMCRYPT_ASSERT( cbSrc == nRows*nBitsPerCoefficient*(SYMCRYPT_MLWE_POLYNOMIAL_COEFFICIENTS / 8) );
1386
1387
0
    UNREFERENCED_PARAMETER( cbSrc );
1388
1389
0
    for( i=0; i<nRows; i++ )
1390
0
    {
1391
0
        peDst  = SYMCRYPT_INTERNAL_MLKEM_VECTOR_ELEMENT( i, pvDst );
1392
0
        scError = SymCryptMlKemPolyElementDecodeAndDecompress( pbSrc, nBitsPerCoefficient, peDst );
1393
0
        if( scError != SYMCRYPT_NO_ERROR )
1394
0
        {
1395
0
            goto cleanup;
1396
0
        }
1397
0
        pbSrc += nBitsPerCoefficient*(SYMCRYPT_MLWE_POLYNOMIAL_COEFFICIENTS / 8);
1398
0
    }
1399
1400
0
cleanup:
1401
0
    return scError;
1402
0
}