Coverage Report

Created: 2024-06-18 06:23

/src/hpn-ssh/sshkey-xmss.c
Line
Count
Source (jump to first uncovered line)
1
/* $OpenBSD: sshkey-xmss.c,v 1.12 2022/10/28 00:39:29 djm Exp $ */
2
/*
3
 * Copyright (c) 2017 Markus Friedl.  All rights reserved.
4
 *
5
 * Redistribution and use in source and binary forms, with or without
6
 * modification, are permitted provided that the following conditions
7
 * are met:
8
 * 1. Redistributions of source code must retain the above copyright
9
 *    notice, this list of conditions and the following disclaimer.
10
 * 2. Redistributions in binary form must reproduce the above copyright
11
 *    notice, this list of conditions and the following disclaimer in the
12
 *    documentation and/or other materials provided with the distribution.
13
 *
14
 * THIS SOFTWARE IS PROVIDED BY THE AUTHOR ``AS IS'' AND ANY EXPRESS OR
15
 * IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED WARRANTIES
16
 * OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE DISCLAIMED.
17
 * IN NO EVENT SHALL THE AUTHOR BE LIABLE FOR ANY DIRECT, INDIRECT,
18
 * INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT
19
 * NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE,
20
 * DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY
21
 * THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT
22
 * (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF
23
 * THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
24
 */
25
26
#include "includes.h"
27
#ifdef WITH_XMSS
28
29
#include <sys/types.h>
30
#include <sys/uio.h>
31
32
#include <stdio.h>
33
#include <string.h>
34
#include <unistd.h>
35
#include <fcntl.h>
36
#include <errno.h>
37
#ifdef HAVE_SYS_FILE_H
38
# include <sys/file.h>
39
#endif
40
41
#include "ssh2.h"
42
#include "ssherr.h"
43
#include "sshbuf.h"
44
#include "cipher.h"
45
#include "sshkey.h"
46
#include "sshkey-xmss.h"
47
#include "atomicio.h"
48
#include "log.h"
49
50
#include "xmss_fast.h"
51
52
/* opaque internal XMSS state */
53
0
#define XMSS_MAGIC    "xmss-state-v1"
54
0
#define XMSS_CIPHERNAME   "aes256-gcm@openssh.com"
55
struct ssh_xmss_state {
56
  xmss_params params;
57
  u_int32_t n, w, h, k;
58
59
  bds_state bds;
60
  u_char    *stack;
61
  u_int32_t stackoffset;
62
  u_char    *stacklevels;
63
  u_char    *auth;
64
  u_char    *keep;
65
  u_char    *th_nodes;
66
  u_char    *retain;
67
  treehash_inst *treehash;
68
69
  u_int32_t idx;    /* state read from file */
70
  u_int32_t maxidx;   /* restricted # of signatures */
71
  int   have_state; /* .state file exists */
72
  int   lockfd;   /* locked in sshkey_xmss_get_state() */
73
  u_char    allow_update; /* allow sshkey_xmss_update_state() */
74
  char    *enc_ciphername;/* encrypt state with cipher */
75
  u_char    *enc_keyiv; /* encrypt state with key */
76
  u_int32_t enc_keyiv_len;  /* length of enc_keyiv */
77
};
78
79
int  sshkey_xmss_init_bds_state(struct sshkey *);
80
int  sshkey_xmss_init_enc_key(struct sshkey *, const char *);
81
void   sshkey_xmss_free_bds(struct sshkey *);
82
int  sshkey_xmss_get_state_from_file(struct sshkey *, const char *,
83
      int *, int);
84
int  sshkey_xmss_encrypt_state(const struct sshkey *, struct sshbuf *,
85
      struct sshbuf **);
86
int  sshkey_xmss_decrypt_state(const struct sshkey *, struct sshbuf *,
87
      struct sshbuf **);
88
int  sshkey_xmss_serialize_enc_key(const struct sshkey *, struct sshbuf *);
89
int  sshkey_xmss_deserialize_enc_key(struct sshkey *, struct sshbuf *);
90
91
0
#define PRINT(...) do { if (printerror) sshlog(__FILE__, __func__, __LINE__, \
92
0
    0, SYSLOG_LEVEL_ERROR, NULL, __VA_ARGS__); } while (0)
