Coverage Report

Created: 2025-07-18 06:20

/src/Botan-3.4.0/src/lib/pubkey/sm2/sm2_enc.cpp
Line
Count
Source (jump to first uncovered line)
1
/*
2
* SM2 Encryption
3
* (C) 2017 Ribose Inc
4
*
5
* Botan is released under the Simplified BSD License (see license.txt)
6
*/
7
8
#include <botan/sm2.h>
9
10
#include <botan/ber_dec.h>
11
#include <botan/der_enc.h>
12
#include <botan/hash.h>
13
#include <botan/kdf.h>
14
#include <botan/pk_ops.h>
15
#include <botan/internal/ct_utils.h>
16
#include <botan/internal/fmt.h>
17
#include <botan/internal/point_mul.h>
18
19
namespace Botan {
20
21
namespace {
22
23
class SM2_Encryption_Operation final : public PK_Ops::Encryption {
24
   public:
25
      SM2_Encryption_Operation(const SM2_Encryption_PublicKey& key,
26
                               RandomNumberGenerator& rng,
27
                               std::string_view kdf_hash) :
28
0
            m_group(key.domain()), m_ws(EC_Point::WORKSPACE_SIZE), m_mul_public_point(key.public_point(), rng, m_ws) {
29
0
         m_hash = HashFunction::create_or_throw(kdf_hash);
30
31
0
         const std::string kdf_name = fmt("KDF2({})", kdf_hash);
32
0
         m_kdf = KDF::create_or_throw(kdf_name);
33
0
      }
34
35
0
      size_t max_input_bits() const override {
36
         // This is arbitrary, but assumes SM2 is used for key encapsulation
37
0
         return 512;
38
0
      }
39
40
0
      size_t ciphertext_length(size_t ptext_len) const override {
41
0
         const size_t elem_size = m_group.get_order_bytes();
42
0
         const size_t der_overhead = 16;
43
44
0
         return der_overhead + 2 * elem_size + m_hash->output_length() + ptext_len;
45
0
      }
46
47
0
      secure_vector<uint8_t> encrypt(const uint8_t msg[], size_t msg_len, RandomNumberGenerator& rng) override {
48
0
         const size_t p_bytes = m_group.get_p_bytes();
49
50
0
         const BigInt k = m_group.random_scalar(rng);
51
52
0
         const EC_Point C1 = m_group.blinded_base_point_multiply(k, rng, m_ws);
53
0
         const BigInt x1 = C1.get_affine_x();
54
0
         const BigInt y1 = C1.get_affine_y();
55
0
         std::vector<uint8_t> x1_bytes(p_bytes);
56
0
         std::vector<uint8_t> y1_bytes(p_bytes);
57
0
         BigInt::encode_1363(x1_bytes.data(), x1_bytes.size(), x1);
58
0
         BigInt::encode_1363(y1_bytes.data(), y1_bytes.size(), y1);
59
60
0
         const EC_Point kPB = m_mul_public_point.mul(k, rng, m_group.get_order(), m_ws);
61
62
0
         const BigInt x2 = kPB.get_affine_x();
63
0
         const BigInt y2 = kPB.get_affine_y();
64
0
         std::vector<uint8_t> x2_bytes(p_bytes);
65
0
         std::vector<uint8_t> y2_bytes(p_bytes);
66
0
         BigInt::encode_1363(x2_bytes.data(), x2_bytes.size(), x2);
67
0
         BigInt::encode_1363(y2_bytes.data(), y2_bytes.size(), y2);
68
69
0
         secure_vector<uint8_t> kdf_input;
70
0
         kdf_input += x2_bytes;
71
0
         kdf_input += y2_bytes;
72
73
0
         const secure_vector<uint8_t> kdf_output = m_kdf->derive_key(msg_len, kdf_input.data(), kdf_input.size());
74
75
0
         secure_vector<uint8_t> masked_msg(msg_len);
76
0
         xor_buf(masked_msg.data(), msg, kdf_output.data(), msg_len);
77
78
0
         m_hash->update(x2_bytes);
79
0
         m_hash->update(msg, msg_len);
80
0
         m_hash->update(y2_bytes);
81
0
         std::vector<uint8_t> C3(m_hash->output_length());
82
0
         m_hash->final(C3.data());
83
84
0
         return DER_Encoder()
85
0
            .start_sequence()
86
0
            .encode(x1)
87
0
            .encode(y1)
88
0
            .encode(C3, ASN1_Type::OctetString)
89
0
            .encode(masked_msg, ASN1_Type::OctetString)
90
0
            .end_cons()
91
0
            .get_contents();
92
0
      }
93
94
   private:
95
      const EC_Group m_group;
96
      std::unique_ptr<HashFunction> m_hash;
97
      std::unique_ptr<KDF> m_kdf;
98
      std::vector<BigInt> m_ws;
99
      EC_Point_Var_Point_Precompute m_mul_public_point;
100
};
101
102
class SM2_Decryption_Operation final : public PK_Ops::Decryption {
103
   public:
104
      SM2_Decryption_Operation(const SM2_Encryption_PrivateKey& key,
105
                               RandomNumberGenerator& rng,
106
                               std::string_view kdf_hash) :
107
0
            m_key(key), m_rng(rng) {
108
0
         m_hash = HashFunction::create_or_throw(kdf_hash);
109
110
0
         const std::string kdf_name = fmt("KDF2({})", kdf_hash);
111
0
         m_kdf = KDF::create_or_throw(kdf_name);
112
0
      }
113
114
0
      size_t plaintext_length(size_t ptext_len) const override {
115
         /*
116
         * This ignores the DER encoding and so overestimates the
117
         * plaintext length by 12 bytes or so
118
         */
119
0
         const size_t elem_size = m_key.domain().get_order_bytes();
120
121
0
         if(ptext_len < 2 * elem_size + m_hash->output_length()) {
122
0
            return 0;
123
0
         }
124
125
0
         return ptext_len - (2 * elem_size + m_hash->output_length());
126
0
      }
127
128
0
      secure_vector<uint8_t> decrypt(uint8_t& valid_mask, const uint8_t ciphertext[], size_t ciphertext_len) override {
129
0
         const EC_Group& group = m_key.domain();
130
0
         const BigInt& cofactor = group.get_cofactor();
131
0
         const size_t p_bytes = group.get_p_bytes();
132
133
0
         valid_mask = 0x00;
134
135
         // Too short to be valid - no timing problem from early return
136
0
         if(ciphertext_len < 1 + p_bytes * 2 + m_hash->output_length()) {
137
0
            return secure_vector<uint8_t>();
138
0
         }
139
140
0
         BigInt x1, y1;
141
0
         secure_vector<uint8_t> C3, masked_msg;
142
143
0
         BER_Decoder(ciphertext, ciphertext_len)
144
0
            .start_sequence()
145
0
            .decode(x1)
146
0
            .decode(y1)
147
0
            .decode(C3, ASN1_Type::OctetString)
148
0
            .decode(masked_msg, ASN1_Type::OctetString)
149
0
            .end_cons()
150
0
            .verify_end();
151
152
0
         std::vector<uint8_t> recode_ctext;
153
0
         DER_Encoder(recode_ctext)
154
0
            .start_sequence()
155
0
            .encode(x1)
156
0
            .encode(y1)
157
0
            .encode(C3, ASN1_Type::OctetString)
158
0
            .encode(masked_msg, ASN1_Type::OctetString)
159
0
            .end_cons();
160
161
0
         if(recode_ctext.size() != ciphertext_len) {
162
0
            return secure_vector<uint8_t>();
163
0
         }
164
165
0
         if(CT::is_equal(recode_ctext.data(), ciphertext, ciphertext_len).as_bool() == false) {
166
0
            return secure_vector<uint8_t>();
167
0
         }
168
169
0
         EC_Point C1 = group.point(x1, y1);
170
0
         C1.randomize_repr(m_rng);
171
172
         // Here C1 is publically invalid, so no problem with early return:
173
0
         if(!C1.on_the_curve()) {
174
0
            return secure_vector<uint8_t>();
175
0
         }
176
177
0
         if(cofactor > 1 && (C1 * cofactor).is_zero()) {
178
0
            return secure_vector<uint8_t>();
179
0
         }
180
181
0
         const EC_Point dbC1 = group.blinded_var_point_multiply(C1, m_key.private_value(), m_rng, m_ws);
182
183
0
         const BigInt x2 = dbC1.get_affine_x();
184
0
         const BigInt y2 = dbC1.get_affine_y();
185
186
0
         secure_vector<uint8_t> x2_bytes(p_bytes);
187
0
         secure_vector<uint8_t> y2_bytes(p_bytes);
188
0
         BigInt::encode_1363(x2_bytes.data(), x2_bytes.size(), x2);
189
0
         BigInt::encode_1363(y2_bytes.data(), y2_bytes.size(), y2);
190
191
0
         secure_vector<uint8_t> kdf_input;
192
0
         kdf_input += x2_bytes;
193
0
         kdf_input += y2_bytes;
194
195
0
         const secure_vector<uint8_t> kdf_output =
196
0
            m_kdf->derive_key(masked_msg.size(), kdf_input.data(), kdf_input.size());
197
198
0
         xor_buf(masked_msg.data(), kdf_output.data(), kdf_output.size());
199
200
0
         m_hash->update(x2_bytes);
201
0
         m_hash->update(masked_msg);
202
0
         m_hash->update(y2_bytes);
203
0
         secure_vector<uint8_t> u = m_hash->final();
204
205
0
         if(!CT::is_equal(u.data(), C3.data(), m_hash->output_length()).as_bool()) {
206
0
            return secure_vector<uint8_t>();
207
0
         }
208
209
0
         valid_mask = 0xFF;
210
0
         return masked_msg;
211
0
      }
212
213
   private:
214
      const SM2_Encryption_PrivateKey& m_key;
215
      RandomNumberGenerator& m_rng;
216
      std::vector<BigInt> m_ws;
217
      std::unique_ptr<HashFunction> m_hash;
218
      std::unique_ptr<KDF> m_kdf;
219
};
220
221
}  // namespace
222
223
std::unique_ptr<PK_Ops::Encryption> SM2_PublicKey::create_encryption_op(RandomNumberGenerator& rng,
224
                                                                        std::string_view params,
225
0
                                                                        std::string_view provider) const {
226
0
   if(provider == "base" || provider.empty()) {
227
0
      if(params.empty()) {
228
0
         return std::make_unique<SM2_Encryption_Operation>(*this, rng, "SM3");
229
0
      } else {
230
0
         return std::make_unique<SM2_Encryption_Operation>(*this, rng, params);
231
0
      }
232
0
   }
233
234
0
   throw Provider_Not_Found(algo_name(), provider);
235
0
}
236
237
std::unique_ptr<PK_Ops::Decryption> SM2_PrivateKey::create_decryption_op(RandomNumberGenerator& rng,
238
                                                                         std::string_view params,
239
0
                                                                         std::string_view provider) const {
240
0
   if(provider == "base" || provider.empty()) {
241
0
      if(params.empty()) {
242
0
         return std::make_unique<SM2_Decryption_Operation>(*this, rng, "SM3");
243
0
      } else {
244
0
         return std::make_unique<SM2_Decryption_Operation>(*this, rng, params);
245
0
      }
246
0
   }
247
248
0
   throw Provider_Not_Found(algo_name(), provider);
249
0
}
250
251
}  // namespace Botan