Coverage Report

Created: 2025-12-04 06:33

next uncovered line (L), next uncovered region (R), next uncovered branch (B)
/src/openssl36/crypto/ml_dsa/ml_dsa_sign.c
Line
Count
Source
1
/*
2
 * Copyright 2024-2025 The OpenSSL Project Authors. All Rights Reserved.
3
 *
4
 * Licensed under the Apache License 2.0 (the "License").  You may not use
5
 * this file except in compliance with the License.  You can obtain a copy
6
 * in the file LICENSE in the source distribution or at
7
 * https://www.openssl.org/source/license.html
8
 */
9
10
#include <openssl/core_dispatch.h>
11
#include <openssl/core_names.h>
12
#include <openssl/params.h>
13
#include <openssl/rand.h>
14
#include <openssl/err.h>
15
#include <openssl/proverr.h>
16
#include "internal/common.h"
17
#include "ml_dsa_local.h"
18
#include "ml_dsa_key.h"
19
#include "ml_dsa_matrix.h"
20
#include "ml_dsa_sign.h"
21
#include "ml_dsa_hash.h"
22
23
#define ML_DSA_MAX_LAMBDA 256 /* bit strength for ML-DSA-87 */
24
25
/*
26
 * @brief Initialize a Signature object by pointing all of its objects to
27
 * preallocated blocks. The values passed for hint, z and
28
 * c_tilde values are not owned/freed by the |sig| object.
29
 *
30
 * @param sig The ML_DSA_SIG to initialize.
31
 * @param hint A preallocated array of |k| polynomial blocks
32
 * @param k The number of |hint| polynomials
33
 * @param z A preallocated array of |l| polynomial blocks
34
 * @param l The number of |z| polynomials
35
 * @param c_tilde A preallocated buffer
36
 * @param c_tilde_len The size of |c_tilde|
37
 */
38
static void signature_init(ML_DSA_SIG *sig,
39
                           POLY *hint, uint32_t k, POLY *z, uint32_t l,
40
                           uint8_t *c_tilde, size_t c_tilde_len)
41
1.18k
{
42
1.18k
    vector_init(&sig->z, z, l);
43
1.18k
    vector_init(&sig->hint, hint, k);
44
1.18k
    sig->c_tilde = c_tilde;
45
1.18k
    sig->c_tilde_len = c_tilde_len;
46
1.18k
}
47
48
/*
49
 * @brief: Auxiliary functions to compute ML-DSA's MU.
50
 * This combines the steps of creating M' and concatenating it
51
 * to the Public Key Hash to obtain MU.
52
 * See FIPS 204 Algorithm 2 Step 10 (and algorithm 3 Step 5) as
53
 * well as Algorithm 7 Step 6 (and algorithm 8 Step 7)
54
 *
55
 * ML_DSA pure signatures are encoded as M' = 00 || ctx_len || ctx || msg
56
 * Where ctx is the empty string by default and ctx_len <= 255.
57
 * The message is appended to the encoded context.
58
 * Finally a public key hash is prepended, and the whole is hashed
59
 * to derive the mu value.
60
 *
61
 * @param key: A public or private ML-DSA key;
62
 * @param encode: if not set, assumes that M' is provided raw and the
63
 * following parameters are ignored.
64
 * @param ctx An optional context to add to the message encoding.
65
 * @param ctx_len The size of |ctx|. It must be in the range 0..255
66
 * @returns an EVP_MD_CTX if the operation is successful, NULL otherwise.
67
 */
68
69
EVP_MD_CTX *ossl_ml_dsa_mu_init(const ML_DSA_KEY *key, int encode,
70
                                const uint8_t *ctx, size_t ctx_len)
