Coverage Report

Created: 2025-07-01 06:04

/src/hpn-ssh/xmss_fast.c
Line
Count
Source (jump to first uncovered line)
1
/* $OpenBSD: xmss_fast.c,v 1.3 2018/03/22 07:06:11 markus Exp $ */
2
/*
3
xmss_fast.c version 20160722
4
Andreas Hülsing
5
Joost Rijneveld
6
Public domain.
7
*/
8
9
#include "includes.h"
10
#ifdef WITH_XMSS
11
12
#include <stdlib.h>
13
#include <string.h>
14
#ifdef HAVE_STDINT_H
15
# include <stdint.h>
16
#endif
17
18
#include "xmss_fast.h"
19
#include "crypto_api.h"
20
#include "xmss_wots.h"
21
#include "xmss_hash.h"
22
23
#include "xmss_commons.h"
24
#include "xmss_hash_address.h"
25
// For testing
26
#include "stdio.h"
27
28
29
30
/**
31
 * Used for pseudorandom keygeneration,
32
 * generates the seed for the WOTS keypair at address addr
33
 *
34
 * takes n byte sk_seed and returns n byte seed using 32 byte address addr.
35
 */
36
static void get_seed(unsigned char *seed, const unsigned char *sk_seed, int n, uint32_t addr[8])
37
0
{
38
0
  unsigned char bytes[32];
39
  // Make sure that chain addr, hash addr, and key bit are 0!
40
0
  setChainADRS(addr,0);
41
0
  setHashADRS(addr,0);
42
0
  setKeyAndMask(addr,0);
43
  // Generate pseudorandom value
44
0
  addr_to_byte(bytes, addr);
45
0
  prf(seed, bytes, sk_seed, n);
46
0
}
47
48
/**
49
 * Initialize xmss params struct
50
 * parameter names are the same as in the draft
51
 * parameter k is K as used in the BDS algorithm
52
 */
53
int xmss_set_params(xmss_params *params, int n, int h, int w, int k)
54
0
{
55
0
  if (k >= h || k < 2 || (h - k) % 2) {
56
0
    fprintf(stderr, "For BDS traversal, H - K must be even, with H > K >= 2!\n");
57
0
    return 1;
58
0
  }
59
0
  params->h = h;
60
0
  params->n = n;
61
0
  params->k = k;
62
0
  wots_params wots_par;
63
0
  wots_set_params(&wots_par, n, w);
64
0
  params->wots_par = wots_par;
65
0
  return 0;
66
0
}
67
68
/**
69
 * Initialize BDS state struct
70
 * parameter names are the same as used in the description of the BDS traversal
71
 */
72
void xmss_set_bds_state(bds_state *state, unsigned char *stack, int stackoffset, unsigned char *stacklevels, unsigned char *auth, unsigned char *keep, treehash_inst *treehash, unsigned char *retain, int next_leaf)
73
0
{
74
0
  state->stack = stack;
75
0
  state->stackoffset = stackoffset;
76
0
  state->stacklevels = stacklevels;
77
0
  state->auth = auth;
78
0
  state->keep = keep;
79
0
  state->treehash = treehash;
80
0
  state->retain = retain;
81
0
  state->next_leaf = next_leaf;
82
0
}
83
84
/**
85
 * Initialize xmssmt_params struct
86
 * parameter names are the same as in the draft
87
 *
88
 * Especially h is the total tree height, i.e. the XMSS trees have height h/d
89
 */
90
int xmssmt_set_params(xmssmt_params *params, int n, int h, int d, int w, int k)
91
0
{
92
0
  if (h % d) {
93
0
    fprintf(stderr, "d must divide h without remainder!\n");
94
0
    return 1;
95
0
  }
96
0
  params->h = h;
97
0
  params->d = d;
98
0
  params->n = n;
99
0
  params->index_len = (h + 7) / 8;
100
0
  xmss_params xmss_par;
101
0
  if (xmss_set_params(&xmss_par, n, (h/d), w, k)) {
102
0
    return 1;
103
0
  }
104
0
  params->xmss_par = xmss_par;
105
0
  return 0;
106
0
}
107
108
/**
109
 * Computes a leaf from a WOTS public key using an L-tree.
110
 */
111
static void l_tree(unsigned char *leaf, unsigned char *wots_pk, const xmss_params *params, const unsigned char *pub_seed, uint32_t addr[8])
112
0
{
113
0
  unsigned int l = params->wots_par.len;
114
0
  unsigned int n = params->n;
115
0
  uint32_t i = 0;
116
0
  uint32_t height = 0;
117
0
  uint32_t bound;
118
119
  //ADRS.setTreeHeight(0);
120
0
  setTreeHeight(addr, height);
121
  
122
0
  while (l > 1) {
123
0
     bound = l >> 1; //floor(l / 2);
124
0
     for (i = 0; i < bound; i++) {
125
       //ADRS.setTreeIndex(i);
126
0
       setTreeIndex(addr, i);
127
       //wots_pk[i] = RAND_HASH(pk[2i], pk[2i + 1], SEED, ADRS);
128
0
       hash_h(wots_pk+i*n, wots_pk+i*2*n, pub_seed, addr, n);
129
0
     }
130
     //if ( l % 2 == 1 ) {
131
0
     if (l & 1) {
132
       //pk[floor(l / 2) + 1] = pk[l];
133
0
       memcpy(wots_pk+(l>>1)*n, wots_pk+(l-1)*n, n);
134
       //l = ceil(l / 2);
135
0
       l=(l>>1)+1;
136
0
     }
137
0
     else {
138
       //l = ceil(l / 2);
139
0
       l=(l>>1);
140
0
     }
141
     //ADRS.setTreeHeight(ADRS.getTreeHeight() + 1);
142
0
     height++;
143
0
     setTreeHeight(addr, height);
144
0
   }
145
   //return pk[0];
146
0
   memcpy(leaf, wots_pk, n);
147
0
}
148
149
/**
150
 * Computes the leaf at a given address. First generates the WOTS key pair, then computes leaf using l_tree. As this happens position independent, we only require that addr encodes the right ltree-address.
151
 */
