Coverage Report

Created: 2024-11-21 07:03

/src/libgcrypt/cipher/sntrup761.c
Line
Count
Source (jump to first uncovered line)
1
/* sntrup761.c  -  Streamlined NTRU Prime sntrup761 key-encapsulation method
2
 * Copyright (C) 2023 Simon Josefsson <simon@josefsson.org>
3
 *
4
 * This file is part of Libgcrypt.
5
 *
6
 * Libgcrypt is free software; you can redistribute it and/or modify
7
 * it under the terms of the GNU Lesser General Public License as
8
 * published by the Free Software Foundation; either version 2.1 of
9
 * the License, or (at your option) any later version.
10
 *
11
 * Libgcrypt is distributed in the hope that it will be useful,
12
 * but WITHOUT ANY WARRANTY; without even the implied warranty of
13
 * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the
14
 * GNU Lesser General Public License for more details.
15
 *
16
 * You should have received a copy of the GNU Lesser General Public
17
 * License along with this program; if not, see <https://www.gnu.org/licenses/>.
18
 * SPDX-License-Identifier: LGPL-2.1-or-later
19
 *
20
 * For a description of the algorithm, see:
21
 *   https://ntruprime.cr.yp.to/
22
 */
23
24
/*
25
 * Derived from public domain source, written by (in alphabetical order):
26
 * - Daniel J. Bernstein
27
 * - Chitchanok Chuengsatiansup
28
 * - Tanja Lange
29
 * - Christine van Vredendaal
30
 */
