Coverage Report

Created: 2025-06-22 06:56

/src/openssl/crypto/slh_dsa/slh_fors.c
Line
Count
Source (jump to first uncovered line)
1
/*
2
 * Copyright 2024-2025 The OpenSSL Project Authors. All Rights Reserved.
3
 *
4
 * Licensed under the Apache License 2.0 (the "License").  You may not use
5
 * this file except in compliance with the License.  You can obtain a copy
6
 * in the file LICENSE in the source distribution or at
7
 * https://www.openssl.org/source/license.html
8
 */
9
10
#include <string.h>
11
#include <openssl/crypto.h>
12
#include "slh_dsa_local.h"
13
#include "slh_dsa_key.h"
14
15
/* k = 14, 17, 22, 33, 35 (number of trees) */
16
#define SLH_MAX_K           35
17
/* a = 6, 8, 9, 12 or 14  - There are (2^a) merkle trees */
18
#define SLH_MAX_A           9
19
20
#define SLH_MAX_K_TIMES_A      (SLH_MAX_A * SLH_MAX_K)
21
#define SLH_MAX_ROOTS          (SLH_MAX_K_TIMES_A * SLH_MAX_N)
22
23
static void slh_base_2b(const uint8_t *in, uint32_t b, uint32_t *out, size_t out_len);
24
25
/**
26
 * @brief Generate FORS secret values
27
 * See FIPS 205 Section 8.1 Algorithm 14.
28
 *
29
 * @param ctx Contains SLH_DSA algorithm functions and constants.
30
 * @param sk_seed A private key seed of size |n|
31
 * @param pk_seed A public key seed of size |n|
32
 * @param adrs An ADRS object containing the layer address of zero, with the
33
 *             tree address and key pair address set to the index of the WOTS+
34
 *             key within the XMSS tree that signs the FORS key.
35
 * @param id The index of the FORS secret value within the sets of FORS trees.
36
 *               (which must be < 2^(hm - height)
37
 * @param pk_out The generated FORS secret value of size |n|
38
 * @param pk_out_len The maximum size of |pk_out|
39
 * @returns 1 on success, or 0 on error.
40
 */
41
static int slh_fors_sk_gen(SLH_DSA_HASH_CTX *ctx, const uint8_t *sk_seed,
42
                           const uint8_t *pk_seed, uint8_t *adrs, uint32_t id,
43
                           uint8_t *pk_out, size_t pk_out_len)
44
0
{
45
0
    const SLH_DSA_KEY *key = ctx->key;
46
0
    SLH_ADRS_DECLARE(sk_adrs);
47
0
    SLH_ADRS_FUNC_DECLARE(key, adrsf);
48
49
0
    adrsf->copy(sk_adrs, adrs);
50
0
    adrsf->set_type_and_clear(sk_adrs, SLH_ADRS_TYPE_FORS_PRF);
51
0
    adrsf->copy_keypair_address(sk_adrs, adrs);
52
0
    adrsf->set_tree_index(sk_adrs, id);
53
0
    return key->hash_func->PRF(ctx, pk_seed, sk_seed, sk_adrs, pk_out, pk_out_len);
54
0
}
55
56
/**
57
 * @brief Computes the nodes of a Merkle tree.
58
 * See FIPS 205 Section 8.2 Algorithm 18
59
 *
60
 * The leaf nodes are hashes of FORS secret values.
61
 * Each parent node is a hash of its 2 children.
62
 * Note this is a recursive function.
63
 *
64
 * @param ctx Contains SLH_DSA algorithm functions and constants.
65
 * @param sk_seed A SLH_DSA private key seed of size |n|
66
 * @param pk_seed A SLH_DSA public key seed of size |n|
67
 * @param adrs The ADRS object must have a layer address of zero, and the
68
 *             tree address set to the XMSS tree that signs the FORS key,
69
 *             the type set to FORS_TREE, and the keypair address set to the
70
 *             index of the WOTS+ key that signs the FORS key.
71
 * @param node_id The target node index
72
 * @param height The target node height
73
 * @param node The returned hash for a node of size|n|
74
 * @param node_len The maximum size of |node|
75
 * @returns 1 on success, or 0 on error.
76
 */