152
static void gen_leaf_wots(unsigned char *leaf, const unsigned char *sk_seed, const xmss_params *params, const unsigned char *pub_seed, uint32_t ltree_addr[8], uint32_t ots_addr[8])
153
0
{
154
0
  unsigned char seed[params->n];
155
0
  unsigned char pk[params->wots_par.keysize];
156
157
0
  get_seed(seed, sk_seed, params->n, ots_addr);
158
0
  wots_pkgen(pk, seed, &(params->wots_par), pub_seed, ots_addr);
159
160
0
  l_tree(leaf, pk, params, pub_seed, ltree_addr);
161
0
}
162
163
0
static int treehash_minheight_on_stack(bds_state* state, const xmss_params *params, const treehash_inst *treehash) {
164
0
  unsigned int r = params->h, i;
165
0
  for (i = 0; i < treehash->stackusage; i++) {
166
0
    if (state->stacklevels[state->stackoffset - i - 1] < r) {
167
0
      r = state->stacklevels[state->stackoffset - i - 1];
168
0
    }
169
0
  }
170
0
  return r;
171
0
}
172
173
/**
174
 * Merkle's TreeHash algorithm. The address only needs to initialize the first 78 bits of addr. Everything else will be set by treehash.
175
 * Currently only used for key generation.
176
 *
177
 */
178
static void treehash_setup(unsigned char *node, int height, int index, bds_state *state, const unsigned char *sk_seed, const xmss_params *params, const unsigned char *pub_seed, const uint32_t addr[8])
179
0
{
180
0
  unsigned int idx = index;
181
0
  unsigned int n = params->n;
182
0
  unsigned int h = params->h;
183
0
  unsigned int k = params->k;
184
  // use three different addresses because at this point we use all three formats in parallel
185
0
  uint32_t ots_addr[8];
186
0
  uint32_t ltree_addr[8];
187
0
  uint32_t  node_addr[8];
188
  // only copy layer and tree address parts
189
0
  memcpy(ots_addr, addr, 12);
190
  // type = ots
191
0
  setType(ots_addr, 0);
192
0
  memcpy(ltree_addr, addr, 12);
193
0
  setType(ltree_addr, 1);
194
0
  memcpy(node_addr, addr, 12);
195
0
  setType(node_addr, 2);
196
197
0
  uint32_t lastnode, i;
198
0
  unsigned char stack[(height+1)*n];
199
0
  unsigned int stacklevels[height+1];
200
0
  unsigned int stackoffset=0;
201
0
  unsigned int nodeh;
202
203
0
  lastnode = idx+(1<<height);
204
205
0
  for (i = 0; i < h-k; i++) {
206
0
    state->treehash[i].h = i;
207
0
    state->treehash[i].completed = 1;
208
0
    state->treehash[i].stackusage = 0;
209
0
  }
210
211
0
  i = 0;
212
0
  for (; idx < lastnode; idx++) {
213
0
    setLtreeADRS(ltree_addr, idx);
214
0
    setOTSADRS(ots_addr, idx);
215
0
    gen_leaf_wots(stack+stackoffset*n, sk_seed, params, pub_seed, ltree_addr, ots_addr);
216
0
    stacklevels[stackoffset] = 0;
217
0
    stackoffset++;
218
0
    if (h - k > 0 && i == 3) {
219
0
      memcpy(state->treehash[0].node, stack+stackoffset*n, n);
220
0
    }
221
0
    while (stackoffset>1 && stacklevels[stackoffset-1] == stacklevels[stackoffset-2])
222
0
    {
223
0
      nodeh = stacklevels[stackoffset-1];
224
0
      if (i >> nodeh == 1) {
225
0
        memcpy(state->auth + nodeh*n, stack+(stackoffset-1)*n, n);
226
0
      }
227
0
      else {
228
0
        if (nodeh < h - k && i >> nodeh == 3) {
229
0
          memcpy(state->treehash[nodeh].node, stack+(stackoffset-1)*n, n);
230
0
        }
231
0
        else if (nodeh >= h - k) {
232
0
          memcpy(state->retain + ((1 << (h - 1 - nodeh)) + nodeh - h + (((i >> nodeh) - 3) >> 1)) * n, stack+(stackoffset-1)*n, n);
233
0
        }
234
0
      }
235
0
      setTreeHeight(node_addr, stacklevels[stackoffset-1]);
236
0
      setTreeIndex(node_addr, (idx >> (stacklevels[stackoffset-1]+1)));
237
0
      hash_h(stack+(stackoffset-2)*n, stack+(stackoffset-2)*n, pub_seed,
238
0
          node_addr, n);
239
0
      stacklevels[stackoffset-2]++;
240
0
      stackoffset--;
241
0
    }
242
0
    i++;
243
0
  }
244
245
0
  for (i = 0; i < n; i++)
246
0
    node[i] = stack[i];
247
0
}
248
249
0
static void treehash_update(treehash_inst *treehash, bds_state *state, const unsigned char *sk_seed, const xmss_params *params, const unsigned char *pub_seed, const uint32_t addr[8]) {
250
0
  int n = params->n;
251
252
0
  uint32_t ots_addr[8];
253
0
  uint32_t ltree_addr[8];
254
0
  uint32_t  node_addr[8];
255
  // only copy layer and tree address parts
256
0
  memcpy(ots_addr, addr, 12);
257
  // type = ots
258
0
  setType(ots_addr, 0);
259
0
  memcpy(ltree_addr, addr, 12);
260
0
  setType(ltree_addr, 1);
261
0
  memcpy(node_addr, addr, 12);
262
0
  setType(node_addr, 2);
263
264
0
  setLtreeADRS(ltree_addr, treehash->next_idx);
265
0
  setOTSADRS(ots_addr, treehash->next_idx);
266
267
0
  unsigned char nodebuffer[2 * n];
268
0
  unsigned int nodeheight = 0;
269
0
  gen_leaf_wots(nodebuffer, sk_seed, params, pub_seed, ltree_addr, ots_addr);
270
0
  while (treehash->stackusage > 0 && state->stacklevels[state->stackoffset-1] == nodeheight) {
271
0
    memcpy(nodebuffer + n, nodebuffer, n);
272
0
    memcpy(nodebuffer, state->stack + (state->stackoffset-1)*n, n);
273
0
    setTreeHeight(node_addr, nodeheight);
274
0
    setTreeIndex(node_addr, (treehash->next_idx >> (nodeheight+1)));
275
0
    hash_h(nodebuffer, nodebuffer, pub_seed, node_addr, n);
276
0
    nodeheight++;
277
0
    treehash->stackusage--;
278
0
    state->stackoffset--;
279
0
  }
280
0
  if (nodeheight == treehash->h) { // this also implies stackusage == 0
281
0
    memcpy(treehash->node, nodebuffer, n);
282
0
    treehash->completed = 1;
283
0
  }
284
0
  else {
285
0
    memcpy(state->stack + state->stackoffset*n, nodebuffer, n);
286
0
    treehash->stackusage++;
287
0
    state->stacklevels[state->stackoffset] = nodeheight;
288
0
    state->stackoffset++;
289
0
    treehash->next_idx++;
290
0
  }
291
0
}
292
293
/**
294
 * Computes a root node given a leaf and an authapth
295
 */
