Coverage Report

Created: 2026-04-15 06:25

next uncovered line (L), next uncovered region (R), next uncovered branch (B)
/src/boringssl/crypto/kyber/kyber.cc
Line
Count
Source
1
// Copyright 2023 The BoringSSL Authors
2
//
3
// Licensed under the Apache License, Version 2.0 (the "License");
4
// you may not use this file except in compliance with the License.
5
// You may obtain a copy of the License at
6
//
7
//     https://www.apache.org/licenses/LICENSE-2.0
8
//
9
// Unless required by applicable law or agreed to in writing, software
10
// distributed under the License is distributed on an "AS IS" BASIS,
11
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
// See the License for the specific language governing permissions and
13
// limitations under the License.
14
15
#include <assert.h>
16
#include <stdlib.h>
17
18
#include <openssl/bytestring.h>
19
#include <openssl/rand.h>
20
21
#include "../fipsmodule/keccak/internal.h"
22
#include "../internal.h"
23
#include "./internal.h"
24
25
26
// See
27
// https://pq-crystals.org/kyber/data/kyber-specification-round3-20210804.pdf
28
29
using namespace bssl;
30
31
1.45k
static void prf(uint8_t *out, size_t out_len, const uint8_t in[33]) {
32
1.45k
  BORINGSSL_keccak(out, out_len, in, 33, boringssl_shake256);
33
1.45k
}
34
35
332
static void hash_h(uint8_t out[32], const uint8_t *in, size_t len) {
36
332
  BORINGSSL_keccak(out, 32, in, len, boringssl_sha3_256);
37
332
}
38
39
225
static void hash_g(uint8_t out[64], const uint8_t *in, size_t len) {
40
225
  BORINGSSL_keccak(out, 64, in, len, boringssl_sha3_512);
41
225
}
42
43
107
static void kdf(uint8_t *out, size_t out_len, const uint8_t *in, size_t len) {
44
107
  BORINGSSL_keccak(out, out_len, in, len, boringssl_shake256);
45
107
}
46
47
2.47M
#define DEGREE 256
48
13.9k
#define RANK 3
49
50
static const size_t kBarrettMultiplier = 5039;
51
static const unsigned kBarrettShift = 24;
52
static const uint16_t kPrime = 3329;
53
static const int kLog2Prime = 12;
54
static const uint16_t kHalfPrime = (/*kPrime=*/3329 - 1) / 2;
55
static const int kDU = 10;
56
static const int kDV = 4;
57
// kInverseDegree is 128^-1 mod 3329; 128 because kPrime does not have a 512th
58
// root of unity.
59
static const uint16_t kInverseDegree = 3303;
60
static const size_t kEncodedVectorSize =
61
    (/*kLog2Prime=*/12 * DEGREE / 8) * RANK;