71
844
{
72
844
    EVP_MD_CTX *md_ctx;
73
844
    uint8_t itb[2];
74
75
844
    if (key == NULL)
76
0
        return NULL;
77
78
844
    md_ctx = EVP_MD_CTX_new();
79
844
    if (md_ctx == NULL)
80
0
        return NULL;
81
82
    /* H(.. */
83
844
    if (!EVP_DigestInit_ex2(md_ctx, key->shake256_md, NULL))
84
0
        goto err;
85
    /* ..pk (= key->tr) */
86
844
    if (!EVP_DigestUpdate(md_ctx, key->tr, sizeof(key->tr)))
87
0
        goto err;
88
    /* M' = .. */
89
844
    if (encode) {
90
844
        if (ctx_len > ML_DSA_MAX_CONTEXT_STRING_LEN)
91
0
            goto err;
92
        /* IntegerToBytes(0, 1) .. */
93
844
        itb[0] = 0;
94
        /* || IntegerToBytes(|ctx|, 1) || .. */
95
844
        itb[1] = (uint8_t)ctx_len;
96
844
        if (!EVP_DigestUpdate(md_ctx, itb, 2))
97
0
            goto err;
98
        /* ctx || .. */
99
844
        if (!EVP_DigestUpdate(md_ctx, ctx, ctx_len))
100
0
            goto err;
101
        /* .. msg) will follow in update and final functions */
102
844
    }
103
104
844
    return md_ctx;
105
106
0
err:
107
0
    EVP_MD_CTX_free(md_ctx);
108
0
    return NULL;
109
844
}
110
111
/*
112
 * @brief: updates the internal ML-DSA hash with an additional message chunk.
113
 *
114
 * @param md_ctx: The hashing context
115
 * @param msg: The next message chunk
116
 * @param msg_len: The length of the msg buffer to process
117
 * @returns 1 on success, 0 on error
118
 */
119
int ossl_ml_dsa_mu_update(EVP_MD_CTX *md_ctx, const uint8_t *msg, size_t msg_len)
120
844
{
121
844
    return EVP_DigestUpdate(md_ctx, msg, msg_len);
122
844
}
123
124
/*
125
 * @brief: finalizes the internal ML-DSA hash
126
 *
127
 * @param md_ctx: The hashing context
128
 * @param mu: The output buffer for Mu
129
 * @param mu_len: The size of the output buffer
130
 * @returns 1 on success, 0 on error
131
 */
132
int ossl_ml_dsa_mu_finalize(EVP_MD_CTX *md_ctx, uint8_t *mu, size_t mu_len)
133
844
{
134
844
    if (!ossl_assert(mu_len == ML_DSA_MU_BYTES)) {
135
0
        ERR_raise(ERR_LIB_PROV, PROV_R_BAD_LENGTH);
136
0
        return 0;
137
0
    }
138
844
    return EVP_DigestSqueeze(md_ctx, mu, mu_len);
139
844
}
140
141
/*
142
 * @brief FIPS 204, Algorithm 7, ML-DSA.Sign_internal()
143
 *
144
 * This algorithm is decomposed in 2 steps, a set of functions to compute mu
145
 * and then the actual signing function.
146
 *
147
 * @param priv: The private ML-DSA key
148
 * @param mu: The pre-computed mu hash
149
 * @param mu_len: The length of the mu buffer
150
 * @param rnd: The random buffer
151
 * @param rnd_len: The length of the random buffer
152
 * @param out_sig: The output signature buffer
153
 * @returns 1 on success, 0 on error
154
 */
155
static int ml_dsa_sign_internal(const ML_DSA_KEY *priv,
156
                                const uint8_t *mu, size_t mu_len,
157
                                const uint8_t *rnd, size_t rnd_len,
158
                                uint8_t *out_sig)