296
static void validate_authpath(unsigned char *root, const unsigned char *leaf, unsigned long leafidx, const unsigned char *authpath, const xmss_params *params, const unsigned char *pub_seed, uint32_t addr[8])
297
0
{
298
0
  unsigned int n = params->n;
299
300
0
  uint32_t i, j;
301
0
  unsigned char buffer[2*n];
302
303
  // If leafidx is odd (last bit = 1), current path element is a right child and authpath has to go to the left.
304
  // Otherwise, it is the other way around
305
0
  if (leafidx & 1) {
306
0
    for (j = 0; j < n; j++)
307
0
      buffer[n+j] = leaf[j];
308
0
    for (j = 0; j < n; j++)
309
0
      buffer[j] = authpath[j];
310
0
  }
311
0
  else {
312
0
    for (j = 0; j < n; j++)
313
0
      buffer[j] = leaf[j];
314
0
    for (j = 0; j < n; j++)
315
0
      buffer[n+j] = authpath[j];
316
0
  }
317
0
  authpath += n;
318
319
0
  for (i=0; i < params->h-1; i++) {
320
0
    setTreeHeight(addr, i);
321
0
    leafidx >>= 1;
322
0
    setTreeIndex(addr, leafidx);
323
0
    if (leafidx&1) {
324
0
      hash_h(buffer+n, buffer, pub_seed, addr, n);
325
0
      for (j = 0; j < n; j++)
326
0
        buffer[j] = authpath[j];
327
0
    }
328
0
    else {
329
0
      hash_h(buffer, buffer, pub_seed, addr, n);
330
0
      for (j = 0; j < n; j++)
331
0
        buffer[j+n] = authpath[j];
332
0
    }
333
0
    authpath += n;
334
0
  }
335
0
  setTreeHeight(addr, (params->h-1));
336
0
  leafidx >>= 1;
337
0
  setTreeIndex(addr, leafidx);
338
0
  hash_h(root, buffer, pub_seed, addr, n);
339
0
}
340
341
/**
342
 * Performs one treehash update on the instance that needs it the most.
343
 * Returns 1 if such an instance was not found
344
 **/
345
0
static char bds_treehash_update(bds_state *state, unsigned int updates, const unsigned char *sk_seed, const xmss_params *params, unsigned char *pub_seed, const uint32_t addr[8]) {
346
0
  uint32_t i, j;
347
0
  unsigned int level, l_min, low;
348
0
  unsigned int h = params->h;
349
0
  unsigned int k = params->k;
350
0
  unsigned int used = 0;
351
352
0
  for (j = 0; j < updates; j++) {
353
0
    l_min = h;
354
0
    level = h - k;
355
0
    for (i = 0; i < h - k; i++) {
356
0
      if (state->treehash[i].completed) {
357
0
        low = h;
358
0
      }
359
0
      else if (state->treehash[i].stackusage == 0) {
360
0
        low = i;
361
0
      }
362
0
      else {
363
0
        low = treehash_minheight_on_stack(state, params, &(state->treehash[i]));
364
0
      }
365
0
      if (low < l_min) {
366
0
        level = i;
367
0
        l_min = low;
368
0
      }
369
0
    }
370
0
    if (level == h - k) {
371
0
      break;
372
0
    }
373
0
    treehash_update(&(state->treehash[level]), state, sk_seed, params, pub_seed, addr);
374
0
    used++;
375
0
  }
376
0
  return updates - used;
377
0
}
378
379
/**
380
 * Updates the state (typically NEXT_i) by adding a leaf and updating the stack
381
 * Returns 1 if all leaf nodes have already been processed
382
 **/
383
0
static char bds_state_update(bds_state *state, const unsigned char *sk_seed, const xmss_params *params, unsigned char *pub_seed, const uint32_t addr[8]) {
384
0
  uint32_t ltree_addr[8];
385
0
  uint32_t node_addr[8];
386
0
  uint32_t ots_addr[8];
387
388
0
  int n = params->n;
389
0
  int h = params->h;
390
0
  int k = params->k;
391
392
0
  int nodeh;
393
0
  int idx = state->next_leaf;
394
0
  if (idx == 1 << h) {
395
0
    return 1;
396
0
  }
397
398
  // only copy layer and tree address parts
399
0
  memcpy(ots_addr, addr, 12);
400
  // type = ots
401
0
  setType(ots_addr, 0);
402
0
  memcpy(ltree_addr, addr, 12);
403
0
  setType(ltree_addr, 1);
404
0
  memcpy(node_addr, addr, 12);
405
0
  setType(node_addr, 2);
406
  
407
0
  setOTSADRS(ots_addr, idx);
408
0
  setLtreeADRS(ltree_addr, idx);
409
410
0
  gen_leaf_wots(state->stack+state->stackoffset*n, sk_seed, params, pub_seed, ltree_addr, ots_addr);
411
412
0
  state->stacklevels[state->stackoffset] = 0;
413
0
  state->stackoffset++;
414
0
  if (h - k > 0 && idx == 3) {
415
0
    memcpy(state->treehash[0].node, state->stack+state->stackoffset*n, n);
416
0
  }
417
0
  while (state->stackoffset>1 && state->stacklevels[state->stackoffset-1] == state->stacklevels[state->stackoffset-2]) {
418
0
    nodeh = state->stacklevels[state->stackoffset-1];
419
0
    if (idx >> nodeh == 1) {
420
0
      memcpy(state->auth + nodeh*n, state->stack+(state->stackoffset-1)*n, n);
421
0
    }
422
0
    else {
423
0
      if (nodeh < h - k && idx >> nodeh == 3) {
424
0
        memcpy(state->treehash[nodeh].node, state->stack+(state->stackoffset-1)*n, n);
425
0
      }
426
0
      else if (nodeh >= h - k) {
427
0
        memcpy(state->retain + ((1 << (h - 1 - nodeh)) + nodeh - h + (((idx >> nodeh) - 3) >> 1)) * n, state->stack+(state->stackoffset-1)*n, n);
428
0
      }
429
0
    }
430
0
    setTreeHeight(node_addr, state->stacklevels[state->stackoffset-1]);
431
0
    setTreeIndex(node_addr, (idx >> (state->stacklevels[state->stackoffset-1]+1)));
432
0
    hash_h(state->stack+(state->stackoffset-2)*n, state->stack+(state->stackoffset-2)*n, pub_seed, node_addr, n);
433
434
0
    state->stacklevels[state->stackoffset-2]++;
435
0
    state->stackoffset--;
436
0
  }
437
0
  state->next_leaf++;
438
0
  return 0;
439
0
}
440
441
/**
442
 * Returns the auth path for node leaf_idx and computes the auth path for the
443
 * next leaf node, using the algorithm described by Buchmann, Dahmen and Szydlo
444
 * in "Post Quantum Cryptography", Springer 2009.
445
 */
