Coverage Report

Created: 2024-11-21 07:03

/src/SymCrypt/lib/mlkem.c
Line
Count
Source (jump to first uncovered line)
1
//
2
// mlkem.c   ML-KEM related functionality
3
//
4
// Copyright (c) Microsoft Corporation. Licensed under the MIT license.
5
//
6
7
#include "precomp.h"
8
9
const SYMCRYPT_MLKEM_INTERNAL_PARAMS SymCryptMlKemInternalParamsMlKem512 =
10
{
11
    .params         = SYMCRYPT_MLKEM_PARAMS_MLKEM512,
12
    .cbPolyElement  = SYMCRYPT_INTERNAL_MLKEM_SIZEOF_POLYRINGELEMENT,
13
    .nRows          = 2,
14
    .cbVector       = sizeof(SYMCRYPT_MLKEM_VECTOR) + (2*SYMCRYPT_INTERNAL_MLKEM_SIZEOF_POLYRINGELEMENT),
15
    .cbMatrix       = sizeof(SYMCRYPT_MLKEM_MATRIX) + (2*2*SYMCRYPT_INTERNAL_MLKEM_SIZEOF_POLYRINGELEMENT),
16
    .nEta1          = 3,
17
    .nEta2          = 2,
18
    .nBitsOfU       = 10,
19
    .nBitsOfV       = 4,
20
};
21
22
const SYMCRYPT_MLKEM_INTERNAL_PARAMS SymCryptMlKemInternalParamsMlKem768 =
23
{
24
    .params         = SYMCRYPT_MLKEM_PARAMS_MLKEM768,
25
    .cbPolyElement  = SYMCRYPT_INTERNAL_MLKEM_SIZEOF_POLYRINGELEMENT,
26
    .nRows          = 3,
27
    .cbVector       = sizeof(SYMCRYPT_MLKEM_VECTOR) + (3*SYMCRYPT_INTERNAL_MLKEM_SIZEOF_POLYRINGELEMENT),
28
    .cbMatrix       = sizeof(SYMCRYPT_MLKEM_MATRIX) + (3*3*SYMCRYPT_INTERNAL_MLKEM_SIZEOF_POLYRINGELEMENT),
29
    .nEta1          = 2,
30
    .nEta2          = 2,
31
    .nBitsOfU       = 10,
32
    .nBitsOfV       = 4,
33
};
34
35
const SYMCRYPT_MLKEM_INTERNAL_PARAMS SymCryptMlKemInternalParamsMlKem1024 =
36
{
37
    .params         = SYMCRYPT_MLKEM_PARAMS_MLKEM1024,
38
    .cbPolyElement  = SYMCRYPT_INTERNAL_MLKEM_SIZEOF_POLYRINGELEMENT,
39
    .nRows          = 4,
40
    .cbVector       = sizeof(SYMCRYPT_MLKEM_VECTOR) + (4*SYMCRYPT_INTERNAL_MLKEM_SIZEOF_POLYRINGELEMENT),
41
    .cbMatrix       = sizeof(SYMCRYPT_MLKEM_MATRIX) + (4*4*SYMCRYPT_INTERNAL_MLKEM_SIZEOF_POLYRINGELEMENT),
42
    .nEta1          = 2,
43
    .nEta2          = 2,
44
    .nBitsOfU       = 11,
45
    .nBitsOfV       = 5,
46
};
47
48
static
49
SYMCRYPT_ERROR
50
SYMCRYPT_CALL
51
SymCryptMlKemkeyGetInternalParamsFromParams(
52
            SYMCRYPT_MLKEM_PARAMS           params,
53
    _Out_   PSYMCRYPT_MLKEM_INTERNAL_PARAMS pInternalParams )
54
0
{
55
0
    SYMCRYPT_ERROR scError = SYMCRYPT_NO_ERROR;
56
57
0
    switch( params )
58
0
    {
59
0
        case SYMCRYPT_MLKEM_PARAMS_MLKEM512:
60
0
            *pInternalParams = SymCryptMlKemInternalParamsMlKem512;
61
0
            break;
62
0
        case SYMCRYPT_MLKEM_PARAMS_MLKEM768:
63
0
            *pInternalParams = SymCryptMlKemInternalParamsMlKem768;
64
0
            break;
65
0
        case SYMCRYPT_MLKEM_PARAMS_MLKEM1024:
66
0
            *pInternalParams = SymCryptMlKemInternalParamsMlKem1024;
67
0
            break;
68
0
        default:
69
0
            scError = SYMCRYPT_INVALID_ARGUMENT;
70
0
            goto cleanup;
71
0
    }
72
73
0
cleanup:
74
0
    return scError;
75
0
}
76
77
static
78
PSYMCRYPT_MLKEMKEY
79
SYMCRYPT_CALL
80
SymCryptMlKemkeyInitialize(
81
    _In_                        PCSYMCRYPT_MLKEM_INTERNAL_PARAMS    pInternalParams,
82
    _Out_writes_bytes_(cbKey)   PBYTE                               pbKey,
83
                                UINT32                              cbKey )
84
0
{
85
0
    PSYMCRYPT_MLKEMKEY pRes = NULL;
86
0
    PSYMCRYPT_MLKEMKEY pKey = (PSYMCRYPT_MLKEMKEY)pbKey;
87
0
    PBYTE pbCurr = pbKey + sizeof(SYMCRYPT_MLKEMKEY);
88
89
0
    SymCryptWipeKnownSize( pbKey, cbKey );
90
91
0
    pKey->fAlgorithmInfo = 0;
92
0
    pKey->params = *pInternalParams;
93
0
    pKey->cbTotalSize = cbKey;
94
0
    pKey->hasPrivateSeed = FALSE;
95
0
    pKey->hasPrivateKey = FALSE;
96
97
0
    pKey->pmAtranspose = SymCryptMlKemMatrixCreate( pbCurr, pInternalParams->cbMatrix, pInternalParams->nRows );
98
0
    if( pKey->pmAtranspose == NULL )
99
0
    {
100
0
        goto cleanup;
101
0
    }
102
0
    pbCurr += pInternalParams->cbMatrix;
103
104
0
    pKey->pvt = SymCryptMlKemVectorCreate( pbCurr, pInternalParams->cbVector, pInternalParams->nRows );
105
0
    if( pKey->pvt == NULL )
106
0
    {
107
0
        goto cleanup;
108
0
    }
109
0
    pbCurr += pInternalParams->cbVector;
110
111
0
    pKey->pvs = SymCryptMlKemVectorCreate( pbCurr, pInternalParams->cbVector, pInternalParams->nRows );
112
0
    if( pKey->pvs == NULL )
113
0
    {
114
0
        goto cleanup;
115
0
    }
116
0
    pbCurr += pInternalParams->cbVector;
117
118
0
    SYMCRYPT_ASSERT( pbCurr == (pbKey + cbKey) );
119
120
0
    SYMCRYPT_SET_MAGIC( pKey );
121
122
0
    pRes = pKey;
123
124
0
cleanup:
125
0
    return pRes;
126
0
}
127
128
PSYMCRYPT_MLKEMKEY
129
SYMCRYPT_CALL
130
SymCryptMlKemkeyAllocate(
131
    SYMCRYPT_MLKEM_PARAMS   params )
132
0
{
133
0
    SYMCRYPT_ERROR scError = SYMCRYPT_NO_ERROR;
134
0
    PBYTE  pbKey = NULL;
135
0
    UINT32 cbKey;
136
0
    SYMCRYPT_MLKEM_INTERNAL_PARAMS internalParams;
137
138
0
    PSYMCRYPT_MLKEMKEY pKey = NULL;
139
140
0
    scError = SymCryptMlKemkeyGetInternalParamsFromParams(params, &internalParams);
141
0
    if( scError != SYMCRYPT_NO_ERROR )
142
0
    {
143
0
        goto cleanup;
144
0
    }
145
146
0
    cbKey = sizeof(SYMCRYPT_MLKEMKEY) + internalParams.cbMatrix + (2*internalParams.cbVector);
147
148
0
    pbKey = SymCryptCallbackAlloc( cbKey );
149
0
    if ( pbKey == NULL )
150
0
    {
151
0
        goto cleanup;
152
0
    }
153
154
0
    pKey = SymCryptMlKemkeyInitialize( &internalParams, pbKey, cbKey );
155
0
    if ( pKey == NULL )
156
0
    {
157
0
        goto cleanup;
158
0
    }
159
160
0
    pbKey = NULL;
161
162
0
cleanup:
163
0
    if ( pbKey != NULL )
164
0
    {
165
0
        SymCryptCallbackFree( pbKey );
166
0
    }
167
168
0
    return pKey;
169
0
}
170
171
VOID
172
SYMCRYPT_CALL
173
SymCryptMlKemkeyFree(
174
    _Inout_ PSYMCRYPT_MLKEMKEY  pkMlKemkey )
