Coverage Report

Created: 2023-06-07 07:11

/src/boringssl/crypto/kyber/kyber.c
Line
Count
Source (jump to first uncovered line)
1
/* Copyright (c) 2023, Google Inc.
2
 *
3
 * Permission to use, copy, modify, and/or distribute this software for any
4
 * purpose with or without fee is hereby granted, provided that the above
5
 * copyright notice and this permission notice appear in all copies.
6
 *
7
 * THE SOFTWARE IS PROVIDED "AS IS" AND THE AUTHOR DISCLAIMS ALL WARRANTIES
8
 * WITH REGARD TO THIS SOFTWARE INCLUDING ALL IMPLIED WARRANTIES OF
9
 * MERCHANTABILITY AND FITNESS. IN NO EVENT SHALL THE AUTHOR BE LIABLE FOR ANY
10
 * SPECIAL, DIRECT, INDIRECT, OR CONSEQUENTIAL DAMAGES OR ANY DAMAGES
11
 * WHATSOEVER RESULTING FROM LOSS OF USE, DATA OR PROFITS, WHETHER IN AN ACTION
12
 * OF CONTRACT, NEGLIGENCE OR OTHER TORTIOUS ACTION, ARISING OUT OF OR IN
13
 * CONNECTION WITH THE USE OR PERFORMANCE OF THIS SOFTWARE. */
14
15
#include <openssl/kyber.h>
16
17
#include <assert.h>
18
#include <stdlib.h>
19
20
#include <openssl/bytestring.h>
21
#include <openssl/rand.h>
22
23
#include "../internal.h"
24
#include "./internal.h"
25
26
27
// See
28
// https://pq-crystals.org/kyber/data/kyber-specification-round3-20210804.pdf
29
30
0
#define DEGREE 256
31
0
#define RANK 3
32
33
static const size_t kBarrettMultiplier = 5039;
34
static const unsigned kBarrettShift = 24;
35
static const uint16_t kPrime = 3329;
36
static const int kLog2Prime = 12;
37
static const uint16_t kHalfPrime = (/*kPrime=*/3329 - 1) / 2;
38
static const int kDU = 10;
39
static const int kDV = 4;
40
// kInverseDegree is 128^-1 mod 3329; 128 because kPrime does not have a 512th
41
// root of unity.
42
static const uint16_t kInverseDegree = 3303;
43
static const size_t kEncodedVectorSize =
44
    (/*kLog2Prime=*/12 * DEGREE / 8) * RANK;
