/src/Botan-3.4.0/src/lib/pubkey/sm2/sm2_enc.cpp
Line | Count | Source (jump to first uncovered line) |
1 | | /* |
2 | | * SM2 Encryption |
3 | | * (C) 2017 Ribose Inc |
4 | | * |
5 | | * Botan is released under the Simplified BSD License (see license.txt) |
6 | | */ |
7 | | |
8 | | #include <botan/sm2.h> |
9 | | |
10 | | #include <botan/ber_dec.h> |
11 | | #include <botan/der_enc.h> |
12 | | #include <botan/hash.h> |
13 | | #include <botan/kdf.h> |
14 | | #include <botan/pk_ops.h> |
15 | | #include <botan/internal/ct_utils.h> |
16 | | #include <botan/internal/fmt.h> |
17 | | #include <botan/internal/point_mul.h> |
18 | | |
19 | | namespace Botan { |
20 | | |
21 | | namespace { |
22 | | |
23 | | class SM2_Encryption_Operation final : public PK_Ops::Encryption { |
24 | | public: |
25 | | SM2_Encryption_Operation(const SM2_Encryption_PublicKey& key, |
26 | | RandomNumberGenerator& rng, |
27 | | std::string_view kdf_hash) : |
28 | 0 | m_group(key.domain()), m_ws(EC_Point::WORKSPACE_SIZE), m_mul_public_point(key.public_point(), rng, m_ws) { |
29 | 0 | m_hash = HashFunction::create_or_throw(kdf_hash); |
30 | |
|
31 | 0 | const std::string kdf_name = fmt("KDF2({})", kdf_hash); |
32 | 0 | m_kdf = KDF::create_or_throw(kdf_name); |
33 | 0 | } |
34 | | |
35 | 0 | size_t max_input_bits() const override { |
36 | | // This is arbitrary, but assumes SM2 is used for key encapsulation |
37 | 0 | return 512; |
38 | 0 | } |
39 | | |
40 | 0 | size_t ciphertext_length(size_t ptext_len) const override { |
41 | 0 | const size_t elem_size = m_group.get_order_bytes(); |
42 | 0 | const size_t der_overhead = 16; |
43 | |
|
44 | 0 | return der_overhead + 2 * elem_size + m_hash->output_length() + ptext_len; |
45 | 0 | } |
46 | | |
47 | 0 | secure_vector<uint8_t> encrypt(const uint8_t msg[], size_t msg_len, RandomNumberGenerator& rng) override { |
48 | 0 | const size_t p_bytes = m_group.get_p_bytes(); |
49 | |
|
50 | 0 | const BigInt k = m_group.random_scalar(rng); |
51 | |
|
52 | 0 | const EC_Point C1 = m_group.blinded_base_point_multiply(k, rng, m_ws); |
53 | 0 | const BigInt x1 = C1.get_affine_x(); |
54 | 0 | const BigInt y1 = C1.get_affine_y(); |
55 | 0 | std::vector<uint8_t> x1_bytes(p_bytes); |
56 | 0 | std::vector<uint8_t> y1_bytes(p_bytes); |
57 | 0 | BigInt::encode_1363(x1_bytes.data(), x1_bytes.size(), x1); |
58 | 0 | BigInt::encode_1363(y1_bytes.data(), y1_bytes.size(), y1); |
59 | |
|
60 | 0 | const EC_Point kPB = m_mul_public_point.mul(k, rng, m_group.get_order(), m_ws); |
61 | |
|
62 | 0 | const BigInt x2 = kPB.get_affine_x(); |
63 | 0 | const BigInt y2 = kPB.get_affine_y(); |
64 | 0 | std::vector<uint8_t> x2_bytes(p_bytes); |
65 | 0 | std::vector<uint8_t> y2_bytes(p_bytes); |
66 | 0 | BigInt::encode_1363(x2_bytes.data(), x2_bytes.size(), x2); |
67 | 0 | BigInt::encode_1363(y2_bytes.data(), y2_bytes.size(), y2); |
68 | |
|
69 | 0 | secure_vector<uint8_t> kdf_input; |
70 | 0 | kdf_input += x2_bytes; |
71 | 0 | kdf_input += y2_bytes; |
72 | |
|
73 | 0 | const secure_vector<uint8_t> kdf_output = m_kdf->derive_key(msg_len, kdf_input.data(), kdf_input.size()); |
74 | |
|
75 | 0 | secure_vector<uint8_t> masked_msg(msg_len); |
76 | 0 | xor_buf(masked_msg.data(), msg, kdf_output.data(), msg_len); |
77 | |
|
78 | 0 | m_hash->update(x2_bytes); |
79 | 0 | m_hash->update(msg, msg_len); |
80 | 0 | m_hash->update(y2_bytes); |
81 | 0 | std::vector<uint8_t> C3(m_hash->output_length()); |
82 | 0 | m_hash->final(C3.data()); |
83 | |
|
84 | 0 | return DER_Encoder() |
85 | 0 | .start_sequence() |
86 | 0 | .encode(x1) |
87 | 0 | .encode(y1) |
88 | 0 | .encode(C3, ASN1_Type::OctetString) |
89 | 0 | .encode(masked_msg, ASN1_Type::OctetString) |
90 | 0 | .end_cons() |
91 | 0 | .get_contents(); |
92 | 0 | } |
93 | | |
94 | | private: |
95 | | const EC_Group m_group; |
96 | | std::unique_ptr<HashFunction> m_hash; |
97 | | std::unique_ptr<KDF> m_kdf; |
98 | | std::vector<BigInt> m_ws; |
99 | | EC_Point_Var_Point_Precompute m_mul_public_point; |
100 | | }; |
101 | | |
102 | | class SM2_Decryption_Operation final : public PK_Ops::Decryption { |
103 | | public: |
104 | | SM2_Decryption_Operation(const SM2_Encryption_PrivateKey& key, |
105 | | RandomNumberGenerator& rng, |
106 | | std::string_view kdf_hash) : |
107 | 0 | m_key(key), m_rng(rng) { |
108 | 0 | m_hash = HashFunction::create_or_throw(kdf_hash); |
109 | |
|
110 | 0 | const std::string kdf_name = fmt("KDF2({})", kdf_hash); |
111 | 0 | m_kdf = KDF::create_or_throw(kdf_name); |
112 | 0 | } |
113 | | |
114 | 0 | size_t plaintext_length(size_t ptext_len) const override { |
115 | | /* |
116 | | * This ignores the DER encoding and so overestimates the |
117 | | * plaintext length by 12 bytes or so |
118 | | */ |
119 | 0 | const size_t elem_size = m_key.domain().get_order_bytes(); |
120 | |
|
121 | 0 | if(ptext_len < 2 * elem_size + m_hash->output_length()) { |
122 | 0 | return 0; |
123 | 0 | } |
124 | | |
125 | 0 | return ptext_len - (2 * elem_size + m_hash->output_length()); |
126 | 0 | } |
127 | | |
128 | 0 | secure_vector<uint8_t> decrypt(uint8_t& valid_mask, const uint8_t ciphertext[], size_t ciphertext_len) override { |
129 | 0 | const EC_Group& group = m_key.domain(); |
130 | 0 | const BigInt& cofactor = group.get_cofactor(); |
131 | 0 | const size_t p_bytes = group.get_p_bytes(); |
132 | |
|
133 | 0 | valid_mask = 0x00; |
134 | | |
135 | | // Too short to be valid - no timing problem from early return |
136 | 0 | if(ciphertext_len < 1 + p_bytes * 2 + m_hash->output_length()) { |
137 | 0 | return secure_vector<uint8_t>(); |
138 | 0 | } |
139 | | |
140 | 0 | BigInt x1, y1; |
141 | 0 | secure_vector<uint8_t> C3, masked_msg; |
142 | |
|
143 | 0 | BER_Decoder(ciphertext, ciphertext_len) |
144 | 0 | .start_sequence() |
145 | 0 | .decode(x1) |
146 | 0 | .decode(y1) |
147 | 0 | .decode(C3, ASN1_Type::OctetString) |
148 | 0 | .decode(masked_msg, ASN1_Type::OctetString) |
149 | 0 | .end_cons() |
150 | 0 | .verify_end(); |
151 | |
|
152 | 0 | std::vector<uint8_t> recode_ctext; |
153 | 0 | DER_Encoder(recode_ctext) |
154 | 0 | .start_sequence() |
155 | 0 | .encode(x1) |
156 | 0 | .encode(y1) |
157 | 0 | .encode(C3, ASN1_Type::OctetString) |
158 | 0 | .encode(masked_msg, ASN1_Type::OctetString) |
159 | 0 | .end_cons(); |
160 | |
|
161 | 0 | if(recode_ctext.size() != ciphertext_len) { |
162 | 0 | return secure_vector<uint8_t>(); |
163 | 0 | } |
164 | | |
165 | 0 | if(CT::is_equal(recode_ctext.data(), ciphertext, ciphertext_len).as_bool() == false) { |
166 | 0 | return secure_vector<uint8_t>(); |
167 | 0 | } |
168 | | |
169 | 0 | EC_Point C1 = group.point(x1, y1); |
170 | 0 | C1.randomize_repr(m_rng); |
171 | | |
172 | | // Here C1 is publically invalid, so no problem with early return: |
173 | 0 | if(!C1.on_the_curve()) { |
174 | 0 | return secure_vector<uint8_t>(); |
175 | 0 | } |
176 | | |
177 | 0 | if(cofactor > 1 && (C1 * cofactor).is_zero()) { |
178 | 0 | return secure_vector<uint8_t>(); |
179 | 0 | } |
180 | | |
181 | 0 | const EC_Point dbC1 = group.blinded_var_point_multiply(C1, m_key.private_value(), m_rng, m_ws); |
182 | |
|
183 | 0 | const BigInt x2 = dbC1.get_affine_x(); |
184 | 0 | const BigInt y2 = dbC1.get_affine_y(); |
185 | |
|
186 | 0 | secure_vector<uint8_t> x2_bytes(p_bytes); |
187 | 0 | secure_vector<uint8_t> y2_bytes(p_bytes); |
188 | 0 | BigInt::encode_1363(x2_bytes.data(), x2_bytes.size(), x2); |
189 | 0 | BigInt::encode_1363(y2_bytes.data(), y2_bytes.size(), y2); |
190 | |
|
191 | 0 | secure_vector<uint8_t> kdf_input; |
192 | 0 | kdf_input += x2_bytes; |
193 | 0 | kdf_input += y2_bytes; |
194 | |
|
195 | 0 | const secure_vector<uint8_t> kdf_output = |
196 | 0 | m_kdf->derive_key(masked_msg.size(), kdf_input.data(), kdf_input.size()); |
197 | |
|
198 | 0 | xor_buf(masked_msg.data(), kdf_output.data(), kdf_output.size()); |
199 | |
|
200 | 0 | m_hash->update(x2_bytes); |
201 | 0 | m_hash->update(masked_msg); |
202 | 0 | m_hash->update(y2_bytes); |
203 | 0 | secure_vector<uint8_t> u = m_hash->final(); |
204 | |
|
205 | 0 | if(!CT::is_equal(u.data(), C3.data(), m_hash->output_length()).as_bool()) { |
206 | 0 | return secure_vector<uint8_t>(); |
207 | 0 | } |
208 | | |
209 | 0 | valid_mask = 0xFF; |
210 | 0 | return masked_msg; |
211 | 0 | } |
212 | | |
213 | | private: |
214 | | const SM2_Encryption_PrivateKey& m_key; |
215 | | RandomNumberGenerator& m_rng; |
216 | | std::vector<BigInt> m_ws; |
217 | | std::unique_ptr<HashFunction> m_hash; |
218 | | std::unique_ptr<KDF> m_kdf; |
219 | | }; |
220 | | |
221 | | } // namespace |
222 | | |
223 | | std::unique_ptr<PK_Ops::Encryption> SM2_PublicKey::create_encryption_op(RandomNumberGenerator& rng, |
224 | | std::string_view params, |
225 | 0 | std::string_view provider) const { |
226 | 0 | if(provider == "base" || provider.empty()) { |
227 | 0 | if(params.empty()) { |
228 | 0 | return std::make_unique<SM2_Encryption_Operation>(*this, rng, "SM3"); |
229 | 0 | } else { |
230 | 0 | return std::make_unique<SM2_Encryption_Operation>(*this, rng, params); |
231 | 0 | } |
232 | 0 | } |
233 | | |
234 | 0 | throw Provider_Not_Found(algo_name(), provider); |
235 | 0 | } |
236 | | |
237 | | std::unique_ptr<PK_Ops::Decryption> SM2_PrivateKey::create_decryption_op(RandomNumberGenerator& rng, |
238 | | std::string_view params, |
239 | 0 | std::string_view provider) const { |
240 | 0 | if(provider == "base" || provider.empty()) { |
241 | 0 | if(params.empty()) { |
242 | 0 | return std::make_unique<SM2_Decryption_Operation>(*this, rng, "SM3"); |
243 | 0 | } else { |
244 | 0 | return std::make_unique<SM2_Decryption_Operation>(*this, rng, params); |
245 | 0 | } |
246 | 0 | } |
247 | | |
248 | 0 | throw Provider_Not_Found(algo_name(), provider); |
249 | 0 | } |
250 | | |
251 | | } // namespace Botan |