Coverage Report

Created: 2020-10-17 06:46

/src/botan/src/lib/pubkey/sm2/sm2_enc.cpp
Line
Count
Source (jump to first uncovered line)
1
/*
2
* SM2 Encryption
3
* (C) 2017 Ribose Inc
4
*
5
* Botan is released under the Simplified BSD License (see license.txt)
6
*/
7
8
#include <botan/sm2.h>
9
#include <botan/internal/point_mul.h>
10
#include <botan/pk_ops.h>
11
#include <botan/der_enc.h>
12
#include <botan/ber_dec.h>
13
#include <botan/kdf.h>
14
#include <botan/hash.h>
15
16
namespace Botan {
17
18
namespace {
19
20
class SM2_Encryption_Operation final : public PK_Ops::Encryption
21
   {
22
   public:
23
      SM2_Encryption_Operation(const SM2_Encryption_PublicKey& key,
24
                               RandomNumberGenerator& rng,
25
                               const std::string& kdf_hash) :
26
         m_group(key.domain()),
27
         m_kdf_hash(kdf_hash),
28
         m_ws(PointGFp::WORKSPACE_SIZE),
29
         m_mul_public_point(key.public_point(), rng, m_ws)
30
0
         {
31
0
         std::unique_ptr<HashFunction> hash = HashFunction::create_or_throw(m_kdf_hash);
32
0
         m_hash_size = hash->output_length();
33
0
         }
34
35
      size_t max_input_bits() const override
36
0
         {
37
         // This is arbitrary, but assumes SM2 is used for key encapsulation
38
0
         return 512;
39
0
         }
40
41
      size_t ciphertext_length(size_t ptext_len) const override
42
0
         {
43
0
         const size_t elem_size = m_group.get_order_bytes();
44
0
         const size_t der_overhead = 16;
45
0
46
0
         return der_overhead + 2*elem_size + m_hash_size + ptext_len;
47
0
         }
48
49
      secure_vector<uint8_t> encrypt(const uint8_t msg[],
50
                                     size_t msg_len,
51
                                     RandomNumberGenerator& rng) override
52
0
         {
53
0
         std::unique_ptr<HashFunction> hash = HashFunction::create_or_throw(m_kdf_hash);
54
0
         std::unique_ptr<KDF> kdf = KDF::create_or_throw("KDF2(" + m_kdf_hash + ")");
55
0
56
0
         const size_t p_bytes = m_group.get_p_bytes();
57
0
58
0
         const BigInt k = m_group.random_scalar(rng);
59
0
60
0
         const PointGFp C1 = m_group.blinded_base_point_multiply(k, rng, m_ws);
61
0
         const BigInt x1 = C1.get_affine_x();
62
0
         const BigInt y1 = C1.get_affine_y();
63
0
         std::vector<uint8_t> x1_bytes(p_bytes);
64
0
         std::vector<uint8_t> y1_bytes(p_bytes);
65
0
         BigInt::encode_1363(x1_bytes.data(), x1_bytes.size(), x1);
66
0
         BigInt::encode_1363(y1_bytes.data(), y1_bytes.size(), y1);
67
0
68
0
         const PointGFp kPB = m_mul_public_point.mul(k, rng, m_group.get_order(), m_ws);
69
0
70
0
         const BigInt x2 = kPB.get_affine_x();
71
0
         const BigInt y2 = kPB.get_affine_y();
72
0
         std::vector<uint8_t> x2_bytes(p_bytes);
73
0
         std::vector<uint8_t> y2_bytes(p_bytes);
74
0
         BigInt::encode_1363(x2_bytes.data(), x2_bytes.size(), x2);
75
0
         BigInt::encode_1363(y2_bytes.data(), y2_bytes.size(), y2);
76
0
77
0
         secure_vector<uint8_t> kdf_input;
78
0
         kdf_input += x2_bytes;
79
0
         kdf_input += y2_bytes;
80
0
81
0
         const secure_vector<uint8_t> kdf_output =
82
0
            kdf->derive_key(msg_len, kdf_input.data(), kdf_input.size());
83
0
84
0
         secure_vector<uint8_t> masked_msg(msg_len);
85
0
         xor_buf(masked_msg.data(), msg, kdf_output.data(), msg_len);
86
0
87
0
         hash->update(x2_bytes);
88
0
         hash->update(msg, msg_len);
89
0
         hash->update(y2_bytes);
90
0
         std::vector<uint8_t> C3(hash->output_length());
91
0
         hash->final(C3.data());
92
0
93
0
         return DER_Encoder()
94
0
            .start_cons(SEQUENCE)
95
0
            .encode(x1)
96
0
            .encode(y1)
97
0
            .encode(C3, OCTET_STRING)
98
0
            .encode(masked_msg, OCTET_STRING)
99
0
            .end_cons()
100
0
            .get_contents();
101
0
         }
102
103
   private:
104
      const EC_Group m_group;
105
      const std::string m_kdf_hash;
106
107
      std::vector<BigInt> m_ws;
108
      PointGFp_Var_Point_Precompute m_mul_public_point;
109
      size_t m_hash_size;
110
   };
111
112
class SM2_Decryption_Operation final : public PK_Ops::Decryption
113
   {
114
   public:
115
      SM2_Decryption_Operation(const SM2_Encryption_PrivateKey& key,
116
                               RandomNumberGenerator& rng,
117
                               const std::string& kdf_hash) :
118
         m_key(key),
119
         m_rng(rng),
120
         m_kdf_hash(kdf_hash)
121
0
         {
122
0
         std::unique_ptr<HashFunction> hash = HashFunction::create_or_throw(m_kdf_hash);
123
0
         m_hash_size = hash->output_length();
124
0
         }
125
126
      size_t plaintext_length(size_t ptext_len) const override
127
0
         {
128
         /*
129
         * This ignores the DER encoding and so overestimates the
130
         * plaintext length by 12 bytes or so
131
         */
132
0
         const size_t elem_size = m_key.domain().get_order_bytes();
133
0
134
0
         if(ptext_len < 2*elem_size + m_hash_size)
135
0
            return 0;
136
0
137
0
         return ptext_len - (2*elem_size + m_hash_size);
138
0
         }
139
140
      secure_vector<uint8_t> decrypt(uint8_t& valid_mask,
141
                                     const uint8_t ciphertext[],
142
                                     size_t ciphertext_len) override
143
0
         {
144
0
         const EC_Group& group = m_key.domain();
145
0
         const BigInt& cofactor = group.get_cofactor();
146
0
         const size_t p_bytes = group.get_p_bytes();
147
0
148
0
         valid_mask = 0x00;
149
0
150
0
         std::unique_ptr<HashFunction> hash = HashFunction::create_or_throw(m_kdf_hash);
151
0
         std::unique_ptr<KDF> kdf = KDF::create_or_throw("KDF2(" + m_kdf_hash + ")");
152
0
153
         // Too short to be valid - no timing problem from early return
154
0
         if(ciphertext_len < 1 + p_bytes*2 + hash->output_length())
155
0
            {
156
0
            return secure_vector<uint8_t>();
157
0
            }
158
0
159
0
         BigInt x1, y1;
160
0
         secure_vector<uint8_t> C3, masked_msg;
161
0
162
0
         BER_Decoder(ciphertext, ciphertext_len)
163
0
            .start_cons(SEQUENCE)
164
0
            .decode(x1)
165
0
            .decode(y1)
166
0
            .decode(C3, OCTET_STRING)
167
0
            .decode(masked_msg, OCTET_STRING)
168
0
            .end_cons()
169
0
            .verify_end();
170
0
171
0
         std::vector<uint8_t> recode_ctext;
172
0
         DER_Encoder(recode_ctext)
173
0
            .start_cons(SEQUENCE)
174
0
            .encode(x1)
175
0
            .encode(y1)
176
0
            .encode(C3, OCTET_STRING)
177
0
            .encode(masked_msg, OCTET_STRING)
178
0
            .end_cons();
179
0
180
0
         if(recode_ctext.size() != ciphertext_len)
181
0
            return secure_vector<uint8_t>();
182
0
183
0
         if(same_mem(recode_ctext.data(), ciphertext, ciphertext_len) == false)
184
0
            return secure_vector<uint8_t>();
185
0
186
0
         PointGFp C1 = group.point(x1, y1);
187
0
         C1.randomize_repr(m_rng);
188
0
189
         // Here C1 is publically invalid, so no problem with early return:
190
0
         if(!C1.on_the_curve())
191
0
            return secure_vector<uint8_t>();
192
0
193
0
         if(cofactor > 1 && (C1 * cofactor).is_zero())
194
0
            {
195
0
            return secure_vector<uint8_t>();
196
0
            }
197
0
198
0
         const PointGFp dbC1 = group.blinded_var_point_multiply(
199
0
            C1, m_key.private_value(), m_rng, m_ws);
200
0
201
0
         const BigInt x2 = dbC1.get_affine_x();
202
0
         const BigInt y2 = dbC1.get_affine_y();
203
0
204
0
         secure_vector<uint8_t> x2_bytes(p_bytes);
205
0
         secure_vector<uint8_t> y2_bytes(p_bytes);
206
0
         BigInt::encode_1363(x2_bytes.data(), x2_bytes.size(), x2);
207
0
         BigInt::encode_1363(y2_bytes.data(), y2_bytes.size(), y2);
208
0
209
0
         secure_vector<uint8_t> kdf_input;
210
0
         kdf_input += x2_bytes;
211
0
         kdf_input += y2_bytes;
212
0
213
0
         const secure_vector<uint8_t> kdf_output =
214
0
            kdf->derive_key(masked_msg.size(), kdf_input.data(), kdf_input.size());
215
0
216
0
         xor_buf(masked_msg.data(), kdf_output.data(), kdf_output.size());
217
0
218
0
         hash->update(x2_bytes);
219
0
         hash->update(masked_msg);
220
0
         hash->update(y2_bytes);
221
0
         secure_vector<uint8_t> u = hash->final();
222
0
223
0
         if(constant_time_compare(u.data(), C3.data(), hash->output_length()) == false)
224
0
            return secure_vector<uint8_t>();
225
0
226
0
         valid_mask = 0xFF;
227
0
         return masked_msg;
228
0
         }
229
   private:
230
      const SM2_Encryption_PrivateKey& m_key;
231
      RandomNumberGenerator& m_rng;
232
      const std::string m_kdf_hash;
233
      std::vector<BigInt> m_ws;
234
      size_t m_hash_size;
235
   };
236
237
}
238
239
std::unique_ptr<PK_Ops::Encryption>
240
SM2_PublicKey::create_encryption_op(RandomNumberGenerator& rng,
241
                                    const std::string& params,
242
                                    const std::string& provider) const
243
0
   {
244
0
   if(provider == "base" || provider.empty())
245
0
      {
246
0
      const std::string kdf_hash = (params.empty() ? "SM3" : params);
247
0
      return std::unique_ptr<PK_Ops::Encryption>(new SM2_Encryption_Operation(*this, rng, kdf_hash));
248
0
      }
249
0
250
0
   throw Provider_Not_Found(algo_name(), provider);
251
0
   }
252
253
std::unique_ptr<PK_Ops::Decryption>
254
SM2_PrivateKey::create_decryption_op(RandomNumberGenerator& rng,
255
                                     const std::string& params,
256
                                     const std::string& provider) const
257
0
   {
258
0
   if(provider == "base" || provider.empty())
259
0
      {
260
0
      const std::string kdf_hash = (params.empty() ? "SM3" : params);
261
0
      return std::unique_ptr<PK_Ops::Decryption>(new SM2_Decryption_Operation(*this, rng, kdf_hash));
262
0
      }
263
0
264
0
   throw Provider_Not_Found(algo_name(), provider);
265
0
   }
266
267
}