45
static const size_t kCompressedVectorSize = /*kDU=*/10 * RANK * DEGREE / 8;
46
47
typedef struct scalar {
48
  // On every function entry and exit, 0 <= c < kPrime.
49
  uint16_t c[DEGREE];
50
} scalar;
51
52
typedef struct vector {
53
  scalar v[RANK];
54
} vector;
55
56
typedef struct matrix {
57
  scalar v[RANK][RANK];
58
} matrix;
59
60
// This bit of Python will be referenced in some of the following comments:
61
//
62
// p = 3329
63
//
64
// def bitreverse(i):
65
//     ret = 0
66
//     for n in range(7):
67
//         bit = i & 1
68
//         ret <<= 1
69
//         ret |= bit
70
//         i >>= 1
71
//     return ret
72
73
// kNTTRoots = [pow(17, bitreverse(i), p) for i in range(128)]
74
static const uint16_t kNTTRoots[128] = {
75
    1,    1729, 2580, 3289, 2642, 630,  1897, 848,  1062, 1919, 193,  797,
76
    2786, 3260, 569,  1746, 296,  2447, 1339, 1476, 3046, 56,   2240, 1333,
77
    1426, 2094, 535,  2882, 2393, 2879, 1974, 821,  289,  331,  3253, 1756,
78
    1197, 2304, 2277, 2055, 650,  1977, 2513, 632,  2865, 33,   1320, 1915,
79
    2319, 1435, 807,  452,  1438, 2868, 1534, 2402, 2647, 2617, 1481, 648,
80
    2474, 3110, 1227, 910,  17,   2761, 583,  2649, 1637, 723,  2288, 1100,
81
    1409, 2662, 3281, 233,  756,  2156, 3015, 3050, 1703, 1651, 2789, 1789,
82
    1847, 952,  1461, 2687, 939,  2308, 2437, 2388, 733,  2337, 268,  641,
83
    1584, 2298, 2037, 3220, 375,  2549, 2090, 1645, 1063, 319,  2773, 757,
84
    2099, 561,  2466, 2594, 2804, 1092, 403,  1026, 1143, 2150, 2775, 886,
85
    1722, 1212, 1874, 1029, 2110, 2935, 885,  2154,
86
};
87
88
// kInverseNTTRoots = [pow(17, -bitreverse(i), p) for i in range(128)]
89
static const uint16_t kInverseNTTRoots[128] = {
90
    1,    1600, 40,   749,  2481, 1432, 2699, 687,  1583, 2760, 69,   543,
91
    2532, 3136, 1410, 2267, 2508, 1355, 450,  936,  447,  2794, 1235, 1903,
92
    1996, 1089, 3273, 283,  1853, 1990, 882,  3033, 2419, 2102, 219,  855,
93
    2681, 1848, 712,  682,  927,  1795, 461,  1891, 2877, 2522, 1894, 1010,
94
    1414, 2009, 3296, 464,  2697, 816,  1352, 2679, 1274, 1052, 1025, 2132,
95
    1573, 76,   2998, 3040, 1175, 2444, 394,  1219, 2300, 1455, 2117, 1607,
96
    2443, 554,  1179, 2186, 2303, 2926, 2237, 525,  735,  863,  2768, 1230,
97
    2572, 556,  3010, 2266, 1684, 1239, 780,  2954, 109,  1292, 1031, 1745,
98
    2688, 3061, 992,  2596, 941,  892,  1021, 2390, 642,  1868, 2377, 1482,
99
    1540, 540,  1678, 1626, 279,  314,  1173, 2573, 3096, 48,   667,  1920,
100
    2229, 1041, 2606, 1692, 680,  2746, 568,  3312,
101
};
102
103
// kModRoots = [pow(17, 2*bitreverse(i) + 1, p) for i in range(128)]
104
static const uint16_t kModRoots[128] = {
105
    17,   3312, 2761, 568,  583,  2746, 2649, 680,  1637, 1692, 723,  2606,
106
    2288, 1041, 1100, 2229, 1409, 1920, 2662, 667,  3281, 48,   233,  3096,
107
    756,  2573, 2156, 1173, 3015, 314,  3050, 279,  1703, 1626, 1651, 1678,
108
    2789, 540,  1789, 1540, 1847, 1482, 952,  2377, 1461, 1868, 2687, 642,
109
    939,  2390, 2308, 1021, 2437, 892,  2388, 941,  733,  2596, 2337, 992,
110
    268,  3061, 641,  2688, 1584, 1745, 2298, 1031, 2037, 1292, 3220, 109,
111
    375,  2954, 2549, 780,  2090, 1239, 1645, 1684, 1063, 2266, 319,  3010,
112
    2773, 556,  757,  2572, 2099, 1230, 561,  2768, 2466, 863,  2594, 735,
113
    2804, 525,  1092, 2237, 403,  2926, 1026, 2303, 1143, 2186, 2150, 1179,
114
    2775, 554,  886,  2443, 1722, 1607, 1212, 2117, 1874, 1455, 1029, 2300,
115
    2110, 1219, 2935, 394,  885,  2444, 2154, 1175,
116
};
117
118
// reduce_once reduces 0 <= x < 2*kPrime, mod kPrime.
119
0
static uint16_t reduce_once(uint16_t x) {
120
0
  assert(x < 2 * kPrime);
121
0
  const uint16_t subtracted = x - kPrime;
122
0
  uint16_t mask = 0u - (subtracted >> 15);
123
  // On Aarch64, omitting a |value_barrier_u16| results in a 2x speedup of Kyber
124
  // overall and Clang still produces constant-time code using `csel`. On other
125
  // platforms & compilers on godbolt that we care about, this code also
126
  // produces constant-time output.
127
0
  return (mask & x) | (~mask & subtracted);
128
0
}
129
130
// constant time reduce x mod kPrime using Barrett reduction. x must be less
131
// than kPrime + 2×kPrime².
132
0
static uint16_t reduce(uint32_t x) {
133
0
  assert(x < kPrime + 2u * kPrime * kPrime);
134
0
  uint64_t product = (uint64_t)x * kBarrettMultiplier;
135
0
  uint32_t quotient = (uint32_t)(product >> kBarrettShift);
136
0
  uint32_t remainder = x - quotient * kPrime;
137
0
  return reduce_once(remainder);
138
0
}
139
140
0
static void scalar_zero(scalar *out) { OPENSSL_memset(out, 0, sizeof(*out)); }
141
142
0
static void vector_zero(vector *out) { OPENSSL_memset(out, 0, sizeof(*out)); }
143
144
// In place number theoretic transform of a given scalar.
145
// Note that Kyber's kPrime 3329 does not have a 512th root of unity, so this
146
// transform leaves off the last iteration of the usual FFT code, with the 128
147
// relevant roots of unity being stored in |kNTTRoots|. This means the output
148
// should be seen as 128 elements in GF(3329^2), with the coefficients of the
149
// elements being consecutive entries in |s->c|.
150
0
static void scalar_ntt(scalar *s) {
151
0
  int offset = DEGREE;
152
  // `int` is used here because using `size_t` throughout caused a ~5% slowdown
153
  // with Clang 14 on Aarch64.
154
0
  for (int step = 1; step < DEGREE / 2; step <<= 1) {
155
0
    offset >>= 1;
156
0
    int k = 0;
157
0
    for (int i = 0; i < step; i++) {
158
0
      const uint32_t step_root = kNTTRoots[i + step];
159
0
      for (int j = k; j < k + offset; j++) {
160
0
        uint16_t odd = reduce(step_root * s->c[j + offset]);
161
0
        uint16_t even = s->c[j];
162
0
        s->c[j] = reduce_once(odd + even);
163
0
        s->c[j + offset] = reduce_once(even - odd + kPrime);
164
0
      }
165
0
      k += 2 * offset;
166
0
    }
167
0
  }
168
0
}
169
170
0
static void vector_ntt(vector *a) {
171
0
  for (int i = 0; i < RANK; i++) {
172
0
    scalar_ntt(&a->v[i]);
173
0
  }
174
0
}
175
176
// In place inverse number theoretic transform of a given scalar, with pairs of
177
// entries of s->v being interpreted as elements of GF(3329^2). Just as with the
178
// number theoretic transform, this leaves off the first step of the normal iFFT
179
// to account for the fact that 3329 does not have a 512th root of unity, using
180
// the precomputed 128 roots of unity stored in |kInverseNTTRoots|.
181
0
static void scalar_inverse_ntt(scalar *s) {
182
0
  int step = DEGREE / 2;
183
  // `int` is used here because using `size_t` throughout caused a ~5% slowdown
184
  // with Clang 14 on Aarch64.
185
0
  for (int offset = 2; offset < DEGREE; offset <<= 1) {
186
0
    step >>= 1;
187
0
    int k = 0;
188
0
    for (int i = 0; i < step; i++) {
189
0
      uint32_t step_root = kInverseNTTRoots[i + step];
190
0
      for (int j = k; j < k + offset; j++) {
191
0
        uint16_t odd = s->c[j + offset];
192
0
        uint16_t even = s->c[j];
193
0
        s->c[j] = reduce_once(odd + even);
194
0
        s->c[j + offset] = reduce(step_root * (even - odd + kPrime));
195
0
      }
196
0
      k += 2 * offset;
197
0
    }
198
0
  }
199
0
  for (int i = 0; i < DEGREE; i++) {
200
0
    s->c[i] = reduce(s->c[i] * kInverseDegree);
201
0
  }
202
0
}
203
204
0
static void vector_inverse_ntt(vector *a) {
205
0
  for (int i = 0; i < RANK; i++) {
206
0
    scalar_inverse_ntt(&a->v[i]);
207
0
  }
208
0
}
209
210
0
static void scalar_add(scalar *lhs, const scalar *rhs) {
211
0
  for (int i = 0; i < DEGREE; i++) {
212
0
    lhs->c[i] = reduce_once(lhs->c[i] + rhs->c[i]);
213
0
  }
214
0
}
215
216
0
static void scalar_sub(scalar *lhs, const scalar *rhs) {
217
0
  for (int i = 0; i < DEGREE; i++) {
218
0
    lhs->c[i] = reduce_once(lhs->c[i] - rhs->c[i] + kPrime);
219
0
  }
220
0
}
221
222
// Multiplying two scalars in the number theoretically transformed state. Since
223
// 3329 does not have a 512th root of unity, this means we have to interpret
224
// the 2*ith and (2*i+1)th entries of the scalar as elements of GF(3329)[X]/(X^2
225
// - 17^(2*bitreverse(i)+1)) The value of 17^(2*bitreverse(i)+1) mod 3329 is
226
// stored in the precomputed |kModRoots| table. Note that our Barrett transform
227
// only allows us to multipy two reduced numbers together, so we need some
228
// intermediate reduction steps, even if an uint64_t could hold 3 multiplied
229
// numbers.
230
0
static void scalar_mult(scalar *out, const scalar *lhs, const scalar *rhs) {
231
0
  for (int i = 0; i < DEGREE / 2; i++) {
232
0
    uint32_t real_real = (uint32_t)lhs->c[2 * i] * rhs->c[2 * i];
233
0
    uint32_t img_img = (uint32_t)lhs->c[2 * i + 1] * rhs->c[2 * i + 1];
234
0
    uint32_t real_img = (uint32_t)lhs->c[2 * i] * rhs->c[2 * i + 1];
235
0
    uint32_t img_real = (uint32_t)lhs->c[2 * i + 1] * rhs->c[2 * i];
236
0
    out->c[2 * i] =
237
0
        reduce(real_real + (uint32_t)reduce(img_img) * kModRoots[i]);
238
0
    out->c[2 * i + 1] = reduce(img_real + real_img);
239
0
  }
240
0
}
241
242
0
static void vector_add(vector *lhs, const vector *rhs) {
243
0
  for (int i = 0; i < RANK; i++) {
244
0
    scalar_add(&lhs->v[i], &rhs->v[i]);
245
0
  }
246
0
}
247
248
0
static void matrix_mult(vector *out, const matrix *m, const vector *a) {
249
0
  vector_zero(out);
250
0
  for (int i = 0; i < RANK; i++) {
251
0
    for (int j = 0; j < RANK; j++) {
252
0
      scalar product;
253
0
      scalar_mult(&product, &m->v[i][j], &a->v[j]);
254
0
      scalar_add(&out->v[i], &product);
255
0
    }
256
0
  }
257
0
}
258
259
static void matrix_mult_transpose(vector *out, const matrix *m,
260
0
                                  const vector *a) {
261
0
  vector_zero(out);
262
0
  for (int i = 0; i < RANK; i++) {
263
0
    for (int j = 0; j < RANK; j++) {
264
0
      scalar product;
265
0
      scalar_mult(&product, &m->v[j][i], &a->v[j]);
266
0
      scalar_add(&out->v[i], &product);
267
0
    }
268
0
  }
269
0
}
270
271
static void scalar_inner_product(scalar *out, const vector *lhs,
272
0
                                 const vector *rhs) {
273
0
  scalar_zero(out);
274
0
  for (int i = 0; i < RANK; i++) {
275
0
    scalar product;
276
0
    scalar_mult(&product, &lhs->v[i], &rhs->v[i]);
277
0
    scalar_add(out, &product);
278
0
  }
279
0
}
280
281
// Algorithm 1 of the Kyber spec. Rejection samples a Keccak stream to get
282
// uniformly distributed elements. This is used for matrix expansion and only
283
// operates on public inputs.
284
static void scalar_from_keccak_vartime(scalar *out,
285
0
                                       struct BORINGSSL_keccak_st *keccak_ctx) {
286
0
  assert(keccak_ctx->offset == 0);
287
0
  assert(keccak_ctx->rate_bytes == 168);
288
0
  static_assert(168 % 3 == 0, "block and coefficient boundaries do not align");
289
290
0
  int done = 0;
291
0
  while (done < DEGREE) {
292
0
    uint8_t block[168];
293
0
    BORINGSSL_keccak_squeeze(keccak_ctx, block, sizeof(block));
294
0
    for (size_t i = 0; i < sizeof(block) && done < DEGREE; i += 3) {
295
0
      uint16_t d1 = block[i] + 256 * (block[i + 1] % 16);
296
0
      uint16_t d2 = block[i + 1] / 16 + 16 * block[i + 2];
297
0
      if (d1 < kPrime) {
298
0
        out->c[done++] = d1;
299
0
      }
300
0
      if (d2 < kPrime && done < DEGREE) {
301
0
        out->c[done++] = d2;
302
0
      }
303
0
    }
304
0
  }
305
0
}
306
307
// Algorithm 2 of the Kyber spec, with eta fixed to two and the PRF call
308
// included. Creates binominally distributed elements by sampling 2*|eta| bits,
309
// and setting the coefficient to the count of the first bits minus the count of
310
// the second bits, resulting in a centered binomial distribution. Since eta is
311
// two this gives -2/2 with a probability of 1/16, -1/1 with probability 1/4,
312
// and 0 with probability 3/8.
313
static void scalar_centered_binomial_distribution_eta_2_with_prf(
314
0
    scalar *out, const uint8_t input[33]) {
315
0
  uint8_t entropy[128];
316
0
  static_assert(sizeof(entropy) == 2 * /*kEta=*/2 * DEGREE / 8, "");
317
0
  BORINGSSL_keccak(entropy, sizeof(entropy), input, 33, boringssl_shake256);
318
319
0
  for (int i = 0; i < DEGREE; i += 2) {
320
0
    uint8_t byte = entropy[i / 2];
321
322
0
    uint16_t value = kPrime;
323
0
    value += (byte & 1) + ((byte >> 1) & 1);
324
0
    value -= ((byte >> 2) & 1) + ((byte >> 3) & 1);
325
0
    out->c[i] = reduce_once(value);
326
327
0
    byte >>= 4;
328
0
    value = kPrime;
329
0
    value += (byte & 1) + ((byte >> 1) & 1);
330
0
    value -= ((byte >> 2) & 1) + ((byte >> 3) & 1);
331
0
    out->c[i + 1] = reduce_once(value);
332
0
  }
333
0
}
334
335
// Generates a secret vector by using
336
// |scalar_centered_binomial_distribution_eta_2_with_prf|, using the given seed
337
// appending and incrementing |counter| for entry of the vector.
338
static void vector_generate_secret_eta_2(vector *out, uint8_t *counter,
339
0
                                         const uint8_t seed[32]) {
340
0
  uint8_t input[33];
341
0
  OPENSSL_memcpy(input, seed, 32);
342
0
  for (int i = 0; i < RANK; i++) {
343
0
    input[32] = (*counter)++;
344
0
    scalar_centered_binomial_distribution_eta_2_with_prf(&out->v[i], input);
345
0
  }
346
0
}
347
348
// Expands the matrix of a seed for key generation and for encaps-CPA.
349
0
static void matrix_expand(matrix *out, const uint8_t rho[32]) {
350
0
  uint8_t input[34];
351
0
  OPENSSL_memcpy(input, rho, 32);
352
0
  for (int i = 0; i < RANK; i++) {
353
0
    for (int j = 0; j < RANK; j++) {
354
0
      input[32] = i;
355
0
      input[33] = j;
356
0
      struct BORINGSSL_keccak_st keccak_ctx;
357
0
      BORINGSSL_keccak_init(&keccak_ctx, input, sizeof(input),
358
0
                            boringssl_shake128);
359
0
      scalar_from_keccak_vartime(&out->v[i][j], &keccak_ctx);
360
0
    }
361
0
  }
362
0
}
363
364
static const uint8_t kMasks[8] = {0x01, 0x03, 0x07, 0x0f,
365
                                  0x1f, 0x3f, 0x7f, 0xff};
