Coverage Report

Created: 2025-11-16 06:40

next uncovered line (L), next uncovered region (R), next uncovered branch (B)
/src/openssl35/crypto/ml_dsa/ml_dsa_sign.c
Line
Count
Source
1
/*
2
 * Copyright 2024-2025 The OpenSSL Project Authors. All Rights Reserved.
3
 *
4
 * Licensed under the Apache License 2.0 (the "License").  You may not use
5
 * this file except in compliance with the License.  You can obtain a copy
6
 * in the file LICENSE in the source distribution or at
7
 * https://www.openssl.org/source/license.html
8
 */
9
10
#include <openssl/core_dispatch.h>
11
#include <openssl/core_names.h>
12
#include <openssl/params.h>
13
#include <openssl/rand.h>
14
#include "ml_dsa_local.h"
15
#include "ml_dsa_key.h"
16
#include "ml_dsa_matrix.h"
17
#include "ml_dsa_sign.h"
18
#include "ml_dsa_hash.h"
19
20
#define ML_DSA_MAX_LAMBDA 256 /* bit strength for ML-DSA-87 */
21
22
/*
23
 * @brief Initialize a Signature object by pointing all of its objects to
24
 * preallocated blocks. The values passed for hint, z and
25
 * c_tilde values are not owned/freed by the |sig| object.
26
 *
27
 * @param sig The ML_DSA_SIG to initialize.
28
 * @param hint A preallocated array of |k| polynomial blocks
29
 * @param k The number of |hint| polynomials
30
 * @param z A preallocated array of |l| polynomial blocks
31
 * @param l The number of |z| polynomials
32
 * @param c_tilde A preallocated buffer
33
 * @param c_tilde_len The size of |c_tilde|
34
 */
35
static void signature_init(ML_DSA_SIG *sig,
36
                           POLY *hint, uint32_t k, POLY *z, uint32_t l,
37
                           uint8_t *c_tilde, size_t c_tilde_len)
38
766
{
39
766
    vector_init(&sig->z, z, l);
40
766
    vector_init(&sig->hint, hint, k);
41
766
    sig->c_tilde = c_tilde;
42
766
    sig->c_tilde_len = c_tilde_len;
43
766
}
44
45
/*
46
 * FIPS 204, Algorithm 7, ML-DSA.Sign_internal()
47
 * @returns 1 on success and 0 on failure.
48
 */
49
static int ml_dsa_sign_internal(const ML_DSA_KEY *priv, int msg_is_mu,
50
                                const uint8_t *encoded_msg,
51
                                size_t encoded_msg_len,
52
                                const uint8_t *rnd, size_t rnd_len,
53
                                uint8_t *out_sig)