93
94
int
95
sshkey_xmss_init(struct sshkey *key, const char *name)
96
403
{
97
403
  struct ssh_xmss_state *state;
98
99
403
  if (key->xmss_state != NULL)
100
0
    return SSH_ERR_INVALID_FORMAT;
101
403
  if (name == NULL)
102
0
    return SSH_ERR_INVALID_FORMAT;
103
403
  state = calloc(sizeof(struct ssh_xmss_state), 1);
104
403
  if (state == NULL)
105
0
    return SSH_ERR_ALLOC_FAIL;
106
403
  if (strcmp(name, XMSS_SHA2_256_W16_H10_NAME) == 0) {
107
162
    state->n = 32;
108
162
    state->w = 16;
109
162
    state->h = 10;
110
241
  } else if (strcmp(name, XMSS_SHA2_256_W16_H16_NAME) == 0) {
111
45
    state->n = 32;
112
45
    state->w = 16;
113
45
    state->h = 16;
114
196
  } else if (strcmp(name, XMSS_SHA2_256_W16_H20_NAME) == 0) {
115
46
    state->n = 32;
116
46
    state->w = 16;
117
46
    state->h = 20;
118
150
  } else {
119
150
    free(state);
120
150
    return SSH_ERR_KEY_TYPE_UNKNOWN;
121
150
  }
122
253
  if ((key->xmss_name = strdup(name)) == NULL) {
123
0
    free(state);
124
0
    return SSH_ERR_ALLOC_FAIL;
125
0
  }
126
253
  state->k = 2; /* XXX hardcoded */
127
253
  state->lockfd = -1;
128
253
  if (xmss_set_params(&state->params, state->n, state->h, state->w,
129
253
      state->k) != 0) {
130
0
    free(state);
131
0
    return SSH_ERR_INVALID_FORMAT;
132
0
  }
133
253
  key->xmss_state = state;
134
253
  return 0;
135
253
}
136
137
void
138
sshkey_xmss_free_state(struct sshkey *key)
139
409
{
140
409
  struct ssh_xmss_state *state = key->xmss_state;
141
142
409
  sshkey_xmss_free_bds(key);
143
409
  if (state) {
144
253
    if (state->enc_keyiv) {
145
0
      explicit_bzero(state->enc_keyiv, state->enc_keyiv_len);
146
0
      free(state->enc_keyiv);
147
0
    }
148
253
    free(state->enc_ciphername);
149
253
    free(state);
150
253
  }
151
409
  key->xmss_state = NULL;
152
409
}
153
154
0
#define SSH_XMSS_K2_MAGIC "k=2"
155
0
#define num_stack(x)    ((x->h+1)*(x->n))
156
0
#define num_stacklevels(x)  (x->h+1)
157
0
#define num_auth(x)   ((x->h)*(x->n))
158
0
#define num_keep(x)   ((x->h >> 1)*(x->n))
159
0
#define num_th_nodes(x)   ((x->h - x->k)*(x->n))
160
0
#define num_retain(x)   (((1ULL << x->k) - x->k - 1) * (x->n))
161
0
#define num_treehash(x)   ((x->h) - (x->k))
162
163
int
164
sshkey_xmss_init_bds_state(struct sshkey *key)
165
0
{
166
0
  struct ssh_xmss_state *state = key->xmss_state;
167
0
  u_int32_t i;
168
169
0
  state->stackoffset = 0;
170
0
  if ((state->stack = calloc(num_stack(state), 1)) == NULL ||
171
0
      (state->stacklevels = calloc(num_stacklevels(state), 1))== NULL ||
172
0
      (state->auth = calloc(num_auth(state), 1)) == NULL ||
173
0
      (state->keep = calloc(num_keep(state), 1)) == NULL ||
174
0
      (state->th_nodes = calloc(num_th_nodes(state), 1)) == NULL ||
175
0
      (state->retain = calloc(num_retain(state), 1)) == NULL ||
176
0
      (state->treehash = calloc(num_treehash(state),
177
0
      sizeof(treehash_inst))) == NULL) {
178
0
    sshkey_xmss_free_bds(key);
179
0
    return SSH_ERR_ALLOC_FAIL;
180
0
  }
181
0
  for (i = 0; i < state->h - state->k; i++)
182
0
    state->treehash[i].node = &state->th_nodes[state->n*i];
183
0
  xmss_set_bds_state(&state->bds, state->stack, state->stackoffset,
184
0
      state->stacklevels, state->auth, state->keep, state->treehash,
185
0
      state->retain, 0);
186
0
  return 0;
187
0
}
188
189
void
190
sshkey_xmss_free_bds(struct sshkey *key)
191
409
{
192
409
  struct ssh_xmss_state *state = key->xmss_state;
193
194
409
  if (state == NULL)
195
156
    return;
196
253
  free(state->stack);
197
253
  free(state->stacklevels);
198
253
  free(state->auth);
199
253
  free(state->keep);
200
253
  free(state->th_nodes);
201
253
  free(state->retain);
202
253
  free(state->treehash);
203
253
  state->stack = NULL;
204
253
  state->stacklevels = NULL;
205
253
  state->auth = NULL;
206
253
  state->keep = NULL;
207
253
  state->th_nodes = NULL;
208
253
  state->retain = NULL;
209
253
  state->treehash = NULL;
210
253
}
211
212
void *
213
sshkey_xmss_params(const struct sshkey *key)
214
177
{
215
177
  struct ssh_xmss_state *state = key->xmss_state;
216
217
177
  if (state == NULL)
218
0
    return NULL;
219
177
  return &state->params;
220
177
}
221
222
void *
223
sshkey_xmss_bds_state(const struct sshkey *key)
224
0
{
225
0
  struct ssh_xmss_state *state = key->xmss_state;
226
227
0
  if (state == NULL)
228
0
    return NULL;
229
0
  return &state->bds;
230
0
}
231
232
int
233
sshkey_xmss_siglen(const struct sshkey *key, size_t *lenp)
234
177
{
235
177
  struct ssh_xmss_state *state = key->xmss_state;
236
237
177
  if (lenp == NULL)
238
0
    return SSH_ERR_INVALID_ARGUMENT;
239
177
  if (state == NULL)
240
0
    return SSH_ERR_INVALID_FORMAT;
241
177
  *lenp = 4 + state->n +
242
177
      state->params.wots_par.keysize +
243
177
      state->h * state->n;
244
177
  return 0;
245
177
}
246
247
size_t
248
sshkey_xmss_pklen(const struct sshkey *key)
249
654
{
250
654
  struct ssh_xmss_state *state = key->xmss_state;
251
252
654
  if (state == NULL)
253
156
    return 0;
254
498
  return state->n * 2;
255
654
}
256
257
size_t
258
sshkey_xmss_sklen(const struct sshkey *key)
259
409
{
260
409
  struct ssh_xmss_state *state = key->xmss_state;
261
262
409
  if (state == NULL)
263
156
    return 0;
264
253
  return state->n * 4 + 4;
265
409
}
266
267
int
268
sshkey_xmss_init_enc_key(struct sshkey *k, const char *ciphername)
269
0
{
270
0
  struct ssh_xmss_state *state = k->xmss_state;
271
0
  const struct sshcipher *cipher;
272
0
  size_t keylen = 0, ivlen = 0;
273
274
0
  if (state == NULL)
275
0
    return SSH_ERR_INVALID_ARGUMENT;
276
0
  if ((cipher = cipher_by_name(ciphername)) == NULL)
277
0
    return SSH_ERR_INTERNAL_ERROR;
278
0
  if ((state->enc_ciphername = strdup(ciphername)) == NULL)
279
0
    return SSH_ERR_ALLOC_FAIL;
280
0
  keylen = cipher_keylen(cipher);
281
0
  ivlen = cipher_ivlen(cipher);
282
0
  state->enc_keyiv_len = keylen + ivlen;
283
0
  if ((state->enc_keyiv = calloc(state->enc_keyiv_len, 1)) == NULL) {
284
0
    free(state->enc_ciphername);
285
0
    state->enc_ciphername = NULL;
286
0
    return SSH_ERR_ALLOC_FAIL;
287
0
  }
288
0
  arc4random_buf(state->enc_keyiv, state->enc_keyiv_len);
289
0
  return 0;
290
0
}
291
292
int
293
sshkey_xmss_serialize_enc_key(const struct sshkey *k, struct sshbuf *b)
294
0
{
295
0
  struct ssh_xmss_state *state = k->xmss_state;
296
0
  int r;
297
298
0
  if (state == NULL || state->enc_keyiv == NULL ||
299
0
      state->enc_ciphername == NULL)
300
0
    return SSH_ERR_INVALID_ARGUMENT;
301
0
  if ((r = sshbuf_put_cstring(b, state->enc_ciphername)) != 0 ||
302
0
      (r = sshbuf_put_string(b, state->enc_keyiv,
303
0
      state->enc_keyiv_len)) != 0)
304
0
    return r;
305
0
  return 0;
306
0
}
307
308
int
309
sshkey_xmss_deserialize_enc_key(struct sshkey *k, struct sshbuf *b)
310
0
{
311
0
  struct ssh_xmss_state *state = k->xmss_state;
312
0
  size_t len;
313
0
  int r;
314
315
0
  if (state == NULL)
316
0
    return SSH_ERR_INVALID_ARGUMENT;
317
0
  if ((r = sshbuf_get_cstring(b, &state->enc_ciphername, NULL)) != 0 ||
318
0
      (r = sshbuf_get_string(b, &state->enc_keyiv, &len)) != 0)
319
0
    return r;
320
0
  state->enc_keyiv_len = len;
321
0
  return 0;
322
0
}
323
324
int
325
sshkey_xmss_serialize_pk_info(const struct sshkey *k, struct sshbuf *b,
326
    enum sshkey_serialize_rep opts)