159
422
{
160
422
    int ret = 0;
161
422
    const ML_DSA_PARAMS *params = priv->params;
162
422
    EVP_MD_CTX *md_ctx = NULL;
163
422
    uint32_t k = (uint32_t)params->k, l = (uint32_t)params->l;
164
422
    uint32_t gamma1 = params->gamma1, gamma2 = params->gamma2;
165
422
    uint8_t *alloc = NULL, *w1_encoded;
166
422
    size_t alloc_len, w1_encoded_len;
167
422
    size_t num_polys_sig_k = 2 * k;
168
422
    size_t num_polys_k = 5 * k;
169
422
    size_t num_polys_l = 3 * l;
170
422
    size_t num_polys_k_by_l = k * l;
171
422
    POLY *p, *c_ntt;
172
422
    VECTOR s1_ntt, s2_ntt, t0_ntt, w, w1, cs1, cs2, y;
173
422
    MATRIX a_ntt;
174
422
    ML_DSA_SIG sig;
175
422
    uint8_t rho_prime[ML_DSA_RHO_PRIME_BYTES];
176
422
    uint8_t c_tilde[ML_DSA_MAX_LAMBDA / 4];
177
422
    size_t c_tilde_len = params->bit_strength >> 2;
178
422
    size_t kappa;
179
180
422
    if (mu_len != ML_DSA_MU_BYTES) {
181
0
        ERR_raise(ERR_LIB_PROV, PROV_R_BAD_LENGTH);
182
0
        return 0;
183
0
    }
184
185
    /*
186
     * Allocate a single blob for most of the variable size temporary variables.
187
     * Mostly used for VECTOR POLYNOMIALS (every POLY is 1K).
188
     */
189
422
    w1_encoded_len = k * (gamma2 == ML_DSA_GAMMA2_Q_MINUS1_DIV88 ? 192 : 128);
190
422
    alloc_len = w1_encoded_len
191
422
        + sizeof(*p) * (1 + num_polys_k + num_polys_l
192
422
                        + num_polys_k_by_l + num_polys_sig_k);
193
422
    alloc = OPENSSL_malloc(alloc_len);
194
422
    if (alloc == NULL)
195
0
        return 0;
196
422
    md_ctx = EVP_MD_CTX_new();
197
422
    if (md_ctx == NULL)
198
0
        goto err;
199
200
422
    w1_encoded = alloc;
201
    /* Init the temp vectors to point to the allocated polys blob */
202
422
    p = (POLY *)(w1_encoded + w1_encoded_len);
203
422
    c_ntt = p++;
204
422
    matrix_init(&a_ntt, p, k, l);
205
422
    p += num_polys_k_by_l;
206
422
    vector_init(&s2_ntt, p, k);
207
422
    vector_init(&t0_ntt, s2_ntt.poly + k, k);
208
422
    vector_init(&w, t0_ntt.poly + k, k);
209
422
    vector_init(&w1, w.poly + k, k);
210
422
    vector_init(&cs2, w1.poly + k, k);
211
422
    p += num_polys_k;
212
422
    vector_init(&s1_ntt, p, l);
213
422
    vector_init(&y, p + l, l);
214
422
    vector_init(&cs1, p + 2 * l, l);
215
422
    p += num_polys_l;
216
422
    signature_init(&sig, p, k, p + k, l, c_tilde, c_tilde_len);
217
    /* End of the allocated blob setup */
218
219
422
    if (!matrix_expand_A(md_ctx, priv->shake128_md, priv->rho, &a_ntt))
220
0
        goto err;
221
222
422
    if (!shake_xof_3(md_ctx, priv->shake256_md, priv->K, sizeof(priv->K),
223
422
                     rnd, rnd_len, mu, mu_len,
224
422
                     rho_prime, sizeof(rho_prime)))
225
0
        goto err;
226
227
422
    vector_copy(&s1_ntt, &priv->s1);
228
422
    vector_ntt(&s1_ntt);
229
422
    vector_copy(&s2_ntt, &priv->s2);
230
422
    vector_ntt(&s2_ntt);
231
422
    vector_copy(&t0_ntt, &priv->t0);
232
422
    vector_ntt(&t0_ntt);
233
234
    /*
235
     * kappa must not exceed 2^16. But the probability of it
236
     * exceeding even 1000 iterations is vanishingly small.
237
     */
238
2.00k
    for (kappa = 0; ; kappa += l) {
239
2.00k
        VECTOR *y_ntt = &cs1;
240
2.00k
        VECTOR *r0 = &w1;
241
2.00k
        VECTOR *ct0 = &w1;
242
2.00k
        uint32_t z_max, r0_max, ct0_max, h_ones;
243
244
2.00k
        vector_expand_mask(&y, rho_prime, sizeof(rho_prime), (uint32_t)kappa,
245
2.00k
                           gamma1, md_ctx, priv->shake256_md);
246
2.00k
        vector_copy(y_ntt, &y);
247
2.00k
        vector_ntt(y_ntt);
248
249
2.00k
        matrix_mult_vector(&a_ntt, y_ntt, &w);
250
2.00k
        vector_ntt_inverse(&w);
251
252
2.00k
        vector_high_bits(&w, gamma2, &w1);
253
2.00k
        ossl_ml_dsa_w1_encode(&w1, gamma2, w1_encoded, w1_encoded_len);
254
255
2.00k
        if (!shake_xof_2(md_ctx, priv->shake256_md, mu, mu_len,
256
2.00k
                         w1_encoded, w1_encoded_len, c_tilde, c_tilde_len))
257
0
            break;
258
259
2.00k
        if (!poly_sample_in_ball_ntt(c_ntt, c_tilde, (int)c_tilde_len,
260
2.00k
                                     md_ctx, priv->shake256_md, params->tau))
261
0
            break;
262
263
2.00k
        vector_mult_scalar(&s1_ntt, c_ntt, &cs1);
264
2.00k
        vector_ntt_inverse(&cs1);
265
2.00k
        vector_mult_scalar(&s2_ntt, c_ntt, &cs2);
266
2.00k
        vector_ntt_inverse(&cs2);
267
268
2.00k
        vector_add(&y, &cs1, &sig.z);
269
270
        /* r0 = lowbits(w - cs2) */
271
2.00k
        vector_sub(&w, &cs2, r0);
272
2.00k
        vector_low_bits(r0, gamma2, r0);
273
274
        /*
275
         * Leaking that the signature is rejected is fine as the next attempt at a
276
         * signature will be (indistinguishable from) independent of this one.
277
         */
278
2.00k
        z_max = vector_max(&sig.z);
279
2.00k
        r0_max = vector_max_signed(r0);
280
2.00k
        if (value_barrier_32(constant_time_ge(z_max, gamma1 - params->beta)
281
2.00k
                             | constant_time_ge(r0_max, gamma2 - params->beta)))
282
1.58k
            continue;
283
284
427
        vector_mult_scalar(&t0_ntt, c_ntt, ct0);
285
427
        vector_ntt_inverse(ct0);
286
427
        vector_make_hint(ct0, &cs2, &w, gamma2, &sig.hint);
287
288
427
        ct0_max = vector_max(ct0);
289
427
        h_ones = (uint32_t)vector_count_ones(&sig.hint);
290
        /* Same reasoning applies to the leak as above */
291
427
        if (value_barrier_32(constant_time_ge(ct0_max, gamma2)
292
427
                             | constant_time_lt(params->omega, h_ones)))
293
5
            continue;
294
422
        ret = ossl_ml_dsa_sig_encode(&sig, params, out_sig);
295
422
        break;
296
427
    }
297
422
err:
298
422
    EVP_MD_CTX_free(md_ctx);
299
422
    OPENSSL_clear_free(alloc, alloc_len);
300
422
    OPENSSL_cleanse(rho_prime, sizeof(rho_prime));
301
422
    return ret;
302
422
}
303
304
/*
305
 * @brief FIPS 204, Algorithm 8, ML-DSA.Verify_internal().
306
 *
307
 * This algorithm is decomposed in 2 steps, a set of functions to compute mu
308
 * and then the actual verification function.
309
 *
310
 * @param pub: The public ML-DSA key
311
 * @param mu: The pre-computed mu hash
312
 * @param mu_len: The length of the mu buffer
313
 * @param sig_enc: The encoded signature to be verified
314
 * @param sig_enc_len: the encoded csignature length
315
 * @returns 1 on success, 0 on error
316
 */