77
static int slh_fors_node(SLH_DSA_HASH_CTX *ctx, const uint8_t *sk_seed,
78
                         const uint8_t *pk_seed, uint8_t *adrs, uint32_t node_id,
79
                         uint32_t height, uint8_t *node, size_t node_len)
80
0
{
81
0
    int ret = 0;
82
0
    const SLH_DSA_KEY *key = ctx->key;
83
0
    uint8_t sk[SLH_MAX_N], lnode[SLH_MAX_N], rnode[SLH_MAX_N];
84
0
    uint32_t n = key->params->n;
85
86
0
    SLH_ADRS_FUNC_DECLARE(key, adrsf);
87
88
0
    if (height == 0) {
89
        /* Gets here for leaf nodes */
90
0
        if (!slh_fors_sk_gen(ctx, sk_seed, pk_seed, adrs, node_id, sk, sizeof(sk)))
91
0
            return 0;
92
0
        adrsf->set_tree_height(adrs, 0);
93
0
        adrsf->set_tree_index(adrs, node_id);
94
0
        ret = key->hash_func->F(ctx, pk_seed, adrs, sk, n, node, node_len);
95
0
        OPENSSL_cleanse(sk, n);
96
0
        return ret;
97
0
    } else {
98
0
        if (!slh_fors_node(ctx, sk_seed, pk_seed, adrs, 2 * node_id, height - 1,
99
0
                           lnode, sizeof(rnode))
100
0
                || !slh_fors_node(ctx, sk_seed, pk_seed, adrs, 2 * node_id + 1,
101
0
                                  height - 1, rnode, sizeof(rnode)))
102
0
            return 0;
103
0
        adrsf->set_tree_height(adrs, height);
104
0
        adrsf->set_tree_index(adrs, node_id);
105
0
        if (!key->hash_func->H(ctx, pk_seed, adrs, lnode, rnode, node, node_len))
106
0
            return 0;
107
0
    }
108
0
    return 1;
109
0
}
110
111
/**
112
 * @brief Generate an FORS signature
113
 * See FIPS 205 Section 8.3 Algorithm 16
114
 *
115
 * A FORS signature has a size of (k * (1 + a) * n) bytes
116
 * There are k trees, each of which have a private key value of size |n| followed
117
 * by an authentication path of size |a| (where each path is size |n|)
118
 *
119
 * @param ctx Contains SLH_DSA algorithm functions and constants.
120
 * @param md A message digest of size |(k * a + 7) / 8| bytes to sign
121
 * @param sk_seed A private key seed of size |n|
122
 * @param pk_seed A public key seed of size |n|
123
 * @param adrs The ADRS object must have a layer address of zero, and the
124
 *             tree address set to the XMSS tree that signs the FORS key,
125
 *             the type set to FORS_TREE, and the keypair address set to the
126
 *             index of the WOTS+ key that signs the FORS key.
127
 * @param sig_wpkt A WPACKET object to write the generated XMSS signature to
128
 * @param sig_len  The size of |sig| which is (2 * n + 3) * n + tree_height * n.
129
 * @returns 1 on success, or 0 on error.
130
 */
131
int ossl_slh_fors_sign(SLH_DSA_HASH_CTX *ctx, const uint8_t *md,
132
                       const uint8_t *sk_seed, const uint8_t *pk_seed,
133
                       uint8_t *adrs, WPACKET *sig_wpkt)