446
static void bds_round(bds_state *state, const unsigned long leaf_idx, const unsigned char *sk_seed, const xmss_params *params, unsigned char *pub_seed, uint32_t addr[8])
447
0
{
448
0
  unsigned int i;
449
0
  unsigned int n = params->n;
450
0
  unsigned int h = params->h;
451
0
  unsigned int k = params->k;
452
453
0
  unsigned int tau = h;
454
0
  unsigned int startidx;
455
0
  unsigned int offset, rowidx;
456
0
  unsigned char buf[2 * n];
457
458
0
  uint32_t ots_addr[8];
459
0
  uint32_t ltree_addr[8];
460
0
  uint32_t  node_addr[8];
461
  // only copy layer and tree address parts
462
0
  memcpy(ots_addr, addr, 12);
463
  // type = ots
464
0
  setType(ots_addr, 0);
465
0
  memcpy(ltree_addr, addr, 12);
466
0
  setType(ltree_addr, 1);
467
0
  memcpy(node_addr, addr, 12);
468
0
  setType(node_addr, 2);
469
470
0
  for (i = 0; i < h; i++) {
471
0
    if (! ((leaf_idx >> i) & 1)) {
472
0
      tau = i;
473
0
      break;
474
0
    }
475
0
  }
476
477
0
  if (tau > 0) {
478
0
    memcpy(buf,     state->auth + (tau-1) * n, n);
479
    // we need to do this before refreshing state->keep to prevent overwriting
480
0
    memcpy(buf + n, state->keep + ((tau-1) >> 1) * n, n);
481
0
  }
482
0
  if (!((leaf_idx >> (tau + 1)) & 1) && (tau < h - 1)) {
483
0
    memcpy(state->keep + (tau >> 1)*n, state->auth + tau*n, n);
484
0
  }
485
0
  if (tau == 0) {
486
0
    setLtreeADRS(ltree_addr, leaf_idx);
487
0
    setOTSADRS(ots_addr, leaf_idx);
488
0
    gen_leaf_wots(state->auth, sk_seed, params, pub_seed, ltree_addr, ots_addr);
489
0
  }
490
0
  else {
491
0
    setTreeHeight(node_addr, (tau-1));
492
0
    setTreeIndex(node_addr, leaf_idx >> tau);
493
0
    hash_h(state->auth + tau * n, buf, pub_seed, node_addr, n);
494
0
    for (i = 0; i < tau; i++) {
495
0
      if (i < h - k) {
496
0
        memcpy(state->auth + i * n, state->treehash[i].node, n);
497
0
      }
498
0
      else {
499
0
        offset = (1 << (h - 1 - i)) + i - h;
500
0
        rowidx = ((leaf_idx >> i) - 1) >> 1;
501
0
        memcpy(state->auth + i * n, state->retain + (offset + rowidx) * n, n);
502
0
      }
503
0
    }
504
505
0
    for (i = 0; i < ((tau < h - k) ? tau : (h - k)); i++) {
506
0
      startidx = leaf_idx + 1 + 3 * (1 << i);
507
0
      if (startidx < 1U << h) {
508
0
        state->treehash[i].h = i;
509
0
        state->treehash[i].next_idx = startidx;
510
0
        state->treehash[i].completed = 0;
511
0
        state->treehash[i].stackusage = 0;
512
0
      }
513
0
    }
514
0
  }
515
0
}
516
517
/*
518
 * Generates a XMSS key pair for a given parameter set.
519
 * Format sk: [(32bit) idx || SK_SEED || SK_PRF || PUB_SEED || root]
520
 * Format pk: [root || PUB_SEED] omitting algo oid.
521
 */
522
int xmss_keypair(unsigned char *pk, unsigned char *sk, bds_state *state, xmss_params *params)
523
0
{
524
0
  unsigned int n = params->n;
525
  // Set idx = 0
526
0
  sk[0] = 0;
527
0
  sk[1] = 0;
528
0
  sk[2] = 0;
529
0
  sk[3] = 0;
530
  // Init SK_SEED (n byte), SK_PRF (n byte), and PUB_SEED (n byte)
531
0
  randombytes(sk+4, 3*n);
532
  // Copy PUB_SEED to public key
533
0
  memcpy(pk+n, sk+4+2*n, n);
534
535
0
  uint32_t addr[8] = {0, 0, 0, 0, 0, 0, 0, 0};
536
537
  // Compute root
538
0
  treehash_setup(pk, params->h, 0, state, sk+4, params, sk+4+2*n, addr);
539
  // copy root to sk
540
0
  memcpy(sk+4+3*n, pk, n);
541
0
  return 0;
542
0
}
543
544
/**
545
 * Signs a message.
546
 * Returns
547
 * 1. an array containing the signature followed by the message AND
548
 * 2. an updated secret key!
549
 *
550
 */