31
32
#ifdef HAVE_CONFIG_H
33
#include <config.h>
34
#endif
35
36
#include "sntrup761.h"
37
38
/* from supercop-20201130/crypto_sort/int32/portable4/int32_minmax.inc */
39
0
#define int32_MINMAX(a,b) \
40
0
do { \
41
0
  int64_t ab = (int64_t)b ^ (int64_t)a; \
42
0
  int64_t c = (int64_t)b - (int64_t)a; \
43
0
  c ^= ab & (c ^ b); \
44
0
  c >>= 31; \
45
0
  c &= ab; \
46
0
  a ^= c; \
47
0
  b ^= c; \
48
0
} while(0)
49
50
/* from supercop-20201130/crypto_sort/int32/portable4/sort.c */
51
static void
52
crypto_sort_int32 (void *array, long long n)
53
0
{
54
0
  long long top, p, q, r, i, j;
55
0
  int32_t *x = array;
56
57
0
  if (n < 2)
58
0
    return;
59
0
  top = 1;
60
0
  while (top < n - top)
61
0
    top += top;
62
63
0
  for (p = top; p >= 1; p >>= 1)
64
0
    {
65
0
      i = 0;
66
0
      while (i + 2 * p <= n)
67
0
  {
68
0
    for (j = i; j < i + p; ++j)
69
0
      int32_MINMAX (x[j], x[j + p]);
70
0
    i += 2 * p;
71
0
  }
72
0
      for (j = i; j < n - p; ++j)
73
0
  int32_MINMAX (x[j], x[j + p]);
74
75
0
      i = 0;
76
0
      j = 0;
77
0
      for (q = top; q > p; q >>= 1)
78
0
  {
79
0
    if (j != i)
80
0
      for (;;)
81
0
        {
82
0
                int32_t a;
83
84
0
    if (j == n - q)
85
0
      goto done;
86
0
    a = x[j + p];
87
0
    for (r = q; r > p; r >>= 1)
88
0
      int32_MINMAX (a, x[j + r]);
89
0
    x[j + p] = a;
90
0
    ++j;
91
0
    if (j == i + p)
92
0
      {
93
0
        i += 2 * p;
94
0
        break;
95
0
      }
96
0
        }
97
0
    while (i + p <= n - q)
98
0
      {
99
0
        for (j = i; j < i + p; ++j)
100
0
    {
101
0
      int32_t a = x[j + p];
102
0
      for (r = q; r > p; r >>= 1)
103
0
        int32_MINMAX (a, x[j + r]);
104
0
      x[j + p] = a;
105
0
    }
106
0
        i += 2 * p;
107
0
      }
108
    /* now i + p > n - q */
109
0
    j = i;
110
0
    while (j < n - q)
111
0
      {
112
0
        int32_t a = x[j + p];
113
0
        for (r = q; r > p; r >>= 1)
114
0
    int32_MINMAX (a, x[j + r]);
115
0
        x[j + p] = a;
116
0
        ++j;
117
0
      }
118
119
0
  done:;
120
0
  }
121
0
    }
122
0
}
123
124
/* from supercop-20201130/crypto_sort/uint32/useint32/sort.c */
125
126
/* can save time by vectorizing xor loops */
127
/* can save time by integrating xor loops with int32_sort */
128
129
static void
130
crypto_sort_uint32 (void *array, long long n)
131
0
{
132
0
  uint32_t *x = array;
133
0
  long long j;
134
0
  for (j = 0; j < n; ++j)
135
0
    x[j] ^= 0x80000000;
136
0
  crypto_sort_int32 (array, n);
137
0
  for (j = 0; j < n; ++j)
138
0
    x[j] ^= 0x80000000;
139
0
}
140
141
/* from supercop-20201130/crypto_kem/sntrup761/ref/uint32.c */
142
143
/*
144
CPU division instruction typically takes time depending on x.
145
This software is designed to take time independent of x.
146
Time still varies depending on m; user must ensure that m is constant.
147
Time also varies on CPUs where multiplication is variable-time.
148
There could be more CPU issues.
149
There could also be compiler issues.
150
*/
151
152
static void
153
uint32_divmod_uint14 (uint32_t * q, uint16_t * r, uint32_t x, uint16_t m)
154
0
{
155
0
  uint32_t v = 0x80000000;
156
0
  uint32_t qpart;
157
0
  uint32_t mask;
158
159
0
  v /= m;
160
161
  /* caller guarantees m > 0 */
162
  /* caller guarantees m < 16384 */
163
  /* vm <= 2^31 <= vm+m-1 */
164
  /* xvm <= 2^31 x <= xvm+x(m-1) */
165
166
0
  *q = 0;
167
168
0
  qpart = (x * (uint64_t) v) >> 31;
169
  /* 2^31 qpart <= xv <= 2^31 qpart + 2^31-1 */
170
  /* 2^31 qpart m <= xvm <= 2^31 qpart m + (2^31-1)m */
171
  /* 2^31 qpart m <= 2^31 x <= 2^31 qpart m + (2^31-1)m + x(m-1) */
172
  /* 0 <= 2^31 newx <= (2^31-1)m + x(m-1) */
173
  /* 0 <= newx <= (1-1/2^31)m + x(m-1)/2^31 */
174
  /* 0 <= newx <= (1-1/2^31)(2^14-1) + (2^32-1)((2^14-1)-1)/2^31 */
175
176
0
  x -= qpart * m;
177
0
  *q += qpart;
178
  /* x <= 49146 */
179
180
0
  qpart = (x * (uint64_t) v) >> 31;
181
  /* 0 <= newx <= (1-1/2^31)m + x(m-1)/2^31 */
182
  /* 0 <= newx <= m + 49146(2^14-1)/2^31 */
183
  /* 0 <= newx <= m + 0.4 */
184
  /* 0 <= newx <= m */
185
186
0
  x -= qpart * m;
187
0
  *q += qpart;
188
  /* x <= m */
189
190
0
  x -= m;
191
0
  *q += 1;
192
0
  mask = -(x >> 31);
193
0
  x += mask & (uint32_t) m;
194
0
  *q += mask;
195
  /* x < m */
196
197
0
  *r = x;
198
0
}
199
200
201
static uint16_t
202
uint32_mod_uint14 (uint32_t x, uint16_t m)
203
0
{
204
0
  uint32_t q;
205
0
  uint16_t r;
206
0
  uint32_divmod_uint14 (&q, &r, x, m);
207
0
  return r;
208
0
}
209
210
/* from supercop-20201130/crypto_kem/sntrup761/ref/int32.c */
211
212
static void
213
int32_divmod_uint14 (int32_t * q, uint16_t * r, int32_t x, uint16_t m)
214
0
{
215
0
  uint32_t uq, uq2;
216
0
  uint16_t ur, ur2;
217
0
  uint32_t mask;
218
219
0
  uint32_divmod_uint14 (&uq, &ur, 0x80000000 + (uint32_t) x, m);
220
0
  uint32_divmod_uint14 (&uq2, &ur2, 0x80000000, m);
221
0
  ur -= ur2;
222
0
  uq -= uq2;
223
0
  mask = -(uint32_t) (ur >> 15);
224
0
  ur += mask & m;
225
0
  uq += mask;
226
0
  *r = ur;
227
0
  *q = uq;
228
0
}
229
230
231
static uint16_t
232
int32_mod_uint14 (int32_t x, uint16_t m)
233
0
{
234
0
  int32_t q;
235
0
  uint16_t r;
236
0
  int32_divmod_uint14 (&q, &r, x, m);
237
0
  return r;
238
0
}
239
240
/* from supercop-20201130/crypto_kem/sntrup761/ref/paramsmenu.h */
241
0
#define p 761
242
0
#define q 4591
243
0
#define Rounded_bytes 1007
244
0
#define Rq_bytes 1158
245
0
#define w 286
246
247
/* from supercop-20201130/crypto_kem/sntrup761/ref/Decode.h */
248
249
/* Decode(R,s,M,len) */
250
/* assumes 0 < M[i] < 16384 */
251
/* produces 0 <= R[i] < M[i] */
252
253
/* from supercop-20201130/crypto_kem/sntrup761/ref/Decode.c */
254
255
static void
256
Decode (uint16_t * out, const unsigned char *S, const uint16_t * M,
257
  long long len)