54
212
{
55
212
    int ret = 0;
56
212
    const ML_DSA_PARAMS *params = priv->params;
57
212
    EVP_MD_CTX *md_ctx = NULL;
58
212
    uint32_t k = params->k, l = params->l;
59
212
    uint32_t gamma1 = params->gamma1, gamma2 = params->gamma2;
60
212
    uint8_t *alloc = NULL, *w1_encoded;
61
212
    size_t alloc_len, w1_encoded_len;
62
212
    size_t num_polys_sig_k = 2 * k;
63
212
    size_t num_polys_k = 5 * k;
64
212
    size_t num_polys_l = 3 * l;
65
212
    size_t num_polys_k_by_l = k * l;
66
212
    POLY *polys = NULL, *p, *c_ntt;
67
212
    VECTOR s1_ntt, s2_ntt, t0_ntt, w, w1, cs1, cs2, y;
68
212
    MATRIX a_ntt;
69
212
    ML_DSA_SIG sig;
70
212
    uint8_t mu[ML_DSA_MU_BYTES], *mu_ptr = mu;
71
212
    const size_t mu_len = sizeof(mu);
72
212
    uint8_t rho_prime[ML_DSA_RHO_PRIME_BYTES];
73
212
    uint8_t c_tilde[ML_DSA_MAX_LAMBDA / 4];
74
212
    size_t c_tilde_len = params->bit_strength >> 2;
75
212
    size_t kappa;
76
77
    /*
78
     * Allocate a single blob for most of the variable size temporary variables.
79
     * Mostly used for VECTOR POLYNOMIALS (every POLY is 1K).
80
     */
81
212
    w1_encoded_len = k * (gamma2 == ML_DSA_GAMMA2_Q_MINUS1_DIV88 ? 192 : 128);
82
212
    alloc_len = w1_encoded_len
83
212
        + sizeof(*polys) * (1 + num_polys_k + num_polys_l
84
212
                            + num_polys_k_by_l + num_polys_sig_k);
85
212
    alloc = OPENSSL_malloc(alloc_len);
86
212
    if (alloc == NULL)
87
0
        return 0;
88
212
    md_ctx = EVP_MD_CTX_new();
89
212
    if (md_ctx == NULL)
90
0
        goto err;
91
92
212
    w1_encoded = alloc;
93
    /* Init the temp vectors to point to the allocated polys blob */
94
212
    p = (POLY *)(w1_encoded + w1_encoded_len);
95
212
    c_ntt = p++;
96
212
    matrix_init(&a_ntt, p, k, l);
97
212
    p += num_polys_k_by_l;
98
212
    vector_init(&s2_ntt, p, k);
99
212
    vector_init(&t0_ntt, s2_ntt.poly + k, k);
100
212
    vector_init(&w, t0_ntt.poly + k, k);
101
212
    vector_init(&w1, w.poly + k, k);
102
212
    vector_init(&cs2, w1.poly + k, k);
103
212
    p += num_polys_k;
104
212
    vector_init(&s1_ntt, p, l);
105
212
    vector_init(&y, p + l, l);
106
212
    vector_init(&cs1, p + 2 * l, l);
107
212
    p += num_polys_l;
108
212
    signature_init(&sig, p, k, p + k, l, c_tilde, c_tilde_len);
109
    /* End of the allocated blob setup */
110
111
212
    if (!matrix_expand_A(md_ctx, priv->shake128_md, priv->rho, &a_ntt))
112
0
        goto err;
113
212
    if (msg_is_mu) {
114
0
        if (encoded_msg_len != mu_len)
115
0
            goto err;
116
0
        mu_ptr = (uint8_t *)encoded_msg;
117
212
    } else {
118
212
        if (!shake_xof_2(md_ctx, priv->shake256_md, priv->tr, sizeof(priv->tr),
119
212
                         encoded_msg, encoded_msg_len, mu_ptr, mu_len))
120
0
            goto err;
121
212
    }
122
212
    if (!shake_xof_3(md_ctx, priv->shake256_md, priv->K, sizeof(priv->K),
123
212
                     rnd, rnd_len, mu_ptr, mu_len,
124
212
                     rho_prime, sizeof(rho_prime)))
125
0
        goto err;
126
127
212
    vector_copy(&s1_ntt, &priv->s1);
128
212
    vector_ntt(&s1_ntt);
129
212
    vector_copy(&s2_ntt, &priv->s2);
130
212
    vector_ntt(&s2_ntt);
131
212
    vector_copy(&t0_ntt, &priv->t0);
132
212
    vector_ntt(&t0_ntt);
133
134
    /*
135
     * kappa must not exceed 2^16. But the probability of it
136
     * exceeding even 1000 iterations is vanishingly small.
137
     */
138
982
    for (kappa = 0; ; kappa += l) {
139
982
        VECTOR *y_ntt = &cs1;
140
982
        VECTOR *r0 = &w1;
141
982
        VECTOR *ct0 = &w1;
142
982
        uint32_t z_max, r0_max, ct0_max, h_ones;
143
144
982
        vector_expand_mask(&y, rho_prime, sizeof(rho_prime), kappa,
145
982
                           gamma1, md_ctx, priv->shake256_md);
146
982
        vector_copy(y_ntt, &y);
147
982
        vector_ntt(y_ntt);
148
149
982
        matrix_mult_vector(&a_ntt, y_ntt, &w);
150
982
        vector_ntt_inverse(&w);
151
152
982
        vector_high_bits(&w, gamma2, &w1);
153
982
        ossl_ml_dsa_w1_encode(&w1, gamma2, w1_encoded, w1_encoded_len);
154
155
982
        if (!shake_xof_2(md_ctx, priv->shake256_md, mu_ptr, mu_len,
156
982
                         w1_encoded, w1_encoded_len, c_tilde, c_tilde_len))
157
0
            break;
158
159
982
        if (!poly_sample_in_ball_ntt(c_ntt, c_tilde, c_tilde_len,
160
982
                                     md_ctx, priv->shake256_md, params->tau))
161
0
            break;
162
163
982
        vector_mult_scalar(&s1_ntt, c_ntt, &cs1);
164
982
        vector_ntt_inverse(&cs1);
165
982
        vector_mult_scalar(&s2_ntt, c_ntt, &cs2);
166
982
        vector_ntt_inverse(&cs2);
167
168
982
        vector_add(&y, &cs1, &sig.z);
169
170
        /* r0 = lowbits(w - cs2) */
171
982
        vector_sub(&w, &cs2, r0);
172
982
        vector_low_bits(r0, gamma2, r0);
173
174
        /*
175
         * Leaking that the signature is rejected is fine as the next attempt at a
176
         * signature will be (indistinguishable from) independent of this one.
177
         */
178
982
        z_max = vector_max(&sig.z);
179
982
        r0_max = vector_max_signed(r0);
180
982
        if (value_barrier_32(constant_time_ge(z_max, gamma1 - params->beta)
181
982
                             | constant_time_ge(r0_max, gamma2 - params->beta)))
182
769
            continue;
183
184
213
        vector_mult_scalar(&t0_ntt, c_ntt, ct0);
185
213
        vector_ntt_inverse(ct0);
186
213
        vector_make_hint(ct0, &cs2, &w, gamma2, &sig.hint);
187
188
213
        ct0_max = vector_max(ct0);
189
213
        h_ones = vector_count_ones(&sig.hint);
190
        /* Same reasoning applies to the leak as above */
191
213
        if (value_barrier_32(constant_time_ge(ct0_max, gamma2)
192
213
                             | constant_time_lt(params->omega, h_ones)))
193
1
            continue;
194
212
        ret = ossl_ml_dsa_sig_encode(&sig, params, out_sig);
195
212
        break;
196
213
    }
197
212
err:
198
212
    EVP_MD_CTX_free(md_ctx);
199
212
    OPENSSL_clear_free(alloc, alloc_len);
200
212
    OPENSSL_cleanse(rho_prime, sizeof(rho_prime));
201
212
    return ret;
202
212
}
203
204
/*
205
 * See FIPS 204, Algorithm 8, ML-DSA.Verify_internal().
206
 */