134
0
{
135
0
    const SLH_DSA_KEY *key = ctx->key;
136
0
    uint32_t tree_id, layer, s, tree_offset;
137
0
    uint32_t ids[SLH_MAX_K];
138
0
    const SLH_DSA_PARAMS *params = key->params;
139
0
    uint32_t n = params->n;
140
0
    uint32_t k = params->k; /* number of trees */
141
0
    uint32_t a = params->a;
142
0
    uint32_t two_power_a = (1 << a); /* this is t in FIPS 205 */
143
0
    uint32_t tree_id_times_two_power_a = 0;
144
0
    uint8_t out[SLH_MAX_N];
145
146
    /*
147
     * Split md into k a-bit values e.g with k = 14, a = 12
148
     * ids[0..13] = 12 bits each of md
149
     */
150
0
    slh_base_2b(md, a, ids, k);
151
152
0
    for (tree_id = 0; tree_id < k; ++tree_id) {
153
        /* Get the tree[i] leaf id */
154
0
        uint32_t node_id = ids[tree_id]; /* |id| = |a| bits */
155
156
        /*
157
         * Give each of the k trees a unique range at each level.
158
         * e.g. If we have 4096 leaf nodes (2^a = 2^12) for each tree
159
         * tree i will use indexes from 4096 * i + (0..4095) for its bottom level.
160
         * For the next level up from the bottom there would be 2048 nodes
161
         * (so tree i uses indexes 2048 * i + (0...2047) for this level)
162
         */
163
0
        tree_offset = tree_id_times_two_power_a;
164
165
0
        if (!slh_fors_sk_gen(ctx, sk_seed, pk_seed, adrs,
166
0
                             node_id + tree_id_times_two_power_a, out, sizeof(out))
167
0
                || !WPACKET_memcpy(sig_wpkt, out, n))
168
0
            return 0;
169
170
        /*
171
         * Traverse from the bottom of the tree (layer = 0)
172
         * up to the root (layer = a - 1).
173
         * NOTE: This is a really inefficient way of doing this, since at
174
         * layer a - 1 it calculates most of the hashes of the entire tree as
175
         * well as all the leaf nodes. So it is calculating nodes multiple times.
176
         */
177
0
        for (layer = 0; layer < a; ++layer) {
178
0
            s = node_id ^ 1; /* XOR gets the index of the other child in a binary tree */
179
0
            if (!slh_fors_node(ctx, sk_seed, pk_seed, adrs,
180
0
                               s + tree_offset, layer, out, sizeof(out)))
181
0
                return 0;
182
0
            node_id >>= 1; /* Get the parent node id */
183
0
            tree_offset >>= 1; /* Each layer up has half as many nodes */
184
0
            if (!WPACKET_memcpy(sig_wpkt, out, n))
185
0
                return 0;
186
0
        }
187
0
        tree_id_times_two_power_a += two_power_a;
188
0
    }
189
0
    return 1;
190
0
}
191
192
/**
193
 * @brief Compute a candidate FORS public key from a message and signature.
194
 * See FIPS 205 Section 8.4 Algorithm 17.
195
 *
196
 * A FORS signature has a size of (k * (a + 1) * n) bytes
197
 *
198
 * @param ctx Contains SLH_DSA algorithm functions and constants.
199
 * @param fors_sig_rpkt A PACKET object to read a FORS signature from
200
 * @param md A message digest of size (k * a / 8) bytes
201
 * @param pk_seed A public key seed of size |n|
202
 * @param adrs The ADRS object must have a layer address of zero, and the
203
 *             tree address set to the XMSS tree that signs the FORS key,
204
 *             the type set to FORS_TREE, and the keypair address set to the
205
 *             index of the WOTS+ key that signs the FORS key.
206
 * @param pk_out The returned candidate FORS public key of size |n|
207
 * @param pk_out_len The maximum size of |pk_out|
208
 * @returns 1 on success, or 0 on error.
209
 */
210
int ossl_slh_fors_pk_from_sig(SLH_DSA_HASH_CTX *ctx, PACKET *fors_sig_rpkt,
211
                              const uint8_t *md, const uint8_t *pk_seed,
212
                              uint8_t *adrs, uint8_t *pk_out, size_t pk_out_len)