175
0
{
176
0
    SYMCRYPT_CHECK_MAGIC( pkMlKemkey );
177
178
0
    SymCryptWipe( (PBYTE) pkMlKemkey, pkMlKemkey->cbTotalSize );
179
180
0
    SymCryptCallbackFree( pkMlKemkey );
181
0
}
182
183
0
#define SYMCRYPT_MLKEM_SIZEOF_ENCODED_UNCOMPRESSED_VECTOR(_nRows)   (384UL * _nRows)
184
185
// d and z are each 32 bytes
186
0
#define SYMCRYPT_MLKEM_SIZEOF_FORMAT_PRIVATE_SEED               (2*32)
187
// s and t are encoded uncompressed vectors
188
// public seed, H(encapsulation key) and z are each 32 bytes
189
0
#define SYMCRYPT_MLKEM_SIZEOF_FORMAT_DECAPSULATION_KEY(_nRows)  ((2*SYMCRYPT_MLKEM_SIZEOF_ENCODED_UNCOMPRESSED_VECTOR(_nRows)) + (3*32))
190
// t is encoded uncompressed vector
191
// public seed is 32 bytes
192
0
#define SYMCRYPT_MLKEM_SIZEOF_FORMAT_ENCAPSULATION_KEY(_nRows)  (SYMCRYPT_MLKEM_SIZEOF_ENCODED_UNCOMPRESSED_VECTOR(_nRows) + 32)
193
194
SYMCRYPT_ERROR
195
SYMCRYPT_CALL
196
SymCryptMlKemSizeofKeyFormatFromParams(
197
            SYMCRYPT_MLKEM_PARAMS       params,
198
            SYMCRYPT_MLKEMKEY_FORMAT    mlKemkeyFormat,
199
    _Out_   SIZE_T*                     pcbKeyFormat )
200
0
{
201
0
    SYMCRYPT_ERROR scError = SYMCRYPT_NO_ERROR;
202
0
    SYMCRYPT_MLKEM_INTERNAL_PARAMS internalParams;
203
204
0
    if( mlKemkeyFormat == SYMCRYPT_MLKEMKEY_FORMAT_NULL )
205
0
    {
206
0
        scError = SYMCRYPT_INCOMPATIBLE_FORMAT;
207
0
        goto cleanup;
208
0
    }
209
210
0
    scError = SymCryptMlKemkeyGetInternalParamsFromParams(params, &internalParams);
211
0
    if( scError != SYMCRYPT_NO_ERROR )
212
0
    {
213
0
        goto cleanup;
214
0
    }
215
216
0
    switch( mlKemkeyFormat )
217
0
    {
218
0
        case SYMCRYPT_MLKEMKEY_FORMAT_PRIVATE_SEED:
219
0
            *pcbKeyFormat = SYMCRYPT_MLKEM_SIZEOF_FORMAT_PRIVATE_SEED;
220
0
            break;
221
222
0
        case SYMCRYPT_MLKEMKEY_FORMAT_DECAPSULATION_KEY:
223
0
            *pcbKeyFormat = SYMCRYPT_MLKEM_SIZEOF_FORMAT_DECAPSULATION_KEY(internalParams.nRows);
224
0
            break;
225
226
0
        case SYMCRYPT_MLKEMKEY_FORMAT_ENCAPSULATION_KEY:
227
0
            *pcbKeyFormat = SYMCRYPT_MLKEM_SIZEOF_FORMAT_ENCAPSULATION_KEY(internalParams.nRows);
228
0
            break;
229
230
0
        default:
231
0
            scError = SYMCRYPT_INVALID_ARGUMENT;
232
0
            goto cleanup;
233
0
    }
234
235
0
cleanup:
236
0
    return scError;
237
0
}
238
239
SYMCRYPT_ERROR
240
SYMCRYPT_CALL
241
SymCryptMlKemSizeofCiphertextFromParams(
242
            SYMCRYPT_MLKEM_PARAMS       params,
243
    _Out_   SIZE_T*                     pcbCiphertext )
244
0
{
245
0
    SYMCRYPT_ERROR scError = SYMCRYPT_NO_ERROR;
246
0
    SYMCRYPT_MLKEM_INTERNAL_PARAMS internalParams;
247
0
    SIZE_T cbU, cbV;
248
249
0
    scError = SymCryptMlKemkeyGetInternalParamsFromParams(params, &internalParams);
250
0
    if( scError != SYMCRYPT_NO_ERROR )
251
0
    {
252
0
        goto cleanup;
253
0
    }
254
255
    // u vector encoded with nBitsOfU * SYMCRYPT_MLWE_POLYNOMIAL_COEFFICIENTS bits per polynomial
256
0
    cbU = ((SIZE_T)internalParams.nRows) * internalParams.nBitsOfU * (SYMCRYPT_MLWE_POLYNOMIAL_COEFFICIENTS / 8);
257
    // v polynomial encoded with nBitsOfV * SYMCRYPT_MLWE_POLYNOMIAL_COEFFICIENTS bits
258
0
    cbV = ((SIZE_T)internalParams.nBitsOfV) * (SYMCRYPT_MLWE_POLYNOMIAL_COEFFICIENTS / 8);
259
0
    *pcbCiphertext = cbU + cbV;
260
261
0
    SYMCRYPT_ASSERT( (internalParams.params != SYMCRYPT_MLKEM_PARAMS_MLKEM512)  || ((cbU + cbV) == SYMCRYPT_MLKEM_CIPHERTEXT_SIZE_MLKEM512)  );
262
0
    SYMCRYPT_ASSERT( (internalParams.params != SYMCRYPT_MLKEM_PARAMS_MLKEM768)  || ((cbU + cbV) == SYMCRYPT_MLKEM_CIPHERTEXT_SIZE_MLKEM768)  );
263
0
    SYMCRYPT_ASSERT( (internalParams.params != SYMCRYPT_MLKEM_PARAMS_MLKEM1024) || ((cbU + cbV) == SYMCRYPT_MLKEM_CIPHERTEXT_SIZE_MLKEM1024) );
264
265
0
cleanup:
266
0
    return scError;
267
0
}
268
269
static
270
VOID
271
SYMCRYPT_CALL
272
SymCryptMlKemkeyExpandPublicMatrixFromPublicSeed(
273
    _Inout_ PSYMCRYPT_MLKEMKEY                                  pkMlKemkey,
274
    _Inout_ PSYMCRYPT_MLKEM_INTERNAL_COMPUTATION_TEMPORARIES    pCompTemps )