327
0
{
328
0
  struct ssh_xmss_state *state = k->xmss_state;
329
0
  u_char have_info = 1;
330
0
  u_int32_t idx;
331
0
  int r;
332
333
0
  if (state == NULL)
334
0
    return SSH_ERR_INVALID_ARGUMENT;
335
0
  if (opts != SSHKEY_SERIALIZE_INFO)
336
0
    return 0;
337
0
  idx = k->xmss_sk ? PEEK_U32(k->xmss_sk) : state->idx;
338
0
  if ((r = sshbuf_put_u8(b, have_info)) != 0 ||
339
0
      (r = sshbuf_put_u32(b, idx)) != 0 ||
340
0
      (r = sshbuf_put_u32(b, state->maxidx)) != 0)
341
0
    return r;
342
0
  return 0;
343
0
}
344
345
int
346
sshkey_xmss_deserialize_pk_info(struct sshkey *k, struct sshbuf *b)
347
212
{
348
212
  struct ssh_xmss_state *state = k->xmss_state;
349
212
  u_char have_info;
350
212
  int r;
351
352
212
  if (state == NULL)
353
0
    return SSH_ERR_INVALID_ARGUMENT;
354
  /* optional */
355
212
  if (sshbuf_len(b) == 0)
356
182
    return 0;
357
30
  if ((r = sshbuf_get_u8(b, &have_info)) != 0)
358
0
    return r;
359
30
  if (have_info != 1)
360
20
    return SSH_ERR_INVALID_ARGUMENT;
361
10
  if ((r = sshbuf_get_u32(b, &state->idx)) != 0 ||
362
10
      (r = sshbuf_get_u32(b, &state->maxidx)) != 0)
363
4
    return r;
364
6
  return 0;
365
10
}
366
367
int
368
sshkey_xmss_generate_private_key(struct sshkey *k, int bits)
369
0
{
370
0
  int r;
371
0
  const char *name;
372
373
0
  if (bits == 10) {
374
0
    name = XMSS_SHA2_256_W16_H10_NAME;
375
0
  } else if (bits == 16) {
376
0
    name = XMSS_SHA2_256_W16_H16_NAME;
377
0
  } else if (bits == 20) {
378
0
    name = XMSS_SHA2_256_W16_H20_NAME;
379
0
  } else {
380
0
    name = XMSS_DEFAULT_NAME;
381
0
  }
382
0
  if ((r = sshkey_xmss_init(k, name)) != 0 ||
383
0
      (r = sshkey_xmss_init_bds_state(k)) != 0 ||
384
0
      (r = sshkey_xmss_init_enc_key(k, XMSS_CIPHERNAME)) != 0)
385
0
    return r;
386
0
  if ((k->xmss_pk = malloc(sshkey_xmss_pklen(k))) == NULL ||
387
0
      (k->xmss_sk = malloc(sshkey_xmss_sklen(k))) == NULL) {
388
0
    return SSH_ERR_ALLOC_FAIL;
389
0
  }
390
0
  xmss_keypair(k->xmss_pk, k->xmss_sk, sshkey_xmss_bds_state(k),
391
0
      sshkey_xmss_params(k));
392
0
  return 0;
393
0
}
394
395
int
396
sshkey_xmss_get_state_from_file(struct sshkey *k, const char *filename,
397
    int *have_file, int printerror)
