/src/botan/src/lib/pubkey/sm2/sm2_enc.cpp
Line | Count | Source (jump to first uncovered line) |
1 | | /* |
2 | | * SM2 Encryption |
3 | | * (C) 2017 Ribose Inc |
4 | | * |
5 | | * Botan is released under the Simplified BSD License (see license.txt) |
6 | | */ |
7 | | |
8 | | #include <botan/sm2.h> |
9 | | #include <botan/internal/point_mul.h> |
10 | | #include <botan/pk_ops.h> |
11 | | #include <botan/der_enc.h> |
12 | | #include <botan/ber_dec.h> |
13 | | #include <botan/kdf.h> |
14 | | #include <botan/hash.h> |
15 | | |
16 | | namespace Botan { |
17 | | |
18 | | namespace { |
19 | | |
20 | | class SM2_Encryption_Operation final : public PK_Ops::Encryption |
21 | | { |
22 | | public: |
23 | | SM2_Encryption_Operation(const SM2_Encryption_PublicKey& key, |
24 | | RandomNumberGenerator& rng, |
25 | | const std::string& kdf_hash) : |
26 | | m_group(key.domain()), |
27 | | m_kdf_hash(kdf_hash), |
28 | | m_ws(PointGFp::WORKSPACE_SIZE), |
29 | | m_mul_public_point(key.public_point(), rng, m_ws) |
30 | 0 | { |
31 | 0 | std::unique_ptr<HashFunction> hash = HashFunction::create_or_throw(m_kdf_hash); |
32 | 0 | m_hash_size = hash->output_length(); |
33 | 0 | } |
34 | | |
35 | | size_t max_input_bits() const override |
36 | 0 | { |
37 | | // This is arbitrary, but assumes SM2 is used for key encapsulation |
38 | 0 | return 512; |
39 | 0 | } |
40 | | |
41 | | size_t ciphertext_length(size_t ptext_len) const override |
42 | 0 | { |
43 | 0 | const size_t elem_size = m_group.get_order_bytes(); |
44 | 0 | const size_t der_overhead = 16; |
45 | 0 |
|
46 | 0 | return der_overhead + 2*elem_size + m_hash_size + ptext_len; |
47 | 0 | } |
48 | | |
49 | | secure_vector<uint8_t> encrypt(const uint8_t msg[], |
50 | | size_t msg_len, |
51 | | RandomNumberGenerator& rng) override |
52 | 0 | { |
53 | 0 | std::unique_ptr<HashFunction> hash = HashFunction::create_or_throw(m_kdf_hash); |
54 | 0 | std::unique_ptr<KDF> kdf = KDF::create_or_throw("KDF2(" + m_kdf_hash + ")"); |
55 | 0 |
|
56 | 0 | const size_t p_bytes = m_group.get_p_bytes(); |
57 | 0 |
|
58 | 0 | const BigInt k = m_group.random_scalar(rng); |
59 | 0 |
|
60 | 0 | const PointGFp C1 = m_group.blinded_base_point_multiply(k, rng, m_ws); |
61 | 0 | const BigInt x1 = C1.get_affine_x(); |
62 | 0 | const BigInt y1 = C1.get_affine_y(); |
63 | 0 | std::vector<uint8_t> x1_bytes(p_bytes); |
64 | 0 | std::vector<uint8_t> y1_bytes(p_bytes); |
65 | 0 | BigInt::encode_1363(x1_bytes.data(), x1_bytes.size(), x1); |
66 | 0 | BigInt::encode_1363(y1_bytes.data(), y1_bytes.size(), y1); |
67 | 0 |
|
68 | 0 | const PointGFp kPB = m_mul_public_point.mul(k, rng, m_group.get_order(), m_ws); |
69 | 0 |
|
70 | 0 | const BigInt x2 = kPB.get_affine_x(); |
71 | 0 | const BigInt y2 = kPB.get_affine_y(); |
72 | 0 | std::vector<uint8_t> x2_bytes(p_bytes); |
73 | 0 | std::vector<uint8_t> y2_bytes(p_bytes); |
74 | 0 | BigInt::encode_1363(x2_bytes.data(), x2_bytes.size(), x2); |
75 | 0 | BigInt::encode_1363(y2_bytes.data(), y2_bytes.size(), y2); |
76 | 0 |
|
77 | 0 | secure_vector<uint8_t> kdf_input; |
78 | 0 | kdf_input += x2_bytes; |
79 | 0 | kdf_input += y2_bytes; |
80 | 0 |
|
81 | 0 | const secure_vector<uint8_t> kdf_output = |
82 | 0 | kdf->derive_key(msg_len, kdf_input.data(), kdf_input.size()); |
83 | 0 |
|
84 | 0 | secure_vector<uint8_t> masked_msg(msg_len); |
85 | 0 | xor_buf(masked_msg.data(), msg, kdf_output.data(), msg_len); |
86 | 0 |
|
87 | 0 | hash->update(x2_bytes); |
88 | 0 | hash->update(msg, msg_len); |
89 | 0 | hash->update(y2_bytes); |
90 | 0 | std::vector<uint8_t> C3(hash->output_length()); |
91 | 0 | hash->final(C3.data()); |
92 | 0 |
|
93 | 0 | return DER_Encoder() |
94 | 0 | .start_cons(SEQUENCE) |
95 | 0 | .encode(x1) |
96 | 0 | .encode(y1) |
97 | 0 | .encode(C3, OCTET_STRING) |
98 | 0 | .encode(masked_msg, OCTET_STRING) |
99 | 0 | .end_cons() |
100 | 0 | .get_contents(); |
101 | 0 | } |
102 | | |
103 | | private: |
104 | | const EC_Group m_group; |
105 | | const std::string m_kdf_hash; |
106 | | |
107 | | std::vector<BigInt> m_ws; |
108 | | PointGFp_Var_Point_Precompute m_mul_public_point; |
109 | | size_t m_hash_size; |
110 | | }; |
111 | | |
112 | | class SM2_Decryption_Operation final : public PK_Ops::Decryption |
113 | | { |
114 | | public: |
115 | | SM2_Decryption_Operation(const SM2_Encryption_PrivateKey& key, |
116 | | RandomNumberGenerator& rng, |
117 | | const std::string& kdf_hash) : |
118 | | m_key(key), |
119 | | m_rng(rng), |
120 | | m_kdf_hash(kdf_hash) |
121 | 0 | { |
122 | 0 | std::unique_ptr<HashFunction> hash = HashFunction::create_or_throw(m_kdf_hash); |
123 | 0 | m_hash_size = hash->output_length(); |
124 | 0 | } |
125 | | |
126 | | size_t plaintext_length(size_t ptext_len) const override |
127 | 0 | { |
128 | | /* |
129 | | * This ignores the DER encoding and so overestimates the |
130 | | * plaintext length by 12 bytes or so |
131 | | */ |
132 | 0 | const size_t elem_size = m_key.domain().get_order_bytes(); |
133 | 0 |
|
134 | 0 | if(ptext_len < 2*elem_size + m_hash_size) |
135 | 0 | return 0; |
136 | 0 | |
137 | 0 | return ptext_len - (2*elem_size + m_hash_size); |
138 | 0 | } |
139 | | |
140 | | secure_vector<uint8_t> decrypt(uint8_t& valid_mask, |
141 | | const uint8_t ciphertext[], |
142 | | size_t ciphertext_len) override |
143 | 0 | { |
144 | 0 | const EC_Group& group = m_key.domain(); |
145 | 0 | const BigInt& cofactor = group.get_cofactor(); |
146 | 0 | const size_t p_bytes = group.get_p_bytes(); |
147 | 0 |
|
148 | 0 | valid_mask = 0x00; |
149 | 0 |
|
150 | 0 | std::unique_ptr<HashFunction> hash = HashFunction::create_or_throw(m_kdf_hash); |
151 | 0 | std::unique_ptr<KDF> kdf = KDF::create_or_throw("KDF2(" + m_kdf_hash + ")"); |
152 | 0 |
|
153 | | // Too short to be valid - no timing problem from early return |
154 | 0 | if(ciphertext_len < 1 + p_bytes*2 + hash->output_length()) |
155 | 0 | { |
156 | 0 | return secure_vector<uint8_t>(); |
157 | 0 | } |
158 | 0 | |
159 | 0 | BigInt x1, y1; |
160 | 0 | secure_vector<uint8_t> C3, masked_msg; |
161 | 0 |
|
162 | 0 | BER_Decoder(ciphertext, ciphertext_len) |
163 | 0 | .start_cons(SEQUENCE) |
164 | 0 | .decode(x1) |
165 | 0 | .decode(y1) |
166 | 0 | .decode(C3, OCTET_STRING) |
167 | 0 | .decode(masked_msg, OCTET_STRING) |
168 | 0 | .end_cons() |
169 | 0 | .verify_end(); |
170 | 0 |
|
171 | 0 | std::vector<uint8_t> recode_ctext; |
172 | 0 | DER_Encoder(recode_ctext) |
173 | 0 | .start_cons(SEQUENCE) |
174 | 0 | .encode(x1) |
175 | 0 | .encode(y1) |
176 | 0 | .encode(C3, OCTET_STRING) |
177 | 0 | .encode(masked_msg, OCTET_STRING) |
178 | 0 | .end_cons(); |
179 | 0 |
|
180 | 0 | if(recode_ctext.size() != ciphertext_len) |
181 | 0 | return secure_vector<uint8_t>(); |
182 | 0 | |
183 | 0 | if(same_mem(recode_ctext.data(), ciphertext, ciphertext_len) == false) |
184 | 0 | return secure_vector<uint8_t>(); |
185 | 0 | |
186 | 0 | PointGFp C1 = group.point(x1, y1); |
187 | 0 | C1.randomize_repr(m_rng); |
188 | 0 |
|
189 | | // Here C1 is publically invalid, so no problem with early return: |
190 | 0 | if(!C1.on_the_curve()) |
191 | 0 | return secure_vector<uint8_t>(); |
192 | 0 | |
193 | 0 | if(cofactor > 1 && (C1 * cofactor).is_zero()) |
194 | 0 | { |
195 | 0 | return secure_vector<uint8_t>(); |
196 | 0 | } |
197 | 0 | |
198 | 0 | const PointGFp dbC1 = group.blinded_var_point_multiply( |
199 | 0 | C1, m_key.private_value(), m_rng, m_ws); |
200 | 0 |
|
201 | 0 | const BigInt x2 = dbC1.get_affine_x(); |
202 | 0 | const BigInt y2 = dbC1.get_affine_y(); |
203 | 0 |
|
204 | 0 | secure_vector<uint8_t> x2_bytes(p_bytes); |
205 | 0 | secure_vector<uint8_t> y2_bytes(p_bytes); |
206 | 0 | BigInt::encode_1363(x2_bytes.data(), x2_bytes.size(), x2); |
207 | 0 | BigInt::encode_1363(y2_bytes.data(), y2_bytes.size(), y2); |
208 | 0 |
|
209 | 0 | secure_vector<uint8_t> kdf_input; |
210 | 0 | kdf_input += x2_bytes; |
211 | 0 | kdf_input += y2_bytes; |
212 | 0 |
|
213 | 0 | const secure_vector<uint8_t> kdf_output = |
214 | 0 | kdf->derive_key(masked_msg.size(), kdf_input.data(), kdf_input.size()); |
215 | 0 |
|
216 | 0 | xor_buf(masked_msg.data(), kdf_output.data(), kdf_output.size()); |
217 | 0 |
|
218 | 0 | hash->update(x2_bytes); |
219 | 0 | hash->update(masked_msg); |
220 | 0 | hash->update(y2_bytes); |
221 | 0 | secure_vector<uint8_t> u = hash->final(); |
222 | 0 |
|
223 | 0 | if(constant_time_compare(u.data(), C3.data(), hash->output_length()) == false) |
224 | 0 | return secure_vector<uint8_t>(); |
225 | 0 | |
226 | 0 | valid_mask = 0xFF; |
227 | 0 | return masked_msg; |
228 | 0 | } |
229 | | private: |
230 | | const SM2_Encryption_PrivateKey& m_key; |
231 | | RandomNumberGenerator& m_rng; |
232 | | const std::string m_kdf_hash; |
233 | | std::vector<BigInt> m_ws; |
234 | | size_t m_hash_size; |
235 | | }; |
236 | | |
237 | | } |
238 | | |
239 | | std::unique_ptr<PK_Ops::Encryption> |
240 | | SM2_PublicKey::create_encryption_op(RandomNumberGenerator& rng, |
241 | | const std::string& params, |
242 | | const std::string& provider) const |
243 | 0 | { |
244 | 0 | if(provider == "base" || provider.empty()) |
245 | 0 | { |
246 | 0 | const std::string kdf_hash = (params.empty() ? "SM3" : params); |
247 | 0 | return std::unique_ptr<PK_Ops::Encryption>(new SM2_Encryption_Operation(*this, rng, kdf_hash)); |
248 | 0 | } |
249 | 0 |
|
250 | 0 | throw Provider_Not_Found(algo_name(), provider); |
251 | 0 | } |
252 | | |
253 | | std::unique_ptr<PK_Ops::Decryption> |
254 | | SM2_PrivateKey::create_decryption_op(RandomNumberGenerator& rng, |
255 | | const std::string& params, |
256 | | const std::string& provider) const |
257 | 0 | { |
258 | 0 | if(provider == "base" || provider.empty()) |
259 | 0 | { |
260 | 0 | const std::string kdf_hash = (params.empty() ? "SM3" : params); |
261 | 0 | return std::unique_ptr<PK_Ops::Decryption>(new SM2_Decryption_Operation(*this, rng, kdf_hash)); |
262 | 0 | } |
263 | 0 |
|
264 | 0 | throw Provider_Not_Found(algo_name(), provider); |
265 | 0 | } |
266 | | |
267 | | } |