207
static int ml_dsa_verify_internal(const ML_DSA_KEY *pub, int msg_is_mu,
208
                                  const uint8_t *msg_enc, size_t msg_enc_len,
209
                                  const uint8_t *sig_enc, size_t sig_enc_len)
210
212
{
211
212
    int ret = 0;
212
212
    uint8_t *alloc = NULL, *w1_encoded;
213
212
    POLY *polys = NULL, *p, *c_ntt;
214
212
    MATRIX a_ntt;
215
212
    VECTOR az_ntt, ct1_ntt, *z_ntt, *w1, *w_approx;
216
212
    ML_DSA_SIG sig;
217
212
    const ML_DSA_PARAMS *params = pub->params;
218
212
    uint32_t k = pub->params->k;
219
212
    uint32_t l = pub->params->l;
220
212
    uint32_t gamma2 = params->gamma2;
221
212
    size_t w1_encoded_len;
222
212
    size_t num_polys_sig = k + l;
223
212
    size_t num_polys_k = 2 * k;
224
212
    size_t num_polys_l = 1 * l;
225
212
    size_t num_polys_k_by_l = k * l;
226
212
    uint8_t mu[ML_DSA_MU_BYTES], *mu_ptr = mu;
227
212
    const size_t mu_len = sizeof(mu);
228
212
    uint8_t c_tilde[ML_DSA_MAX_LAMBDA / 4];
229
212
    uint8_t c_tilde_sig[ML_DSA_MAX_LAMBDA / 4];
230
212
    EVP_MD_CTX *md_ctx = NULL;
231
212
    size_t c_tilde_len = params->bit_strength >> 2;
232
212
    uint32_t z_max;
233
234
    /* Allocate space for all the POLYNOMIALS used by temporary VECTORS */
235
212
    w1_encoded_len = k * (gamma2 == ML_DSA_GAMMA2_Q_MINUS1_DIV88 ? 192 : 128);
236
212
    alloc = OPENSSL_malloc(w1_encoded_len
237
212
                           + sizeof(*polys) * (1 + num_polys_k
238
212
                                               + num_polys_l
239
212
                                               + num_polys_k_by_l
240
212
                                               + num_polys_sig));
241
212
    if (alloc == NULL)
242
0
        return 0;
243
212
    md_ctx = EVP_MD_CTX_new();
244
212
    if (md_ctx == NULL)
245
0
        goto err;
246
247
212
    w1_encoded = alloc;
248
    /* Init the temp vectors to point to the allocated polys blob */
249
212
    p = (POLY *)(w1_encoded + w1_encoded_len);
250
212
    c_ntt = p++;
251
212
    matrix_init(&a_ntt, p, k, l);
252
212
    p += num_polys_k_by_l;
253
212
    signature_init(&sig, p, k, p + k, l, c_tilde_sig, c_tilde_len);
254
212
    p += num_polys_sig;
255
212
    vector_init(&az_ntt, p, k);
256
212
    vector_init(&ct1_ntt, p + k, k);
257
258
212
    if (!ossl_ml_dsa_sig_decode(&sig, sig_enc, sig_enc_len, pub->params)
259
212
            || !matrix_expand_A(md_ctx, pub->shake128_md, pub->rho, &a_ntt))
260
0
        goto err;
261
212
    if (msg_is_mu) {
262
0
        if (msg_enc_len != mu_len)
263
0
            goto err;
264
0
        mu_ptr = (uint8_t *)msg_enc;
265
212
    } else {
266
212
        if (!shake_xof_2(md_ctx, pub->shake256_md, pub->tr, sizeof(pub->tr),
267
212
                         msg_enc, msg_enc_len, mu_ptr, mu_len))
268
0
            goto err;
269
212
    }
270
    /* Compute verifiers challenge c_ntt = NTT(SampleInBall(c_tilde) */
271
212
    if (!poly_sample_in_ball_ntt(c_ntt, c_tilde_sig, c_tilde_len,
272
212
                                 md_ctx, pub->shake256_md, params->tau))
273
0
        goto err;
274
275
    /* ct1_ntt = NTT(c) * NTT(t1 * 2^d) */
276
212
    vector_scale_power2_round_ntt(&pub->t1, &ct1_ntt);
277
212
    vector_mult_scalar(&ct1_ntt, c_ntt, &ct1_ntt);
278
279
    /* compute z_max early in order to reuse sig.z */
280
212
    z_max = vector_max(&sig.z);
281
282
    /* w_approx = NTT_inverse(A * NTT(z) - ct1_ntt) */
283
212
    z_ntt = &sig.z;
284
212
    vector_ntt(z_ntt);
285
212
    matrix_mult_vector(&a_ntt, z_ntt, &az_ntt);
286
212
    w_approx = &az_ntt;
287
212
    vector_sub(&az_ntt, &ct1_ntt, w_approx);
288
212
    vector_ntt_inverse(w_approx);
289
290
    /* compute w1_encoded */
291
212
    w1 = w_approx;
292
212
    vector_use_hint(&sig.hint, w_approx, gamma2, w1);
293
212
    ossl_ml_dsa_w1_encode(w1, gamma2, w1_encoded, w1_encoded_len);
294
295
212
    if (!shake_xof_3(md_ctx, pub->shake256_md, mu_ptr, mu_len,
296
212
                     w1_encoded, w1_encoded_len, NULL, 0, c_tilde, c_tilde_len))
297
0
        goto err;
298
299
212
    ret = (z_max < (uint32_t)(params->gamma1 - params->beta))
300
212
        && memcmp(c_tilde, sig.c_tilde, c_tilde_len) == 0;
301
212
err:
302
212
    OPENSSL_free(alloc);
303
212
    EVP_MD_CTX_free(md_ctx);
304
212
    return ret;
305
212
}
306
307
/**
308
 * @brief Encode a message
309
 * See FIPS 204 Algorithm 2 Step 10 (and algorithm 3 Step 5).
310
 *
311
 * ML_DSA pure signatures are encoded as M' = 00 || ctx_len || ctx || msg
312
 * Where ctx is the empty string by default and ctx_len <= 255.
313
 *
314
 * Note this code could be shared with SLH_DSA
315
 *
316
 * @param msg A message to encode
317
 * @param msg_len The size of |msg|
318
 * @param ctx An optional context to add to the message encoding.
319
 * @param ctx_len The size of |ctx|. It must be in the range 0..255
320
 * @param encode Use the Pure signature encoding if this is 1, and dont encode
321
 *               if this value is 0.
322
 * @param tmp A small buffer that may be used if the message is small.
323
 * @param tmp_len The size of |tmp|
324
 * @param out_len The size of the returned encoded buffer.
325
 * @returns A buffer containing the encoded message. If the passed in
326
 * |tmp| buffer is big enough to hold the encoded message then it returns |tmp|
327
 * otherwise it allocates memory which must be freed by the caller. If |encode|
328
 * is 0 then it returns |msg|. NULL is returned if there is a failure.
329
 */