213
0
{
214
0
    const SLH_DSA_KEY *key = ctx->key;
215
0
    int ret = 0;
216
0
    uint32_t i, j, aoff = 0;
217
0
    uint32_t ids[SLH_MAX_K];
218
0
    const SLH_DSA_PARAMS *params = key->params;
219
0
    uint32_t a = params->a;
220
0
    uint32_t k = params->k;
221
0
    uint32_t n = params->n;
222
0
    uint32_t two_power_a = (1 << a);
223
0
    const uint8_t *sk, *authj; /* Pointers to |sig| buffer inside fors_sig_rpkt */
224
0
    uint8_t roots[SLH_MAX_ROOTS];
225
0
    size_t roots_len = 0; /* The size of |roots| */
226
0
    uint8_t *node0, *node1; /* Pointers into roots[] */
227
0
    WPACKET root_pkt, *wroot_pkt = &root_pkt; /* Points to |roots| buffer */
228
229
0
    SLH_ADRS_DECLARE(pk_adrs);
230
0
    SLH_ADRS_FUNC_DECLARE(key, adrsf);
231
0
    SLH_ADRS_FN_DECLARE(adrsf, set_tree_index);
232
0
    SLH_ADRS_FN_DECLARE(adrsf, set_tree_height);
233
0
    SLH_HASH_FUNC_DECLARE(key, hashf);
234
0
    SLH_HASH_FN_DECLARE(hashf, F);
235
0
    SLH_HASH_FN_DECLARE(hashf, H);
236
237
0
    if (!WPACKET_init_static_len(wroot_pkt, roots, sizeof(roots), 0))
238
0
        return 0;
239
240
    /* Split md into k a-bit values e.g ids[0..k-1] = 12 bits each of md */
241
0
    slh_base_2b(md, a, ids, k);
242
243
    /* Compute the roots of k Merkle trees */
244
0
    for (i = 0; i < k; ++i) {
245
0
        uint32_t id = ids[i];
246
0
        uint32_t node_id = id + aoff;
247
248
0
        set_tree_height(adrs, 0);
249
0
        set_tree_index(adrs, node_id);
250
251
        /* Regenerate the public key of the leaf */
252
0
        if (!PACKET_get_bytes(fors_sig_rpkt, &sk, n)
253
0
                || !WPACKET_allocate_bytes(wroot_pkt, n, &node0)
254
0
                || !F(ctx, pk_seed, adrs, sk, n, node0, n))
255
0
            goto err;
256
257
        /* This omits the copying of the nodes that the FIPS 205 code does */
258
0
        node1 = node0;
259
0
        for (j = 0; j < a; ++j) {
260
            /* Get this layers other child public key */
261
0
            if (!PACKET_get_bytes(fors_sig_rpkt, &authj, n))
262
0
                goto err;
263
            /* Hash the children together to get the parent nodes public key */
264
0
            set_tree_height(adrs, j + 1);
265
0
            if ((id & 1) == 0) {
266
0
                node_id >>= 1;
267
0
                set_tree_index(adrs, node_id);
268
0
                if (!H(ctx, pk_seed, adrs, node0, authj, node1, n))
269
0
                    goto err;
270
0
            } else {
271
0
                node_id = (node_id - 1) >> 1;
272
0
                set_tree_index(adrs, node_id);
273
0
                if (!H(ctx, pk_seed, adrs, authj, node0, node1, n))
274
0
                    goto err;
275
0
            }
276
0
            id >>= 1;
277
0
        }
278
0
        aoff += two_power_a;
279
0
    }
280
0
    if (!WPACKET_get_total_written(wroot_pkt, &roots_len))
281
0
        goto err;
282
283
    /* The public key is the hash of all the roots of the k trees */
284
0
    adrsf->copy(pk_adrs, adrs);
285
0
    adrsf->set_type_and_clear(pk_adrs, SLH_ADRS_TYPE_FORS_ROOTS);
286
0
    adrsf->copy_keypair_address(pk_adrs, adrs);
287
0
    ret = hashf->T(ctx, pk_seed, pk_adrs, roots, roots_len, pk_out, pk_out_len);
288
0
 err:
289
0
    if (!WPACKET_finish(wroot_pkt))
290
0
        ret = 0;
291
0
    return ret;
292
0
}
293
294
/**
295
 * @brief Convert a byte string into a base 2^b representation
296
 * See FIPS 205 Algorithm 4
297
 *
298
 * @param in An input byte stream with a size >= |outlen * b / 8|
299
 * @param b The bit size to divide |in| into
300
 *          This is one of 6, 8, 9, 12 or 14 for FORS.
301
 * @param out The array of returned base 2^b integers that represents the first
302
 *            |outlen|*|b| bits of |in|
303
 * @param out_len The size of |out|
304
 */
305
static void slh_base_2b(const uint8_t *in, uint32_t b,
306
                        uint32_t *out, size_t out_len)
307
0
{
308
0
    size_t consumed = 0;
309
0
    uint32_t bits = 0;
310
0
    uint32_t total = 0;
311
0
    uint32_t mask = (1 << b) - 1;
312
313
0
    for (consumed = 0; consumed < out_len; consumed++) {
314
0
        while (bits < b) {
315
0
            total <<= 8;
316
0
            total += *in++;
317
0
            bits += 8;
318
0
        }
319
0
        bits -= b;
320
0
        *out++ = (total >> bits) & mask;
321
0
    }
322
0
}