551
int xmss_sign(unsigned char *sk, bds_state *state, unsigned char *sig_msg, unsigned long long *sig_msg_len, const unsigned char *msg, unsigned long long msglen, const xmss_params *params)
552
0
{
553
0
  unsigned int h = params->h;
554
0
  unsigned int n = params->n;
555
0
  unsigned int k = params->k;
556
0
  uint16_t i = 0;
557
558
  // Extract SK
559
0
  unsigned long idx = ((unsigned long)sk[0] << 24) | ((unsigned long)sk[1] << 16) | ((unsigned long)sk[2] << 8) | sk[3];
560
0
  unsigned char sk_seed[n];
561
0
  memcpy(sk_seed, sk+4, n);
562
0
  unsigned char sk_prf[n];
563
0
  memcpy(sk_prf, sk+4+n, n);
564
0
  unsigned char pub_seed[n];
565
0
  memcpy(pub_seed, sk+4+2*n, n);
566
  
567
  // index as 32 bytes string
568
0
  unsigned char idx_bytes_32[32];
569
0
  to_byte(idx_bytes_32, idx, 32);
570
  
571
0
  unsigned char hash_key[3*n]; 
572
  
573
  // Update SK
574
0
  sk[0] = ((idx + 1) >> 24) & 255;
575
0
  sk[1] = ((idx + 1) >> 16) & 255;
576
0
  sk[2] = ((idx + 1) >> 8) & 255;
577
0
  sk[3] = (idx + 1) & 255;
578
  // -- Secret key for this non-forward-secure version is now updated.
579
  // -- A productive implementation should use a file handle instead and write the updated secret key at this point!
580
581
  // Init working params
582
0
  unsigned char R[n];
583
0
  unsigned char msg_h[n];
584
0
  unsigned char ots_seed[n];
585
0
  uint32_t ots_addr[8] = {0, 0, 0, 0, 0, 0, 0, 0};
586
587
  // ---------------------------------
588
  // Message Hashing
589
  // ---------------------------------
590
591
  // Message Hash:
592
  // First compute pseudorandom value
593
0
  prf(R, idx_bytes_32, sk_prf, n);
594
  // Generate hash key (R || root || idx)
595
0
  memcpy(hash_key, R, n);
596
0
  memcpy(hash_key+n, sk+4+3*n, n);
597
0
  to_byte(hash_key+2*n, idx, n);
598
  // Then use it for message digest
599
0
  h_msg(msg_h, msg, msglen, hash_key, 3*n, n);
600
601
  // Start collecting signature
602
0
  *sig_msg_len = 0;
603
604
  // Copy index to signature
605
0
  sig_msg[0] = (idx >> 24) & 255;
606
0
  sig_msg[1] = (idx >> 16) & 255;
607
0
  sig_msg[2] = (idx >> 8) & 255;
608
0
  sig_msg[3] = idx & 255;
609
610
0
  sig_msg += 4;
611
0
  *sig_msg_len += 4;
612
613
  // Copy R to signature
614
0
  for (i = 0; i < n; i++)
615
0
    sig_msg[i] = R[i];
616
617
0
  sig_msg += n;
618
0
  *sig_msg_len += n;
619
620
  // ----------------------------------
621
  // Now we start to "really sign"
622
  // ----------------------------------
623
624
  // Prepare Address
625
0
  setType(ots_addr, 0);
626
0
  setOTSADRS(ots_addr, idx);
627
628
  // Compute seed for OTS key pair
629
0
  get_seed(ots_seed, sk_seed, n, ots_addr);
630
631
  // Compute WOTS signature
632
0
  wots_sign(sig_msg, msg_h, ots_seed, &(params->wots_par), pub_seed, ots_addr);
633
634
0
  sig_msg += params->wots_par.keysize;
635
0
  *sig_msg_len += params->wots_par.keysize;
636
637
  // the auth path was already computed during the previous round
638
0
  memcpy(sig_msg, state->auth, h*n);
639
640
0
  if (idx < (1U << h) - 1) {
641
0
    bds_round(state, idx, sk_seed, params, pub_seed, ots_addr);
642
0
    bds_treehash_update(state, (h - k) >> 1, sk_seed, params, pub_seed, ots_addr);
643
0
  }
644
645
/* TODO: save key/bds state here! */
646
647
0
  sig_msg += params->h*n;
648
0
  *sig_msg_len += params->h*n;
649
650
  //Whipe secret elements?
651
  //zerobytes(tsk, CRYPTO_SECRETKEYBYTES);
652
653
654
0
  memcpy(sig_msg, msg, msglen);
655
0
  *sig_msg_len += msglen;
656
657
0
  return 0;
658
0
}
659
660
/**
661
 * Verifies a given message signature pair under a given public key.
662
 */
663
int xmss_sign_open(unsigned char *msg, unsigned long long *msglen, const unsigned char *sig_msg, unsigned long long sig_msg_len, const unsigned char *pk, const xmss_params *params)
664
0
{
665
0
  unsigned int n = params->n;
666
667
0
  unsigned long long i, m_len;
668
0
  unsigned long idx=0;
669
0
  unsigned char wots_pk[params->wots_par.keysize];
670
0
  unsigned char pkhash[n];
671
0
  unsigned char root[n];
672
0
  unsigned char msg_h[n];
673
0
  unsigned char hash_key[3*n];
674
675
0
  unsigned char pub_seed[n];
676
0
  memcpy(pub_seed, pk+n, n);
677
678
  // Init addresses
679
0
  uint32_t ots_addr[8] = {0, 0, 0, 0, 0, 0, 0, 0};
680
0
  uint32_t ltree_addr[8] = {0, 0, 0, 0, 0, 0, 0, 0};
681
0
  uint32_t node_addr[8] = {0, 0, 0, 0, 0, 0, 0, 0};
682
683
0
  setType(ots_addr, 0);
684
0
  setType(ltree_addr, 1);
685
0
  setType(node_addr, 2);
686
687
  // Extract index
688
0
  idx = ((unsigned long)sig_msg[0] << 24) | ((unsigned long)sig_msg[1] << 16) | ((unsigned long)sig_msg[2] << 8) | sig_msg[3];
689
0
  printf("verify:: idx = %lu\n", idx);
690
  
691
  // Generate hash key (R || root || idx)
692
0
  memcpy(hash_key, sig_msg+4,n);
693
0
  memcpy(hash_key+n, pk, n);
694
0
  to_byte(hash_key+2*n, idx, n);
695
  
696
0
  sig_msg += (n+4);
697
0
  sig_msg_len -= (n+4);
698
699
  // hash message 
700
0
  unsigned long long tmp_sig_len = params->wots_par.keysize+params->h*n;
701
0
  m_len = sig_msg_len - tmp_sig_len;
702
0
  h_msg(msg_h, sig_msg + tmp_sig_len, m_len, hash_key, 3*n, n);
703
704
  //-----------------------
705
  // Verify signature
706
  //-----------------------
707
708
  // Prepare Address
709
0
  setOTSADRS(ots_addr, idx);
710
  // Check WOTS signature
711
0
  wots_pkFromSig(wots_pk, sig_msg, msg_h, &(params->wots_par), pub_seed, ots_addr);
712
713
0
  sig_msg += params->wots_par.keysize;
714
0
  sig_msg_len -= params->wots_par.keysize;
715
716
  // Compute Ltree
717
0
  setLtreeADRS(ltree_addr, idx);
718
0
  l_tree(pkhash, wots_pk, params, pub_seed, ltree_addr);
719
720
  // Compute root
721
0
  validate_authpath(root, pkhash, idx, sig_msg, params, pub_seed, node_addr);
722
723
0
  sig_msg += params->h*n;
724
0
  sig_msg_len -= params->h*n;
725
726
0
  for (i = 0; i < n; i++)
727
0
    if (root[i] != pk[i])
728
0
      goto fail;
729
730
0
  *msglen = sig_msg_len;
731
0
  for (i = 0; i < *msglen; i++)
732
0
    msg[i] = sig_msg[i];
733
734
0
  return 0;
735
736
737
0
fail:
738
0
  *msglen = sig_msg_len;
739
0
  for (i = 0; i < *msglen; i++)
740
0
    msg[i] = 0;
741
0
  *msglen = -1;
742
0
  return -1;
743
0
}
744
745
/*
746
 * Generates a XMSSMT key pair for a given parameter set.
747
 * Format sk: [(ceil(h/8) bit) idx || SK_SEED || SK_PRF || PUB_SEED || root]
748
 * Format pk: [root || PUB_SEED] omitting algo oid.
749
 */