258
0
{
259
0
  if (len == 1)
260
0
    {
261
0
      if (M[0] == 1)
262
0
  *out = 0;
263
0
      else if (M[0] <= 256)
264
0
  *out = uint32_mod_uint14 (S[0], M[0]);
265
0
      else
266
0
  *out = uint32_mod_uint14 (S[0] + (((uint16_t) S[1]) << 8), M[0]);
267
0
    }
268
0
  if (len > 1)
269
0
    {
270
0
      uint16_t R2[(len + 1) / 2];
271
0
      uint16_t M2[(len + 1) / 2];
272
0
      uint16_t bottomr[len / 2];
273
0
      uint32_t bottomt[len / 2];
274
0
      long long i;
275
0
      for (i = 0; i < len - 1; i += 2)
276
0
  {
277
0
    uint32_t m = M[i] * (uint32_t) M[i + 1];
278
0
    if (m > 256 * 16383)
279
0
      {
280
0
        bottomt[i / 2] = 256 * 256;
281
0
        bottomr[i / 2] = S[0] + 256 * S[1];
282
0
        S += 2;
283
0
        M2[i / 2] = (((m + 255) >> 8) + 255) >> 8;
284
0
      }
285
0
    else if (m >= 16384)
286
0
      {
287
0
        bottomt[i / 2] = 256;
288
0
        bottomr[i / 2] = S[0];
289
0
        S += 1;
290
0
        M2[i / 2] = (m + 255) >> 8;
291
0
      }
292
0
    else
293
0
      {
294
0
        bottomt[i / 2] = 1;
295
0
        bottomr[i / 2] = 0;
296
0
        M2[i / 2] = m;
297
0
      }
298
0
  }
299
0
      if (i < len)
300
0
  M2[i / 2] = M[i];
301
0
      Decode (R2, S, M2, (len + 1) / 2);
302
0
      for (i = 0; i < len - 1; i += 2)
303
0
  {
304
0
    uint32_t r = bottomr[i / 2];
305
0
    uint32_t r1;
306
0
    uint16_t r0;
307
0
    r += bottomt[i / 2] * R2[i / 2];
308
0
    uint32_divmod_uint14 (&r1, &r0, r, M[i]);
309
0
    r1 = uint32_mod_uint14 (r1, M[i + 1]);  /* only needed for invalid inputs */
310
0
    *out++ = r0;
311
0
    *out++ = r1;
312
0
  }
313
0
      if (i < len)
314
0
  *out++ = R2[i / 2];
315
0
    }
316
0
}
317
318
/* from supercop-20201130/crypto_kem/sntrup761/ref/Encode.h */
319
320
/* Encode(s,R,M,len) */
321
/* assumes 0 <= R[i] < M[i] < 16384 */
322
323
/* from supercop-20201130/crypto_kem/sntrup761/ref/Encode.c */
324
325
/* 0 <= R[i] < M[i] < 16384 */
326
static void
327
Encode (unsigned char *out, const uint16_t * R, const uint16_t * M,
328
  long long len)