317
static int ml_dsa_verify_internal(const ML_DSA_KEY *pub,
318
                                  const uint8_t *mu, size_t mu_len,
319
                                  const uint8_t *sig_enc,
320
                                  size_t sig_enc_len)
321
216
{
322
216
    int ret = 0;
323
216
    uint8_t *alloc = NULL, *w1_encoded;
324
216
    POLY *p, *c_ntt;
325
216
    MATRIX a_ntt;
326
216
    VECTOR az_ntt, ct1_ntt, *z_ntt, *w1, *w_approx;
327
216
    ML_DSA_SIG sig;
328
216
    const ML_DSA_PARAMS *params = pub->params;
329
216
    uint32_t k = (uint32_t)pub->params->k;
330
216
    uint32_t l = (uint32_t)pub->params->l;
331
216
    uint32_t gamma2 = params->gamma2;
332
216
    size_t w1_encoded_len;
333
216
    size_t num_polys_sig = k + l;
334
216
    size_t num_polys_k = 2 * k;
335
216
    size_t num_polys_l = 1 * l;
336
216
    size_t num_polys_k_by_l = k * l;
337
216
    uint8_t c_tilde[ML_DSA_MAX_LAMBDA / 4];
338
216
    uint8_t c_tilde_sig[ML_DSA_MAX_LAMBDA / 4];
339
216
    EVP_MD_CTX *md_ctx = NULL;
340
216
    size_t c_tilde_len = params->bit_strength >> 2;
341
216
    uint32_t z_max;
342
343
216
    if (mu_len != ML_DSA_MU_BYTES) {
344
0
        ERR_raise(ERR_LIB_PROV, PROV_R_BAD_LENGTH);
345
0
        return 0;
346
0
    }
347
348
349
    /* Allocate space for all the POLYNOMIALS used by temporary VECTORS */
350
216
    w1_encoded_len = k * (gamma2 == ML_DSA_GAMMA2_Q_MINUS1_DIV88 ? 192 : 128);
351
216
    alloc = OPENSSL_malloc(w1_encoded_len
352
216
                           + sizeof(*p) * (1 + num_polys_k
353
216
                                           + num_polys_l
354
216
                                           + num_polys_k_by_l
355
216
                                           + num_polys_sig));
356
216
    if (alloc == NULL)
357
0
        return 0;
358
216
    md_ctx = EVP_MD_CTX_new();
359
216
    if (md_ctx == NULL)
360
0
        goto err;
361
362
216
    w1_encoded = alloc;
363
    /* Init the temp vectors to point to the allocated polys blob */
364
216
    p = (POLY *)(w1_encoded + w1_encoded_len);
365
216
    c_ntt = p++;
366
216
    matrix_init(&a_ntt, p, k, l);
367
216
    p += num_polys_k_by_l;
368
216
    signature_init(&sig, p, k, p + k, l, c_tilde_sig, c_tilde_len);
369
216
    p += num_polys_sig;
370
216
    vector_init(&az_ntt, p, k);
371
216
    vector_init(&ct1_ntt, p + k, k);
372
373
216
    if (!ossl_ml_dsa_sig_decode(&sig, sig_enc, sig_enc_len, pub->params)
374
216
            || !matrix_expand_A(md_ctx, pub->shake128_md, pub->rho, &a_ntt))
375
0
        goto err;
376
377
    /* Compute verifiers challenge c_ntt = NTT(SampleInBall(c_tilde)) */
378
216
    if (!poly_sample_in_ball_ntt(c_ntt, c_tilde_sig, (int)c_tilde_len,
379
216
                                 md_ctx, pub->shake256_md, params->tau))
380
0
        goto err;
381
382
    /* ct1_ntt = NTT(c) * NTT(t1 * 2^d) */
383
216
    vector_scale_power2_round_ntt(&pub->t1, &ct1_ntt);
384
216
    vector_mult_scalar(&ct1_ntt, c_ntt, &ct1_ntt);
385
386
    /* compute z_max early in order to reuse sig.z */
387
216
    z_max = vector_max(&sig.z);
388
389
    /* w_approx = NTT_inverse(A * NTT(z) - ct1_ntt) */
390
216
    z_ntt = &sig.z;
391
216
    vector_ntt(z_ntt);
392
216
    matrix_mult_vector(&a_ntt, z_ntt, &az_ntt);
393
216
    w_approx = &az_ntt;
394
216
    vector_sub(&az_ntt, &ct1_ntt, w_approx);
395
216
    vector_ntt_inverse(w_approx);
396
397
    /* compute w1_encoded */
398
216
    w1 = w_approx;
399
216
    vector_use_hint(&sig.hint, w_approx, gamma2, w1);
400
216
    ossl_ml_dsa_w1_encode(w1, gamma2, w1_encoded, w1_encoded_len);
401
402
216
    if (!shake_xof_3(md_ctx, pub->shake256_md, mu, mu_len,
403
216
                     w1_encoded, w1_encoded_len, NULL, 0, c_tilde, c_tilde_len))
404
0
        goto err;
405
406
216
    ret = (z_max < (uint32_t)(params->gamma1 - params->beta))
407
216
        && memcmp(c_tilde, sig.c_tilde, c_tilde_len) == 0;
408
216
err:
409
216
    OPENSSL_free(alloc);
410
216
    EVP_MD_CTX_free(md_ctx);
411
216
    return ret;
412
216
}
413
414
/**
415
 * See FIPS 204 Section 5.2 Algorithm 2 ML-DSA.Sign()
416
 *
417
 * @returns 1 on success, or 0 on error.
418
 */