398
0
{
399
0
  struct sshbuf *b = NULL, *enc = NULL;
400
0
  int ret = SSH_ERR_SYSTEM_ERROR, r, fd = -1;
401
0
  u_int32_t len;
402
0
  unsigned char buf[4], *data = NULL;
403
404
0
  *have_file = 0;
405
0
  if ((fd = open(filename, O_RDONLY)) >= 0) {
406
0
    *have_file = 1;
407
0
    if (atomicio(read, fd, buf, sizeof(buf)) != sizeof(buf)) {
408
0
      PRINT("corrupt state file: %s", filename);
409
0
      goto done;
410
0
    }
411
0
    len = PEEK_U32(buf);
412
0
    if ((data = calloc(len, 1)) == NULL) {
413
0
      ret = SSH_ERR_ALLOC_FAIL;
414
0
      goto done;
415
0
    }
416
0
    if (atomicio(read, fd, data, len) != len) {
417
0
      PRINT("cannot read blob: %s", filename);
418
0
      goto done;
419
0
    }
420
0
    if ((enc = sshbuf_from(data, len)) == NULL) {
421
0
      ret = SSH_ERR_ALLOC_FAIL;
422
0
      goto done;
423
0
    }
424
0
    sshkey_xmss_free_bds(k);
425
0
    if ((r = sshkey_xmss_decrypt_state(k, enc, &b)) != 0) {
426
0
      ret = r;
427
0
      goto done;
428
0
    }
429
0
    if ((r = sshkey_xmss_deserialize_state(k, b)) != 0) {
430
0
      ret = r;
431
0
      goto done;
432
0
    }
433
0
    ret = 0;
434
0
  }
435
0
done:
436
0
  if (fd != -1)
437
0
    close(fd);
438
0
  free(data);
439
0
  sshbuf_free(enc);
440
0
  sshbuf_free(b);
441
0
  return ret;
442
0
}
443
444
int
445
sshkey_xmss_get_state(const struct sshkey *k, int printerror)
446
0
{
447
0
  struct ssh_xmss_state *state = k->xmss_state;
448
0
  u_int32_t idx = 0;
449
0
  char *filename = NULL;
450
0
  char *statefile = NULL, *ostatefile = NULL, *lockfile = NULL;
451
0
  int lockfd = -1, have_state = 0, have_ostate, tries = 0;
452
0
  int ret = SSH_ERR_INVALID_ARGUMENT, r;
453
454
0
  if (state == NULL)
455
0
    goto done;
456
  /*
457
   * If maxidx is set, then we are allowed a limited number
458
   * of signatures, but don't need to access the disk.
459
   * Otherwise we need to deal with the on-disk state.
460
   */
461
0
  if (state->maxidx) {
462
    /* xmss_sk always contains the current state */
463
0
    idx = PEEK_U32(k->xmss_sk);
464
0
    if (idx < state->maxidx) {
465
0
      state->allow_update = 1;
466
0
      return 0;
467
0
    }
468
0
    return SSH_ERR_INVALID_ARGUMENT;
469
0
  }
470
0
  if ((filename = k->xmss_filename) == NULL)
471
0
    goto done;
472
0
  if (asprintf(&lockfile, "%s.lock", filename) == -1 ||
473
0
      asprintf(&statefile, "%s.state", filename) == -1 ||
474
0
      asprintf(&ostatefile, "%s.ostate", filename) == -1) {
475
0
    ret = SSH_ERR_ALLOC_FAIL;
476
0
    goto done;
477
0
  }
478
0
  if ((lockfd = open(lockfile, O_CREAT|O_RDONLY, 0600)) == -1) {
479
0
    ret = SSH_ERR_SYSTEM_ERROR;
480
0
    PRINT("cannot open/create: %s", lockfile);
481
0
    goto done;
482
0
  }
483
0
  while (flock(lockfd, LOCK_EX|LOCK_NB) == -1) {
484
0
    if (errno != EWOULDBLOCK) {
485
0
      ret = SSH_ERR_SYSTEM_ERROR;
486
0
      PRINT("cannot lock: %s", lockfile);
487
0
      goto done;
488
0
    }
489
0
    if (++tries > 10) {
490
0
      ret = SSH_ERR_SYSTEM_ERROR;
491
0
      PRINT("giving up on: %s", lockfile);
492
0
      goto done;
493
0
    }
494
0
    usleep(1000*100*tries);
495
0
  }
496
  /* XXX no longer const */
497
0
  if ((r = sshkey_xmss_get_state_from_file((struct sshkey *)k,
498
0
      statefile, &have_state, printerror)) != 0) {
499
0
    if ((r = sshkey_xmss_get_state_from_file((struct sshkey *)k,
500
0
        ostatefile, &have_ostate, printerror)) == 0) {
501
0
      state->allow_update = 1;
502
0
      r = sshkey_xmss_forward_state(k, 1);
503
0
      state->idx = PEEK_U32(k->xmss_sk);
504
0
      state->allow_update = 0;
505
0
    }
506
0
  }
507
0
  if (!have_state && !have_ostate) {
508
    /* check that bds state is initialized */
509
0
    if (state->bds.auth == NULL)
510
0
      goto done;
511
0
    PRINT("start from scratch idx 0: %u", state->idx);
512
0
  } else if (r != 0) {
513
0
    ret = r;
514
0
    goto done;
515
0
  }
516
0
  if (state->idx + 1 < state->idx) {
517
0
    PRINT("state wrap: %u", state->idx);
518
0
    goto done;
519
0
  }
520
0
  state->have_state = have_state;
521
0
  state->lockfd = lockfd;
522
0
  state->allow_update = 1;
523
0
  lockfd = -1;
524
0
  ret = 0;
525
0
done:
526
0
  if (lockfd != -1)
527
0
    close(lockfd);
528
0
  free(lockfile);
529
0
  free(statefile);
530
0
  free(ostatefile);
531
0
  return ret;
532
0
}
533
534
int
535
sshkey_xmss_forward_state(const struct sshkey *k, u_int32_t reserve)
536
0
{
537
0
  struct ssh_xmss_state *state = k->xmss_state;
538
0
  u_char *sig = NULL;
539
0
  size_t required_siglen;
540
0
  unsigned long long smlen;
541
0
  u_char data;
542
0
  int ret, r;
543
544
0
  if (state == NULL || !state->allow_update)
545
0
    return SSH_ERR_INVALID_ARGUMENT;
546
0
  if (reserve == 0)
547
0
    return SSH_ERR_INVALID_ARGUMENT;
548
0
  if (state->idx + reserve <= state->idx)
549
0
    return SSH_ERR_INVALID_ARGUMENT;
550
0
  if ((r = sshkey_xmss_siglen(k, &required_siglen)) != 0)
551
0
    return r;
552
0
  if ((sig = malloc(required_siglen)) == NULL)
553
0
    return SSH_ERR_ALLOC_FAIL;
554
0
  while (reserve-- > 0) {
555
0
    state->idx = PEEK_U32(k->xmss_sk);
556
0
    smlen = required_siglen;
557
0
    if ((ret = xmss_sign(k->xmss_sk, sshkey_xmss_bds_state(k),
558
0
        sig, &smlen, &data, 0, sshkey_xmss_params(k))) != 0) {
559
0
      r = SSH_ERR_INVALID_ARGUMENT;
560
0
      break;
561
0
    }
562
0
  }
563
0
  free(sig);
564
0
  return r;
565
0
}
566
567
int
568
sshkey_xmss_update_state(const struct sshkey *k, int printerror)
569
0
{
570
0
  struct ssh_xmss_state *state = k->xmss_state;
571
0
  struct sshbuf *b = NULL, *enc = NULL;
572
0
  u_int32_t idx = 0;
573
0
  unsigned char buf[4];
574
0
  char *filename = NULL;
575
0
  char *statefile = NULL, *ostatefile = NULL, *nstatefile = NULL;
576
0
  int fd = -1;
577
0
  int ret = SSH_ERR_INVALID_ARGUMENT;
578
579
0
  if (state == NULL || !state->allow_update)
580
0
    return ret;
581
0
  if (state->maxidx) {
582
    /* no update since the number of signatures is limited */
583
0
    ret = 0;
584
0
    goto done;
585
0
  }
586
0
  idx = PEEK_U32(k->xmss_sk);
587
0
  if (idx == state->idx) {
588
    /* no signature happened, no need to update */
589
0
    ret = 0;
590
0
    goto done;
591
0
  } else if (idx != state->idx + 1) {
592
0
    PRINT("more than one signature happened: idx %u state %u",
593
0
        idx, state->idx);
594
0
    goto done;
595
0
  }
596
0
  state->idx = idx;
597
0
  if ((filename = k->xmss_filename) == NULL)
598
0
    goto done;
599
0
  if (asprintf(&statefile, "%s.state", filename) == -1 ||
600
0
      asprintf(&ostatefile, "%s.ostate", filename) == -1 ||
601
0
      asprintf(&nstatefile, "%s.nstate", filename) == -1) {
602
0
    ret = SSH_ERR_ALLOC_FAIL;
603
0
    goto done;
604
0
  }
605
0
  unlink(nstatefile);
606
0
  if ((b = sshbuf_new()) == NULL) {
607
0
    ret = SSH_ERR_ALLOC_FAIL;
608
0
    goto done;
609
0
  }
610
0
  if ((ret = sshkey_xmss_serialize_state(k, b)) != 0) {
611
0
    PRINT("SERLIALIZE FAILED: %d", ret);
612
0
    goto done;
613
0
  }
614
0
  if ((ret = sshkey_xmss_encrypt_state(k, b, &enc)) != 0) {
615
0
    PRINT("ENCRYPT FAILED: %d", ret);
616
0
    goto done;
617
0
  }
618
0
  if ((fd = open(nstatefile, O_CREAT|O_WRONLY|O_EXCL, 0600)) == -1) {
619
0
    ret = SSH_ERR_SYSTEM_ERROR;
620
0
    PRINT("open new state file: %s", nstatefile);
621
0
    goto done;
622
0
  }
623
0
  POKE_U32(buf, sshbuf_len(enc));
624
0
  if (atomicio(vwrite, fd, buf, sizeof(buf)) != sizeof(buf)) {
625
0
    ret = SSH_ERR_SYSTEM_ERROR;
626
0
    PRINT("write new state file hdr: %s", nstatefile);
627
0
    close(fd);
628
0
    goto done;
629
0
  }
630
0
  if (atomicio(vwrite, fd, sshbuf_mutable_ptr(enc), sshbuf_len(enc)) !=
631
0
      sshbuf_len(enc)) {
632
0
    ret = SSH_ERR_SYSTEM_ERROR;
633
0
    PRINT("write new state file data: %s", nstatefile);
634
0
    close(fd);
635
0
    goto done;
636
0
  }
637
0
  if (fsync(fd) == -1) {
638
0
    ret = SSH_ERR_SYSTEM_ERROR;
639
0
    PRINT("sync new state file: %s", nstatefile);
640
0
    close(fd);
641
0
    goto done;
642
0
  }
643
0
  if (close(fd) == -1) {
644
0
    ret = SSH_ERR_SYSTEM_ERROR;
645
0
    PRINT("close new state file: %s", nstatefile);
646
0
    goto done;
647
0
  }
648
0
  if (state->have_state) {
649
0
    unlink(ostatefile);
650
0
    if (link(statefile, ostatefile)) {
651
0
      ret = SSH_ERR_SYSTEM_ERROR;
652
0
      PRINT("backup state %s to %s", statefile, ostatefile);
653
0
      goto done;
654
0
    }
655
0
  }
656
0
  if (rename(nstatefile, statefile) == -1) {
657
0
    ret = SSH_ERR_SYSTEM_ERROR;
658
0
    PRINT("rename %s to %s", nstatefile, statefile);
659
0
    goto done;
660
0
  }
661
0
  ret = 0;
662
0
done:
663
0
  if (state->lockfd != -1) {
664
0
    close(state->lockfd);
665
0
    state->lockfd = -1;
666
0
  }
667
0
  if (nstatefile)
668
0
    unlink(nstatefile);
669
0
  free(statefile);
670
0
  free(ostatefile);
671
0
  free(nstatefile);
672
0
  sshbuf_free(b);
673
0
  sshbuf_free(enc);
674
0
  return ret;
675
0
}
676
677
int
678
sshkey_xmss_serialize_state(const struct sshkey *k, struct sshbuf *b)
679
0
{
680
0
  struct ssh_xmss_state *state = k->xmss_state;
681
0
  treehash_inst *th;
682
0
  u_int32_t i, node;
683
0
  int r;
684
685
0
  if (state == NULL)
686
0
    return SSH_ERR_INVALID_ARGUMENT;
687
0
  if (state->stack == NULL)
688
0
    return SSH_ERR_INVALID_ARGUMENT;
689
0
  state->stackoffset = state->bds.stackoffset;  /* copy back */
690
0
  if ((r = sshbuf_put_cstring(b, SSH_XMSS_K2_MAGIC)) != 0 ||
691
0
      (r = sshbuf_put_u32(b, state->idx)) != 0 ||
692
0
      (r = sshbuf_put_string(b, state->stack, num_stack(state))) != 0 ||
693
0
      (r = sshbuf_put_u32(b, state->stackoffset)) != 0 ||
694
0
      (r = sshbuf_put_string(b, state->stacklevels, num_stacklevels(state))) != 0 ||
695
0
      (r = sshbuf_put_string(b, state->auth, num_auth(state))) != 0 ||
696
0
      (r = sshbuf_put_string(b, state->keep, num_keep(state))) != 0 ||
697
0
      (r = sshbuf_put_string(b, state->th_nodes, num_th_nodes(state))) != 0 ||
698
0
      (r = sshbuf_put_string(b, state->retain, num_retain(state))) != 0 ||
699
0
      (r = sshbuf_put_u32(b, num_treehash(state))) != 0)
700
0
    return r;
701
0
  for (i = 0; i < num_treehash(state); i++) {
702
0
    th = &state->treehash[i];
703
0
    node = th->node - state->th_nodes;
704
0
    if ((r = sshbuf_put_u32(b, th->h)) != 0 ||
705
0
        (r = sshbuf_put_u32(b, th->next_idx)) != 0 ||
706
0
        (r = sshbuf_put_u32(b, th->stackusage)) != 0 ||
707
0
        (r = sshbuf_put_u8(b, th->completed)) != 0 ||
708
0
        (r = sshbuf_put_u32(b, node)) != 0)
709
0
      return r;
710
0
  }
711
0
  return 0;
712
0
}
713
714
int
715
sshkey_xmss_serialize_state_opt(const struct sshkey *k, struct sshbuf *b,
716
    enum sshkey_serialize_rep opts)