329
0
{
330
0
  if (len == 1)
331
0
    {
332
0
      uint16_t r = R[0];
333
0
      uint16_t m = M[0];
334
0
      while (m > 1)
335
0
  {
336
0
    *out++ = r;
337
0
    r >>= 8;
338
0
    m = (m + 255) >> 8;
339
0
  }
340
0
    }
341
0
  if (len > 1)
342
0
    {
343
0
      uint16_t R2[(len + 1) / 2];
344
0
      uint16_t M2[(len + 1) / 2];
345
0
      long long i;
346
0
      for (i = 0; i < len - 1; i += 2)
347
0
  {
348
0
    uint32_t m0 = M[i];
349
0
    uint32_t r = R[i] + R[i + 1] * m0;
350
0
    uint32_t m = M[i + 1] * m0;
351
0
    while (m >= 16384)
352
0
      {
353
0
        *out++ = r;
354
0
        r >>= 8;
355
0
        m = (m + 255) >> 8;
356
0
      }
357
0
    R2[i / 2] = r;
358
0
    M2[i / 2] = m;
359
0
  }
360
0
      if (i < len)
361
0
  {
362
0
    R2[i / 2] = R[i];
363
0
    M2[i / 2] = M[i];
364
0
  }
365
0
      Encode (out, R2, M2, (len + 1) / 2);
366
0
    }
367
0
}
368
369
/* from supercop-20201130/crypto_kem/sntrup761/ref/kem.c */
370
371
/* ----- masks */
372
373
/* return -1 if x!=0; else return 0 */
374
static int
375
int16_t_nonzero_mask (int16_t x)
376
0
{
377
0
  uint16_t u = x;   /* 0, else 1...65535 */
378
0
  uint32_t v = u;   /* 0, else 1...65535 */
379
0
  v = -v;     /* 0, else 2^32-65535...2^32-1 */
380
0
  v >>= 31;     /* 0, else 1 */
381
0
  return -v;      /* 0, else -1 */
382
0
}
383
384
/* return -1 if x<0; otherwise return 0 */
385
static int
386
int16_t_negative_mask (int16_t x)
387
0
{
388
0
  uint16_t u = x;
389
0
  u >>= 15;
390
0
  return -(int) u;
391
  /* alternative with gcc -fwrapv: */
392
  /* x>>15 compiles to CPU's arithmetic right shift */
393
0
}
394
395
/* ----- arithmetic mod 3 */
396
397
typedef int8_t small;
398
399
/* F3 is always represented as -1,0,1 */
400
/* so ZZ_fromF3 is a no-op */
401
402
/* x must not be close to top int16_t */
403
static small
404
F3_freeze (int16_t x)
405
0
{
406
0
  return int32_mod_uint14 (x + 1, 3) - 1;
407
0
}
408
409
/* ----- arithmetic mod q */
410
411
0
#define q12 ((q-1)/2)
412
typedef int16_t Fq;
413
/* always represented as -q12...q12 */
414
/* so ZZ_fromFq is a no-op */
415
416
/* x must not be close to top int32 */
417
static Fq
418
Fq_freeze (int32_t x)
419
0
{
420
0
  return int32_mod_uint14 (x + q12, q) - q12;
421
0
}
422
423
static Fq
424
Fq_recip (Fq a1)
425
0
{
426
0
  int i = 1;
427
0
  Fq ai = a1;
428
429
0
  while (i < q - 2)
430
0
    {
431
0
      ai = Fq_freeze (a1 * (int32_t) ai);
432
0
      i += 1;
433
0
    }
434
0
  return ai;
435
0
}
436
437
/* ----- small polynomials */
438
439
/* 0 if Weightw_is(r), else -1 */
440
static int
441
Weightw_mask (small * r)
442
0
{
443
0
  int weight = 0;
444
0
  int i;
445
446
0
  for (i = 0; i < p; ++i)
447
0
    weight += r[i] & 1;
448
0
  return int16_t_nonzero_mask (weight - w);
449
0
}
450
451
/* R3_fromR(R_fromRq(r)) */
452
static void
453
R3_fromRq (small * out, const Fq * r)
454
0
{
455
0
  int i;
456
0
  for (i = 0; i < p; ++i)
457
0
    out[i] = F3_freeze (r[i]);
458
0
}
459
460
/* h = f*g in the ring R3 */
461
static void
462
R3_mult (small * h, const small * f, const small * g)
463
0
{
464
0
  small fg[p + p - 1];
465
0
  small result;
466
0
  int i, j;
467
468
0
  for (i = 0; i < p; ++i)
469
0
    {
470
0
      result = 0;
471
0
      for (j = 0; j <= i; ++j)
472
0
  result = F3_freeze (result + f[j] * g[i - j]);
473
0
      fg[i] = result;
474
0
    }
475
0
  for (i = p; i < p + p - 1; ++i)
476
0
    {
477
0
      result = 0;
478
0
      for (j = i - p + 1; j < p; ++j)
479
0
  result = F3_freeze (result + f[j] * g[i - j]);
480
0
      fg[i] = result;
481
0
    }
482
483
0
  for (i = p + p - 2; i >= p; --i)
484
0
    {
485
0
      fg[i - p] = F3_freeze (fg[i - p] + fg[i]);
486
0
      fg[i - p + 1] = F3_freeze (fg[i - p + 1] + fg[i]);
487
0
    }
488
489
0
  for (i = 0; i < p; ++i)
490
0
    h[i] = fg[i];
491
0
}
492
493
/* returns 0 if recip succeeded; else -1 */
494
static int
495
R3_recip (small * out, const small * in)
496
0
{
497
0
  small f[p + 1], g[p + 1], v[p + 1], r[p + 1];
498
0
  int i, loop, delta;
499
0
  int sign, swap, t;
500
501
0
  for (i = 0; i < p + 1; ++i)
502
0
    v[i] = 0;
503
0
  for (i = 0; i < p + 1; ++i)
504
0
    r[i] = 0;
505
0
  r[0] = 1;
506
0
  for (i = 0; i < p; ++i)
507
0
    f[i] = 0;
508
0
  f[0] = 1;
509
0
  f[p - 1] = f[p] = -1;
510
0
  for (i = 0; i < p; ++i)
511
0
    g[p - 1 - i] = in[i];
512
0
  g[p] = 0;
513
514
0
  delta = 1;
515
516
0
  for (loop = 0; loop < 2 * p - 1; ++loop)
517
0
    {
518
0
      for (i = p; i > 0; --i)
519
0
  v[i] = v[i - 1];
520
0
      v[0] = 0;
521
522
0
      sign = -g[0] * f[0];
523
0
      swap = int16_t_negative_mask (-delta) & int16_t_nonzero_mask (g[0]);
524
0
      delta ^= swap & (delta ^ -delta);
525
0
      delta += 1;
526
527
0
      for (i = 0; i < p + 1; ++i)
528
0
  {
529
0
    t = swap & (f[i] ^ g[i]);
530
0
    f[i] ^= t;
531
0
    g[i] ^= t;
532
0
    t = swap & (v[i] ^ r[i]);
533
0
    v[i] ^= t;
534
0
    r[i] ^= t;
535
0
  }
536
537
0
      for (i = 0; i < p + 1; ++i)
538
0
  g[i] = F3_freeze (g[i] + sign * f[i]);
539
0
      for (i = 0; i < p + 1; ++i)
540
0
  r[i] = F3_freeze (r[i] + sign * v[i]);
541
542
0
      for (i = 0; i < p; ++i)
543
0
  g[i] = g[i + 1];
544
0
      g[p] = 0;
545
0
    }
546
547
0
  sign = f[0];
548
0
  for (i = 0; i < p; ++i)
549
0
    out[i] = sign * v[p - 1 - i];
550
551
0
  return int16_t_nonzero_mask (delta);
552
0
}
553
554
/* ----- polynomials mod q */
555
556
/* h = f*g in the ring Rq */
557
static void
558
Rq_mult_small (Fq * h, const Fq * f, const small * g)
559
0
{
560
0
  Fq fg[p + p - 1];
561
0
  Fq result;
562
0
  int i, j;
563
564
0
  for (i = 0; i < p; ++i)
565
0
    {
566
0
      result = 0;
567
0
      for (j = 0; j <= i; ++j)
568
0
  result = Fq_freeze (result + f[j] * (int32_t) g[i - j]);
569
0
      fg[i] = result;
570
0
    }
571
0
  for (i = p; i < p + p - 1; ++i)
572
0
    {
573
0
      result = 0;
574
0
      for (j = i - p + 1; j < p; ++j)
575
0
  result = Fq_freeze (result + f[j] * (int32_t) g[i - j]);
576
0
      fg[i] = result;
577
0
    }
578
579
0
  for (i = p + p - 2; i >= p; --i)
580
0
    {
581
0
      fg[i - p] = Fq_freeze (fg[i - p] + fg[i]);
582
0
      fg[i - p + 1] = Fq_freeze (fg[i - p + 1] + fg[i]);
583
0
    }
584
585
0
  for (i = 0; i < p; ++i)
586
0
    h[i] = fg[i];
587
0
}
588
589
/* h = 3f in Rq */
590
static void
591
Rq_mult3 (Fq * h, const Fq * f)
592
0
{
593
0
  int i;
594
595
0
  for (i = 0; i < p; ++i)
596
0
    h[i] = Fq_freeze (3 * f[i]);
597
0
}
598
599
/* out = 1/(3*in) in Rq */
600
/* returns 0 if recip succeeded; else -1 */
601
static int
602
Rq_recip3 (Fq * out, const small * in)
603
0
{
604
0
  Fq f[p + 1], g[p + 1], v[p + 1], r[p + 1];
605
0
  int i, loop, delta;
606
0
  int swap, t;
607
0
  int32_t f0, g0;
608
0
  Fq scale;
609
610
0
  for (i = 0; i < p + 1; ++i)
611
0
    v[i] = 0;
612
0
  for (i = 0; i < p + 1; ++i)
613
0
    r[i] = 0;
614
0
  r[0] = Fq_recip (3);
615
0
  for (i = 0; i < p; ++i)
616
0
    f[i] = 0;
617
0
  f[0] = 1;
618
0
  f[p - 1] = f[p] = -1;
619
0
  for (i = 0; i < p; ++i)
620
0
    g[p - 1 - i] = in[i];
621
0
  g[p] = 0;
622
623
0
  delta = 1;
624
625
0
  for (loop = 0; loop < 2 * p - 1; ++loop)
626
0
    {
627
0
      for (i = p; i > 0; --i)
628
0
  v[i] = v[i - 1];
629
0
      v[0] = 0;
630
631
0
      swap = int16_t_negative_mask (-delta) & int16_t_nonzero_mask (g[0]);
632
0
      delta ^= swap & (delta ^ -delta);
633
0
      delta += 1;
634
635
0
      for (i = 0; i < p + 1; ++i)
636
0
  {
637
0
    t = swap & (f[i] ^ g[i]);
638
0
    f[i] ^= t;
639
0
    g[i] ^= t;
640
0
    t = swap & (v[i] ^ r[i]);
641
0
    v[i] ^= t;
642
0
    r[i] ^= t;
643
0
  }
644
645
0
      f0 = f[0];
646
0
      g0 = g[0];
647
0
      for (i = 0; i < p + 1; ++i)
648
0
  g[i] = Fq_freeze (f0 * g[i] - g0 * f[i]);
649
0
      for (i = 0; i < p + 1; ++i)
650
0
  r[i] = Fq_freeze (f0 * r[i] - g0 * v[i]);
651
652
0
      for (i = 0; i < p; ++i)
653
0
  g[i] = g[i + 1];
654
0
      g[p] = 0;
655
0
    }
656
657
0
  scale = Fq_recip (f[0]);
658
0
  for (i = 0; i < p; ++i)
659
0
    out[i] = Fq_freeze (scale * (int32_t) v[p - 1 - i]);
660
661
0
  return int16_t_nonzero_mask (delta);
662
0
}
663
664
/* ----- rounded polynomials mod q */
665
666
static void
667
Round (Fq * out, const Fq * a)
668
0
{
669
0
  int i;
670
0
  for (i = 0; i < p; ++i)
671
0
    out[i] = a[i] - F3_freeze (a[i]);
672
0
}
673
674
/* ----- sorting to generate short polynomial */
675
676
static void
677
Short_fromlist (small * out, const uint32_t * in)
678
0
{
679
0
  uint32_t L[p];
680
0
  int i;
681
682
0
  for (i = 0; i < w; ++i)
683
0
    L[i] = in[i] & (uint32_t) - 2;
684
0
  for (i = w; i < p; ++i)
685
0
    L[i] = (in[i] & (uint32_t) - 3) | 1;
686
0
  crypto_sort_uint32 (L, p);
687
0
  for (i = 0; i < p; ++i)
688
0
    out[i] = (L[i] & 3) - 1;
689
0
}
690
691
/* ----- underlying hash function */
692
693
0
#define Hash_bytes 32
694
695
/* e.g., b = 0 means out = Hash0(in) */
696
static void
697
Hash_prefix (unsigned char *out, int b, const unsigned char *in, int inlen)
698
0
{
699
0
  unsigned char x[inlen + 1];
700
0
  unsigned char h[64];
701
0
  int i;
702
703
0
  x[0] = b;
704
0
  for (i = 0; i < inlen; ++i)
705
0
    x[i + 1] = in[i];
706
0
  crypto_hash_sha512 (h, x, inlen + 1);
707
0
  for (i = 0; i < 32; ++i)
708
0
    out[i] = h[i];
709
0
}
710
711
/* ----- higher-level randomness */
712
713
static uint32_t
714
urandom32 (void *random_ctx, sntrup761_random_func * random)
715
0
{
716
0
  unsigned char c[4];
717
0
  uint32_t out[4];
718
719
0
  random (random_ctx, 4, c);
720
0
  out[0] = (uint32_t) c[0];
721
0
  out[1] = ((uint32_t) c[1]) << 8;
722
0
  out[2] = ((uint32_t) c[2]) << 16;
723
0
  out[3] = ((uint32_t) c[3]) << 24;
724
0
  return out[0] + out[1] + out[2] + out[3];
725
0
}
726
727
static void
728
Short_random (small * out, void *random_ctx, sntrup761_random_func * random)
729
0
{
730
0
  uint32_t L[p];
731
0
  int i;
732
733
0
  for (i = 0; i < p; ++i)
734
0
    L[i] = urandom32 (random_ctx, random);
735
0
  Short_fromlist (out, L);
736
0
}
737
738
static void
739
Small_random (small * out, void *random_ctx, sntrup761_random_func * random)
740
0
{
741
0
  int i;
742
743
0
  for (i = 0; i < p; ++i)
744
0
    out[i] = (((urandom32 (random_ctx, random) & 0x3fffffff) * 3) >> 30) - 1;
745
0
}
746
747
/* ----- Streamlined NTRU Prime Core */
748
749
/* h,(f,ginv) = KeyGen() */
750
static void
751
KeyGen (Fq * h, small * f, small * ginv, void *random_ctx,
752
  sntrup761_random_func * random)