275
0
{
276
0
    UINT32 i, j;
277
0
    BYTE coordinates[2];
278
279
0
    PSYMCRYPT_SHAKE128_STATE pShakeStateBase = &pCompTemps->hashState0.shake128State;
280
0
    PSYMCRYPT_SHAKE128_STATE pShakeStateWork = &pCompTemps->hashState1.shake128State;
281
0
    const UINT32 nRows = pkMlKemkey->params.nRows;
282
283
0
    SymCryptShake128Init( pShakeStateBase );
284
0
    SymCryptShake128Append( pShakeStateBase, pkMlKemkey->publicSeed, sizeof(pkMlKemkey->publicSeed) );
285
286
0
    for( i=0; i<nRows; i++ )
287
0
    {
288
0
        coordinates[1] = (BYTE)i;
289
0
        for( j=0; j<nRows; j++ )
290
0
        {
291
0
            coordinates[0] = (BYTE)j;
292
0
            SymCryptShake128StateCopy( pShakeStateBase, pShakeStateWork );
293
0
            SymCryptShake128Append( pShakeStateWork, coordinates, sizeof(coordinates) );
294
295
0
            SymCryptMlKemPolyElementSampleNTTFromShake128( pShakeStateWork, pkMlKemkey->pmAtranspose->apPolyElements[(i*nRows)+j] );
296
0
        }
297
0
    }
298
299
    // no need to wipe; everything computed here is always public
300
0
}
301
302
static
303
VOID
304
SYMCRYPT_CALL
305
SymCryptMlKemkeyComputeEncapsulationKeyHash(
306
    _Inout_ PSYMCRYPT_MLKEMKEY                                  pkMlKemkey,
307
    _Inout_ PSYMCRYPT_MLKEM_INTERNAL_COMPUTATION_TEMPORARIES    pCompTemps,
308
            SIZE_T                                              cbEncodedVector )
309
0
{
310
0
    PSYMCRYPT_SHA3_256_STATE pState = &pCompTemps->hashState0.sha3_256State;
311
312
0
    SymCryptSha3_256Init( pState );
313
0
    SymCryptSha3_256Append( pState, pkMlKemkey->encodedT, cbEncodedVector );
314
0
    SymCryptSha3_256Append( pState, pkMlKemkey->publicSeed, sizeof(pkMlKemkey->publicSeed) );
315
0
    SymCryptSha3_256Result( pState, pkMlKemkey->encapsKeyHash );
316
0
}
317
318
static
319
VOID
320
SYMCRYPT_CALL
321
SymCryptMlKemkeyExpandFromPrivateSeed(
322
    _Inout_ PSYMCRYPT_MLKEMKEY                                  pkMlKemkey,
323
    _Inout_ PSYMCRYPT_MLKEM_INTERNAL_COMPUTATION_TEMPORARIES    pCompTemps )
324
0
{
325
0
    BYTE privateSeedHash[SYMCRYPT_SHA3_512_RESULT_SIZE];
326
0
    BYTE CBDSampleBuffer[3*64 + 1];
327
0
    PSYMCRYPT_MLKEM_VECTOR pvTmp;
328
0
    PSYMCRYPT_MLKEM_POLYELEMENT_ACCUMULATOR paTmp;
329
0
    PSYMCRYPT_SHAKE256_STATE pShakeStateBase = &pCompTemps->hashState0.shake256State;
330
0
    PSYMCRYPT_SHAKE256_STATE pShakeStateWork = &pCompTemps->hashState1.shake256State;
331
0
    UINT32 i;
332
0
    const UINT32 nRows = pkMlKemkey->params.nRows;
333
0
    const UINT32 nEta1 = pkMlKemkey->params.nEta1;
334
0
    const SIZE_T cbEncodedVector = SYMCRYPT_MLKEM_SIZEOF_ENCODED_UNCOMPRESSED_VECTOR(nRows);
335
0
    const UINT32 cbPolyElement = pkMlKemkey->params.cbPolyElement;
336
0
    const UINT32 cbVector = pkMlKemkey->params.cbVector;
337
338
0
    SYMCRYPT_ASSERT( pkMlKemkey->hasPrivateSeed );
339
0
    SYMCRYPT_ASSERT( (nEta1 == 2) || (nEta1 == 3) );
340
0
    SYMCRYPT_ASSERT( cbEncodedVector <= sizeof(pkMlKemkey->encodedT) );
341
342
0
    pvTmp = SymCryptMlKemVectorCreate( pCompTemps->abVectorBuffer0, cbVector, nRows );
343
0
    SYMCRYPT_ASSERT( pvTmp != NULL );
344
0
    paTmp = SymCryptMlKemPolyElementAccumulatorCreate( pCompTemps->abPolyElementAccumulatorBuffer, 2*cbPolyElement );
345
0
    SYMCRYPT_ASSERT( paTmp != NULL );
346
347
    // (rho || sigma) = G(d || k)
348
    // use CBDSampleBuffer to concatenate the private seed and encoding of nRows
349
0
    memcpy( CBDSampleBuffer, pkMlKemkey->privateSeed, sizeof(pkMlKemkey->privateSeed) );
350
0
    CBDSampleBuffer[sizeof(pkMlKemkey->privateSeed)] = (BYTE) nRows;
351
0
    SymCryptSha3_512( CBDSampleBuffer, sizeof(pkMlKemkey->privateSeed)+1, privateSeedHash );
352
353
    // copy public seed
354
0
    memcpy( pkMlKemkey->publicSeed, privateSeedHash, sizeof(pkMlKemkey->publicSeed) );
355
356
    // generate A from public seed
357
0
    SymCryptMlKemkeyExpandPublicMatrixFromPublicSeed( pkMlKemkey, pCompTemps );
358
359
    // Initialize pShakeStateBase with sigma
360
0
    SymCryptShake256Init( pShakeStateBase );
361
0
    SymCryptShake256Append( pShakeStateBase, privateSeedHash+sizeof(pkMlKemkey->publicSeed), 32 );
362
363
    // Expand s in place
364
0
    for( i=0; i<nRows; i++ )
365
0
    {
366
0
        CBDSampleBuffer[0] = (BYTE) i;
367
0
        SymCryptShake256StateCopy( pShakeStateBase, pShakeStateWork );
368
0
        SymCryptShake256Append( pShakeStateWork, CBDSampleBuffer, 1 );
369
370
0
        SymCryptShake256Extract( pShakeStateWork, CBDSampleBuffer, 64ul*nEta1, FALSE );
371
372
0
        SymCryptMlKemPolyElementSampleCBDFromBytes( CBDSampleBuffer, nEta1, SYMCRYPT_INTERNAL_MLKEM_VECTOR_ELEMENT(i, pkMlKemkey->pvs) );
373
0
    }
374
    // Expand e in t, ready for multiply-add
375
0
    for( i=0; i<nRows; i++ )
376
0
    {
377
0
        CBDSampleBuffer[0] = (BYTE) (nRows+i);
378
0
        SymCryptShake256StateCopy( pShakeStateBase, pShakeStateWork );
379
0
        SymCryptShake256Append( pShakeStateWork, CBDSampleBuffer, 1 );
380
381
0
        SymCryptShake256Extract( pShakeStateWork, CBDSampleBuffer, 64ul*nEta1, FALSE );
382
383
0
        SymCryptMlKemPolyElementSampleCBDFromBytes( CBDSampleBuffer, nEta1, SYMCRYPT_INTERNAL_MLKEM_VECTOR_ELEMENT(i, pkMlKemkey->pvt) );
384
0
    }
385
386
    // Perform NTT on s and e
387
0
    SymCryptMlKemVectorNTT( pkMlKemkey->pvs );
388
0
    SymCryptMlKemVectorNTT( pkMlKemkey->pvt );
389
390
    // pvTmp = s .* R
391
0
    SymCryptMlKemVectorMulR( pkMlKemkey->pvs, pvTmp );
392
393
    // t = ((A o (s .* R)) ./ R) + e = A o s + e
394
0
    SymCryptMlKemMatrixVectorMontMulAndAdd( pkMlKemkey->pmAtranspose, pvTmp, pkMlKemkey->pvt, paTmp );
395
396
    // transpose A
397
0
    SymCryptMlKemMatrixTranspose( pkMlKemkey->pmAtranspose );
398
399
    // precompute byte-encoding of public vector t
400
0
    SymCryptMlKemVectorCompressAndEncode( pkMlKemkey->pvt, 12, pkMlKemkey->encodedT, cbEncodedVector );
401
402
    // precompute hash of encapsulation key blob
403
0
    SymCryptMlKemkeyComputeEncapsulationKeyHash( pkMlKemkey, pCompTemps, cbEncodedVector );
404
405
    // Cleanup!
406
0
    SymCryptWipeKnownSize( privateSeedHash, sizeof(privateSeedHash) );
407
0
    SymCryptWipeKnownSize( CBDSampleBuffer, sizeof(CBDSampleBuffer) );
408
0
}
409
410
SYMCRYPT_ERROR
411
SYMCRYPT_CALL
412
SymCryptMlKemkeySetValue(
413
    _In_reads_bytes_( cbSrc )   PCBYTE                      pbSrc,
414
                                SIZE_T                      cbSrc,
415
                                SYMCRYPT_MLKEMKEY_FORMAT    mlKemkeyFormat,
416
                                UINT32                      flags,
417
    _Inout_                     PSYMCRYPT_MLKEMKEY          pkMlKemkey )