62
static const size_t kCompressedVectorSize = /*kDU=*/10 * RANK * DEGREE / 8;
63
64
typedef struct scalar {
65
  // On every function entry and exit, 0 <= c < kPrime.
66
  uint16_t c[DEGREE];
67
} scalar;
68
69
typedef struct vector {
70
  scalar v[RANK];
71
} vector;
72
73
typedef struct matrix {
74
  scalar v[RANK][RANK];
75
} matrix;
76
77
// This bit of Python will be referenced in some of the following comments:
78
//
79
// p = 3329
80
//
81
// def bitreverse(i):
82
//     ret = 0
83
//     for n in range(7):
84
//         bit = i & 1
85
//         ret <<= 1
86
//         ret |= bit
87
//         i >>= 1
88
//     return ret
89
90
// kNTTRoots = [pow(17, bitreverse(i), p) for i in range(128)]
91
static const uint16_t kNTTRoots[128] = {
92
    1,    1729, 2580, 3289, 2642, 630,  1897, 848,  1062, 1919, 193,  797,
93
    2786, 3260, 569,  1746, 296,  2447, 1339, 1476, 3046, 56,   2240, 1333,
94
    1426, 2094, 535,  2882, 2393, 2879, 1974, 821,  289,  331,  3253, 1756,
95
    1197, 2304, 2277, 2055, 650,  1977, 2513, 632,  2865, 33,   1320, 1915,
96
    2319, 1435, 807,  452,  1438, 2868, 1534, 2402, 2647, 2617, 1481, 648,
97
    2474, 3110, 1227, 910,  17,   2761, 583,  2649, 1637, 723,  2288, 1100,
98
    1409, 2662, 3281, 233,  756,  2156, 3015, 3050, 1703, 1651, 2789, 1789,
99
    1847, 952,  1461, 2687, 939,  2308, 2437, 2388, 733,  2337, 268,  641,
100
    1584, 2298, 2037, 3220, 375,  2549, 2090, 1645, 1063, 319,  2773, 757,
101
    2099, 561,  2466, 2594, 2804, 1092, 403,  1026, 1143, 2150, 2775, 886,
102
    1722, 1212, 1874, 1029, 2110, 2935, 885,  2154,
103
};
104
105
// kInverseNTTRoots = [pow(17, -bitreverse(i), p) for i in range(128)]
106
static const uint16_t kInverseNTTRoots[128] = {
107
    1,    1600, 40,   749,  2481, 1432, 2699, 687,  1583, 2760, 69,   543,
108
    2532, 3136, 1410, 2267, 2508, 1355, 450,  936,  447,  2794, 1235, 1903,
109
    1996, 1089, 3273, 283,  1853, 1990, 882,  3033, 2419, 2102, 219,  855,
110
    2681, 1848, 712,  682,  927,  1795, 461,  1891, 2877, 2522, 1894, 1010,
111
    1414, 2009, 3296, 464,  2697, 816,  1352, 2679, 1274, 1052, 1025, 2132,
112
    1573, 76,   2998, 3040, 1175, 2444, 394,  1219, 2300, 1455, 2117, 1607,
113
    2443, 554,  1179, 2186, 2303, 2926, 2237, 525,  735,  863,  2768, 1230,
114
    2572, 556,  3010, 2266, 1684, 1239, 780,  2954, 109,  1292, 1031, 1745,
115
    2688, 3061, 992,  2596, 941,  892,  1021, 2390, 642,  1868, 2377, 1482,
116
    1540, 540,  1678, 1626, 279,  314,  1173, 2573, 3096, 48,   667,  1920,
117
    2229, 1041, 2606, 1692, 680,  2746, 568,  3312,
118
};
119
120
// kModRoots = [pow(17, 2*bitreverse(i) + 1, p) for i in range(128)]
121
static const uint16_t kModRoots[128] = {
122
    17,   3312, 2761, 568,  583,  2746, 2649, 680,  1637, 1692, 723,  2606,
123
    2288, 1041, 1100, 2229, 1409, 1920, 2662, 667,  3281, 48,   233,  3096,
124
    756,  2573, 2156, 1173, 3015, 314,  3050, 279,  1703, 1626, 1651, 1678,
125
    2789, 540,  1789, 1540, 1847, 1482, 952,  2377, 1461, 1868, 2687, 642,
126
    939,  2390, 2308, 1021, 2437, 892,  2388, 941,  733,  2596, 2337, 992,
127
    268,  3061, 641,  2688, 1584, 1745, 2298, 1031, 2037, 1292, 3220, 109,
128
    375,  2954, 2549, 780,  2090, 1239, 1645, 1684, 1063, 2266, 319,  3010,
129
    2773, 556,  757,  2572, 2099, 1230, 561,  2768, 2466, 863,  2594, 735,
130
    2804, 525,  1092, 2237, 403,  2926, 1026, 2303, 1143, 2186, 2150, 1179,
131
    2775, 554,  886,  2443, 1722, 1607, 1212, 2117, 1874, 1455, 1029, 2300,
132
    2110, 1219, 2935, 394,  885,  2444, 2154, 1175,
133
};
134
135
// reduce_once reduces 0 <= x < 2*kPrime, mod kPrime.
136
5.37M
static uint16_t reduce_once(uint16_t x) {
137
5.37M
  declassify_assert(x < 2 * kPrime);
138
5.37M
  const uint16_t subtracted = x - kPrime;
139
5.37M
  uint16_t mask = 0u - (subtracted >> 15);
140
  // Although this is a constant-time select, we omit a value barrier here.
141
  // Value barriers impede auto-vectorization (likely because it forces the
142
  // value to transit through a general-purpose register). On AArch64, this is a
143
  // difference of 2x.
144
  //
145
  // We usually add value barriers to selects because Clang turns consecutive
146
  // selects with the same condition into a branch instead of CMOV/CSEL. This
147
  // condition does not occur in Kyber, so omitting it seems to be safe so far,
148
  // but see |scalar_centered_binomial_distribution_eta_2_with_prf|.
149
5.37M
  return (mask & x) | (~mask & subtracted);
150
5.37M
}
151
152
// constant time reduce x mod kPrime using Barrett reduction. x must be less
153
// than kPrime + 2×kPrime².
154
2.31M
static uint16_t reduce(uint32_t x) {
155
2.31M
  declassify_assert(x < kPrime + 2u * kPrime * kPrime);
156
2.31M
  uint64_t product = (uint64_t)x * kBarrettMultiplier;
157
2.31M
  uint32_t quotient = (uint32_t)(product >> kBarrettShift);
158
2.31M
  uint32_t remainder = x - quotient * kPrime;
159
2.31M
  return reduce_once(remainder);
160
2.31M
}
161
162
107
static void scalar_zero(scalar *out) { OPENSSL_memset(out, 0, sizeof(*out)); }
163
164
225
static void vector_zero(vector *out) { OPENSSL_memset(out, 0, sizeof(*out)); }
165
166
// In place number theoretic transform of a given scalar.
167
// Note that Kyber's kPrime 3329 does not have a 512th root of unity, so this
168
// transform leaves off the last iteration of the usual FFT code, with the 128
169
// relevant roots of unity being stored in |kNTTRoots|. This means the output
170
// should be seen as 128 elements in GF(3329^2), with the coefficients of the
171
// elements being consecutive entries in |s->c|.
172
1.02k
static void scalar_ntt(scalar *s) {
173
1.02k
  int offset = DEGREE;
174
  // `int` is used here because using `size_t` throughout caused a ~5% slowdown
175
  // with Clang 14 on Aarch64.
176
8.23k
  for (int step = 1; step < DEGREE / 2; step <<= 1) {
177
7.20k
    offset >>= 1;
178
7.20k
    int k = 0;
179
137k
    for (int i = 0; i < step; i++) {
180
130k
      const uint32_t step_root = kNTTRoots[i + step];
181
1.05M
      for (int j = k; j < k + offset; j++) {
182
921k
        uint16_t odd = reduce(step_root * s->c[j + offset]);
183
921k
        uint16_t even = s->c[j];
184
921k
        s->c[j] = reduce_once(odd + even);
185
921k
        s->c[j + offset] = reduce_once(even - odd + kPrime);
186
921k
      }
187
130k
      k += 2 * offset;
188
130k
    }
189
7.20k
  }
190
1.02k
}
191
192
343
static void vector_ntt(vector *a) {
193
1.37k
  for (int i = 0; i < RANK; i++) {
194
1.02k
    scalar_ntt(&a->v[i]);
195
1.02k
  }
196
343
}
197
198
// In place inverse number theoretic transform of a given scalar, with pairs of
199
// entries of s->v being interpreted as elements of GF(3329^2). Just as with the
200
// number theoretic transform, this leaves off the first step of the normal iFFT
201
// to account for the fact that 3329 does not have a 512th root of unity, using
202
// the precomputed 128 roots of unity stored in |kInverseNTTRoots|.
203
428
static void scalar_inverse_ntt(scalar *s) {
204
428
  int step = DEGREE / 2;
205
  // `int` is used here because using `size_t` throughout caused a ~5% slowdown
206
  // with Clang 14 on Aarch64.
207
3.42k
  for (int offset = 2; offset < DEGREE; offset <<= 1) {
208
2.99k
    step >>= 1;
209
2.99k
    int k = 0;
210
57.3k
    for (int i = 0; i < step; i++) {
211
54.3k
      uint32_t step_root = kInverseNTTRoots[i + step];
212
437k
      for (int j = k; j < k + offset; j++) {
213
383k
        uint16_t odd = s->c[j + offset];
214
383k
        uint16_t even = s->c[j];
215
383k
        s->c[j] = reduce_once(odd + even);
216
383k
        s->c[j + offset] = reduce(step_root * (even - odd + kPrime));
217
383k
      }
218
54.3k
      k += 2 * offset;
219
54.3k
    }
220
2.99k
  }
221
109k
  for (int i = 0; i < DEGREE; i++) {
222
109k
    s->c[i] = reduce(s->c[i] * kInverseDegree);
223
109k
  }
224
428
}
225
226
107
static void vector_inverse_ntt(vector *a) {
227
428
  for (int i = 0; i < RANK; i++) {
228
321
    scalar_inverse_ntt(&a->v[i]);
229
321
  }
230
107
}
231
232
3.23k
static void scalar_add(scalar *lhs, const scalar *rhs) {
233
831k
  for (int i = 0; i < DEGREE; i++) {
234
828k
    lhs->c[i] = reduce_once(lhs->c[i] + rhs->c[i]);
235
828k
  }
236
3.23k
}
237
238
0
static void scalar_sub(scalar *lhs, const scalar *rhs) {
239
0
  for (int i = 0; i < DEGREE; i++) {
240
0
    lhs->c[i] = reduce_once(lhs->c[i] - rhs->c[i] + kPrime);
241
0
  }
242
0
}
243
244
// Multiplying two scalars in the number theoretically transformed state. Since
245
// 3329 does not have a 512th root of unity, this means we have to interpret
246
// the 2*ith and (2*i+1)th entries of the scalar as elements of GF(3329)[X]/(X^2
247
// - 17^(2*bitreverse(i)+1)) The value of 17^(2*bitreverse(i)+1) mod 3329 is
248
// stored in the precomputed |kModRoots| table. Note that our Barrett transform
249
// only allows us to multiply two reduced numbers together, so we need some
250
// intermediate reduction steps, even if an uint64_t could hold 3 multiplied
251
// numbers.
252
2.34k
static void scalar_mult(scalar *out, const scalar *lhs, const scalar *rhs) {
253
302k
  for (int i = 0; i < DEGREE / 2; i++) {
254
300k
    uint32_t real_real = (uint32_t)lhs->c[2 * i] * rhs->c[2 * i];
255
300k
    uint32_t img_img = (uint32_t)lhs->c[2 * i + 1] * rhs->c[2 * i + 1];
256
300k
    uint32_t real_img = (uint32_t)lhs->c[2 * i] * rhs->c[2 * i + 1];
257
300k
    uint32_t img_real = (uint32_t)lhs->c[2 * i + 1] * rhs->c[2 * i];
258
300k
    out->c[2 * i] =
259
300k
        reduce(real_real + (uint32_t)reduce(img_img) * kModRoots[i]);
260
300k
    out->c[2 * i + 1] = reduce(img_real + real_img);
261
300k
  }
262
2.34k
}
263
264
225
static void vector_add(vector *lhs, const vector *rhs) {
265
900
  for (int i = 0; i < RANK; i++) {
266
675
    scalar_add(&lhs->v[i], &rhs->v[i]);
267
675
  }
268
225
}
269
270
107
static void matrix_mult(vector *out, const matrix *m, const vector *a) {
271
107
  vector_zero(out);
272
428
  for (int i = 0; i < RANK; i++) {
273
1.28k
    for (int j = 0; j < RANK; j++) {
274
963
      scalar product;
275
963
      scalar_mult(&product, &m->v[i][j], &a->v[j]);
276
963
      scalar_add(&out->v[i], &product);
277
963
    }
278
321
  }
279
107
}
280
281
static void matrix_mult_transpose(vector *out, const matrix *m,
282
118
                                  const vector *a) {
283
118
  vector_zero(out);
284
472
  for (int i = 0; i < RANK; i++) {
285
1.41k
    for (int j = 0; j < RANK; j++) {
286
1.06k
      scalar product;
287
1.06k
      scalar_mult(&product, &m->v[j][i], &a->v[j]);
288
1.06k
      scalar_add(&out->v[i], &product);
289
1.06k
    }
290
354
  }
291
118
}
292
293
static void scalar_inner_product(scalar *out, const vector *lhs,
294
107
                                 const vector *rhs) {
295
107
  scalar_zero(out);
296
428
  for (int i = 0; i < RANK; i++) {
297
321
    scalar product;
298
321
    scalar_mult(&product, &lhs->v[i], &rhs->v[i]);
299
321
    scalar_add(out, &product);
300
321
  }
301
107
}
302
303
// Algorithm 1 of the Kyber spec. Rejection samples a Keccak stream to get
304
// uniformly distributed elements. This is used for matrix expansion and only
305
// operates on public inputs.
306
static void scalar_from_keccak_vartime(scalar *out,
307
2.02k
                                       struct BORINGSSL_keccak_st *keccak_ctx) {
308
2.02k
  assert(keccak_ctx->squeeze_offset == 0);
309
2.02k
  assert(keccak_ctx->rate_bytes == 168);
310
2.02k
  static_assert(168 % 3 == 0, "block and coefficient boundaries do not align");
311
312
2.02k
  int done = 0;
313
8.12k
  while (done < DEGREE) {
314
6.09k
    uint8_t block[168];
315
6.09k
    BORINGSSL_keccak_squeeze(keccak_ctx, block, sizeof(block));
316
325k
    for (size_t i = 0; i < sizeof(block) && done < DEGREE; i += 3) {
317
319k
      uint16_t d1 = block[i] + 256 * (block[i + 1] % 16);
318
319k
      uint16_t d2 = block[i + 1] / 16 + 16 * block[i + 2];
319
319k
      if (d1 < kPrime) {
320
259k
        out->c[done++] = d1;
321
259k
      }
322
319k
      if (d2 < kPrime && done < DEGREE) {
323
258k
        out->c[done++] = d2;
324
258k
      }
325
319k
    }
326
6.09k
  }
327
2.02k
}
328
329
// Algorithm 2 of the Kyber spec, with eta fixed to two and the PRF call
330
// included. Creates binominally distributed elements by sampling 2*|eta| bits,
331
// and setting the coefficient to the count of the first bits minus the count of
332
// the second bits, resulting in a centered binomial distribution. Since eta is
333
// two this gives -2/2 with a probability of 1/16, -1/1 with probability 1/4,
334
// and 0 with probability 3/8.
335
static void scalar_centered_binomial_distribution_eta_2_with_prf(
336
1.45k
    scalar *out, const uint8_t input[33]) {
337
1.45k
  uint8_t entropy[128];
338
1.45k
  static_assert(sizeof(entropy) == 2 * /*kEta=*/2 * DEGREE / 8);
339
1.45k
  prf(entropy, sizeof(entropy), input);
340
341
187k
  for (int i = 0; i < DEGREE; i += 2) {
342
186k
    uint8_t byte = entropy[i / 2];
343
344
186k
    uint16_t value = (byte & 1) + ((byte >> 1) & 1);
345
186k
    value -= ((byte >> 2) & 1) + ((byte >> 3) & 1);
346
    // Add |kPrime| if |value| underflowed. See |reduce_once| for a discussion
347
    // on why the value barrier is omitted. While this could have been written
348
    // reduce_once(value + kPrime), this is one extra addition and small range
349
    // of |value| tempts some versions of Clang to emit a branch.
350
186k
    uint16_t mask = 0u - (value >> 15);
351
186k
    out->c[i] = value + (kPrime & mask);
352
353
186k
    byte >>= 4;
354
186k
    value = (byte & 1) + ((byte >> 1) & 1);
355
186k
    value -= ((byte >> 2) & 1) + ((byte >> 3) & 1);
356
    // See above.
357
186k
    mask = 0u - (value >> 15);
358
186k
    out->c[i + 1] = value + (kPrime & mask);
359
186k
  }
360
1.45k
}
361
362
// Generates a secret vector by using
363
// |scalar_centered_binomial_distribution_eta_2_with_prf|, using the given seed
364
// appending and incrementing |counter| for entry of the vector.
365
static void vector_generate_secret_eta_2(vector *out, uint8_t *counter,
366
450
                                         const uint8_t seed[32]) {
367
450
  uint8_t input[33];
368
450
  OPENSSL_memcpy(input, seed, 32);
369
1.80k
  for (int i = 0; i < RANK; i++) {
370
1.35k
    input[32] = (*counter)++;
371
1.35k
    scalar_centered_binomial_distribution_eta_2_with_prf(&out->v[i], input);
372
1.35k
  }
373
450
}
374
375
// Expands the matrix of a seed for key generation and for encaps-CPA.
376
225
static void matrix_expand(matrix *out, const uint8_t rho[32]) {
377
225
  uint8_t input[34];
378
225
  OPENSSL_memcpy(input, rho, 32);
379
900
  for (int i = 0; i < RANK; i++) {
380
2.70k
    for (int j = 0; j < RANK; j++) {
381
2.02k
      input[32] = i;
382
2.02k
      input[33] = j;
383
2.02k
      struct BORINGSSL_keccak_st keccak_ctx;
384
2.02k
      BORINGSSL_keccak_init(&keccak_ctx, boringssl_shake128);
385
2.02k
      BORINGSSL_keccak_absorb(&keccak_ctx, input, sizeof(input));
386
2.02k
      scalar_from_keccak_vartime(&out->v[i][j], &keccak_ctx);
387
2.02k
    }
388
675
  }
389
225
}
390
391
static const uint8_t kMasks[8] = {0x01, 0x03, 0x07, 0x0f,
392
                                  0x1f, 0x3f, 0x7f, 0xff};