717
0
{
718
0
  struct ssh_xmss_state *state = k->xmss_state;
719
0
  int r = SSH_ERR_INVALID_ARGUMENT;
720
0
  u_char have_stack, have_filename, have_enc;
721
722
0
  if (state == NULL)
723
0
    return SSH_ERR_INVALID_ARGUMENT;
724
0
  if ((r = sshbuf_put_u8(b, opts)) != 0)
725
0
    return r;
726
0
  switch (opts) {
727
0
  case SSHKEY_SERIALIZE_STATE:
728
0
    r = sshkey_xmss_serialize_state(k, b);
729
0
    break;
730
0
  case SSHKEY_SERIALIZE_FULL:
731
0
    if ((r = sshkey_xmss_serialize_enc_key(k, b)) != 0)
732
0
      return r;
733
0
    r = sshkey_xmss_serialize_state(k, b);
734
0
    break;
735
0
  case SSHKEY_SERIALIZE_SHIELD:
736
    /* all of stack/filename/enc are optional */
737
0
    have_stack = state->stack != NULL;
738
0
    if ((r = sshbuf_put_u8(b, have_stack)) != 0)
739
0
      return r;
740
0
    if (have_stack) {
741
0
      state->idx = PEEK_U32(k->xmss_sk); /* update */
742
0
      if ((r = sshkey_xmss_serialize_state(k, b)) != 0)
743
0
        return r;
744
0
    }
745
0
    have_filename = k->xmss_filename != NULL;
746
0
    if ((r = sshbuf_put_u8(b, have_filename)) != 0)
747
0
      return r;
748
0
    if (have_filename &&
749
0
        (r = sshbuf_put_cstring(b, k->xmss_filename)) != 0)
750
0
      return r;
751
0
    have_enc = state->enc_keyiv != NULL;
752
0
    if ((r = sshbuf_put_u8(b, have_enc)) != 0)
753
0
      return r;
754
0
    if (have_enc &&
755
0
        (r = sshkey_xmss_serialize_enc_key(k, b)) != 0)
756
0
      return r;
757
0
    if ((r = sshbuf_put_u32(b, state->maxidx)) != 0 ||
758
0
        (r = sshbuf_put_u8(b, state->allow_update)) != 0)
759
0
      return r;
760
0
    break;
761
0
  case SSHKEY_SERIALIZE_DEFAULT:
762
0
    r = 0;
763
0
    break;
764
0
  default:
765
0
    r = SSH_ERR_INVALID_ARGUMENT;
766
0
    break;
767
0
  }
768
0
  return r;
769
0
}
770
771
int
772
sshkey_xmss_deserialize_state(struct sshkey *k, struct sshbuf *b)
773
0
{
774
0
  struct ssh_xmss_state *state = k->xmss_state;
775
0
  treehash_inst *th;
776
0
  u_int32_t i, lh, node;
777
0
  size_t ls, lsl, la, lk, ln, lr;
778
0
  char *magic;
779
0
  int r = SSH_ERR_INTERNAL_ERROR;
780
781
0
  if (state == NULL)
782
0
    return SSH_ERR_INVALID_ARGUMENT;
783
0
  if (k->xmss_sk == NULL)
784
0
    return SSH_ERR_INVALID_ARGUMENT;
785
0
  if ((state->treehash = calloc(num_treehash(state),
786
0
      sizeof(treehash_inst))) == NULL)
787
0
    return SSH_ERR_ALLOC_FAIL;
788
0
  if ((r = sshbuf_get_cstring(b, &magic, NULL)) != 0 ||
789
0
      (r = sshbuf_get_u32(b, &state->idx)) != 0 ||
790
0
      (r = sshbuf_get_string(b, &state->stack, &ls)) != 0 ||
791
0
      (r = sshbuf_get_u32(b, &state->stackoffset)) != 0 ||
792
0
      (r = sshbuf_get_string(b, &state->stacklevels, &lsl)) != 0 ||
793
0
      (r = sshbuf_get_string(b, &state->auth, &la)) != 0 ||
794
0
      (r = sshbuf_get_string(b, &state->keep, &lk)) != 0 ||
795
0
      (r = sshbuf_get_string(b, &state->th_nodes, &ln)) != 0 ||
796
0
      (r = sshbuf_get_string(b, &state->retain, &lr)) != 0 ||
797
0
      (r = sshbuf_get_u32(b, &lh)) != 0)
798
0
    goto out;
799
0
  if (strcmp(magic, SSH_XMSS_K2_MAGIC) != 0) {
800
0
    r = SSH_ERR_INVALID_ARGUMENT;
801
0
    goto out;
802
0
  }
803
  /* XXX check stackoffset */
804
0
  if (ls != num_stack(state) ||
805
0
      lsl != num_stacklevels(state) ||
806
0
      la != num_auth(state) ||
807
0
      lk != num_keep(state) ||
808
0
      ln != num_th_nodes(state) ||
809
0
      lr != num_retain(state) ||
810
0
      lh != num_treehash(state)) {
811
0
    r = SSH_ERR_INVALID_ARGUMENT;
812
0
    goto out;
813
0
  }
814
0
  for (i = 0; i < num_treehash(state); i++) {
815
0
    th = &state->treehash[i];
816
0
    if ((r = sshbuf_get_u32(b, &th->h)) != 0 ||
817
0
        (r = sshbuf_get_u32(b, &th->next_idx)) != 0 ||
818
0
        (r = sshbuf_get_u32(b, &th->stackusage)) != 0 ||
819
0
        (r = sshbuf_get_u8(b, &th->completed)) != 0 ||
820
0
        (r = sshbuf_get_u32(b, &node)) != 0)
821
0
      goto out;
822
0
    if (node < num_th_nodes(state))
823
0
      th->node = &state->th_nodes[node];
824
0
  }
825
0
  POKE_U32(k->xmss_sk, state->idx);
826
0
  xmss_set_bds_state(&state->bds, state->stack, state->stackoffset,
827
0
      state->stacklevels, state->auth, state->keep, state->treehash,
828
0
      state->retain, 0);
829
  /* success */
830
0
  r = 0;
831
0
 out:
832
0
  free(magic);
833
0
  return r;
834
0
}
835
836
int
837
sshkey_xmss_deserialize_state_opt(struct sshkey *k, struct sshbuf *b)
838
0
{
839
0
  struct ssh_xmss_state *state = k->xmss_state;
840
0
  enum sshkey_serialize_rep opts;
841
0
  u_char have_state, have_stack, have_filename, have_enc;
842
0
  int r;
843
844
0
  if ((r = sshbuf_get_u8(b, &have_state)) != 0)
845
0
    return r;
846
847
0
  opts = have_state;
848
0
  switch (opts) {
849
0
  case SSHKEY_SERIALIZE_DEFAULT:
850
0
    r = 0;
851
0
    break;
852
0
  case SSHKEY_SERIALIZE_SHIELD:
853
0
    if ((r = sshbuf_get_u8(b, &have_stack)) != 0)
854
0
      return r;
855
0
    if (have_stack &&
856
0
        (r = sshkey_xmss_deserialize_state(k, b)) != 0)
857
0
      return r;
858
0
    if ((r = sshbuf_get_u8(b, &have_filename)) != 0)
859
0
      return r;
860
0
    if (have_filename &&
861
0
        (r = sshbuf_get_cstring(b, &k->xmss_filename, NULL)) != 0)
862
0
      return r;
863
0
    if ((r = sshbuf_get_u8(b, &have_enc)) != 0)
864
0
      return r;
865
0
    if (have_enc &&
866
0
        (r = sshkey_xmss_deserialize_enc_key(k, b)) != 0)
867
0
      return r;
868
0
    if ((r = sshbuf_get_u32(b, &state->maxidx)) != 0 ||
869
0
        (r = sshbuf_get_u8(b, &state->allow_update)) != 0)
870
0
      return r;
871
0
    break;
872
0
  case SSHKEY_SERIALIZE_STATE:
873
0
    if ((r = sshkey_xmss_deserialize_state(k, b)) != 0)
874
0
      return r;
875
0
    break;
876
0
  case SSHKEY_SERIALIZE_FULL:
877
0
    if ((r = sshkey_xmss_deserialize_enc_key(k, b)) != 0 ||
878
0
        (r = sshkey_xmss_deserialize_state(k, b)) != 0)
879
0
      return r;
880
0
    break;
881
0
  default:
882
0
    r = SSH_ERR_INVALID_FORMAT;
883
0
    break;
884
0
  }
885
0
  return r;
886
0
}
887
888
int
889
sshkey_xmss_encrypt_state(const struct sshkey *k, struct sshbuf *b,
890
   struct sshbuf **retp)