418
0
{
419
0
    SYMCRYPT_ERROR scError = SYMCRYPT_NO_ERROR;
420
0
    PCBYTE pbCurr = pbSrc;
421
0
    PSYMCRYPT_MLKEM_INTERNAL_COMPUTATION_TEMPORARIES pCompTemps = NULL;
422
0
    const UINT32 nRows = pkMlKemkey->params.nRows;
423
0
    const SIZE_T cbEncodedVector = SYMCRYPT_MLKEM_SIZEOF_ENCODED_UNCOMPRESSED_VECTOR( nRows );
424
425
    // Ensure only allowed flags are specified
426
0
    UINT32 allowedFlags = SYMCRYPT_FLAG_KEY_NO_FIPS | SYMCRYPT_FLAG_KEY_MINIMAL_VALIDATION;
427
428
0
    if ( ( flags & ~allowedFlags ) != 0 )
429
0
    {
430
0
        scError = SYMCRYPT_INVALID_ARGUMENT;
431
0
        goto cleanup;
432
0
    }
433
434
    // Check that minimal validation flag only specified with no fips
435
0
    if ( ( ( flags & SYMCRYPT_FLAG_KEY_NO_FIPS ) == 0 ) &&
436
0
         ( ( flags & SYMCRYPT_FLAG_KEY_MINIMAL_VALIDATION ) != 0 ) )
437
0
    {
438
0
        scError = SYMCRYPT_INVALID_ARGUMENT;
439
0
        goto cleanup;
440
0
    }
441
442
0
    if( mlKemkeyFormat == SYMCRYPT_MLKEMKEY_FORMAT_NULL )
443
0
    {
444
0
        scError = SYMCRYPT_INVALID_ARGUMENT;
445
0
        goto cleanup;
446
0
    }
447
448
0
    if( ( flags & SYMCRYPT_FLAG_KEY_NO_FIPS ) == 0 )
449
0
    {
450
        // Ensure ML-KEM algorithm selftest is run before first use of ML-KEM algorithms;
451
        // notably _before_ first full KeyGen
452
0
        SYMCRYPT_RUN_SELFTEST_ONCE(
453
0
            SymCryptMlKemSelftest,
454
0
            SYMCRYPT_SELFTEST_ALGORITHM_MLKEM);
455
0
    }
456
457
0
    pCompTemps = SymCryptCallbackAlloc( sizeof(SYMCRYPT_MLKEM_INTERNAL_COMPUTATION_TEMPORARIES) );
458
0
    if( pCompTemps == NULL )
459
0
    {
460
0
        scError = SYMCRYPT_MEMORY_ALLOCATION_FAILURE;
461
0
        goto cleanup;
462
0
    }
463
464
0
    if( mlKemkeyFormat == SYMCRYPT_MLKEMKEY_FORMAT_PRIVATE_SEED )
465
0
    {
466
0
        if( cbSrc != SYMCRYPT_MLKEM_SIZEOF_FORMAT_PRIVATE_SEED )
467
0
        {
468
0
            scError = SYMCRYPT_WRONG_KEY_SIZE;
469
0
            goto cleanup;
470
0
        }
471
472
0
        pkMlKemkey->hasPrivateSeed = TRUE;
473
0
        memcpy( pkMlKemkey->privateSeed, pbCurr, sizeof(pkMlKemkey->privateSeed) );
474
0
        pbCurr += sizeof(pkMlKemkey->privateSeed);
475
476
0
        pkMlKemkey->hasPrivateKey = TRUE;
477
0
        memcpy( pkMlKemkey->privateRandom, pbCurr, sizeof(pkMlKemkey->privateRandom) );
478
0
        pbCurr += sizeof(pkMlKemkey->privateRandom);
479
480
0
        SymCryptMlKemkeyExpandFromPrivateSeed( pkMlKemkey, pCompTemps );
481
0
    }
482
0
    else if( mlKemkeyFormat == SYMCRYPT_MLKEMKEY_FORMAT_DECAPSULATION_KEY )
483
0
    {
484
0
        if( cbSrc != SYMCRYPT_MLKEM_SIZEOF_FORMAT_DECAPSULATION_KEY( nRows ) )
485
0
        {
486
0
            scError = SYMCRYPT_WRONG_KEY_SIZE;
487
0
            goto cleanup;
488
0
        }
489
490
        // decode s
491
0
        scError = SymCryptMlKemVectorDecodeAndDecompress( pbCurr, cbEncodedVector, 12, pkMlKemkey->pvs );
492
0
        if( scError != SYMCRYPT_NO_ERROR )
493
0
        {
494
0
            goto cleanup;
495
0
        }
496
0
        pbCurr += cbEncodedVector;
497
498
        // copy t and decode t
499
0
        memcpy( pkMlKemkey->encodedT, pbCurr, cbEncodedVector );
500
0
        pbCurr += cbEncodedVector;
501
0
        scError = SymCryptMlKemVectorDecodeAndDecompress( pkMlKemkey->encodedT, cbEncodedVector, 12, pkMlKemkey->pvt );
502
0
        if( scError != SYMCRYPT_NO_ERROR )
503
0
        {
504
0
            goto cleanup;
505
0
        }
506
507
        // copy public seed and expand public matrix
508
0
        memcpy( pkMlKemkey->publicSeed, pbCurr, sizeof(pkMlKemkey->publicSeed) );
509
0
        pbCurr += sizeof(pkMlKemkey->publicSeed);
510
0
        SymCryptMlKemkeyExpandPublicMatrixFromPublicSeed( pkMlKemkey, pCompTemps );
511
512
        // transpose A
513
0
        SymCryptMlKemMatrixTranspose( pkMlKemkey->pmAtranspose );
514
515
        // copy hash of encapsulation key
516
0
        memcpy( pkMlKemkey->encapsKeyHash, pbCurr, sizeof(pkMlKemkey->encapsKeyHash) );
517
0
        pbCurr += sizeof(pkMlKemkey->encapsKeyHash);
518
519
        // copy private random
520
0
        memcpy( pkMlKemkey->privateRandom, pbCurr, sizeof(pkMlKemkey->privateRandom) );
521
0
        pbCurr += sizeof(pkMlKemkey->privateRandom);
522
523
0
        pkMlKemkey->hasPrivateSeed = FALSE;
524
0
        pkMlKemkey->hasPrivateKey  = TRUE;
525
0
    }
526
0
    else if( mlKemkeyFormat == SYMCRYPT_MLKEMKEY_FORMAT_ENCAPSULATION_KEY )
527
0
    {
528
0
        if( cbSrc != SYMCRYPT_MLKEM_SIZEOF_FORMAT_ENCAPSULATION_KEY( nRows ) )
529
0
        {
530
0
            scError = SYMCRYPT_WRONG_KEY_SIZE;
531
0
            goto cleanup;
532
0
        }
533
534
        // copy t and decode t
535
0
        memcpy( pkMlKemkey->encodedT, pbCurr, cbEncodedVector );
536
0
        pbCurr += cbEncodedVector;
537
0
        scError = SymCryptMlKemVectorDecodeAndDecompress( pkMlKemkey->encodedT, cbEncodedVector, 12, pkMlKemkey->pvt );
538
0
        if( scError != SYMCRYPT_NO_ERROR )
539
0
        {
540
0
            goto cleanup;
541
0
        }
542
543
        // copy public seed and expand public matrix
544
0
        memcpy( pkMlKemkey->publicSeed, pbCurr, sizeof(pkMlKemkey->publicSeed) );
545
0
        pbCurr += sizeof(pkMlKemkey->publicSeed);
546
0
        SymCryptMlKemkeyExpandPublicMatrixFromPublicSeed( pkMlKemkey, pCompTemps );
547
548
        // transpose A
549
0
        SymCryptMlKemMatrixTranspose( pkMlKemkey->pmAtranspose );
550
551
        // precompute hash of encapsulation key blob
552
0
        SymCryptMlKemkeyComputeEncapsulationKeyHash( pkMlKemkey, pCompTemps, cbEncodedVector );
553
554
0
        pkMlKemkey->hasPrivateSeed = FALSE;
555
0
        pkMlKemkey->hasPrivateKey  = FALSE;
556
0
    }
557
0
    else
558
0
    {
559
0
        scError = SYMCRYPT_NOT_IMPLEMENTED;
560
0
        goto cleanup;
561
0
    }
562
563
0
    SYMCRYPT_ASSERT( pbCurr == pbSrc + cbSrc );
564
565
0
cleanup:
566
0
    if( pCompTemps != NULL )
567
0
    {
568
0
        SymCryptWipe( pCompTemps, sizeof(*pCompTemps) );
569
0
        SymCryptCallbackFree( pCompTemps );
570
0
    }
571
572
0
    return scError;
573
0
}
574
575
576
SYMCRYPT_ERROR
577
SYMCRYPT_CALL
578
SymCryptMlKemkeyGetValue(
579
    _In_                        PCSYMCRYPT_MLKEMKEY         pkMlKemkey,
580
    _Out_writes_bytes_( cbDst ) PBYTE                       pbDst,
581
                                SIZE_T                      cbDst,
582
                                SYMCRYPT_MLKEMKEY_FORMAT    mlKemkeyFormat,
583
                                UINT32                      flags )