750
int xmssmt_keypair(unsigned char *pk, unsigned char *sk, bds_state *states, unsigned char *wots_sigs, xmssmt_params *params)
751
0
{
752
0
  unsigned int n = params->n;
753
0
  unsigned int i;
754
0
  unsigned char ots_seed[params->n];
755
  // Set idx = 0
756
0
  for (i = 0; i < params->index_len; i++) {
757
0
    sk[i] = 0;
758
0
  }
759
  // Init SK_SEED (n byte), SK_PRF (n byte), and PUB_SEED (n byte)
760
0
  randombytes(sk+params->index_len, 3*n);
761
  // Copy PUB_SEED to public key
762
0
  memcpy(pk+n, sk+params->index_len+2*n, n);
763
764
  // Set address to point on the single tree on layer d-1
765
0
  uint32_t addr[8] = {0, 0, 0, 0, 0, 0, 0, 0};
766
0
  setLayerADRS(addr, (params->d-1));
767
  // Set up state and compute wots signatures for all but topmost tree root
768
0
  for (i = 0; i < params->d - 1; i++) {
769
    // Compute seed for OTS key pair
770
0
    treehash_setup(pk, params->xmss_par.h, 0, states + i, sk+params->index_len, &(params->xmss_par), pk+n, addr);
771
0
    setLayerADRS(addr, (i+1));
772
0
    get_seed(ots_seed, sk+params->index_len, n, addr);
773
0
    wots_sign(wots_sigs + i*params->xmss_par.wots_par.keysize, pk, ots_seed, &(params->xmss_par.wots_par), pk+n, addr);
774
0
  }
775
0
  treehash_setup(pk, params->xmss_par.h, 0, states + i, sk+params->index_len, &(params->xmss_par), pk+n, addr);
776
0
  memcpy(sk+params->index_len+3*n, pk, n);
777
0
  return 0;
778
0
}
779
780
/**
781
 * Signs a message.
782
 * Returns
783
 * 1. an array containing the signature followed by the message AND
784
 * 2. an updated secret key!
785
 *
786
 */