753
0
{
754
0
  small g[p];
755
0
  Fq finv[p];
756
757
0
  for (;;)
758
0
    {
759
0
      Small_random (g, random_ctx, random);
760
0
      if (R3_recip (ginv, g) == 0)
761
0
  break;
762
0
    }
763
0
  Short_random (f, random_ctx, random);
764
0
  Rq_recip3 (finv, f);    /* always works */
765
0
  Rq_mult_small (h, finv, g);
766
0
}
767
768
/* c = Encrypt(r,h) */
769
static void
770
Encrypt (Fq * c, const small * r, const Fq * h)
771
0
{
772
0
  Fq hr[p];
773
774
0
  Rq_mult_small (hr, h, r);
775
0
  Round (c, hr);
776
0
}
777
778
/* r = Decrypt(c,(f,ginv)) */
779
static void
780
Decrypt (small * r, const Fq * c, const small * f, const small * ginv)
781
0
{
782
0
  Fq cf[p];
783
0
  Fq cf3[p];
784
0
  small e[p];
785
0
  small ev[p];
786
0
  int mask;
787
0
  int i;
788
789
0
  Rq_mult_small (cf, c, f);
790
0
  Rq_mult3 (cf3, cf);
791
0
  R3_fromRq (e, cf3);
792
0
  R3_mult (ev, e, ginv);
793
794
0
  mask = Weightw_mask (ev); /* 0 if weight w, else -1 */
795
0
  for (i = 0; i < w; ++i)
796
0
    r[i] = ((ev[i] ^ 1) & ~mask) ^ 1;
797
0
  for (i = w; i < p; ++i)
798
0
    r[i] = ev[i] & ~mask;
799
0
}
800
801
/* ----- encoding small polynomials (including short polynomials) */
802
803
0
#define Small_bytes ((p+3)/4)
804
805
/* these are the only functions that rely on p mod 4 = 1 */
806
807
static void
808
Small_encode (unsigned char *s, const small * f)
809
0
{
810
0
  small x;
811
0
  int i;
812
813
0
  for (i = 0; i < p / 4; ++i)
814
0
    {
815
0
      x = *f++ + 1;
816
0
      x += (*f++ + 1) << 2;
817
0
      x += (*f++ + 1) << 4;
818
0
      x += (*f++ + 1) << 6;
819
0
      *s++ = x;
820
0
    }
821
0
  x = *f++ + 1;
822
0
  *s++ = x;
823
0
}
824
825
static void
826
Small_decode (small * f, const unsigned char *s)
827
0
{
828
0
  unsigned char x;
829
0
  int i;
830
831
0
  for (i = 0; i < p / 4; ++i)
832
0
    {
833
0
      x = *s++;
834
0
      *f++ = ((small) (x & 3)) - 1;
835
0
      x >>= 2;
836
0
      *f++ = ((small) (x & 3)) - 1;
837
0
      x >>= 2;
838
0
      *f++ = ((small) (x & 3)) - 1;
839
0
      x >>= 2;
840
0
      *f++ = ((small) (x & 3)) - 1;
841
0
    }
842
0
  x = *s++;
843
0
  *f++ = ((small) (x & 3)) - 1;
844
0
}
845
846
/* ----- encoding general polynomials */
847
848
static void
849
Rq_encode (unsigned char *s, const Fq * r)
850
0
{
851
0
  uint16_t R[p], M[p];
852
0
  int i;
853
854
0
  for (i = 0; i < p; ++i)
855
0
    R[i] = r[i] + q12;
856
0
  for (i = 0; i < p; ++i)
857
0
    M[i] = q;
858
0
  Encode (s, R, M, p);
859
0
}
860
861
static void
862
Rq_decode (Fq * r, const unsigned char *s)
863
0
{
864
0
  uint16_t R[p], M[p];
865
0
  int i;
866
867
0
  for (i = 0; i < p; ++i)
868
0
    M[i] = q;
869
0
  Decode (R, s, M, p);
870
0
  for (i = 0; i < p; ++i)
871
0
    r[i] = ((Fq) R[i]) - q12;
872
0
}
873
874
/* ----- encoding rounded polynomials */
875
876
static void
877
Rounded_encode (unsigned char *s, const Fq * r)
878
0
{
879
0
  uint16_t R[p], M[p];
880
0
  int i;
881
882
0
  for (i = 0; i < p; ++i)
883
0
    R[i] = ((r[i] + q12) * 10923) >> 15;
884
0
  for (i = 0; i < p; ++i)
885
0
    M[i] = (q + 2) / 3;
886
0
  Encode (s, R, M, p);
887
0
}
888
889
static void
890
Rounded_decode (Fq * r, const unsigned char *s)
891
0
{
892
0
  uint16_t R[p], M[p];
893
0
  int i;
894
895
0
  for (i = 0; i < p; ++i)
896
0
    M[i] = (q + 2) / 3;
897
0
  Decode (R, s, M, p);
898
0
  for (i = 0; i < p; ++i)
899
0
    r[i] = R[i] * 3 - q12;
900
0
}
901
902
/* ----- Streamlined NTRU Prime Core plus encoding */
903
904
typedef small Inputs[p];  /* passed by reference */
905
0
#define Inputs_random Short_random
906
0
#define Inputs_encode Small_encode
907
0
#define Inputs_bytes Small_bytes
908
909
0
#define Ciphertexts_bytes Rounded_bytes
910
0
#define SecretKeys_bytes (2*Small_bytes)
911
0
#define PublicKeys_bytes Rq_bytes
912
913
/* pk,sk = ZKeyGen() */
914
static void
915
ZKeyGen (unsigned char *pk, unsigned char *sk, void *random_ctx,
916
   sntrup761_random_func * random)