330
static uint8_t *msg_encode(const uint8_t *msg, size_t msg_len,
331
                           const uint8_t *ctx, size_t ctx_len, int encode,
332
                           uint8_t *tmp, size_t tmp_len, size_t *out_len)
333
424
{
334
424
    uint8_t *encoded = NULL;
335
424
    size_t encoded_len;
336
337
424
    if (encode == 0) {
338
        /* Raw message */
339
0
        *out_len = msg_len;
340
0
        return (uint8_t *)msg;
341
0
    }
342
424
    if (ctx_len > ML_DSA_MAX_CONTEXT_STRING_LEN)
343
0
        return NULL;
344
345
    /* Pure encoding */
346
424
    encoded_len = 1 + 1 + ctx_len + msg_len;
347
424
    *out_len = encoded_len;
348
424
    if (encoded_len <= tmp_len) {
349
424
        encoded = tmp;
350
424
    } else {
351
0
        encoded = OPENSSL_malloc(encoded_len);
352
0
        if (encoded == NULL)
353
0
            return NULL;
354
0
    }
355
424
    encoded[0] = 0;
356
424
    encoded[1] = (uint8_t)ctx_len;
357
424
    memcpy(&encoded[2], ctx, ctx_len);
358
424
    memcpy(&encoded[2 + ctx_len], msg, msg_len);
359
424
    return encoded;
360
424
}
361
362
/**
363
 * See FIPS 204 Section 5.2 Algorithm 2 ML-DSA.Sign()
364
 *
365
 * @returns 1 on success, or 0 on error.
366
 */