787
int xmssmt_sign(unsigned char *sk, bds_state *states, unsigned char *wots_sigs, unsigned char *sig_msg, unsigned long long *sig_msg_len, const unsigned char *msg, unsigned long long msglen, const xmssmt_params *params)
788
0
{
789
0
  unsigned int n = params->n;
790
  
791
0
  unsigned int tree_h = params->xmss_par.h;
792
0
  unsigned int h = params->h;
793
0
  unsigned int k = params->xmss_par.k;
794
0
  unsigned int idx_len = params->index_len;
795
0
  uint64_t idx_tree;
796
0
  uint32_t idx_leaf;
797
0
  uint64_t i, j;
798
0
  int needswap_upto = -1;
799
0
  unsigned int updates;
800
801
0
  unsigned char sk_seed[n];
802
0
  unsigned char sk_prf[n];
803
0
  unsigned char pub_seed[n];
804
  // Init working params
805
0
  unsigned char R[n];
806
0
  unsigned char msg_h[n];
807
0
  unsigned char hash_key[3*n];
808
0
  unsigned char ots_seed[n];
809
0
  uint32_t addr[8] = {0, 0, 0, 0, 0, 0, 0, 0};
810
0
  uint32_t ots_addr[8] = {0, 0, 0, 0, 0, 0, 0, 0};
811
0
  unsigned char idx_bytes_32[32];
812
0
  bds_state tmp;
813
814
  // Extract SK 
815
0
  unsigned long long idx = 0;
816
0
  for (i = 0; i < idx_len; i++) {
817
0
    idx |= ((unsigned long long)sk[i]) << 8*(idx_len - 1 - i);
818
0
  }
819
820
0
  memcpy(sk_seed, sk+idx_len, n);
821
0
  memcpy(sk_prf, sk+idx_len+n, n);
822
0
  memcpy(pub_seed, sk+idx_len+2*n, n);
823
824
  // Update SK
825
0
  for (i = 0; i < idx_len; i++) {
826
0
    sk[i] = ((idx + 1) >> 8*(idx_len - 1 - i)) & 255;
827
0
  }
828
  // -- Secret key for this non-forward-secure version is now updated.
829
  // -- A productive implementation should use a file handle instead and write the updated secret key at this point!
830
831
832
  // ---------------------------------
833
  // Message Hashing
834
  // ---------------------------------
835
836
  // Message Hash:
837
  // First compute pseudorandom value
838
0
  to_byte(idx_bytes_32, idx, 32);
839
0
  prf(R, idx_bytes_32, sk_prf, n);
840
  // Generate hash key (R || root || idx)
841
0
  memcpy(hash_key, R, n);
842
0
  memcpy(hash_key+n, sk+idx_len+3*n, n);
843
0
  to_byte(hash_key+2*n, idx, n);
844
  
845
  // Then use it for message digest
846
0
  h_msg(msg_h, msg, msglen, hash_key, 3*n, n);
847
848
  // Start collecting signature
849
0
  *sig_msg_len = 0;
850
851
  // Copy index to signature
852
0
  for (i = 0; i < idx_len; i++) {
853
0
    sig_msg[i] = (idx >> 8*(idx_len - 1 - i)) & 255;
854
0
  }
855
856
0
  sig_msg += idx_len;
857
0
  *sig_msg_len += idx_len;
858
859
  // Copy R to signature
860
0
  for (i = 0; i < n; i++)
861
0
    sig_msg[i] = R[i];
862
863
0
  sig_msg += n;
864
0
  *sig_msg_len += n;
865
866
  // ----------------------------------
867
  // Now we start to "really sign"
868
  // ----------------------------------
869
870
  // Handle lowest layer separately as it is slightly different...
871
872
  // Prepare Address
873
0
  setType(ots_addr, 0);
874
0
  idx_tree = idx >> tree_h;
875
0
  idx_leaf = (idx & ((1 << tree_h)-1));
876
0
  setLayerADRS(ots_addr, 0);
877
0
  setTreeADRS(ots_addr, idx_tree);
878
0
  setOTSADRS(ots_addr, idx_leaf);
879
880
  // Compute seed for OTS key pair
881
0
  get_seed(ots_seed, sk_seed, n, ots_addr);
882
883
  // Compute WOTS signature
884
0
  wots_sign(sig_msg, msg_h, ots_seed, &(params->xmss_par.wots_par), pub_seed, ots_addr);
885
886
0
  sig_msg += params->xmss_par.wots_par.keysize;
887
0
  *sig_msg_len += params->xmss_par.wots_par.keysize;
888
889
0
  memcpy(sig_msg, states[0].auth, tree_h*n);
890
0
  sig_msg += tree_h*n;
891
0
  *sig_msg_len += tree_h*n;
892
893
  // prepare signature of remaining layers
894
0
  for (i = 1; i < params->d; i++) {
895
    // put WOTS signature in place
896
0
    memcpy(sig_msg, wots_sigs + (i-1)*params->xmss_par.wots_par.keysize, params->xmss_par.wots_par.keysize);
897
898
0
    sig_msg += params->xmss_par.wots_par.keysize;
899
0
    *sig_msg_len += params->xmss_par.wots_par.keysize;
900
901
    // put AUTH nodes in place
902
0
    memcpy(sig_msg, states[i].auth, tree_h*n);
903
0
    sig_msg += tree_h*n;
904
0
    *sig_msg_len += tree_h*n;
905
0
  }
906
907
0
  updates = (tree_h - k) >> 1;
908
909
0
  setTreeADRS(addr, (idx_tree + 1));
910
  // mandatory update for NEXT_0 (does not count towards h-k/2) if NEXT_0 exists
911
0
  if ((1 + idx_tree) * (1 << tree_h) + idx_leaf < (1ULL << h)) {
912
0
    bds_state_update(&states[params->d], sk_seed, &(params->xmss_par), pub_seed, addr);
913
0
  }
914
915
0
  for (i = 0; i < params->d; i++) {
916
    // check if we're not at the end of a tree
917
0
    if (! (((idx + 1) & ((1ULL << ((i+1)*tree_h)) - 1)) == 0)) {
918
0
      idx_leaf = (idx >> (tree_h * i)) & ((1 << tree_h)-1);
919
0
      idx_tree = (idx >> (tree_h * (i+1)));
920
0
      setLayerADRS(addr, i);
921
0
      setTreeADRS(addr, idx_tree);
922
0
      if (i == (unsigned int) (needswap_upto + 1)) {
923
0
        bds_round(&states[i], idx_leaf, sk_seed, &(params->xmss_par), pub_seed, addr);
924
0
      }
925
0
      updates = bds_treehash_update(&states[i], updates, sk_seed, &(params->xmss_par), pub_seed, addr);
926
0
      setTreeADRS(addr, (idx_tree + 1));
927
      // if a NEXT-tree exists for this level;
928
0
      if ((1 + idx_tree) * (1 << tree_h) + idx_leaf < (1ULL << (h - tree_h * i))) {
929
0
        if (i > 0 && updates > 0 && states[params->d + i].next_leaf < (1ULL << h)) {
930
0
          bds_state_update(&states[params->d + i], sk_seed, &(params->xmss_par), pub_seed, addr);
931
0
          updates--;
932
0
        }
933
0
      }
934
0
    }
935
0
    else if (idx < (1ULL << h) - 1) {
936
0
      memcpy(&tmp, states+params->d + i, sizeof(bds_state));
937
0
      memcpy(states+params->d + i, states + i, sizeof(bds_state));
938
0
      memcpy(states + i, &tmp, sizeof(bds_state));
939
940
0
      setLayerADRS(ots_addr, (i+1));
941
0
      setTreeADRS(ots_addr, ((idx + 1) >> ((i+2) * tree_h)));
942
0
      setOTSADRS(ots_addr, (((idx >> ((i+1) * tree_h)) + 1) & ((1 << tree_h)-1)));
943
944
0
      get_seed(ots_seed, sk+params->index_len, n, ots_addr);
945
0
      wots_sign(wots_sigs + i*params->xmss_par.wots_par.keysize, states[i].stack, ots_seed, &(params->xmss_par.wots_par), pub_seed, ots_addr);
946
947
0
      states[params->d + i].stackoffset = 0;
948
0
      states[params->d + i].next_leaf = 0;
949
950
0
      updates--; // WOTS-signing counts as one update
951
0
      needswap_upto = i;
952
0
      for (j = 0; j < tree_h-k; j++) {
953
0
        states[i].treehash[j].completed = 1;
954
0
      }
955
0
    }
956
0
  }
957
958
  //Whipe secret elements?
959
  //zerobytes(tsk, CRYPTO_SECRETKEYBYTES);
960
961
0
  memcpy(sig_msg, msg, msglen);
962
0
  *sig_msg_len += msglen;
963
964
0
  return 0;
965
0
}
966
967
/**
968
 * Verifies a given message signature pair under a given public key.
969
 */