891
0
{
892
0
  struct ssh_xmss_state *state = k->xmss_state;
893
0
  struct sshbuf *encrypted = NULL, *encoded = NULL, *padded = NULL;
894
0
  struct sshcipher_ctx *ciphercontext = NULL;
895
0
  const struct sshcipher *cipher;
896
0
  u_char *cp, *key, *iv = NULL;
897
0
  size_t i, keylen, ivlen, blocksize, authlen, encrypted_len, aadlen;
898
0
  int r = SSH_ERR_INTERNAL_ERROR;
899
900
0
  if (retp != NULL)
901
0
    *retp = NULL;
902
0
  if (state == NULL ||
903
0
      state->enc_keyiv == NULL ||
904
0
      state->enc_ciphername == NULL)
905
0
    return SSH_ERR_INTERNAL_ERROR;
906
  /*
907
   * chacha20-poly1305-mt@hpnssh.org and chacha20-poly1305@openssh.com
908
   * represent different implementations of the same cipher. For key
909
   * encryption purposes, they're equivalent, and the multithreaded
910
   * implementation is excessive. It can be assumed that references to the
911
   * multithreaded implementation in this context are unintentional, so
912
   * these checks should look for the serial implementation instead.
913
   *
914
   * Additionally, the following code is safe regardless of whether the
915
   * multithreaded implementation is enabled, so no #ifdefs are necessary.
916
   */
917
0
  if (strcmp(state->enc_ciphername, "chacha20-poly1305-mt@hpnssh.org")
918
0
      == 0) {
919
0
    if ((cipher = cipher_by_name("chacha20-poly1305@openssh.com"))
920
0
        == NULL) {
921
0
      r = SSH_ERR_INTERNAL_ERROR;
922
0
      goto out;
923
0
    }
924
0
  } else {
925
0
    if ((cipher = cipher_by_name(state->enc_ciphername)) == NULL) {
926
0
      r = SSH_ERR_INTERNAL_ERROR;
927
0
      goto out;
928
0
    }
929
0
  }
930
0
  blocksize = cipher_blocksize(cipher);
931
0
  keylen = cipher_keylen(cipher);
932
0
  ivlen = cipher_ivlen(cipher);
933
0
  authlen = cipher_authlen(cipher);
934
0
  if (state->enc_keyiv_len != keylen + ivlen) {
935
0
    r = SSH_ERR_INVALID_FORMAT;
936
0
    goto out;
937
0
  }
938
0
  key = state->enc_keyiv;
939
0
  if ((encrypted = sshbuf_new()) == NULL ||
940
0
      (encoded = sshbuf_new()) == NULL ||
941
0
      (padded = sshbuf_new()) == NULL ||
942
0
      (iv = malloc(ivlen)) == NULL) {
943
0
    r = SSH_ERR_ALLOC_FAIL;
944
0
    goto out;
945
0
  }
946
947
  /* replace first 4 bytes of IV with index to ensure uniqueness */
948
0
  memcpy(iv, key + keylen, ivlen);
949
0
  POKE_U32(iv, state->idx);
950
951
0
  if ((r = sshbuf_put(encoded, XMSS_MAGIC, sizeof(XMSS_MAGIC))) != 0 ||
952
0
      (r = sshbuf_put_u32(encoded, state->idx)) != 0)
953
0
    goto out;
954
955
  /* padded state will be encrypted */
956
0
  if ((r = sshbuf_putb(padded, b)) != 0)
957
0
    goto out;
958
0
  i = 0;
959
0
  while (sshbuf_len(padded) % blocksize) {
960
0
    if ((r = sshbuf_put_u8(padded, ++i & 0xff)) != 0)
961
0
      goto out;
962
0
  }
963
0
  encrypted_len = sshbuf_len(padded);
964
965
  /* header including the length of state is used as AAD */
966
0
  if ((r = sshbuf_put_u32(encoded, encrypted_len)) != 0)
967
0
    goto out;
968
0
  aadlen = sshbuf_len(encoded);
969
970
  /* concat header and state */
971
0
  if ((r = sshbuf_putb(encoded, padded)) != 0)
972
0
    goto out;
973
974
  /* reserve space for encryption of encoded data plus auth tag */
975
  /* encrypt at offset addlen */
976
0
  if ((r = sshbuf_reserve(encrypted,
977
0
      encrypted_len + aadlen + authlen, &cp)) != 0 ||
978
0
      (r = cipher_init(&ciphercontext, cipher, key, keylen, iv, ivlen, 0,
979
0
          CIPHER_ENCRYPT, CIPHER_SERIAL)) != 0 ||
980
0
      (r = cipher_crypt(ciphercontext, 0, cp, sshbuf_ptr(encoded),
981
0
      encrypted_len, aadlen, authlen)) != 0)
982
0
    goto out;
983
984
  /* success */
985
0
  r = 0;
986
0
 out:
987
0
  if (retp != NULL) {
988
0
    *retp = encrypted;
989
0
    encrypted = NULL;
990
0
  }
991
0
  sshbuf_free(padded);
992
0
  sshbuf_free(encoded);
993
0
  sshbuf_free(encrypted);
994
0
  cipher_free(ciphercontext);
995
0
  free(iv);
996
0
  return r;
997
0
}
998
999
int
1000
sshkey_xmss_decrypt_state(const struct sshkey *k, struct sshbuf *encoded,
1001
   struct sshbuf **retp)