367
int ossl_ml_dsa_sign(const ML_DSA_KEY *priv, int msg_is_mu,
368
                     const uint8_t *msg, size_t msg_len,
369
                     const uint8_t *context, size_t context_len,
370
                     const uint8_t *rand, size_t rand_len, int encode,
371
                     unsigned char *sig, size_t *sig_len, size_t sig_size)
372
424
{
373
424
    int ret = 1;
374
424
    uint8_t m_tmp[1024], *m = m_tmp, *alloced_m = NULL;
375
424
    size_t m_len = 0;
376
377
424
    if (ossl_ml_dsa_key_get_priv(priv) == NULL)
378
0
        return 0;
379
424
    if (sig != NULL) {
380
212
        if (sig_size < priv->params->sig_len)
381
0
            return 0;
382
212
        if (msg_is_mu) {
383
0
            m = (uint8_t *)msg;
384
0
            m_len = msg_len;
385
212
        } else {
386
212
            m = msg_encode(msg, msg_len, context, context_len, encode,
387
212
                           m_tmp, sizeof(m_tmp), &m_len);
388
212
            if (m == NULL)
389
0
                return 0;
390
212
            if (m != msg && m != m_tmp)
391
0
                alloced_m = m;
392
212
        }
393
212
        ret = ml_dsa_sign_internal(priv, msg_is_mu, m, m_len, rand, rand_len, sig);
394
212
        OPENSSL_free(alloced_m);
395
212
    }
396
424
    if (sig_len != NULL)
397
424
        *sig_len = priv->params->sig_len;
398
424
    return ret;
399
424
}
400
401
/**
402
 * See FIPS 203 Section 5.3 Algorithm 3 ML-DSA.Verify()
403
 * @returns 1 on success, or 0 on error.
404
 */
405
int ossl_ml_dsa_verify(const ML_DSA_KEY *pub, int msg_is_mu,
406
                       const uint8_t *msg, size_t msg_len,
407
                       const uint8_t *context, size_t context_len, int encode,
408
                       const uint8_t *sig, size_t sig_len)
409
212
{
410
212
    uint8_t *m, *alloced_m = NULL;
411
212
    size_t m_len;
412
212
    uint8_t m_tmp[1024];
413
212
    int ret = 0;
414
415
212
    if (ossl_ml_dsa_key_get_pub(pub) == NULL)
416
0
        return 0;
417
418
212
    if (msg_is_mu) {
419
0
        m = (uint8_t *)msg;
420
0
        m_len = msg_len;
421
212
    } else {
422
212
        m = msg_encode(msg, msg_len, context, context_len, encode,
423
212
                       m_tmp, sizeof(m_tmp), &m_len);
424
212
        if (m == NULL)
425
0
            return 0;
426
212
        if (m != msg && m != m_tmp)
427
0
            alloced_m = m;
428
212
    }
429
430
212
    ret = ml_dsa_verify_internal(pub, msg_is_mu, m, m_len, sig, sig_len);
431
212
    OPENSSL_free(alloced_m);
432
212
    return ret;
433
212
}