917
0
{
918
0
  Fq h[p];
919
0
  small f[p], v[p];
920
921
0
  KeyGen (h, f, v, random_ctx, random);
922
0
  Rq_encode (pk, h);
923
0
  Small_encode (sk, f);
924
0
  sk += Small_bytes;
925
0
  Small_encode (sk, v);
926
0
}
927
928
/* C = ZEncrypt(r,pk) */
929
static void
930
ZEncrypt (unsigned char *C, const Inputs r, const unsigned char *pk)
931
0
{
932
0
  Fq h[p];
933
0
  Fq c[p];
934
0
  Rq_decode (h, pk);
935
0
  Encrypt (c, r, h);
936
0
  Rounded_encode (C, c);
937
0
}
938
939
/* r = ZDecrypt(C,sk) */
940
static void
941
ZDecrypt (Inputs r, const unsigned char *C, const unsigned char *sk)
942
0
{
943
0
  small f[p], v[p];
944
0
  Fq c[p];
945
946
0
  Small_decode (f, sk);
947
0
  sk += Small_bytes;
948
0
  Small_decode (v, sk);
949
0
  Rounded_decode (c, C);
950
0
  Decrypt (r, c, f, v);
951
0
}
952
953
/* ----- confirmation hash */
954
955
0
#define Confirm_bytes 32
956
957
/* h = HashConfirm(r,pk,cache); cache is Hash4(pk) */
958
static void
959
HashConfirm (unsigned char *h, const unsigned char *r,
960
       /* const unsigned char *pk, */ const unsigned char *cache)