970
int xmssmt_sign_open(unsigned char *msg, unsigned long long *msglen, const unsigned char *sig_msg, unsigned long long sig_msg_len, const unsigned char *pk, const xmssmt_params *params)
971
0
{
972
0
  unsigned int n = params->n;
973
974
0
  unsigned int tree_h = params->xmss_par.h;
975
0
  unsigned int idx_len = params->index_len;
976
0
  uint64_t idx_tree;
977
0
  uint32_t idx_leaf;
978
979
0
  unsigned long long i, m_len;
980
0
  unsigned long long idx=0;
981
0
  unsigned char wots_pk[params->xmss_par.wots_par.keysize];
982
0
  unsigned char pkhash[n];
983
0
  unsigned char root[n];
984
0
  unsigned char msg_h[n];
985
0
  unsigned char hash_key[3*n];
986
987
0
  unsigned char pub_seed[n];
988
0
  memcpy(pub_seed, pk+n, n);
989
990
  // Init addresses
991
0
  uint32_t ots_addr[8] = {0, 0, 0, 0, 0, 0, 0, 0};
992
0
  uint32_t ltree_addr[8] = {0, 0, 0, 0, 0, 0, 0, 0};
993
0
  uint32_t node_addr[8] = {0, 0, 0, 0, 0, 0, 0, 0};
994
995
  // Extract index
996
0
  for (i = 0; i < idx_len; i++) {
997
0
    idx |= ((unsigned long long)sig_msg[i]) << (8*(idx_len - 1 - i));
998
0
  }
999
0
  printf("verify:: idx = %llu\n", idx);
1000
0
  sig_msg += idx_len;
1001
0
  sig_msg_len -= idx_len;
1002
  
1003
  // Generate hash key (R || root || idx)
1004
0
  memcpy(hash_key, sig_msg,n);
1005
0
  memcpy(hash_key+n, pk, n);
1006
0
  to_byte(hash_key+2*n, idx, n);
1007
1008
0
  sig_msg += n;
1009
0
  sig_msg_len -= n;
1010
  
1011
1012
  // hash message (recall, R is now on pole position at sig_msg
1013
0
  unsigned long long tmp_sig_len = (params->d * params->xmss_par.wots_par.keysize) + (params->h * n);
1014
0
  m_len = sig_msg_len - tmp_sig_len;
1015
0
  h_msg(msg_h, sig_msg + tmp_sig_len, m_len, hash_key, 3*n, n);
1016
1017
  
1018
  //-----------------------
1019
  // Verify signature
1020
  //-----------------------
1021
1022
  // Prepare Address
1023
0
  idx_tree = idx >> tree_h;
1024
0
  idx_leaf = (idx & ((1 << tree_h)-1));
1025
0
  setLayerADRS(ots_addr, 0);
1026
0
  setTreeADRS(ots_addr, idx_tree);
1027
0
  setType(ots_addr, 0);
1028
1029
0
  memcpy(ltree_addr, ots_addr, 12);
1030
0
  setType(ltree_addr, 1);
1031
1032
0
  memcpy(node_addr, ltree_addr, 12);
1033
0
  setType(node_addr, 2);
1034
  
1035
0
  setOTSADRS(ots_addr, idx_leaf);
1036
1037
  // Check WOTS signature
1038
0
  wots_pkFromSig(wots_pk, sig_msg, msg_h, &(params->xmss_par.wots_par), pub_seed, ots_addr);
1039
1040
0
  sig_msg += params->xmss_par.wots_par.keysize;
1041
0
  sig_msg_len -= params->xmss_par.wots_par.keysize;
1042
1043
  // Compute Ltree
1044
0
  setLtreeADRS(ltree_addr, idx_leaf);
1045
0
  l_tree(pkhash, wots_pk, &(params->xmss_par), pub_seed, ltree_addr);
1046
1047
  // Compute root
1048
0
  validate_authpath(root, pkhash, idx_leaf, sig_msg, &(params->xmss_par), pub_seed, node_addr);
1049
1050
0
  sig_msg += tree_h*n;
1051
0
  sig_msg_len -= tree_h*n;
1052
1053
0
  for (i = 1; i < params->d; i++) {
1054
    // Prepare Address
1055
0
    idx_leaf = (idx_tree & ((1 << tree_h)-1));
1056
0
    idx_tree = idx_tree >> tree_h;
1057
1058
0
    setLayerADRS(ots_addr, i);
1059
0
    setTreeADRS(ots_addr, idx_tree);
1060
0
    setType(ots_addr, 0);
1061
1062
0
    memcpy(ltree_addr, ots_addr, 12);
1063
0
    setType(ltree_addr, 1);
1064
1065
0
    memcpy(node_addr, ltree_addr, 12);
1066
0
    setType(node_addr, 2);
1067
1068
0
    setOTSADRS(ots_addr, idx_leaf);
1069
1070
    // Check WOTS signature
1071
0
    wots_pkFromSig(wots_pk, sig_msg, root, &(params->xmss_par.wots_par), pub_seed, ots_addr);
1072
1073
0
    sig_msg += params->xmss_par.wots_par.keysize;
1074
0
    sig_msg_len -= params->xmss_par.wots_par.keysize;
1075
1076
    // Compute Ltree
1077
0
    setLtreeADRS(ltree_addr, idx_leaf);
1078
0
    l_tree(pkhash, wots_pk, &(params->xmss_par), pub_seed, ltree_addr);
1079
1080
    // Compute root
1081
0
    validate_authpath(root, pkhash, idx_leaf, sig_msg, &(params->xmss_par), pub_seed, node_addr);
1082
1083
0
    sig_msg += tree_h*n;
1084
0
    sig_msg_len -= tree_h*n;
1085
1086
0
  }
1087
1088
0
  for (i = 0; i < n; i++)
1089
0
    if (root[i] != pk[i])
1090
0
      goto fail;
1091
1092
0
  *msglen = sig_msg_len;
1093
0
  for (i = 0; i < *msglen; i++)
1094
0
    msg[i] = sig_msg[i];
1095
1096
0
  return 0;
1097
1098
1099
0
fail:
1100
0
  *msglen = sig_msg_len;
1101
0
  for (i = 0; i < *msglen; i++)
1102
0
    msg[i] = 0;
1103
0
  *msglen = -1;
1104
0
  return -1;
1105
0
}
1106
#endif /* WITH_XMSS */