419
int ossl_ml_dsa_sign(const ML_DSA_KEY *priv, int msg_is_mu,
420
                     const uint8_t *msg, size_t msg_len,
421
                     const uint8_t *context, size_t context_len,
422
                     const uint8_t *rand, size_t rand_len, int encode,
423
                     unsigned char *sig, size_t *sig_len, size_t sig_size)
424
844
{
425
844
    EVP_MD_CTX *md_ctx = NULL;
426
844
    uint8_t mu[ML_DSA_MU_BYTES];
427
844
    const uint8_t *mu_ptr = mu;
428
844
    size_t mu_len = sizeof(mu);
429
844
    int ret = 0;
430
431
844
    if (ossl_ml_dsa_key_get_priv(priv) == NULL)
432
0
        return 0;
433
434
844
    if (sig_len != NULL)
435
844
        *sig_len = priv->params->sig_len;
436
437
844
    if (sig == NULL)
438
422
        return (sig_len != NULL) ? 1 : 0;
439
440
422
    if (sig_size < priv->params->sig_len)
441
0
        return 0;
442
443
422
    if (msg_is_mu) {
444
0
        mu_ptr = msg;
445
0
        mu_len = msg_len;
446
422
    } else {
447
422
        md_ctx = ossl_ml_dsa_mu_init(priv, encode, context, context_len);
448
422
        if (md_ctx == NULL)
449
0
            return 0;
450
451
422
        if (!ossl_ml_dsa_mu_update(md_ctx, msg, msg_len))
452
0
            goto err;
453
454
422
        if (!ossl_ml_dsa_mu_finalize(md_ctx, mu, mu_len))
455
0
            goto err;
456
422
    }
457
458
422
    ret = ml_dsa_sign_internal(priv, mu_ptr, mu_len, rand, rand_len, sig);
459
460
422
err:
461
422
    EVP_MD_CTX_free(md_ctx);
462
422
    return ret;
463
422
}
464
465
/**
466
 * See FIPS 203 Section 5.3 Algorithm 3 ML-DSA.Verify()
467
 * @returns 1 on success, or 0 on error.
468
 */
