Line data Source code
1 : #include "source/extensions/common/aws/sigv4a_key_derivation.h"
2 :
3 : #include <openssl/ssl.h>
4 :
5 : #include "source/common/common/logger.h"
6 : #include "source/common/crypto/utility.h"
7 : #include "source/extensions/common/aws/sigv4a_signer_impl.h"
8 :
9 : namespace Envoy {
10 : namespace Extensions {
11 : namespace Common {
12 : namespace Aws {
13 :
14 : EC_KEY* SigV4AKeyDerivation::derivePrivateKey(absl::string_view access_key_id,
15 0 : absl::string_view secret_access_key) {
16 :
17 0 : auto& crypto_util = Envoy::Common::Crypto::UtilitySingleton::get();
18 :
19 0 : const uint8_t key_length = 32; // AWS_CAL_ECDSA_P256
20 0 : std::vector<uint8_t> private_key_buf(key_length);
21 :
22 0 : const uint8_t access_key_length = access_key_id.length();
23 0 : const uint8_t required_fixed_input_length = 32 + access_key_length;
24 0 : std::vector<uint8_t> fixed_input(required_fixed_input_length);
25 :
26 0 : const auto secret_key =
27 0 : absl::StrCat(SigV4ASignatureConstants::get().SigV4ASignatureVersion, secret_access_key);
28 :
29 0 : enum SigV4AKeyDerivationResult result = AkdrNextCounter;
30 0 : uint8_t external_counter = 1;
31 :
32 0 : BIGNUM* priv_key_num;
33 0 : EC_KEY* ec_key;
34 :
35 0 : while ((result == AkdrNextCounter) &&
36 0 : (external_counter <= 254)) // MAX_KEY_DERIVATION_COUNTER_VALUE
37 0 : {
38 0 : fixed_input.clear();
39 :
40 0 : fixed_input.insert(fixed_input.begin(), {0x00, 0x00, 0x00, 0x01});
41 0 : fixed_input.insert(fixed_input.end(), SigV4ASignatureConstants::get().SigV4ALabel.begin(),
42 0 : SigV4ASignatureConstants::get().SigV4ALabel.end());
43 0 : fixed_input.insert(fixed_input.end(), 0x00);
44 0 : fixed_input.insert(fixed_input.end(), access_key_id.begin(), access_key_id.end());
45 0 : fixed_input.insert(fixed_input.end(), external_counter);
46 0 : fixed_input.insert(fixed_input.end(), {0x00, 0x00, 0x01, 0x00});
47 :
48 0 : auto k0 = crypto_util.getSha256Hmac(
49 0 : std::vector<uint8_t>(secret_key.begin(), secret_key.end()),
50 0 : absl::string_view(reinterpret_cast<char*>(fixed_input.data()), fixed_input.size()));
51 :
52 : // ECDSA q - 2
53 0 : std::vector<uint8_t> s_n_minus_2 = {
54 0 : 0xFF, 0xFF, 0xFF, 0xFF, 0x00, 0x00, 0x00, 0x00, 0xFF, 0xFF, 0xFF,
55 0 : 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xBC, 0xE6, 0xFA, 0xAD, 0xA7, 0x17,
56 0 : 0x9E, 0x84, 0xF3, 0xB9, 0xCA, 0xC2, 0xFC, 0x63, 0x25, 0x4F,
57 0 : };
58 :
59 : // check that k0 < s_n_minus_2
60 0 : bool lt_result = constantTimeLessThanOrEqualTo(k0, s_n_minus_2);
61 :
62 0 : if (!lt_result) {
63 : // Loop if k0 >= s_n_minus_2 and the counter will cause a new hmac to be generated
64 0 : external_counter++;
65 0 : } else {
66 0 : result = SigV4AKeyDerivationResult::AkdrSuccess;
67 : // PrivateKey d = c+1
68 0 : constantTimeAddOne(&k0);
69 :
70 0 : priv_key_num = BN_bin2bn(k0.data(), k0.size(), nullptr);
71 :
72 : // Create a new OpenSSL EC_KEY by curve nid for secp256r1 (NIST P-256)
73 0 : ec_key = EC_KEY_new_by_curve_name(NID_X9_62_prime256v1);
74 :
75 : // And set the private key we calculated above
76 0 : if (!EC_KEY_set_private_key(ec_key, priv_key_num)) {
77 0 : ENVOY_LOG(debug, "Failed to set openssl private key");
78 0 : BN_free(priv_key_num);
79 0 : OPENSSL_free(ec_key);
80 0 : return nullptr;
81 0 : }
82 0 : BN_free(priv_key_num);
83 0 : }
84 0 : }
85 :
86 0 : if (result == SigV4AKeyDerivationResult::AkdrNextCounter) {
87 0 : ENVOY_LOG(debug, "Key derivation exceeded retries, returning no signature");
88 0 : return nullptr;
89 0 : }
90 :
91 0 : return ec_key;
92 0 : }
93 :
94 0 : bool SigV4AKeyDerivation::derivePublicKey(EC_KEY* ec_key) {
95 :
96 0 : const BIGNUM* priv_key_num = EC_KEY_get0_private_key(ec_key);
97 0 : const EC_GROUP* group = EC_KEY_get0_group(ec_key);
98 0 : EC_POINT* point = EC_POINT_new(group);
99 :
100 0 : EC_POINT_mul(group, point, priv_key_num, nullptr, nullptr, nullptr);
101 :
102 0 : EC_KEY_set_public_key(ec_key, point);
103 :
104 0 : EC_POINT_free(point);
105 0 : return true;
106 0 : }
107 :
108 : // code based on aws sdk key derivation constant time implementations
109 : // https://github.com/awslabs/aws-c-auth/blob/baeffa791d9d1cf61460662a6d9ac2186aaf05df/source/key_derivation.c#L152
110 :
111 : bool SigV4AKeyDerivation::constantTimeLessThanOrEqualTo(std::vector<uint8_t> lhs_raw_be_bigint,
112 0 : std::vector<uint8_t> rhs_raw_be_bigint) {
113 :
114 0 : volatile uint8_t gt = 0;
115 0 : volatile uint8_t eq = 1;
116 :
117 0 : for (uint8_t i = 0; i < lhs_raw_be_bigint.size(); ++i) {
118 0 : volatile int32_t lhs_digit = lhs_raw_be_bigint[i];
119 0 : volatile int32_t rhs_digit = rhs_raw_be_bigint[i];
120 :
121 0 : gt = gt | (((rhs_digit - lhs_digit) >> 31) & eq);
122 0 : eq = eq & ((((lhs_digit ^ rhs_digit) - 1) >> 31) & 0x01);
123 0 : }
124 0 : return (gt + gt + eq - 1) <= 0;
125 0 : }
126 :
127 0 : void SigV4AKeyDerivation::constantTimeAddOne(std::vector<uint8_t>* raw_be_bigint) {
128 :
129 0 : const uint8_t byte_count = raw_be_bigint->size();
130 :
131 0 : volatile uint32_t carry = 1;
132 :
133 0 : for (size_t i = 0; i < byte_count; ++i) {
134 0 : const size_t index = byte_count - i - 1;
135 :
136 0 : volatile uint32_t current_digit = (*raw_be_bigint)[index];
137 0 : current_digit = current_digit + carry;
138 :
139 0 : carry = (current_digit >> 8) & 0x01;
140 :
141 0 : (*raw_be_bigint)[index] = (current_digit & 0xFF);
142 0 : }
143 0 : }
144 :
145 : } // namespace Aws
146 : } // namespace Common
147 : } // namespace Extensions
148 : } // namespace Envoy
|