Coverage Report

Created: 2022-06-23 06:44

/src/botan/src/lib/pubkey/ec_h2c/ec_h2c.cpp
Line
Count
Source (jump to first uncovered line)
1
/*
2
* (C) 2019,2020,2021 Jack Lloyd
3
*
4
* Botan is released under the Simplified BSD License (see license.txt)
5
*/
6
7
#include <botan/internal/ec_h2c.h>
8
#include <botan/ec_group.h>
9
#include <botan/numthry.h>
10
#include <botan/reducer.h>
11
#include <botan/hash.h>
12
13
namespace Botan {
14
15
void expand_message_xmd(const std::string& hash_fn,
16
                        uint8_t output[],
17
                        size_t output_len,
18
                        const uint8_t input[],
19
                        size_t input_len,
20
                        const uint8_t domain_sep[],
21
                        size_t domain_sep_len)
22
0
   {
23
0
   if(domain_sep_len > 0xFF)
24
0
      throw Invalid_Argument("expand_message_xmd domain seperator too long");
25
26
0
   auto hash = HashFunction::create_or_throw(hash_fn);
27
0
   const size_t block_size = hash->hash_block_size();
28
0
   if(block_size == 0)
29
0
      throw Invalid_Argument("expand_message_xmd cannot be used with " + hash_fn);
30
31
0
   const size_t hash_output_size = hash->output_length();
32
0
   if(output_len > 255*hash_output_size || output_len > 0xFFFF)
33
0
      throw Invalid_Argument("expand_message_xmd requested output length too long");
34
35
   // Compute b_0 = H(msg_prime) = H(Z_pad || msg || l_i_b_str || 0x00 || DST_prime)
36
37
0
   hash->update(std::vector<uint8_t>(block_size));
38
0
   hash->update(input, input_len);
39
0
   hash->update_be(static_cast<uint16_t>(output_len));
40
0
   hash->update(0x00);
41
0
   hash->update(domain_sep, domain_sep_len);
42
0
   hash->update(static_cast<uint8_t>(domain_sep_len));
43
44
0
   const secure_vector<uint8_t> b_0 = hash->final();
45
46
   // Compute b_1 = H(b_0 || 0x01 || DST_prime)
47
48
0
   hash->update(b_0);
49
0
   hash->update(0x01);
50
0
   hash->update(domain_sep, domain_sep_len);
51
0
   hash->update(static_cast<uint8_t>(domain_sep_len));
52
53
0
   secure_vector<uint8_t> b_i = hash->final();
54
55
0
   uint8_t cnt = 2;
56
0
   while(output_len > 0)
57
0
      {
58
0
      const size_t produced = std::min(output_len, hash_output_size);
59
60
0
      copy_mem(output, b_i.data(), produced);
61
0
      output += produced;
62
0
      output_len -= produced;
63
64
      // Now compute the next b_i
65
66
0
      b_i ^= b_0;
67
0
      hash->update(b_i);
68
0
      hash->update(cnt);
69
0
      hash->update(domain_sep, domain_sep_len);
70
0
      hash->update(static_cast<uint8_t>(domain_sep_len));
71
0
      hash->final(b_i.data());
72
0
      cnt += 1;
73
0
      }
74
0
   }
75
76
namespace {
77
78
std::vector<BigInt>
79
hash_to_field(const EC_Group& group,
80
              const Modular_Reducer& mod_p,
81
              const std::string& hash_fn,
82
              uint8_t count,
83
              const uint8_t input[], size_t input_len,
84
              const uint8_t domain_sep[], size_t domain_sep_len)
85
0
   {
86
0
   const size_t k = (group.get_order_bits() + 1) / 2;
87
0
   const size_t L = (group.get_p_bits() + k + 7) / 8;
88
89
0
   std::vector<BigInt> results;
90
0
   results.reserve(count);
91
92
0
   secure_vector<uint8_t> output(L * count);
93
0
   expand_message_xmd(hash_fn,
94
0
                      output.data(), output.size(),
95
0
                      input, input_len,
96
0
                      domain_sep, domain_sep_len);
97
98
0
   for(size_t i = 0; i != count; ++i)
99
0
      {
100
0
      BigInt v(&output[i*L], L);
101
0
      results.push_back(mod_p.reduce(v));
102
0
      }
103
104
0
   return results;
105
0
   }
106
107
BigInt sswu_z(const EC_Group& group)
108
0
   {
109
0
   const BigInt& p = group.get_p();
110
0
   const OID& oid = group.get_curve_oid();
111
112
0
   if(oid == OID{1,2,840,10045,3,1,7}) // secp256r1
113
0
      return p - 10;
114
0
   if(oid == OID{1,3,132,0,34}) // secp384r1
115
0
      return p - 12;
116
0
   if(oid == OID{1,3,132,0,35}) // secp521r1
117
0
      return p - 4;
118
119
0
   return 0;
120
0
   }
121
122
BigInt ct_choose(bool first, const BigInt& x, const BigInt& y)
123
0
   {
124
0
   BigInt z = y;
125
0
   z.ct_cond_assign(first, x);
126
0
   return z;
127
0
   }
128
129
PointGFp map_to_curve_sswu(const EC_Group& group, const Modular_Reducer& mod_p, const BigInt& u)
130
0
   {
131
0
   const BigInt& p = group.get_p();
132
0
   const BigInt& A = group.get_a();
133
0
   const BigInt& B = group.get_b();
134
0
   const BigInt Z = sswu_z(group);
135
136
0
   if(Z.is_zero() || A.is_zero() || B.is_zero() || p % 4 != 3)
137
0
      throw Invalid_Argument("map_to_curve_sswu does not support this curve");
138
139
   // These values could be precomputed:
140
0
   const BigInt c1 = mod_p.multiply(p - B, inverse_mod(A, p));
141
0
   const BigInt c2 = mod_p.multiply(p - 1, inverse_mod(Z, p));
142
143
   /*
144
   * See Appendix F.2 of draft-irtf-cfrg-hash-to-curve
145
   */
146
147
0
   const BigInt tv1 = mod_p.multiply(Z, mod_p.square(u));
148
0
   const BigInt tv2 = mod_p.square(tv1);
149
150
0
   BigInt x1 = inverse_mod(tv1 + tv2, p);
151
0
   const bool e1 = x1.is_zero();
152
0
   x1 += 1;
153
0
   x1.ct_cond_assign(e1, c2);
154
0
   x1 = mod_p.multiply(x1, c1);
155
156
   // gx1 = x1^3 + A*x1 + B;
157
0
   BigInt gx1 = mod_p.square(x1);
158
0
   gx1 += A;
159
0
   gx1 = mod_p.multiply(gx1, x1);
160
0
   gx1 += B;
161
0
   gx1 = mod_p.reduce(gx1);
162
163
0
   const BigInt x2 = mod_p.multiply(tv1, x1);
164
165
   // gx2 = (Z * u^2)^3 * gx1
166
0
   const BigInt gx2 = mod_p.multiply(gx1, mod_p.multiply(tv1, tv2));
167
168
   // assumes p % 4 == 3
169
0
   const bool gx1_is_square = (power_mod(gx1, (p-1)/2, p) <= 1);
170
171
0
   const BigInt x = ct_choose(gx1_is_square, x1, x2);
172
0
   const BigInt y2 = ct_choose(gx1_is_square, gx1, gx2);
173
174
   // assumes p % 4 == 3
175
0
   const BigInt y = power_mod(y2, (p + 1)/4, p);
176
0
   const BigInt neg_y = p - y;
177
178
0
   const bool uy_sign = u.get_bit(0) != y.get_bit(0);
179
0
   return group.point(x, ct_choose(uy_sign, neg_y, y));
180
0
   }
181
182
}
183
184
PointGFp hash_to_curve_sswu(const EC_Group& group,
185
                            const std::string& hash_fn,
186
                            const uint8_t input[],
187
                            size_t input_len,
188
                            const uint8_t domain_sep[],
189
                            size_t domain_sep_len,
190
                            bool random_oracle)
191
0
   {
192
0
   const Modular_Reducer mod_p(group.get_p());
193
194
0
   const uint8_t count = (random_oracle ? 2 : 1);
195
196
0
   const auto u = hash_to_field(group, mod_p, hash_fn, count,
197
0
                                input, input_len,
198
0
                                domain_sep, domain_sep_len);
199
200
0
   PointGFp pt = map_to_curve_sswu(group, mod_p, u[0]);
201
202
0
   for(size_t i = 1; i != u.size(); ++i)
203
0
      pt += map_to_curve_sswu(group, mod_p, u[i]);
204
205
0
   return pt;
206
0
   }
207
208
}