469
int ossl_ml_dsa_verify(const ML_DSA_KEY *pub, int msg_is_mu,
470
                       const uint8_t *msg, size_t msg_len,
471
                       const uint8_t *context, size_t context_len, int encode,
472
                       const uint8_t *sig, size_t sig_len)
473
422
{
474
422
    EVP_MD_CTX *md_ctx = NULL;
475
422
    uint8_t mu[ML_DSA_MU_BYTES];
476
422
    const uint8_t *mu_ptr = mu;
477
422
    size_t mu_len = sizeof(mu);
478
422
    int ret = 0;
479
480
422
    if (ossl_ml_dsa_key_get_pub(pub) == NULL)
481
0
        return 0;
482
483
422
    if (msg_is_mu) {
484
0
        mu_ptr = msg;
485
0
        mu_len = msg_len;
486
422
    } else {
487
422
        md_ctx = ossl_ml_dsa_mu_init(pub, encode, context, context_len);
488
422
        if (md_ctx == NULL)
489
0
            return 0;
490
491
422
        if (!ossl_ml_dsa_mu_update(md_ctx, msg, msg_len))
492
0
            goto err;
493
494
422
        if (!ossl_ml_dsa_mu_finalize(md_ctx, mu, mu_len))
495
0
            goto err;
496
422
    }
497
498
422
    ret = ml_dsa_verify_internal(pub, mu_ptr, mu_len, sig, sig_len);
499
422
err:
500
422
    EVP_MD_CTX_free(md_ctx);
501
422
    return ret;
502
422
}