584
0
{
585
0
    SYMCRYPT_ERROR scError = SYMCRYPT_NO_ERROR;
586
0
    PBYTE pbCurr = pbDst;
587
0
    const UINT32 nRows = pkMlKemkey->params.nRows;
588
0
    const SIZE_T cbEncodedVector = SYMCRYPT_MLKEM_SIZEOF_ENCODED_UNCOMPRESSED_VECTOR( nRows );
589
590
0
    UNREFERENCED_PARAMETER( flags );
591
592
0
    if( mlKemkeyFormat == SYMCRYPT_MLKEMKEY_FORMAT_NULL )
593
0
    {
594
0
        scError = SYMCRYPT_INVALID_ARGUMENT;
595
0
        goto cleanup;
596
0
    }
597
598
0
    if( mlKemkeyFormat == SYMCRYPT_MLKEMKEY_FORMAT_PRIVATE_SEED )
599
0
    {
600
0
        if( cbDst != SYMCRYPT_MLKEM_SIZEOF_FORMAT_PRIVATE_SEED )
601
0
        {
602
0
            scError = SYMCRYPT_WRONG_KEY_SIZE;
603
0
            goto cleanup;
604
0
        }
605
606
0
        if( !pkMlKemkey->hasPrivateSeed )
607
0
        {
608
0
            scError = SYMCRYPT_INCOMPATIBLE_FORMAT;
609
0
            goto cleanup;
610
0
        }
611
612
0
        memcpy( pbCurr, pkMlKemkey->privateSeed, sizeof(pkMlKemkey->privateSeed) );
613
0
        pbCurr += sizeof(pkMlKemkey->privateSeed);
614
615
0
        memcpy( pbCurr, pkMlKemkey->privateRandom, sizeof(pkMlKemkey->privateRandom) );
616
0
        pbCurr += sizeof(pkMlKemkey->privateRandom);
617
0
    }
618
0
    else if( mlKemkeyFormat == SYMCRYPT_MLKEMKEY_FORMAT_DECAPSULATION_KEY )
619
0
    {
620
0
        if( cbDst != SYMCRYPT_MLKEM_SIZEOF_FORMAT_DECAPSULATION_KEY( nRows ) )
621
0
        {
622
0
            scError = SYMCRYPT_INVALID_ARGUMENT;
623
0
            goto cleanup;
624
0
        }
625
626
0
        if( !pkMlKemkey->hasPrivateKey )
627
0
        {
628
0
            scError = SYMCRYPT_INVALID_ARGUMENT;
629
0
            goto cleanup;
630
0
        }
631
632
        // We don't precompute byte-encoding of private key as exporting decapsulation key is not a critical path operation
633
        // All other fields are kept in memory
634
0
        SymCryptMlKemVectorCompressAndEncode( pkMlKemkey->pvs, 12, pbCurr, cbEncodedVector );
635
0
        pbCurr += cbEncodedVector;
636
637
0
        memcpy( pbCurr, pkMlKemkey->encodedT, cbEncodedVector );
638
0
        pbCurr += cbEncodedVector;
639
640
0
        memcpy( pbCurr, pkMlKemkey->publicSeed, sizeof(pkMlKemkey->publicSeed) );
641
0
        pbCurr += sizeof(pkMlKemkey->publicSeed);
642
643
0
        memcpy( pbCurr, pkMlKemkey->encapsKeyHash, sizeof(pkMlKemkey->encapsKeyHash) );
644
0
        pbCurr += sizeof(pkMlKemkey->encapsKeyHash);
645
646
0
        memcpy( pbCurr, pkMlKemkey->privateRandom, sizeof(pkMlKemkey->privateRandom) );
647
0
        pbCurr += sizeof(pkMlKemkey->privateRandom);
648
0
    }
649
0
    else if( mlKemkeyFormat == SYMCRYPT_MLKEMKEY_FORMAT_ENCAPSULATION_KEY )
650
0
    {
651
0
        if( cbDst != SYMCRYPT_MLKEM_SIZEOF_FORMAT_ENCAPSULATION_KEY( nRows ) )
652
0
        {
653
0
            scError = SYMCRYPT_INVALID_ARGUMENT;
654
0
            goto cleanup;
655
0
        }
656
657
0
        memcpy( pbCurr, pkMlKemkey->encodedT, cbEncodedVector );
658
0
        pbCurr += cbEncodedVector;
659
660
0
        memcpy( pbCurr, pkMlKemkey->publicSeed, sizeof(pkMlKemkey->publicSeed) );
661
0
        pbCurr += sizeof(pkMlKemkey->publicSeed);
662
0
    }
663
0
    else
664
0
    {
665
0
        scError = SYMCRYPT_NOT_IMPLEMENTED;
666
0
        goto cleanup;
667
0
    }
668
669
0
    SYMCRYPT_ASSERT( pbCurr == pbDst + cbDst );
670
671
0
cleanup:
672
0
    return scError;
673
0
}
674
675
676
SYMCRYPT_ERROR
677
SYMCRYPT_CALL
678
SymCryptMlKemkeyGenerate(
679
    _Inout_                     PSYMCRYPT_MLKEMKEY  pkMlKemkey,
680
                                UINT32              flags )
681
0
{
682
0
    SYMCRYPT_ERROR scError = SYMCRYPT_NO_ERROR;
683
0
    BYTE privateSeed[SYMCRYPT_MLKEM_SIZEOF_FORMAT_PRIVATE_SEED];
684
685
    // Ensure only allowed flags are specified
686
0
    UINT32 allowedFlags = SYMCRYPT_FLAG_KEY_NO_FIPS;
687
688
0
    if ( ( flags & ~allowedFlags ) != 0 )
689
0
    {
690
0
        scError = SYMCRYPT_INVALID_ARGUMENT;
691
0
        goto cleanup;
692
0
    }
693
694
0
    scError = SymCryptCallbackRandom( privateSeed, sizeof(privateSeed) );
695
0
    if( scError != SYMCRYPT_NO_ERROR )
696
0
    {
697
0
        goto cleanup;
698
0
    }
699
700
0
    scError = SymCryptMlKemkeySetValue( privateSeed, sizeof(privateSeed), SYMCRYPT_MLKEMKEY_FORMAT_PRIVATE_SEED, flags, pkMlKemkey );
701
0
    if( scError != SYMCRYPT_NO_ERROR )
702
0
    {
703
0
        goto cleanup;
704
0
    }
705
706
    // SymCryptMlKemkeySetValue ensures the self-test is run before
707
    // first operational use of MlKem
708
709
    // Awaiting feedback from NIST for discussion from PQC forum and CMUF
710
    // before implementing costly PCT on ML-KEM key generation which is
711
    // not expected by FIPS 203
712
713
0
cleanup:
714
0
    SymCryptWipeKnownSize( privateSeed, sizeof(privateSeed) );
715
716
0
    return scError;
717
0
}
718
719
SYMCRYPT_ERROR
720
SYMCRYPT_CALL
721
SymCryptMlKemEncapsulateInternal(
722
    _In_    PCSYMCRYPT_MLKEMKEY                                 pkMlKemkey,
723
    _Out_writes_bytes_( cbAgreedSecret )
724
            PBYTE                                               pbAgreedSecret,
725
            SIZE_T                                              cbAgreedSecret,
726
    _Out_writes_bytes_( cbCiphertext )
727
            PBYTE                                               pbCiphertext,
728
            SIZE_T                                              cbCiphertext,
729
    _In_reads_bytes_( SYMCRYPT_MLKEM_SIZEOF_ENCAPS_RANDOM )
730
            PCBYTE                                              pbRandom,
731
    _Inout_ PSYMCRYPT_MLKEM_INTERNAL_COMPUTATION_TEMPORARIES    pCompTemps )