961
0
{
962
0
  unsigned char x[Hash_bytes * 2];
963
0
  int i;
964
965
0
  Hash_prefix (x, 3, r, Inputs_bytes);
966
0
  for (i = 0; i < Hash_bytes; ++i)
967
0
    x[Hash_bytes + i] = cache[i];
968
0
  Hash_prefix (h, 2, x, sizeof x);
969
0
}
970
971
/* ----- session-key hash */
972
973
/* k = HashSession(b,y,z) */
974
static void
975
HashSession (unsigned char *k, int b, const unsigned char *y,
976
       const unsigned char *z)
977
0
{
978
0
  unsigned char x[Hash_bytes + Ciphertexts_bytes + Confirm_bytes];
979
0
  int i;
980
981
0
  Hash_prefix (x, 3, y, Inputs_bytes);
982
0
  for (i = 0; i < Ciphertexts_bytes + Confirm_bytes; ++i)
983
0
    x[Hash_bytes + i] = z[i];
984
0
  Hash_prefix (k, b, x, sizeof x);
985
0
}
986
987
/* ----- Streamlined NTRU Prime */
988
989
/* pk,sk = KEM_KeyGen() */
990
void
991
sntrup761_keypair (unsigned char *pk, unsigned char *sk, void *random_ctx,
992
       sntrup761_random_func * random)
