LCOV - code coverage report
Current view: top level - source/extensions/common/aws - sigv4a_key_derivation.cc (source / functions) Hit Total Coverage
Test: coverage.dat Lines: 0 86 0.0 %
Date: 2024-01-05 06:35:25 Functions: 0 4 0.0 %

          Line data    Source code
       1             : #include "source/extensions/common/aws/sigv4a_key_derivation.h"
       2             : 
       3             : #include <openssl/ssl.h>
       4             : 
       5             : #include "source/common/common/logger.h"
       6             : #include "source/common/crypto/utility.h"
       7             : #include "source/extensions/common/aws/sigv4a_signer_impl.h"
       8             : 
       9             : namespace Envoy {
      10             : namespace Extensions {
      11             : namespace Common {
      12             : namespace Aws {
      13             : 
      14             : EC_KEY* SigV4AKeyDerivation::derivePrivateKey(absl::string_view access_key_id,
      15           0 :                                               absl::string_view secret_access_key) {
      16             : 
      17           0 :   auto& crypto_util = Envoy::Common::Crypto::UtilitySingleton::get();
      18             : 
      19           0 :   const uint8_t key_length = 32; // AWS_CAL_ECDSA_P256
      20           0 :   std::vector<uint8_t> private_key_buf(key_length);
      21             : 
      22           0 :   const uint8_t access_key_length = access_key_id.length();
      23           0 :   const uint8_t required_fixed_input_length = 32 + access_key_length;
      24           0 :   std::vector<uint8_t> fixed_input(required_fixed_input_length);
      25             : 
      26           0 :   const auto secret_key =
      27           0 :       absl::StrCat(SigV4ASignatureConstants::get().SigV4ASignatureVersion, secret_access_key);
      28             : 
      29           0 :   enum SigV4AKeyDerivationResult result = AkdrNextCounter;
      30           0 :   uint8_t external_counter = 1;
      31             : 
      32           0 :   BIGNUM* priv_key_num;
      33           0 :   EC_KEY* ec_key;
      34             : 
      35           0 :   while ((result == AkdrNextCounter) &&
      36           0 :          (external_counter <= 254)) // MAX_KEY_DERIVATION_COUNTER_VALUE
      37           0 :   {
      38           0 :     fixed_input.clear();
      39             : 
      40           0 :     fixed_input.insert(fixed_input.begin(), {0x00, 0x00, 0x00, 0x01});
      41           0 :     fixed_input.insert(fixed_input.end(), SigV4ASignatureConstants::get().SigV4ALabel.begin(),
      42           0 :                        SigV4ASignatureConstants::get().SigV4ALabel.end());
      43           0 :     fixed_input.insert(fixed_input.end(), 0x00);
      44           0 :     fixed_input.insert(fixed_input.end(), access_key_id.begin(), access_key_id.end());
      45           0 :     fixed_input.insert(fixed_input.end(), external_counter);
      46           0 :     fixed_input.insert(fixed_input.end(), {0x00, 0x00, 0x01, 0x00});
      47             : 
      48           0 :     auto k0 = crypto_util.getSha256Hmac(
      49           0 :         std::vector<uint8_t>(secret_key.begin(), secret_key.end()),
      50           0 :         absl::string_view(reinterpret_cast<char*>(fixed_input.data()), fixed_input.size()));
      51             : 
      52             :     // ECDSA q - 2
      53           0 :     std::vector<uint8_t> s_n_minus_2 = {
      54           0 :         0xFF, 0xFF, 0xFF, 0xFF, 0x00, 0x00, 0x00, 0x00, 0xFF, 0xFF, 0xFF,
      55           0 :         0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xBC, 0xE6, 0xFA, 0xAD, 0xA7, 0x17,
      56           0 :         0x9E, 0x84, 0xF3, 0xB9, 0xCA, 0xC2, 0xFC, 0x63, 0x25, 0x4F,
      57           0 :     };
      58             : 
      59             :     // check that k0 < s_n_minus_2
      60           0 :     bool lt_result = constantTimeLessThanOrEqualTo(k0, s_n_minus_2);
      61             : 
      62           0 :     if (!lt_result) {
      63             :       // Loop if k0 >= s_n_minus_2 and the counter will cause a new hmac to be generated
      64           0 :       external_counter++;
      65           0 :     } else {
      66           0 :       result = SigV4AKeyDerivationResult::AkdrSuccess;
      67             :       // PrivateKey d = c+1
      68           0 :       constantTimeAddOne(&k0);
      69             : 
      70           0 :       priv_key_num = BN_bin2bn(k0.data(), k0.size(), nullptr);
      71             : 
      72             :       // Create a new OpenSSL EC_KEY by curve nid for secp256r1 (NIST P-256)
      73           0 :       ec_key = EC_KEY_new_by_curve_name(NID_X9_62_prime256v1);
      74             : 
      75             :       // And set the private key we calculated above
      76           0 :       if (!EC_KEY_set_private_key(ec_key, priv_key_num)) {
      77           0 :         ENVOY_LOG(debug, "Failed to set openssl private key");
      78           0 :         BN_free(priv_key_num);
      79           0 :         OPENSSL_free(ec_key);
      80           0 :         return nullptr;
      81           0 :       }
      82           0 :       BN_free(priv_key_num);
      83           0 :     }
      84           0 :   }
      85             : 
      86           0 :   if (result == SigV4AKeyDerivationResult::AkdrNextCounter) {
      87           0 :     ENVOY_LOG(debug, "Key derivation exceeded retries, returning no signature");
      88           0 :     return nullptr;
      89           0 :   }
      90             : 
      91           0 :   return ec_key;
      92           0 : }
      93             : 
      94           0 : bool SigV4AKeyDerivation::derivePublicKey(EC_KEY* ec_key) {
      95             : 
      96           0 :   const BIGNUM* priv_key_num = EC_KEY_get0_private_key(ec_key);
      97           0 :   const EC_GROUP* group = EC_KEY_get0_group(ec_key);
      98           0 :   EC_POINT* point = EC_POINT_new(group);
      99             : 
     100           0 :   EC_POINT_mul(group, point, priv_key_num, nullptr, nullptr, nullptr);
     101             : 
     102           0 :   EC_KEY_set_public_key(ec_key, point);
     103             : 
     104           0 :   EC_POINT_free(point);
     105           0 :   return true;
     106           0 : }
     107             : 
     108             : // code based on aws sdk key derivation constant time implementations
     109             : // https://github.com/awslabs/aws-c-auth/blob/baeffa791d9d1cf61460662a6d9ac2186aaf05df/source/key_derivation.c#L152
     110             : 
     111             : bool SigV4AKeyDerivation::constantTimeLessThanOrEqualTo(std::vector<uint8_t> lhs_raw_be_bigint,
     112           0 :                                                         std::vector<uint8_t> rhs_raw_be_bigint) {
     113             : 
     114           0 :   volatile uint8_t gt = 0;
     115           0 :   volatile uint8_t eq = 1;
     116             : 
     117           0 :   for (uint8_t i = 0; i < lhs_raw_be_bigint.size(); ++i) {
     118           0 :     volatile int32_t lhs_digit = lhs_raw_be_bigint[i];
     119           0 :     volatile int32_t rhs_digit = rhs_raw_be_bigint[i];
     120             : 
     121           0 :     gt = gt | (((rhs_digit - lhs_digit) >> 31) & eq);
     122           0 :     eq = eq & ((((lhs_digit ^ rhs_digit) - 1) >> 31) & 0x01);
     123           0 :   }
     124           0 :   return (gt + gt + eq - 1) <= 0;
     125           0 : }
     126             : 
     127           0 : void SigV4AKeyDerivation::constantTimeAddOne(std::vector<uint8_t>* raw_be_bigint) {
     128             : 
     129           0 :   const uint8_t byte_count = raw_be_bigint->size();
     130             : 
     131           0 :   volatile uint32_t carry = 1;
     132             : 
     133           0 :   for (size_t i = 0; i < byte_count; ++i) {
     134           0 :     const size_t index = byte_count - i - 1;
     135             : 
     136           0 :     volatile uint32_t current_digit = (*raw_be_bigint)[index];
     137           0 :     current_digit = current_digit + carry;
     138             : 
     139           0 :     carry = (current_digit >> 8) & 0x01;
     140             : 
     141           0 :     (*raw_be_bigint)[index] = (current_digit & 0xFF);
     142           0 :   }
     143           0 : }
     144             : 
     145             : } // namespace Aws
     146             : } // namespace Common
     147             : } // namespace Extensions
     148             : } // namespace Envoy

Generated by: LCOV version 1.15