732
0
{
733
0
    BYTE CBDSampleBuffer[3*64 + 1];
734
0
    SYMCRYPT_ERROR scError = SYMCRYPT_NO_ERROR;
735
0
    PSYMCRYPT_MLKEM_VECTOR pvrInner;
736
0
    PSYMCRYPT_MLKEM_VECTOR pvTmp;
737
0
    PSYMCRYPT_MLKEM_POLYELEMENT peTmp0, peTmp1;
738
0
    PSYMCRYPT_MLKEM_POLYELEMENT_ACCUMULATOR paTmp;
739
0
    PSYMCRYPT_SHA3_512_STATE pHashState = &pCompTemps->hashState0.sha3_512State;
740
0
    PSYMCRYPT_SHAKE256_STATE pShakeBaseState = &pCompTemps->hashState0.shake256State;
741
0
    PSYMCRYPT_SHAKE256_STATE pShakeWorkState = &pCompTemps->hashState1.shake256State;
742
0
    SIZE_T cbU, cbV;
743
0
    UINT32 i;
744
0
    const UINT32 nRows = pkMlKemkey->params.nRows;
745
0
    const UINT32 nBitsOfU = pkMlKemkey->params.nBitsOfU;
746
0
    const UINT32 nBitsOfV = pkMlKemkey->params.nBitsOfV;
747
0
    const UINT32 nEta1 = pkMlKemkey->params.nEta1;
748
0
    const UINT32 nEta2 = pkMlKemkey->params.nEta2;
749
0
    const UINT32 cbPolyElement = pkMlKemkey->params.cbPolyElement;
750
0
    const UINT32 cbVector = pkMlKemkey->params.cbVector;
751
752
    // u vector encoded with nBitsOfU * SYMCRYPT_MLWE_POLYNOMIAL_COEFFICIENTS bits per polynomial
753
0
    cbU = nRows * nBitsOfU * (SYMCRYPT_MLWE_POLYNOMIAL_COEFFICIENTS / 8);
754
    // v polynomial encoded with nBitsOfV * SYMCRYPT_MLWE_POLYNOMIAL_COEFFICIENTS bits
755
0
    cbV = nBitsOfV * (SYMCRYPT_MLWE_POLYNOMIAL_COEFFICIENTS / 8);
756
757
0
    if( (cbAgreedSecret != SYMCRYPT_MLKEM_SIZEOF_AGREED_SECRET) ||
758
0
        (cbCiphertext != cbU + cbV) )
759
0
    {
760
0
        scError = SYMCRYPT_INVALID_ARGUMENT;
761
0
        goto cleanup;
762
0
    }
763
764
0
    pvrInner = SymCryptMlKemVectorCreate( pCompTemps->abVectorBuffer0, cbVector, nRows );
765
0
    SYMCRYPT_ASSERT( pvrInner != NULL );
766
0
    pvTmp = SymCryptMlKemVectorCreate( pCompTemps->abVectorBuffer1, cbVector, nRows );
767
0
    SYMCRYPT_ASSERT( pvTmp != NULL );
768
0
    peTmp0 = SymCryptMlKemPolyElementCreate( pCompTemps->abPolyElementBuffer0, cbPolyElement );
769
0
    SYMCRYPT_ASSERT( peTmp0 != NULL );
770
0
    peTmp1 = SymCryptMlKemPolyElementCreate( pCompTemps->abPolyElementBuffer1, cbPolyElement );
771
0
    SYMCRYPT_ASSERT( peTmp1 != NULL );
772
0
    paTmp = SymCryptMlKemPolyElementAccumulatorCreate( pCompTemps->abPolyElementAccumulatorBuffer, 2*cbPolyElement );
773
0
    SYMCRYPT_ASSERT( paTmp != NULL );
774
775
    // CBDSampleBuffer = (K || rOuter) = SHA3-512(pbRandom || encapsKeyHash)
776
0
    SymCryptSha3_512Init( pHashState );
777
0
    SymCryptSha3_512Append( pHashState, pbRandom, SYMCRYPT_MLKEM_SIZEOF_ENCAPS_RANDOM );
778
0
    SymCryptSha3_512Append( pHashState, pkMlKemkey->encapsKeyHash, sizeof(pkMlKemkey->encapsKeyHash) );
779
0
    SymCryptSha3_512Result( pHashState, CBDSampleBuffer );
780
781
    // Write K to pbAgreedSecret
782
0
    memcpy( pbAgreedSecret, CBDSampleBuffer, SYMCRYPT_MLKEM_SIZEOF_AGREED_SECRET );
783
784
    // Initialize pShakeStateBase with rOuter
785
0
    SymCryptShake256Init( pShakeBaseState );
786
0
    SymCryptShake256Append( pShakeBaseState, CBDSampleBuffer+cbAgreedSecret, 32 );
787
788
    // Expand rInner vector
789
0
    for( i=0; i<nRows; i++ )
790
0
    {
791
0
        CBDSampleBuffer[0] = (BYTE) i;
792
0
        SymCryptShake256StateCopy( pShakeBaseState, pShakeWorkState );
793
0
        SymCryptShake256Append( pShakeWorkState, CBDSampleBuffer, 1 );
794
795
0
        SymCryptShake256Extract( pShakeWorkState, CBDSampleBuffer, 64ul*nEta1, FALSE );
796
797
0
        SymCryptMlKemPolyElementSampleCBDFromBytes( CBDSampleBuffer, nEta1, SYMCRYPT_INTERNAL_MLKEM_VECTOR_ELEMENT(i, pvrInner) );
798
0
    }
799
800
    // Perform NTT on rInner
801
0
    SymCryptMlKemVectorNTT( pvrInner );
802
803
    // Set pvTmp to 0
804
0
    SymCryptMlKemVectorSetZero( pvTmp );
805
806
    // pvTmp = (Atranspose o rInner) ./ R
807
0
    SymCryptMlKemMatrixVectorMontMulAndAdd( pkMlKemkey->pmAtranspose, pvrInner, pvTmp, paTmp );
808
809
    // pvTmp = INTT(Atranspose o rInner)
810
0
    SymCryptMlKemVectorINTTAndMulR( pvTmp );
811
812
    // Expand e1 and add it to pvTmp - do addition PolyElement-wise to reduce memory usage
813
0
    for( i=0; i<nRows; i++ )
814
0
    {
815
0
        CBDSampleBuffer[0] = (BYTE) (nRows+i);
816
0
        SymCryptShake256StateCopy( pShakeBaseState, pShakeWorkState );
817
0
        SymCryptShake256Append( pShakeWorkState, CBDSampleBuffer, 1 );
818
819
0
        SymCryptShake256Extract( pShakeWorkState, CBDSampleBuffer, 64ul*nEta2, FALSE );
820
821
0
        SymCryptMlKemPolyElementSampleCBDFromBytes( CBDSampleBuffer, nEta2, peTmp0 );
822
823
0
        SymCryptMlKemPolyElementAdd( SYMCRYPT_INTERNAL_MLKEM_VECTOR_ELEMENT(i, pvTmp), peTmp0, SYMCRYPT_INTERNAL_MLKEM_VECTOR_ELEMENT(i, pvTmp) );
824
0
    }
825
826
    // pvTmp = u = INTT(Atranspose o rInner) + e1
827
    // Compress and encode u into prefix of ciphertext
828
0
    SymCryptMlKemVectorCompressAndEncode( pvTmp, nBitsOfU, pbCiphertext, cbU );
829
830
    // peTmp0 = (t o r) ./ R
831
0
    SymCryptMlKemVectorMontDotProduct( pkMlKemkey->pvt, pvrInner, peTmp0, paTmp );
832
833
    // peTmp0 = INTT(t o r)
834
0
    SymCryptMlKemPolyElementINTTAndMulR( peTmp0 );
835
836
    // Expand e2 polynomial in peTmp1
837
0
    CBDSampleBuffer[0] = (BYTE) (2*nRows);
838
0
    SymCryptShake256StateCopy( pShakeBaseState, pShakeWorkState );
839
0
    SymCryptShake256Append( pShakeWorkState, CBDSampleBuffer, 1 );
840
841
0
    SymCryptShake256Extract( pShakeWorkState, CBDSampleBuffer, 64ul*nEta2, FALSE );
842
843
0
    SymCryptMlKemPolyElementSampleCBDFromBytes( CBDSampleBuffer, nEta2, peTmp1 );
844
845
    // peTmp = INTT(t o r) + e2
846
0
    SymCryptMlKemPolyElementAdd( peTmp0, peTmp1, peTmp0 );
847
848
    // peTmp1 = mu
849
0
    SymCryptMlKemPolyElementDecodeAndDecompress( pbRandom, 1, peTmp1 );
850
851
    // peTmp0 = v = INTT(t o r) + e2 + mu
852
0
    SymCryptMlKemPolyElementAdd( peTmp0, peTmp1, peTmp0 );
853
854
    // Compress and encode v into remainder of ciphertext
855
0
    SymCryptMlKemPolyElementCompressAndEncode( peTmp0, nBitsOfV, pbCiphertext+cbU );
856
857
0
cleanup:
858
0
    SymCryptWipeKnownSize( CBDSampleBuffer, sizeof(CBDSampleBuffer) );
859
860
0
    return scError;
861
0
}
862
863
864
SYMCRYPT_ERROR
865
SYMCRYPT_CALL
866
SymCryptMlKemEncapsulateEx(
867
    _In_                                    PCSYMCRYPT_MLKEMKEY pkMlKemkey,
868
    _In_reads_bytes_( cbRandom )            PCBYTE              pbRandom,
869
                                            SIZE_T              cbRandom,
870
    _Out_writes_bytes_( cbAgreedSecret )    PBYTE               pbAgreedSecret,
871
                                            SIZE_T              cbAgreedSecret,
872
    _Out_writes_bytes_( cbCiphertext )      PBYTE               pbCiphertext,
873
                                            SIZE_T              cbCiphertext )