993
0
{
994
0
  int i;
995
996
0
  ZKeyGen (pk, sk, random_ctx, random);
997
0
  sk += SecretKeys_bytes;
998
0
  for (i = 0; i < PublicKeys_bytes; ++i)
999
0
    *sk++ = pk[i];
1000
0
  random (random_ctx, Inputs_bytes, sk);
1001
0
  sk += Inputs_bytes;
1002
0
  Hash_prefix (sk, 4, pk, PublicKeys_bytes);
1003
0
}
1004
1005
/* c,r_enc = Hide(r,pk,cache); cache is Hash4(pk) */
1006
static void
1007
Hide (unsigned char *c, unsigned char *r_enc, const Inputs r,
1008
      const unsigned char *pk, const unsigned char *cache)
1009
0
{
1010
0
  Inputs_encode (r_enc, r);
1011
0
  ZEncrypt (c, r, pk);
1012
0
  c += Ciphertexts_bytes;
1013
0
  HashConfirm (c, r_enc, cache);
1014
0
}
1015
1016
/* c,k = Encap(pk) */
1017
void
1018
sntrup761_enc (unsigned char *c, unsigned char *k, const unsigned char *pk,
1019
         void *random_ctx, sntrup761_random_func * random)
1020
0
{
1021
0
  Inputs r;
1022
0
  unsigned char r_enc[Inputs_bytes];
1023
0
  unsigned char cache[Hash_bytes];
1024
1025
0
  Hash_prefix (cache, 4, pk, PublicKeys_bytes);
1026
0
  Inputs_random (r, random_ctx, random);
1027
0
  Hide (c, r_enc, r, pk, cache);
1028
0
  HashSession (k, 1, r_enc, c);
1029
0
}
1030
1031
/* 0 if matching ciphertext+confirm, else -1 */
1032
static int
1033
Ciphertexts_diff_mask (const unsigned char *c, const unsigned char *c2)
1034
0
{
1035
0
  uint16_t differentbits = 0;
1036
0
  int len = Ciphertexts_bytes + Confirm_bytes;
1037
1038
0
  while (len-- > 0)
1039
0
    differentbits |= (*c++) ^ (*c2++);
1040
0
  return (1 & ((differentbits - 1) >> 8)) - 1;
1041
0
}
1042
1043
/* k = Decap(c,sk) */
1044
void
1045
sntrup761_dec (unsigned char *k, const unsigned char *c, const unsigned char *sk)
1046
0
{
1047
0
  const unsigned char *pk = sk + SecretKeys_bytes;
1048
0
  const unsigned char *rho = pk + PublicKeys_bytes;
1049
0
  const unsigned char *cache = rho + Inputs_bytes;
1050
0
  Inputs r;
1051
0
  unsigned char r_enc[Inputs_bytes];
1052
0
  unsigned char cnew[Ciphertexts_bytes + Confirm_bytes];
1053
0
  int mask;
1054
0
  int i;
1055
1056
0
  ZDecrypt (r, c, sk);
1057
0
  Hide (cnew, r_enc, r, pk, cache);
1058
0
  mask = Ciphertexts_diff_mask (c, cnew);
1059
0
  for (i = 0; i < Inputs_bytes; ++i)
1060
0
    r_enc[i] ^= mask & (r_enc[i] ^ rho[i]);
1061
0
  HashSession (k, 1 + mask, r_enc, c);
1062
0
}