393
394
782
static void scalar_encode(uint8_t *out, const scalar *s, int bits) {
395
782
  assert(bits <= (int)sizeof(*s->c) * 8 && bits != 1);
396
397
782
  uint8_t out_byte = 0;
398
782
  int out_byte_bits = 0;
399
400
200k
  for (int i = 0; i < DEGREE; i++) {
401
200k
    uint16_t element = s->c[i];
402
200k
    int element_bits_done = 0;
403
404
573k
    while (element_bits_done < bits) {
405
372k
      int chunk_bits = bits - element_bits_done;
406
372k
      int out_bits_remaining = 8 - out_byte_bits;
407
372k
      if (chunk_bits >= out_bits_remaining) {
408
252k
        chunk_bits = out_bits_remaining;
409
252k
        out_byte |= (element & kMasks[chunk_bits - 1]) << out_byte_bits;
410
252k
        *out = out_byte;
411
252k
        out++;
412
252k
        out_byte_bits = 0;
413
252k
        out_byte = 0;
414
252k
      } else {
415
120k
        out_byte |= (element & kMasks[chunk_bits - 1]) << out_byte_bits;
416
120k
        out_byte_bits += chunk_bits;
417
120k
      }
418
419
372k
      element_bits_done += chunk_bits;
420
372k
      element >>= chunk_bits;
421
372k
    }
422
200k
  }
423
424
782
  if (out_byte_bits > 0) {
425
0
    *out = out_byte;
426
0
  }
427
782
}
428
429
// scalar_encode_1 is |scalar_encode| specialised for |bits| == 1.
430
0
static void scalar_encode_1(uint8_t out[32], const scalar *s) {
431
0
  for (int i = 0; i < DEGREE; i += 8) {
432
0
    uint8_t out_byte = 0;
433
0
    for (int j = 0; j < 8; j++) {
434
0
      out_byte |= (s->c[i + j] & 1) << j;
435
0
    }
436
0
    *out = out_byte;
437
0
    out++;
438
0
  }
439
0
}
440
441
// Encodes an entire vector into 32*|RANK|*|bits| bytes. Note that since 256
442
// (DEGREE) is divisible by 8, the individual vector entries will always fill a
443
// whole number of bytes, so we do not need to worry about bit packing here.
444
225
static void vector_encode(uint8_t *out, const vector *a, int bits) {
445
900
  for (int i = 0; i < RANK; i++) {
446
675
    scalar_encode(out + i * bits * DEGREE / 8, &a->v[i], bits);
447
675
  }
448
225
}
449
450
// scalar_decode parses |DEGREE * bits| bits from |in| into |DEGREE| values in
451
// |out|. It returns one on success and zero if any parsed value is >=
452
// |kPrime|.
453
417
static int scalar_decode(scalar *out, const uint8_t *in, int bits) {
454
417
  assert(bits <= (int)sizeof(*out->c) * 8 && bits != 1);
455
456
417
  uint8_t in_byte = 0;
457
417
  int in_byte_bits_left = 0;
458
459
97.0k
  for (int i = 0; i < DEGREE; i++) {
460
96.6k
    uint16_t element = 0;
461
96.6k
    int element_bits_done = 0;
462
463
289k
    while (element_bits_done < bits) {
464
193k
      if (in_byte_bits_left == 0) {
465
145k
        in_byte = *in;
466
145k
        in++;
467
145k
        in_byte_bits_left = 8;
468
145k
      }
469
470
193k
      int chunk_bits = bits - element_bits_done;
471
193k
      if (chunk_bits > in_byte_bits_left) {
472
96.6k
        chunk_bits = in_byte_bits_left;
473
96.6k
      }
474
475
193k
      element |= (in_byte & kMasks[chunk_bits - 1]) << element_bits_done;
476
193k
      in_byte_bits_left -= chunk_bits;
477
193k
      in_byte >>= chunk_bits;
478
479
193k
      element_bits_done += chunk_bits;
480
193k
    }
481
482
    // An element is only out of range in the case of invalid input, in which
483
    // case it is okay to leak the comparison.
484
96.6k
    if (constant_time_declassify_int(element >= kPrime)) {
485
47
      return 0;
486
47
    }
487
96.6k
    out->c[i] = element;
488
96.6k
  }
489
490
370
  return 1;
491
417
}
492
493
// scalar_decode_1 is |scalar_decode| specialised for |bits| == 1.
494
107
static void scalar_decode_1(scalar *out, const uint8_t in[32]) {
495
3.53k
  for (int i = 0; i < DEGREE; i += 8) {
496
3.42k
    uint8_t in_byte = *in;
497
3.42k
    in++;
498
30.8k
    for (int j = 0; j < 8; j++) {
499
27.3k
      out->c[i + j] = in_byte & 1;
500
27.3k
      in_byte >>= 1;
501
27.3k
    }
502
3.42k
  }
503
107
}
504
505
// Decodes 32*|RANK|*|bits| bytes from |in| into |out|. It returns one on
506
// success or zero if any parsed value is >= |kPrime|.
507
154
static int vector_decode(vector *out, const uint8_t *in, int bits) {
508
524
  for (int i = 0; i < RANK; i++) {
509
417
    if (!scalar_decode(&out->v[i], in + i * bits * DEGREE / 8, bits)) {
510
47
      return 0;
511
47
    }
512
417
  }
513
107
  return 1;
514
154
}
515
516
// Compresses (lossily) an input |x| mod 3329 into |bits| many bits by grouping
517
// numbers close to each other together. The formula used is
518
// round(2^|bits|/kPrime*x) mod 2^|bits|.
519
// Uses Barrett reduction to achieve constant time. Since we need both the
520
// remainder (for rounding) and the quotient (as the result), we cannot use
521
// |reduce| here, but need to do the Barrett reduction directly.
522
109k
static uint16_t compress(uint16_t x, int bits) {
523
109k
  uint32_t shifted = (uint32_t)x << bits;
524
109k
  uint64_t product = (uint64_t)shifted * kBarrettMultiplier;
525
109k
  uint32_t quotient = (uint32_t)(product >> kBarrettShift);
526
109k
  uint32_t remainder = shifted - quotient * kPrime;
527
528
  // Adjust the quotient to round correctly:
529
  //   0 <= remainder <= kHalfPrime round to 0
530
  //   kHalfPrime < remainder <= kPrime + kHalfPrime round to 1
531
  //   kPrime + kHalfPrime < remainder < 2 * kPrime round to 2
532
109k
  declassify_assert(remainder < 2u * kPrime);
533
109k
  quotient += 1 & constant_time_lt_w(kHalfPrime, remainder);
534
109k
  quotient += 1 & constant_time_lt_w(kPrime + kHalfPrime, remainder);
535
109k
  return quotient & ((1 << bits) - 1);
536
109k
}
537
538
// Decompresses |x| by using an equi-distant representative. The formula is
539
// round(kPrime/2^|bits|*x). Note that 2^|bits| being the divisor allows us to
540
// implement this logic using only bit operations.
541
27.3k
static uint16_t decompress(uint16_t x, int bits) {
542
27.3k
  uint32_t product = (uint32_t)x * kPrime;
543
27.3k
  uint32_t power = 1 << bits;
544
  // This is |product| % power, since |power| is a power of 2.
545
27.3k
  uint32_t remainder = product & (power - 1);
546
  // This is |product| / power, since |power| is a power of 2.
547
27.3k
  uint32_t lower = product >> bits;
548
  // The rounding logic works since the first half of numbers mod |power| have a
549
  // 0 as first bit, and the second half has a 1 as first bit, since |power| is
550
  // a power of 2. As a 12 bit number, |remainder| is always positive, so we
551
  // will shift in 0s for a right shift.
552
27.3k
  return lower + (remainder >> (bits - 1));
553
27.3k
}
554
555
428
static void scalar_compress(scalar *s, int bits) {
556
109k
  for (int i = 0; i < DEGREE; i++) {
557
109k
    s->c[i] = compress(s->c[i], bits);
558
109k
  }
559
428
}
560
561
107
static void scalar_decompress(scalar *s, int bits) {
562
27.4k
  for (int i = 0; i < DEGREE; i++) {
563
27.3k
    s->c[i] = decompress(s->c[i], bits);
564
27.3k
  }
565
107
}
566
567
107
static void vector_compress(vector *a, int bits) {
568
428
  for (int i = 0; i < RANK; i++) {
569
321
    scalar_compress(&a->v[i], bits);
570
321
  }
571
107
}
572
573
0
static void vector_decompress(vector *a, int bits) {
574
0
  for (int i = 0; i < RANK; i++) {
575
0
    scalar_decompress(&a->v[i], bits);
576
0
  }
577
0
}
578
579
namespace {
580
581
struct public_key {
582
  vector t;
583
  uint8_t rho[32];
584
  uint8_t public_key_hash[32];
585
  matrix m;
586
};
587
588
static struct public_key *public_key_from_external(
589
261
    const struct KYBER_public_key *external) {
590
261
  static_assert(sizeof(struct KYBER_public_key) >= sizeof(struct public_key),
591
261
                "Kyber public key is too small");
592
261
  static_assert(alignof(struct KYBER_public_key) >= alignof(struct public_key),
593
261
                "Kyber public key align incorrect");
594
261
  return (struct public_key *)external;
595
261
}
596
597
struct private_key {
598
  struct public_key pub;
599
  vector s;
600
  uint8_t fo_failure_secret[32];
601
};
602
603
static struct private_key *private_key_from_external(
604
118
    const struct KYBER_private_key *external) {
605
118
  static_assert(sizeof(struct KYBER_private_key) >= sizeof(struct private_key),
606
118
                "Kyber private key too small");
607
118
  static_assert(
608
118
      alignof(struct KYBER_private_key) >= alignof(struct private_key),
609
118
      "Kyber private key align incorrect");
610
118
  return (struct private_key *)external;
611
118
}
612
613
}  // namespace
614
615
// Calls |KYBER_generate_key_external_entropy| with random bytes from
616
// |RAND_bytes|.
617
void bssl::KYBER_generate_key(
618
    uint8_t out_encoded_public_key[KYBER_PUBLIC_KEY_BYTES],
619
118
    struct KYBER_private_key *out_private_key) {
620
118
  uint8_t entropy[KYBER_GENERATE_KEY_ENTROPY];
621
118
  RAND_bytes(entropy, sizeof(entropy));
622
118
  CONSTTIME_SECRET(entropy, sizeof(entropy));
623
118
  KYBER_generate_key_external_entropy(out_encoded_public_key, out_private_key,
624
118
                                      entropy);
625
118
}
626
627
118
static int kyber_marshal_public_key(CBB *out, const struct public_key *pub) {
628
118
  uint8_t *vector_output;
629
118
  if (!CBB_add_space(out, &vector_output, kEncodedVectorSize)) {
630
0
    return 0;
631
0
  }
632
118
  vector_encode(vector_output, &pub->t, kLog2Prime);
633
118
  if (!CBB_add_bytes(out, pub->rho, sizeof(pub->rho))) {
634
0
    return 0;
635
0
  }
636
118
  return 1;
637
118
}
638
639
// Algorithms 4 and 7 of the Kyber spec. Algorithms are combined since key
640
// generation is not part of the FO transform, and the spec uses Algorithm 7 to
641
// specify the actual key format.
642
void bssl::KYBER_generate_key_external_entropy(
643
    uint8_t out_encoded_public_key[KYBER_PUBLIC_KEY_BYTES],
644
    struct KYBER_private_key *out_private_key,
645
118
    const uint8_t entropy[KYBER_GENERATE_KEY_ENTROPY]) {
646
118
  struct private_key *priv = private_key_from_external(out_private_key);
647
118
  uint8_t hashed[64];
648
118
  hash_g(hashed, entropy, 32);
649
118
  const uint8_t *const rho = hashed;
650
118
  const uint8_t *const sigma = hashed + 32;
651
  // rho is public.
652
118
  CONSTTIME_DECLASSIFY(rho, 32);
653
118
  OPENSSL_memcpy(priv->pub.rho, hashed, sizeof(priv->pub.rho));
654
118
  matrix_expand(&priv->pub.m, rho);
655
118
  uint8_t counter = 0;
656
118
  vector_generate_secret_eta_2(&priv->s, &counter, sigma);
657
118
  vector_ntt(&priv->s);
658
118
  vector error;
659
118
  vector_generate_secret_eta_2(&error, &counter, sigma);
660
118
  vector_ntt(&error);
661
118
  matrix_mult_transpose(&priv->pub.t, &priv->pub.m, &priv->s);
662
118
  vector_add(&priv->pub.t, &error);
663
  // t is part of the public key and thus is public.
664
118
  CONSTTIME_DECLASSIFY(&priv->pub.t, sizeof(priv->pub.t));
665
666
118
  CBB cbb;
667
118
  CBB_init_fixed(&cbb, out_encoded_public_key, KYBER_PUBLIC_KEY_BYTES);
668
118
  if (!kyber_marshal_public_key(&cbb, &priv->pub)) {
669
0
    abort();
670
0
  }
671
672
118
  hash_h(priv->pub.public_key_hash, out_encoded_public_key,
673
118
         KYBER_PUBLIC_KEY_BYTES);
674
118
  OPENSSL_memcpy(priv->fo_failure_secret, entropy + 32, 32);
675
118
}
676
677
void bssl::KYBER_public_from_private(
678
    struct KYBER_public_key *out_public_key,
679
0
    const struct KYBER_private_key *private_key) {
680
0
  struct public_key *const pub = public_key_from_external(out_public_key);
681
0
  const struct private_key *const priv = private_key_from_external(private_key);
682
0
  *pub = priv->pub;
683
0
}
684
685
// Algorithm 5 of the Kyber spec. Encrypts a message with given randomness to
686
// the ciphertext in |out|. Without applying the Fujisaki-Okamoto transform this
687
// would not result in a CCA secure scheme, since lattice schemes are vulnerable
688
// to decryption failure oracles.
689
static void encrypt_cpa(uint8_t out[KYBER_CIPHERTEXT_BYTES],
690
                        const struct public_key *pub, const uint8_t message[32],
691
107
                        const uint8_t randomness[32]) {
692
107
  uint8_t counter = 0;
693
107
  vector secret;
694
107
  vector_generate_secret_eta_2(&secret, &counter, randomness);
695
107
  vector_ntt(&secret);
696
107
  vector error;
697
107
  vector_generate_secret_eta_2(&error, &counter, randomness);
698
107
  uint8_t input[33];
699
107
  OPENSSL_memcpy(input, randomness, 32);
700
107
  input[32] = counter;
701
107
  scalar scalar_error;
702
107
  scalar_centered_binomial_distribution_eta_2_with_prf(&scalar_error, input);
703
107
  vector u;
704
107
  matrix_mult(&u, &pub->m, &secret);
705
107
  vector_inverse_ntt(&u);
706
107
  vector_add(&u, &error);
707
107
  scalar v;
708
107
  scalar_inner_product(&v, &pub->t, &secret);
709
107
  scalar_inverse_ntt(&v);
710
107
  scalar_add(&v, &scalar_error);
711
107
  scalar expanded_message;
712
107
  scalar_decode_1(&expanded_message, message);
713
107
  scalar_decompress(&expanded_message, 1);
714
107
  scalar_add(&v, &expanded_message);
715
107
  vector_compress(&u, kDU);
716
107
  vector_encode(out, &u, kDU);
717
107
  scalar_compress(&v, kDV);
718
107
  scalar_encode(out + kCompressedVectorSize, &v, kDV);
719
107
}
720
721
// Calls KYBER_encap_external_entropy| with random bytes from |RAND_bytes|
722
void bssl::KYBER_encap(uint8_t out_ciphertext[KYBER_CIPHERTEXT_BYTES],
723
                       uint8_t out_shared_secret[KYBER_SHARED_SECRET_BYTES],
724
107
                       const struct KYBER_public_key *public_key) {
725
107
  uint8_t entropy[KYBER_ENCAP_ENTROPY];
726
107
  RAND_bytes(entropy, KYBER_ENCAP_ENTROPY);
727
107
  CONSTTIME_SECRET(entropy, KYBER_ENCAP_ENTROPY);
728
107
  KYBER_encap_external_entropy(out_ciphertext, out_shared_secret, public_key,
729
107
                               entropy);
730
107
}
731
732
// Algorithm 8 of the Kyber spec, safe for line 2 of the spec. The spec there
733
// hashes the output of the system's random number generator, since the FO
734
// transform will reveal it to the decrypting party. There is no reason to do
735
// this when a secure random number generator is used. When an insecure random
736
// number generator is used, the caller should switch to a secure one before
737
// calling this method.
738
void bssl::KYBER_encap_external_entropy(
739
    uint8_t out_ciphertext[KYBER_CIPHERTEXT_BYTES],
740
    uint8_t out_shared_secret[KYBER_SHARED_SECRET_BYTES],
741
    const struct KYBER_public_key *public_key,
742
107
    const uint8_t entropy[KYBER_ENCAP_ENTROPY]) {
743
107
  const struct public_key *pub = public_key_from_external(public_key);
744
107
  uint8_t input[64];
745
107
  OPENSSL_memcpy(input, entropy, KYBER_ENCAP_ENTROPY);
746
107
  OPENSSL_memcpy(input + KYBER_ENCAP_ENTROPY, pub->public_key_hash,
747
107
                 sizeof(input) - KYBER_ENCAP_ENTROPY);
748
107
  uint8_t prekey_and_randomness[64];
749
107
  hash_g(prekey_and_randomness, input, sizeof(input));
750
107
  encrypt_cpa(out_ciphertext, pub, entropy, prekey_and_randomness + 32);
751
  // The ciphertext is public.
752
107
  CONSTTIME_DECLASSIFY(out_ciphertext, KYBER_CIPHERTEXT_BYTES);
753
107
  hash_h(prekey_and_randomness + 32, out_ciphertext, KYBER_CIPHERTEXT_BYTES);
754
107
  kdf(out_shared_secret, KYBER_SHARED_SECRET_BYTES, prekey_and_randomness,
755
107
      sizeof(prekey_and_randomness));
756
107
}
757
758
// Algorithm 6 of the Kyber spec.
759
static void decrypt_cpa(uint8_t out[32], const struct private_key *priv,
760
0
                        const uint8_t ciphertext[KYBER_CIPHERTEXT_BYTES]) {
761
0
  vector u;
762
0
  vector_decode(&u, ciphertext, kDU);
763
0
  vector_decompress(&u, kDU);
764
0
  vector_ntt(&u);
765
0
  scalar v;
766
0
  scalar_decode(&v, ciphertext + kCompressedVectorSize, kDV);
767
0
  scalar_decompress(&v, kDV);
768
0
  scalar mask;
769
0
  scalar_inner_product(&mask, &priv->s, &u);
770
0
  scalar_inverse_ntt(&mask);
771
0
  scalar_sub(&v, &mask);
772
0
  scalar_compress(&v, 1);
773
0
  scalar_encode_1(out, &v);
774
0
}
775
776
// Algorithm 9 of the Kyber spec, performing the FO transform by running
777
// encrypt_cpa on the decrypted message. The spec does not allow the decryption
778
// failure to be passed on to the caller, and instead returns a result that is
779
// deterministic but unpredictable to anyone without knowledge of the private
780
// key.
781
void bssl::KYBER_decap(uint8_t out_shared_secret[KYBER_SHARED_SECRET_BYTES],
782
                       const uint8_t ciphertext[KYBER_CIPHERTEXT_BYTES],
783
0
                       const struct KYBER_private_key *private_key) {
784
0
  const struct private_key *priv = private_key_from_external(private_key);
785
0
  uint8_t decrypted[64];
786
0
  decrypt_cpa(decrypted, priv, ciphertext);
787
0
  OPENSSL_memcpy(decrypted + 32, priv->pub.public_key_hash,
788
0
                 sizeof(decrypted) - 32);
789
0
  uint8_t prekey_and_randomness[64];
790
0
  hash_g(prekey_and_randomness, decrypted, sizeof(decrypted));
791
0
  uint8_t expected_ciphertext[KYBER_CIPHERTEXT_BYTES];
792
0
  encrypt_cpa(expected_ciphertext, &priv->pub, decrypted,
793
0
              prekey_and_randomness + 32);
794
0
  uint8_t mask =
795
0
      constant_time_eq_int_8(CRYPTO_memcmp(ciphertext, expected_ciphertext,
796
0
                                           sizeof(expected_ciphertext)),
797
0
                             0);
798
0
  uint8_t input[64];
799
0
  for (int i = 0; i < 32; i++) {
800
0
    input[i] = constant_time_select_8(mask, prekey_and_randomness[i],
801
0
                                      priv->fo_failure_secret[i]);
802
0
  }
803
0
  hash_h(input + 32, ciphertext, KYBER_CIPHERTEXT_BYTES);
804
0
  kdf(out_shared_secret, KYBER_SHARED_SECRET_BYTES, input, sizeof(input));
805
0
}
806
807
int bssl::KYBER_marshal_public_key(CBB *out,
808
0
                                   const struct KYBER_public_key *public_key) {
809
0
  return kyber_marshal_public_key(out, public_key_from_external(public_key));
810
0
}
811
812
// kyber_parse_public_key_no_hash parses |in| into |pub| but doesn't calculate
813
// the value of |pub->public_key_hash|.
814
154
static int kyber_parse_public_key_no_hash(struct public_key *pub, CBS *in) {
815
154
  CBS t_bytes;
816
154
  if (!CBS_get_bytes(in, &t_bytes, kEncodedVectorSize) ||
817
154
      !vector_decode(&pub->t, CBS_data(&t_bytes), kLog2Prime) ||
818
107
      !CBS_copy_bytes(in, pub->rho, sizeof(pub->rho))) {
819
47
    return 0;
820
47
  }
821
107
  matrix_expand(&pub->m, pub->rho);
822
107
  return 1;
823
154
}
824
825
154
int bssl::KYBER_parse_public_key(struct KYBER_public_key *public_key, CBS *in) {
826
154
  struct public_key *pub = public_key_from_external(public_key);
827
154
  CBS orig_in = *in;
828
154
  if (!kyber_parse_public_key_no_hash(pub, in) ||  //
829
107
      CBS_len(in) != 0) {
830
47
    return 0;
831
47
  }
832
107
  hash_h(pub->public_key_hash, CBS_data(&orig_in), CBS_len(&orig_in));
833
107
  return 1;
834
154
}
835
836
int bssl::KYBER_marshal_private_key(
837
0
    CBB *out, const struct KYBER_private_key *private_key) {
838
0
  const struct private_key *const priv = private_key_from_external(private_key);
839
0
  uint8_t *s_output;
840
0
  if (!CBB_add_space(out, &s_output, kEncodedVectorSize)) {
841
0
    return 0;
842
0
  }
843
0
  vector_encode(s_output, &priv->s, kLog2Prime);
844
0
  if (!kyber_marshal_public_key(out, &priv->pub) ||
845
0
      !CBB_add_bytes(out, priv->pub.public_key_hash,
846
0
                     sizeof(priv->pub.public_key_hash)) ||
847
0
      !CBB_add_bytes(out, priv->fo_failure_secret,
848
0
                     sizeof(priv->fo_failure_secret))) {
849
0
    return 0;
850
0
  }
851
0
  return 1;
852
0
}
853
854
int bssl::KYBER_parse_private_key(struct KYBER_private_key *out_private_key,
855
0
                                  CBS *in) {
856
0
  struct private_key *const priv = private_key_from_external(out_private_key);
857
858
0
  CBS s_bytes;
859
0
  if (!CBS_get_bytes(in, &s_bytes, kEncodedVectorSize) ||
860
0
      !vector_decode(&priv->s, CBS_data(&s_bytes), kLog2Prime) ||
861
0
      !kyber_parse_public_key_no_hash(&priv->pub, in) ||
862
0
      !CBS_copy_bytes(in, priv->pub.public_key_hash,
863
0
                      sizeof(priv->pub.public_key_hash)) ||
864
0
      !CBS_copy_bytes(in, priv->fo_failure_secret,
865
0
                      sizeof(priv->fo_failure_secret)) ||
866
0
      CBS_len(in) != 0) {
867
0
    return 0;
868
0
  }
869
0
  return 1;
870
0
}