874
0
{
875
0
    SYMCRYPT_ERROR scError = SYMCRYPT_NO_ERROR;
876
0
    PSYMCRYPT_MLKEM_INTERNAL_COMPUTATION_TEMPORARIES pCompTemps = NULL;
877
878
0
    if( cbRandom != SYMCRYPT_MLKEM_SIZEOF_ENCAPS_RANDOM )
879
0
    {
880
0
        scError = SYMCRYPT_INVALID_ARGUMENT;
881
0
        goto cleanup;
882
0
    }
883
884
0
    pCompTemps = SymCryptCallbackAlloc( sizeof(SYMCRYPT_MLKEM_INTERNAL_COMPUTATION_TEMPORARIES) );
885
0
    if( pCompTemps == NULL )
886
0
    {
887
0
        scError = SYMCRYPT_MEMORY_ALLOCATION_FAILURE;
888
0
        goto cleanup;
889
0
    }
890
891
0
    scError = SymCryptMlKemEncapsulateInternal(
892
0
        pkMlKemkey,
893
0
        pbAgreedSecret, cbAgreedSecret,
894
0
        pbCiphertext, cbCiphertext,
895
0
        pbRandom,
896
0
        pCompTemps );
897
898
0
cleanup:
899
0
    if( pCompTemps != NULL )
900
0
    {
901
0
        SymCryptWipe( pCompTemps, sizeof(*pCompTemps) );
902
0
        SymCryptCallbackFree( pCompTemps );
903
0
    }
904
905
0
    return scError;
906
0
}
907
908
SYMCRYPT_ERROR
909
SYMCRYPT_CALL
910
SymCryptMlKemEncapsulate(
911
    _In_                                    PCSYMCRYPT_MLKEMKEY pkMlKemkey,
912
    _Out_writes_bytes_( cbAgreedSecret )    PBYTE               pbAgreedSecret,
913
                                            SIZE_T              cbAgreedSecret,
914
    _Out_writes_bytes_( cbCiphertext )      PBYTE               pbCiphertext,
915
                                            SIZE_T              cbCiphertext )
916
0
{
917
0
    SYMCRYPT_ERROR scError = SYMCRYPT_NO_ERROR;
918
0
    BYTE pbm[SYMCRYPT_MLKEM_SIZEOF_ENCAPS_RANDOM];
919
920
0
    scError = SymCryptCallbackRandom( pbm, sizeof(pbm) );
921
0
    if( scError != SYMCRYPT_NO_ERROR )
922
0
    {
923
0
        goto cleanup;
924
0
    }
925
926
0
    scError = SymCryptMlKemEncapsulateEx(
927
0
        pkMlKemkey,
928
0
        pbm, sizeof(pbm),
929
0
        pbAgreedSecret, cbAgreedSecret,
930
0
        pbCiphertext, cbCiphertext );
931
932
0
cleanup:
933
0
    SymCryptWipeKnownSize( pbm, sizeof(pbm) );
934
935
0
    return scError;
936
0
}
937
938
SYMCRYPT_ERROR
939
SYMCRYPT_CALL
940
SymCryptMlKemDecapsulate(
941
    _In_                                    PCSYMCRYPT_MLKEMKEY pkMlKemkey,
942
    _In_reads_bytes_( cbCiphertext )        PCBYTE              pbCiphertext,
943
                                            SIZE_T              cbCiphertext,
944
    _Out_writes_bytes_( cbAgreedSecret )    PBYTE               pbAgreedSecret,
945
                                            SIZE_T              cbAgreedSecret )