366
367
0
static void scalar_encode(uint8_t *out, const scalar *s, int bits) {
368
0
  assert(bits <= (int)sizeof(*s->c) * 8 && bits != 1);
369
370
0
  uint8_t out_byte = 0;
371
0
  int out_byte_bits = 0;
372
373
0
  for (int i = 0; i < DEGREE; i++) {
374
0
    uint16_t element = s->c[i];
375
0
    int element_bits_done = 0;
376
377
0
    while (element_bits_done < bits) {
378
0
      int chunk_bits = bits - element_bits_done;
379
0
      int out_bits_remaining = 8 - out_byte_bits;
380
0
      if (chunk_bits >= out_bits_remaining) {
381
0
        chunk_bits = out_bits_remaining;
382
0
        out_byte |= (element & kMasks[chunk_bits - 1]) << out_byte_bits;
383
0
        *out = out_byte;
384
0
        out++;
385
0
        out_byte_bits = 0;
386
0
        out_byte = 0;
387
0
      } else {
388
0
        out_byte |= (element & kMasks[chunk_bits - 1]) << out_byte_bits;
389
0
        out_byte_bits += chunk_bits;
390
0
      }
391
392
0
      element_bits_done += chunk_bits;
393
0
      element >>= chunk_bits;
394
0
    }
395
0
  }
396
397
0
  if (out_byte_bits > 0) {
398
0
    *out = out_byte;
399
0
  }
400
0
}
401
402
// scalar_encode_1 is |scalar_encode| specialised for |bits| == 1.
403
0
static void scalar_encode_1(uint8_t out[32], const scalar *s) {
404
0
  for (int i = 0; i < DEGREE; i += 8) {
405
0
    uint8_t out_byte = 0;
406
0
    for (int j = 0; j < 8; j++) {
407
0
      out_byte |= (s->c[i + j] & 1) << j;
408
0
    }
409
0
    *out = out_byte;
410
0
    out++;
411
0
  }
412
0
}
413
414
// Encodes an entire vector into 32*|RANK|*|bits| bytes. Note that since 256
415
// (DEGREE) is divisible by 8, the individual vector entries will always fill a
416
// whole number of bytes, so we do not need to worry about bit packing here.
417
0
static void vector_encode(uint8_t *out, const vector *a, int bits) {
418
0
  for (int i = 0; i < RANK; i++) {
419
0
    scalar_encode(out + i * bits * DEGREE / 8, &a->v[i], bits);
420
0
  }
421
0
}
422
423
// scalar_decode parses |DEGREE * bits| bits from |in| into |DEGREE| values in
424
// |out|. It returns one on success and zero if any parsed value is >=
425
// |kPrime|.
426
0
static int scalar_decode(scalar *out, const uint8_t *in, int bits) {
427
0
  assert(bits <= (int)sizeof(*out->c) * 8 && bits != 1);
428
429
0
  uint8_t in_byte = 0;
430
0
  int in_byte_bits_left = 0;
431
432
0
  for (int i = 0; i < DEGREE; i++) {
433
0
    uint16_t element = 0;
434
0
    int element_bits_done = 0;
435
436
0
    while (element_bits_done < bits) {
437
0
      if (in_byte_bits_left == 0) {
438
0
        in_byte = *in;
439
0
        in++;
440
0
        in_byte_bits_left = 8;
441
0
      }
442
443
0
      int chunk_bits = bits - element_bits_done;
444
0
      if (chunk_bits > in_byte_bits_left) {
445
0
        chunk_bits = in_byte_bits_left;
446
0
      }
447
448
0
      element |= (in_byte & kMasks[chunk_bits - 1]) << element_bits_done;
449
0
      in_byte_bits_left -= chunk_bits;
450
0
      in_byte >>= chunk_bits;
451
452
0
      element_bits_done += chunk_bits;
453
0
    }
454
455
0
    if (element >= kPrime) {
456
0
      return 0;
457
0
    }
458
0
    out->c[i] = element;
459
0
  }
460
461
0
  return 1;
462
0
}
463
464
// scalar_decode_1 is |scalar_decode| specialised for |bits| == 1.
465
0
static void scalar_decode_1(scalar *out, const uint8_t in[32]) {
466
0
  for (int i = 0; i < DEGREE; i += 8) {
467
0
    uint8_t in_byte = *in;
468
0
    in++;
469
0
    for (int j = 0; j < 8; j++) {
470
0
      out->c[i + j] = in_byte & 1;
471
0
      in_byte >>= 1;
472
0
    }
473
0
  }
474
0
}
475
476
// Decodes 32*|RANK|*|bits| bytes from |in| into |out|. It returns one on
477
// success or zero if any parsed value is >= |kPrime|.
478
0
static int vector_decode(vector *out, const uint8_t *in, int bits) {
479
0
  for (int i = 0; i < RANK; i++) {
480
0
    if (!scalar_decode(&out->v[i], in + i * bits * DEGREE / 8, bits)) {
481
0
      return 0;
482
0
    }
483
0
  }
484
0
  return 1;
485
0
}
486
487
// Compresses (lossily) an input |x| mod 3329 into |bits| many bits by grouping
488
// numbers close to each other together. The formula used is
489
// round(2^|bits|/kPrime*x) mod 2^|bits|.
490
// Uses Barrett reduction to achieve constant time. Since we need both the
491
// remainder (for rounding) and the quotient (as the result), we cannot use
492
// |reduce| here, but need to do the Barrett reduction directly.
493
0
static uint16_t compress(uint16_t x, int bits) {
494
0
  uint32_t shifted = (uint32_t)x << bits;
495
0
  uint64_t product = (uint64_t)shifted * kBarrettMultiplier;
496
0
  uint32_t quotient = (uint32_t)(product >> kBarrettShift);
497
0
  uint32_t remainder = shifted - quotient * kPrime;
498
499
  // Adjust the quotient to round correctly:
500
  //   0 <= remainder <= kHalfPrime round to 0
501
  //   kHalfPrime < remainder <= kPrime + kHalfPrime round to 1
502
  //   kPrime + kHalfPrime < remainder < 2 * kPrime round to 2
503
0
  assert(remainder < 2u * kPrime);
504
0
  quotient += 1 & constant_time_lt_w(kHalfPrime, remainder);
505
0
  quotient += 1 & constant_time_lt_w(kPrime + kHalfPrime, remainder);
506
0
  return quotient & ((1 << bits) - 1);
507
0
}
508
509
// Decompresses |x| by using an equi-distant representative. The formula is
510
// round(kPrime/2^|bits|*x). Note that 2^|bits| being the divisor allows us to
511
// implement this logic using only bit operations.
512
0
static uint16_t decompress(uint16_t x, int bits) {
513
0
  uint32_t product = (uint32_t)x * kPrime;
514
0
  uint32_t power = 1 << bits;
515
  // This is |product| % power, since |power| is a power of 2.
516
0
  uint32_t remainder = product & (power - 1);
517
  // This is |product| / power, since |power| is a power of 2.
518
0
  uint32_t lower = product >> bits;
519
  // The rounding logic works since the first half of numbers mod |power| have a
520
  // 0 as first bit, and the second half has a 1 as first bit, since |power| is
521
  // a power of 2. As a 12 bit number, |remainder| is always positive, so we
522
  // will shift in 0s for a right shift.
523
0
  return lower + (remainder >> (bits - 1));
524
0
}
525
526
0
static void scalar_compress(scalar *s, int bits) {
527
0
  for (int i = 0; i < DEGREE; i++) {
528
0
    s->c[i] = compress(s->c[i], bits);
529
0
  }
530
0
}
531
532
0
static void scalar_decompress(scalar *s, int bits) {
533
0
  for (int i = 0; i < DEGREE; i++) {
534
0
    s->c[i] = decompress(s->c[i], bits);
535
0
  }
536
0
}
537
538
0
static void vector_compress(vector *a, int bits) {
539
0
  for (int i = 0; i < RANK; i++) {
540
0
    scalar_compress(&a->v[i], bits);
541
0
  }
542
0
}
543
544
0
static void vector_decompress(vector *a, int bits) {
545
0
  for (int i = 0; i < RANK; i++) {
546
0
    scalar_decompress(&a->v[i], bits);
547
0
  }
548
0
}
549
550
struct public_key {
551
  vector t;
552
  uint8_t rho[32];
553
  uint8_t public_key_hash[32];
554
  matrix m;
555
};
556
557
static struct public_key *public_key_from_external(
558
0
    const struct KYBER_public_key *external) {
559
0
  static_assert(sizeof(struct KYBER_public_key) >= sizeof(struct public_key),
560
0
                "Kyber public key is too small");
561
0
  static_assert(alignof(struct KYBER_public_key) >= alignof(struct public_key),
562
0
                "Kyber public key align incorrect");
563
0
  return (struct public_key *)external;
564
0
}
565
566
struct private_key {
567
  struct public_key pub;
568
  vector s;
569
  uint8_t fo_failure_secret[32];
570
};
571
572
static struct private_key *private_key_from_external(
573
0
    const struct KYBER_private_key *external) {
574
0
  static_assert(sizeof(struct KYBER_private_key) >= sizeof(struct private_key),
575
0
                "Kyber private key too small");
576
0
  static_assert(
577
0
      alignof(struct KYBER_private_key) >= alignof(struct private_key),
578
0
      "Kyber private key align incorrect");
579
0
  return (struct private_key *)external;
580
0
}
581
582
// Calls |KYBER_generate_key_external_entropy| with random bytes from
583
// |RAND_bytes|.
584
void KYBER_generate_key(uint8_t out_encoded_public_key[KYBER_PUBLIC_KEY_BYTES],
585
0
                        struct KYBER_private_key *out_private_key) {
586
0
  uint8_t entropy[KYBER_GENERATE_KEY_ENTROPY];
587
0
  RAND_bytes(entropy, sizeof(entropy));
588
0
  KYBER_generate_key_external_entropy(out_encoded_public_key, out_private_key,
589
0
                                      entropy);
590
0
}
591
592
0
static int kyber_marshal_public_key(CBB *out, const struct public_key *pub) {
593
0
  uint8_t *vector_output;
594
0
  if (!CBB_add_space(out, &vector_output, kEncodedVectorSize)) {
595
0
    return 0;
596
0
  }
597
0
  vector_encode(vector_output, &pub->t, kLog2Prime);
598
0
  if (!CBB_add_bytes(out, pub->rho, sizeof(pub->rho))) {
599
0
    return 0;
600
0
  }
601
0
  return 1;
602
0
}
603
604
// Algorithms 4 and 7 of the Kyber spec. Algorithms are combined since key
605
// generation is not part of the FO transform, and the spec uses Algorithm 7 to
606
// specify the actual key format.
607
void KYBER_generate_key_external_entropy(
608
    uint8_t out_encoded_public_key[KYBER_PUBLIC_KEY_BYTES],
609
    struct KYBER_private_key *out_private_key,
610
0
    const uint8_t entropy[KYBER_GENERATE_KEY_ENTROPY]) {
611
0
  struct private_key *priv = private_key_from_external(out_private_key);
612
0
  uint8_t hashed[64];
613
0
  BORINGSSL_keccak(hashed, sizeof(hashed), entropy, 32, boringssl_sha3_512);
614
0
  const uint8_t *const rho = hashed;
615
0
  const uint8_t *const sigma = hashed + 32;
616
0
  OPENSSL_memcpy(priv->pub.rho, hashed, sizeof(priv->pub.rho));
617
0
  matrix_expand(&priv->pub.m, rho);
618
0
  uint8_t counter = 0;
619
0
  vector_generate_secret_eta_2(&priv->s, &counter, sigma);
620
0
  vector_ntt(&priv->s);
621
0
  vector error;
622
0
  vector_generate_secret_eta_2(&error, &counter, sigma);
623
0
  vector_ntt(&error);
624
0
  matrix_mult_transpose(&priv->pub.t, &priv->pub.m, &priv->s);
625
0
  vector_add(&priv->pub.t, &error);
626
627
0
  CBB cbb;
628
0
  CBB_init_fixed(&cbb, out_encoded_public_key, KYBER_PUBLIC_KEY_BYTES);
629
0
  if (!kyber_marshal_public_key(&cbb, &priv->pub)) {
630
0
    abort();
631
0
  }
632
633
0
  BORINGSSL_keccak(priv->pub.public_key_hash, sizeof(priv->pub.public_key_hash),
634
0
                   out_encoded_public_key, KYBER_PUBLIC_KEY_BYTES,
635
0
                   boringssl_sha3_256);
636
0
  OPENSSL_memcpy(priv->fo_failure_secret, entropy + 32, 32);
637
0
}
638
639
void KYBER_public_from_private(struct KYBER_public_key *out_public_key,
640
0
                               const struct KYBER_private_key *private_key) {
641
0
  struct public_key *const pub = public_key_from_external(out_public_key);
642
0
  const struct private_key *const priv = private_key_from_external(private_key);
643
0
  *pub = priv->pub;
644
0
}
645
646
// Algorithm 5 of the Kyber spec. Encrypts a message with given randomness to
647
// the ciphertext in |out|. Without applying the Fujisaki-Okamoto transform this
648
// would not result in a CCA secure scheme, since lattice schemes are vulnerable
649
// to decryption failure oracles.
650
static void encrypt_cpa(uint8_t out[KYBER_CIPHERTEXT_BYTES],
651
                        const struct public_key *pub, const uint8_t message[32],
652
0
                        const uint8_t randomness[32]) {
653
0
  uint8_t counter = 0;
654
0
  vector secret;
655
0
  vector_generate_secret_eta_2(&secret, &counter, randomness);
656
0
  vector_ntt(&secret);
657
0
  vector error;
658
0
  vector_generate_secret_eta_2(&error, &counter, randomness);
659
0
  uint8_t input[33];
660
0
  OPENSSL_memcpy(input, randomness, 32);
661
0
  input[32] = counter;
662
0
  scalar scalar_error;
663
0
  scalar_centered_binomial_distribution_eta_2_with_prf(&scalar_error, input);
664
0
  vector u;
665
0
  matrix_mult(&u, &pub->m, &secret);
666
0
  vector_inverse_ntt(&u);
667
0
  vector_add(&u, &error);
668
0
  scalar v;
669
0
  scalar_inner_product(&v, &pub->t, &secret);
670
0
  scalar_inverse_ntt(&v);
671
0
  scalar_add(&v, &scalar_error);
672
0
  scalar expanded_message;
673
0
  scalar_decode_1(&expanded_message, message);
674
0
  scalar_decompress(&expanded_message, 1);
675
0
  scalar_add(&v, &expanded_message);
676
0
  vector_compress(&u, kDU);
677
0
  vector_encode(out, &u, kDU);
678
0
  scalar_compress(&v, kDV);
679
0
  scalar_encode(out + kCompressedVectorSize, &v, kDV);
680
0
}
681
682
// Calls KYBER_encap_external_entropy| with random bytes from |RAND_bytes|
683
void KYBER_encap(uint8_t out_ciphertext[KYBER_CIPHERTEXT_BYTES],
684
                 uint8_t *out_shared_secret, size_t out_shared_secret_len,
685
0
                 const struct KYBER_public_key *public_key) {
686
0
  uint8_t entropy[KYBER_ENCAP_ENTROPY];
687
0
  RAND_bytes(entropy, KYBER_ENCAP_ENTROPY);
688
0
  KYBER_encap_external_entropy(out_ciphertext, out_shared_secret,
689
0
                               out_shared_secret_len, public_key, entropy);
690
0
}
691
692
// Algorithm 8 of the Kyber spec, safe for line 2 of the spec. The spec there
693
// hashes the output of the system's random number generator, since the FO
694
// transform will reveal it to the decrypting party. There is no reason to do
695
// this when a secure random number generator is used. When an insecure random
696
// number generator is used, the caller should switch to a secure one before
697
// calling this method.
698
void KYBER_encap_external_entropy(
699
    uint8_t out_ciphertext[KYBER_CIPHERTEXT_BYTES], uint8_t *out_shared_secret,
700
    size_t out_shared_secret_len, const struct KYBER_public_key *public_key,
701
0
    const uint8_t entropy[KYBER_ENCAP_ENTROPY]) {
702
0
  const struct public_key *pub = public_key_from_external(public_key);
703
0
  uint8_t input[64];
704
0
  OPENSSL_memcpy(input, entropy, KYBER_ENCAP_ENTROPY);
705
0
  OPENSSL_memcpy(input + KYBER_ENCAP_ENTROPY, pub->public_key_hash,
706
0
                 sizeof(input) - KYBER_ENCAP_ENTROPY);
707
0
  uint8_t prekey_and_randomness[64];
708
0
  BORINGSSL_keccak(prekey_and_randomness, sizeof(prekey_and_randomness), input,
709
0
                   sizeof(input), boringssl_sha3_512);
710
0
  encrypt_cpa(out_ciphertext, pub, entropy, prekey_and_randomness + 32);
711
0
  BORINGSSL_keccak(prekey_and_randomness + 32, 32, out_ciphertext,
712
0
                   KYBER_CIPHERTEXT_BYTES, boringssl_sha3_256);
713
0
  BORINGSSL_keccak(out_shared_secret, out_shared_secret_len,
714
0
                   prekey_and_randomness, sizeof(prekey_and_randomness),
715
0
                   boringssl_shake256);
716
0
}
717
718
// Algorithm 6 of the Kyber spec.
719
static void decrypt_cpa(uint8_t out[32], const struct private_key *priv,
720
0
                        const uint8_t ciphertext[KYBER_CIPHERTEXT_BYTES]) {
721
0
  vector u;
722
0
  vector_decode(&u, ciphertext, kDU);
723
0
  vector_decompress(&u, kDU);
724
0
  vector_ntt(&u);
725
0
  scalar v;
726
0
  scalar_decode(&v, ciphertext + kCompressedVectorSize, kDV);
727
0
  scalar_decompress(&v, kDV);
728
0
  scalar mask;
729
0
  scalar_inner_product(&mask, &priv->s, &u);
730
0
  scalar_inverse_ntt(&mask);
731
0
  scalar_sub(&v, &mask);
732
0
  scalar_compress(&v, 1);
733
0
  scalar_encode_1(out, &v);
734
0
}
735
736
// Algorithm 9 of the Kyber spec, performing the FO transform by running
737
// encrypt_cpa on the decrypted message. The spec does not allow the decryption
738
// failure to be passed on to the caller, and instead returns a result that is
739
// deterministic but unpredictable to anyone without knowledge of the private
740
// key.
741
void KYBER_decap(uint8_t *out_shared_secret, size_t out_shared_secret_len,
742
                 const uint8_t ciphertext[KYBER_CIPHERTEXT_BYTES],
743
0
                 const struct KYBER_private_key *private_key) {
744
0
  const struct private_key *priv = private_key_from_external(private_key);
745
0
  uint8_t decrypted[64];
746
0
  decrypt_cpa(decrypted, priv, ciphertext);
747
0
  OPENSSL_memcpy(decrypted + 32, priv->pub.public_key_hash,
748
0
                 sizeof(decrypted) - 32);
749
0
  uint8_t prekey_and_randomness[64];
750
0
  BORINGSSL_keccak(prekey_and_randomness, sizeof(prekey_and_randomness),
751
0
                   decrypted, sizeof(decrypted), boringssl_sha3_512);
752
0
  uint8_t expected_ciphertext[KYBER_CIPHERTEXT_BYTES];
753
0
  encrypt_cpa(expected_ciphertext, &priv->pub, decrypted,
754
0
              prekey_and_randomness + 32);
755
0
  uint8_t mask =
756
0
      constant_time_eq_int_8(CRYPTO_memcmp(ciphertext, expected_ciphertext,
757
0
                                           sizeof(expected_ciphertext)),
758
0
                             0);
759
0
  uint8_t input[64];
760
0
  for (int i = 0; i < 32; i++) {
761
0
    input[i] = constant_time_select_8(mask, prekey_and_randomness[i],
762
0
                                      priv->fo_failure_secret[i]);
763
0
  }
764
0
  BORINGSSL_keccak(input + 32, 32, ciphertext, KYBER_CIPHERTEXT_BYTES,
765
0
                   boringssl_sha3_256);
766
0
  BORINGSSL_keccak(out_shared_secret, out_shared_secret_len, input,
767
0
                   sizeof(input), boringssl_shake256);
768
0
}
769
770
int KYBER_marshal_public_key(CBB *out,
771
0
                             const struct KYBER_public_key *public_key) {
772
0
  return kyber_marshal_public_key(out, public_key_from_external(public_key));
773
0
}
774
775
// kyber_parse_public_key_no_hash parses |in| into |pub| but doesn't calculate
776
// the value of |pub->public_key_hash|.
777
0
static int kyber_parse_public_key_no_hash(struct public_key *pub, CBS *in) {
778
0
  CBS t_bytes;
779
0
  if (!CBS_get_bytes(in, &t_bytes, kEncodedVectorSize) ||
780
0
      !vector_decode(&pub->t, CBS_data(&t_bytes), kLog2Prime) ||
781
0
      !CBS_copy_bytes(in, pub->rho, sizeof(pub->rho))) {
782
0
    return 0;
783
0
  }
784
0
  matrix_expand(&pub->m, pub->rho);
785
0
  return 1;
786
0
}
787
788
0
int KYBER_parse_public_key(struct KYBER_public_key *public_key, CBS *in) {
789
0
  struct public_key *pub = public_key_from_external(public_key);
790
0
  CBS orig_in = *in;
791
0
  if (!kyber_parse_public_key_no_hash(pub, in) ||  //
792
0
      CBS_len(in) != 0) {
793
0
    return 0;
794
0
  }
795
0
  BORINGSSL_keccak(pub->public_key_hash, sizeof(pub->public_key_hash),
796
0
                   CBS_data(&orig_in), CBS_len(&orig_in), boringssl_sha3_256);
797
0
  return 1;
798
0
}
799
800
int KYBER_marshal_private_key(CBB *out,
801
0
                              const struct KYBER_private_key *private_key) {
802
0
  const struct private_key *const priv = private_key_from_external(private_key);
803
0
  uint8_t *s_output;
804
0
  if (!CBB_add_space(out, &s_output, kEncodedVectorSize)) {
805
0
    return 0;
806
0
  }
807
0
  vector_encode(s_output, &priv->s, kLog2Prime);
808
0
  if (!kyber_marshal_public_key(out, &priv->pub) ||
809
0
      !CBB_add_bytes(out, priv->pub.public_key_hash,
810
0
                     sizeof(priv->pub.public_key_hash)) ||
811
0
      !CBB_add_bytes(out, priv->fo_failure_secret,
812
0
                     sizeof(priv->fo_failure_secret))) {
813
0
    return 0;
814
0
  }
815
0
  return 1;
816
0
}
817
818
int KYBER_parse_private_key(struct KYBER_private_key *out_private_key,
819
0
                            CBS *in) {
820
0
  struct private_key *const priv = private_key_from_external(out_private_key);
821
822
0
  CBS s_bytes;
823
0
  if (!CBS_get_bytes(in, &s_bytes, kEncodedVectorSize) ||
824
0
      !vector_decode(&priv->s, CBS_data(&s_bytes), kLog2Prime) ||
825
0
      !kyber_parse_public_key_no_hash(&priv->pub, in) ||
826
0
      !CBS_copy_bytes(in, priv->pub.public_key_hash,
827
0
                      sizeof(priv->pub.public_key_hash)) ||
828
0
      !CBS_copy_bytes(in, priv->fo_failure_secret,
829
0
                      sizeof(priv->fo_failure_secret)) ||
830
0
      CBS_len(in) != 0) {
831
0
    return 0;
832
0
  }
833
0
  return 1;
834
0
}