1002
0
{
1003
0
  struct ssh_xmss_state *state = k->xmss_state;
1004
0
  struct sshbuf *copy = NULL, *decrypted = NULL;
1005
0
  struct sshcipher_ctx *ciphercontext = NULL;
1006
0
  const struct sshcipher *cipher = NULL;
1007
0
  u_char *key, *iv = NULL, *dp;
1008
0
  size_t keylen, ivlen, authlen, aadlen;
1009
0
  u_int blocksize, encrypted_len, index;
1010
0
  int r = SSH_ERR_INTERNAL_ERROR;
1011
1012
0
  if (retp != NULL)
1013
0
    *retp = NULL;
1014
0
  if (state == NULL ||
1015
0
      state->enc_keyiv == NULL ||
1016
0
      state->enc_ciphername == NULL)
1017
0
    return SSH_ERR_INTERNAL_ERROR;
1018
  /*
1019
   * chacha20-poly1305-mt@hpnssh.org and chacha20-poly1305@openssh.com
1020
   * represent different implementations of the same cipher. For key
1021
   * encryption purposes, they're equivalent, and the multithreaded
1022
   * implementation is excessive. It can be assumed that references to the
1023
   * multithreaded implementation in this context are unintentional, so
1024
   * these checks should look for the serial implementation instead.
1025
   *
1026
   * Additionally, the following code is safe regardless of whether the
1027
   * multithreaded implementation is enabled, so no #ifdefs are necessary.
1028
   */
1029
0
  if (strcmp(state->enc_ciphername, "chacha20-poly1305-mt@hpnssh.org")
1030
0
      == 0) {
1031
0
    if ((cipher = cipher_by_name("chacha20-poly1305@openssh.com"))
1032
0
        == NULL) {
1033
0
      r = SSH_ERR_INVALID_FORMAT;
1034
0
      goto out;
1035
0
    }
1036
0
  } else {
1037
0
    if ((cipher = cipher_by_name(state->enc_ciphername)) == NULL) {
1038
0
      r = SSH_ERR_INVALID_FORMAT;
1039
0
      goto out;
1040
0
    }
1041
0
  }
1042
0
  blocksize = cipher_blocksize(cipher);
1043
0
  keylen = cipher_keylen(cipher);
1044
0
  ivlen = cipher_ivlen(cipher);
1045
0
  authlen = cipher_authlen(cipher);
1046
0
  if (state->enc_keyiv_len != keylen + ivlen) {
1047
0
    r = SSH_ERR_INTERNAL_ERROR;
1048
0
    goto out;
1049
0
  }
1050
0
  key = state->enc_keyiv;
1051
1052
0
  if ((copy = sshbuf_fromb(encoded)) == NULL ||
1053
0
      (decrypted = sshbuf_new()) == NULL ||
1054
0
      (iv = malloc(ivlen)) == NULL) {
1055
0
    r = SSH_ERR_ALLOC_FAIL;
1056
0
    goto out;
1057
0
  }
1058
1059
  /* check magic */
1060
0
  if (sshbuf_len(encoded) < sizeof(XMSS_MAGIC) ||
1061
0
      memcmp(sshbuf_ptr(encoded), XMSS_MAGIC, sizeof(XMSS_MAGIC))) {
1062
0
    r = SSH_ERR_INVALID_FORMAT;
1063
0
    goto out;
1064
0
  }
1065
  /* parse public portion */
1066
0
  if ((r = sshbuf_consume(encoded, sizeof(XMSS_MAGIC))) != 0 ||
1067
0
      (r = sshbuf_get_u32(encoded, &index)) != 0 ||
1068
0
      (r = sshbuf_get_u32(encoded, &encrypted_len)) != 0)
1069
0
    goto out;
1070
1071
  /* check size of encrypted key blob */
1072
0
  if (encrypted_len < blocksize || (encrypted_len % blocksize) != 0) {
1073
0
    r = SSH_ERR_INVALID_FORMAT;
1074
0
    goto out;
1075
0
  }
1076
  /* check that an appropriate amount of auth data is present */
1077
0
  if (sshbuf_len(encoded) < authlen ||
1078
0
      sshbuf_len(encoded) - authlen < encrypted_len) {
1079
0
    r = SSH_ERR_INVALID_FORMAT;
1080
0
    goto out;
1081
0
  }
1082
1083
0
  aadlen = sshbuf_len(copy) - sshbuf_len(encoded);
1084
1085
  /* replace first 4 bytes of IV with index to ensure uniqueness */
1086
0
  memcpy(iv, key + keylen, ivlen);
1087
0
  POKE_U32(iv, index);
1088
1089
  /* decrypt private state of key */
1090
0
  if ((r = sshbuf_reserve(decrypted, aadlen + encrypted_len, &dp)) != 0 ||
1091
0
      (r = cipher_init(&ciphercontext, cipher, key, keylen, iv, ivlen, 0,
1092
0
          CIPHER_DECRYPT, CIPHER_SERIAL)) != 0 ||
1093
0
      (r = cipher_crypt(ciphercontext, 0, dp, sshbuf_ptr(copy),
1094
0
      encrypted_len, aadlen, authlen)) != 0)
1095
0
    goto out;
1096
1097
  /* there should be no trailing data */
1098
0
  if ((r = sshbuf_consume(encoded, encrypted_len + authlen)) != 0)
1099
0
    goto out;
1100
0
  if (sshbuf_len(encoded) != 0) {
1101
0
    r = SSH_ERR_INVALID_FORMAT;
1102
0
    goto out;
1103
0
  }
1104
1105
  /* remove AAD */
1106
0
  if ((r = sshbuf_consume(decrypted, aadlen)) != 0)
1107
0
    goto out;
1108
  /* XXX encrypted includes unchecked padding */
1109
1110
  /* success */
1111
0
  r = 0;
1112
0
  if (retp != NULL) {
1113
0
    *retp = decrypted;
1114
0
    decrypted = NULL;
1115
0
  }
1116
0
 out:
1117
0
  cipher_free(ciphercontext);
1118
0
  sshbuf_free(copy);
1119
0
  sshbuf_free(decrypted);
1120
0
  free(iv);
1121
0
  return r;
1122
0
}
1123
1124
u_int32_t
1125
sshkey_xmss_signatures_left(const struct sshkey *k)
1126
0
{
1127
0
  struct ssh_xmss_state *state = k->xmss_state;
1128
0
  u_int32_t idx;
1129
1130
0
  if (sshkey_type_plain(k->type) == KEY_XMSS && state &&
1131
0
      state->maxidx) {
1132
0
    idx = k->xmss_sk ? PEEK_U32(k->xmss_sk) : state->idx;
1133
0
    if (idx < state->maxidx)
1134
0
      return state->maxidx - idx;
1135
0
  }
1136
0
  return 0;
1137
0
}
1138
1139
int
1140
sshkey_xmss_enable_maxsign(struct sshkey *k, u_int32_t maxsign)
1141
0
{
1142
0
  struct ssh_xmss_state *state = k->xmss_state;
1143
1144
0
  if (sshkey_type_plain(k->type) != KEY_XMSS)
1145
0
    return SSH_ERR_INVALID_ARGUMENT;
1146
0
  if (maxsign == 0)
1147
0
    return 0;
1148
0
  if (state->idx + maxsign < state->idx)
1149
0
    return SSH_ERR_INVALID_ARGUMENT;
1150
0
  state->maxidx = state->idx + maxsign;
1151
0
  return 0;
1152
0
}
1153
#endif /* WITH_XMSS */