946
0
{
947
0
    PSYMCRYPT_MLKEM_INTERNAL_COMPUTATION_TEMPORARIES pCompTemps = NULL;
948
0
    BYTE pbDecryptedRandom[SYMCRYPT_MLKEM_SIZEOF_ENCAPS_RANDOM];
949
0
    BYTE pbDecapsulatedSecret[SYMCRYPT_MLKEM_SIZEOF_AGREED_SECRET];
950
0
    BYTE pbImplicitRejectionSecret[SYMCRYPT_MLKEM_SIZEOF_AGREED_SECRET];
951
0
    PBYTE pbReadCiphertext, pbReencapsulatedCiphertext;
952
0
    BOOLEAN successfulReencrypt;
953
954
0
    PBYTE pbCurr;
955
0
    PBYTE pbAlloc = NULL;
956
0
    const SIZE_T cbAlloc = sizeof(SYMCRYPT_MLKEM_INTERNAL_COMPUTATION_TEMPORARIES) + (2*cbCiphertext);
957
958
0
    SYMCRYPT_ERROR scError = SYMCRYPT_NO_ERROR;
959
0
    SIZE_T cbU, cbV, cbCopy;
960
0
    PSYMCRYPT_MLKEM_VECTOR pvu;
961
0
    PSYMCRYPT_MLKEM_POLYELEMENT peTmp0, peTmp1;
962
0
    PSYMCRYPT_MLKEM_POLYELEMENT_ACCUMULATOR paTmp;
963
0
    PSYMCRYPT_SHAKE256_STATE pShakeState;
964
0
    const UINT32 nRows = pkMlKemkey->params.nRows;
965
0
    const UINT32 nBitsOfU = pkMlKemkey->params.nBitsOfU;
966
0
    const UINT32 nBitsOfV = pkMlKemkey->params.nBitsOfV;
967
0
    const UINT32 cbPolyElement = pkMlKemkey->params.cbPolyElement;
968
0
    const UINT32 cbVector = pkMlKemkey->params.cbVector;
969
970
    // u vector encoded with nBitsOfU * SYMCRYPT_MLWE_POLYNOMIAL_COEFFICIENTS bits per polynomial
971
0
    cbU = nRows * nBitsOfU * (SYMCRYPT_MLWE_POLYNOMIAL_COEFFICIENTS / 8);
972
    // v polynomial encoded with nBitsOfV * SYMCRYPT_MLWE_POLYNOMIAL_COEFFICIENTS bits
973
0
    cbV = nBitsOfV * (SYMCRYPT_MLWE_POLYNOMIAL_COEFFICIENTS / 8);
974
975
0
    if( (cbAgreedSecret != SYMCRYPT_MLKEM_SIZEOF_AGREED_SECRET) ||
976
0
        (cbCiphertext != cbU + cbV) ||
977
0
        !pkMlKemkey->hasPrivateKey )
978
0
    {
979
0
        scError = SYMCRYPT_INVALID_ARGUMENT;
980
0
        goto cleanup;
981
0
    }
982
983
0
    pbAlloc = SymCryptCallbackAlloc( cbAlloc );
984
0
    if( pbAlloc == NULL )
985
0
    {
986
0
        scError = SYMCRYPT_MEMORY_ALLOCATION_FAILURE;
987
0
        goto cleanup;
988
0
    }
989
0
    pbCurr = pbAlloc;
990
991
0
    pCompTemps = (PSYMCRYPT_MLKEM_INTERNAL_COMPUTATION_TEMPORARIES) pbCurr;
992
0
    pbCurr += sizeof(SYMCRYPT_MLKEM_INTERNAL_COMPUTATION_TEMPORARIES);
993
994
0
    pbReadCiphertext = pbCurr;
995
0
    pbCurr += cbCiphertext;
996
997
0
    pbReencapsulatedCiphertext = pbCurr;
998
0
    pbCurr += cbCiphertext;
999
1000
0
    SYMCRYPT_ASSERT( pbCurr == (pbAlloc + cbAlloc) );
1001
1002
    // Read the input ciphertext once to local pbReadCiphertext to ensure our view of ciphertext consistent
1003
0
    memcpy( pbReadCiphertext, pbCiphertext, cbCiphertext );
1004
1005
0
    pvu = SymCryptMlKemVectorCreate( pCompTemps->abVectorBuffer0, cbVector, nRows );
1006
0
    SYMCRYPT_ASSERT( pvu != NULL );
1007
0
    peTmp0 = SymCryptMlKemPolyElementCreate( pCompTemps->abPolyElementBuffer0, cbPolyElement );
1008
0
    SYMCRYPT_ASSERT( peTmp0 != NULL );
1009
0
    peTmp1 = SymCryptMlKemPolyElementCreate( pCompTemps->abPolyElementBuffer1, cbPolyElement );
1010
0
    SYMCRYPT_ASSERT( peTmp1 != NULL );
1011
0
    paTmp = SymCryptMlKemPolyElementAccumulatorCreate( pCompTemps->abPolyElementAccumulatorBuffer, 2*cbPolyElement );
1012
0
    SYMCRYPT_ASSERT( paTmp != NULL );
1013
1014
    // Decode and decompress u
1015
0
    scError = SymCryptMlKemVectorDecodeAndDecompress( pbReadCiphertext, cbU, nBitsOfU, pvu );
1016
0
    SYMCRYPT_ASSERT( scError == SYMCRYPT_NO_ERROR );
1017
1018
    // Perform NTT on u
1019
0
    SymCryptMlKemVectorNTT( pvu );
1020
1021
    // peTmp0 = (s o NTT(u)) ./ R
1022
0
    SymCryptMlKemVectorMontDotProduct( pkMlKemkey->pvs, pvu, peTmp0, paTmp );
1023
1024
    // peTmp0 = INTT(s o NTT(u))
1025
0
    SymCryptMlKemPolyElementINTTAndMulR( peTmp0 );
1026
1027
    // Decode and decompress v
1028
0
    scError = SymCryptMlKemPolyElementDecodeAndDecompress( pbReadCiphertext+cbU, nBitsOfV, peTmp1 );
1029
0
    SYMCRYPT_ASSERT( scError == SYMCRYPT_NO_ERROR );
1030
1031
    // peTmp0 = w = v - INTT(s o NTT(u))
1032
0
    SymCryptMlKemPolyElementSub( peTmp1, peTmp0, peTmp0 );
1033
1034
    // pbDecryptedRandom = m' = Encoding of w
1035
0
    SymCryptMlKemPolyElementCompressAndEncode( peTmp0, 1, pbDecryptedRandom );
1036
1037
    // Compute:
1038
    //  pbDecapsulatedSecret = K' = Decapsulated secret (without implicit rejection)
1039
    //  pbReencapsulatedCiphertext = c' = Ciphertext from re-encapsulating decrypted random value
1040
0
    scError = SymCryptMlKemEncapsulateInternal(
1041
0
        pkMlKemkey,
1042
0
        pbDecapsulatedSecret, sizeof(pbDecapsulatedSecret),
1043
0
        pbReencapsulatedCiphertext, cbCiphertext,
1044
0
        pbDecryptedRandom,
1045
0
        pCompTemps );
1046
0
    SYMCRYPT_ASSERT( scError == SYMCRYPT_NO_ERROR );
1047
1048
    // Compute the secret we will return if using implicit rejection
1049
    // pbImplicitRejectionSecret = K_bar = SHAKE256( z || c )
1050
0
    pShakeState = &pCompTemps->hashState0.shake256State;
1051
0
    SymCryptShake256Init( pShakeState );
1052
0
    SymCryptShake256Append( pShakeState, pkMlKemkey->privateRandom, sizeof(pkMlKemkey->privateRandom) );
1053
0
    SymCryptShake256Append( pShakeState, pbReadCiphertext, cbCiphertext );
1054
0
    SymCryptShake256Extract( pShakeState, pbImplicitRejectionSecret, sizeof(pbImplicitRejectionSecret), FALSE );
1055
1056
    // Constant time test if re-encryption successful
1057
0
    successfulReencrypt = SymCryptEqual( pbReencapsulatedCiphertext, pbReadCiphertext, cbCiphertext );
1058
1059
    // If not successful, perform side-channel-safe copy of Implicit Rejection secret over Decapsulated secret
1060
0
    cbCopy = (((SIZE_T)successfulReencrypt)-1) & SYMCRYPT_MLKEM_SIZEOF_AGREED_SECRET;
1061
0
    SymCryptScsCopy( pbImplicitRejectionSecret, cbCopy, pbDecapsulatedSecret, SYMCRYPT_MLKEM_SIZEOF_AGREED_SECRET );
1062
1063
    // Write agreed secret (with implicit rejection) to pbAgreedSecret
1064
0
    memcpy( pbAgreedSecret, pbDecapsulatedSecret, SYMCRYPT_MLKEM_SIZEOF_AGREED_SECRET );
1065
1066
0
cleanup:
1067
0
    if( pbAlloc != NULL )
1068
0
    {
1069
0
        SymCryptWipe( pbAlloc, cbAlloc );
1070
0
        SymCryptCallbackFree( pbAlloc );
1071
0
    }
1072
1073
0
    SymCryptWipeKnownSize( pbDecryptedRandom, sizeof(pbDecryptedRandom) );
1074
0
    SymCryptWipeKnownSize( pbDecapsulatedSecret, sizeof(pbDecapsulatedSecret) );
1075
0
    SymCryptWipeKnownSize( pbImplicitRejectionSecret, sizeof(pbImplicitRejectionSecret) );
1076
1077
0
    return scError;
1078
0
}