LCOV - code coverage report
Current view: top level - source/extensions/transport_sockets/alts - alts_tsi_handshaker.cc (source / functions) Hit Total Coverage
Test: coverage.dat Lines: 0 110 0.0 %
Date: 2024-01-05 06:35:25 Functions: 0 6 0.0 %

          Line data    Source code
       1             : #include "source/extensions/transport_sockets/alts/alts_tsi_handshaker.h"
       2             : 
       3             : #include <algorithm>
       4             : #include <cstdint>
       5             : #include <cstring>
       6             : #include <memory>
       7             : #include <string>
       8             : #include <utility>
       9             : #include <vector>
      10             : 
      11             : #include "absl/memory/memory.h"
      12             : #include "absl/status/status.h"
      13             : #include "absl/strings/str_format.h"
      14             : #include "absl/types/span.h"
      15             : #include "src/core/lib/iomgr/exec_ctx.h"
      16             : #include "src/core/tsi/alts/zero_copy_frame_protector/alts_zero_copy_grpc_protector.h"
      17             : 
      18             : namespace Envoy {
      19             : namespace Extensions {
      20             : namespace TransportSockets {
      21             : namespace Alts {
      22             : 
      23             : using ::grpc::gcp::HandshakerResp;
      24             : 
      25             : constexpr std::size_t AltsAes128GcmRekeyKeyLength = 44;
      26             : 
      27             : std::unique_ptr<AltsTsiHandshaker>
      28           0 : AltsTsiHandshaker::createForClient(std::shared_ptr<grpc::Channel> handshaker_service_channel) {
      29           0 :   return absl::WrapUnique(new AltsTsiHandshaker(/*is_client=*/true, handshaker_service_channel));
      30           0 : }
      31             : 
      32             : std::unique_ptr<AltsTsiHandshaker>
      33           0 : AltsTsiHandshaker::createForServer(std::shared_ptr<grpc::Channel> handshaker_service_channel) {
      34           0 :   return absl::WrapUnique(new AltsTsiHandshaker(/*is_client=*/false, handshaker_service_channel));
      35           0 : }
      36             : 
      37             : AltsTsiHandshaker::AltsTsiHandshaker(bool is_client,
      38             :                                      std::shared_ptr<grpc::Channel> handshaker_service_channel)
      39           0 :     : is_client_(is_client), handshaker_service_channel_(handshaker_service_channel) {}
      40             : 
      41             : absl::Status AltsTsiHandshaker::next(void* handshaker, const unsigned char* received_bytes,
      42           0 :                                      size_t received_bytes_size, OnNextDone on_next_done) {
      43             :   // Argument and state checks.
      44           0 :   if (handshaker == nullptr || (received_bytes == nullptr && received_bytes_size > 0) ||
      45           0 :       on_next_done == nullptr) {
      46           0 :     return absl::InvalidArgumentError("Invalid nullptr argument to AltsTsiHandshaker::Next.");
      47           0 :   }
      48           0 :   if (is_handshake_complete_) {
      49           0 :     return absl::InternalError("Handshake is already complete.");
      50           0 :   }
      51             : 
      52             :   // Get a handshake message from the handshaker service.
      53           0 :   absl::Span<const uint8_t> in_bytes = absl::MakeConstSpan(received_bytes, received_bytes_size);
      54           0 :   HandshakerResp response;
      55           0 :   if (!has_sent_initial_handshake_message_) {
      56           0 :     has_sent_initial_handshake_message_ = true;
      57           0 :     auto alts_proxy = AltsProxy::create(handshaker_service_channel_);
      58           0 :     if (!alts_proxy.ok()) {
      59           0 :       return alts_proxy.status();
      60           0 :     }
      61           0 :     alts_proxy_ = *std::move(alts_proxy);
      62           0 :     if (is_client_) {
      63           0 :       auto client_start = alts_proxy_->sendStartClientHandshakeReq();
      64           0 :       if (!client_start.ok()) {
      65           0 :         return client_start.status();
      66           0 :       }
      67           0 :       response = *std::move(client_start);
      68           0 :     } else {
      69           0 :       auto server_start = alts_proxy_->sendStartServerHandshakeReq(in_bytes);
      70           0 :       if (!server_start.ok()) {
      71           0 :         return server_start.status();
      72           0 :       }
      73           0 :       response = *std::move(server_start);
      74           0 :     }
      75           0 :   } else {
      76           0 :     auto next = alts_proxy_->sendNextHandshakeReq(in_bytes);
      77           0 :     if (!next.ok()) {
      78           0 :       return next.status();
      79           0 :     }
      80           0 :     response = *std::move(next);
      81           0 :   }
      82             : 
      83             :   // Maybe prepare the handshake result.
      84           0 :   std::unique_ptr<AltsHandshakeResult> handshake_result = nullptr;
      85           0 :   if (response.has_result()) {
      86           0 :     is_handshake_complete_ = true;
      87           0 :     auto result = getHandshakeResult(response.result(), in_bytes, response.bytes_consumed());
      88           0 :     if (!result.ok()) {
      89           0 :       return result.status();
      90           0 :     }
      91           0 :     handshake_result = *std::move(result);
      92           0 :   }
      93             : 
      94             :   // Write the out bytes.
      95           0 :   const std::string& out_bytes = response.out_frames();
      96           0 :   on_next_done(absl::OkStatus(), handshaker,
      97           0 :                reinterpret_cast<const unsigned char*>(out_bytes.c_str()), out_bytes.size(),
      98           0 :                std::move(handshake_result));
      99           0 :   return absl::OkStatus();
     100           0 : }
     101             : 
     102           0 : std::size_t AltsTsiHandshaker::computeMaxFrameSize(const grpc::gcp::HandshakerResult& result) {
     103           0 :   if (result.max_frame_size() > 0) {
     104           0 :     return std::clamp(static_cast<std::size_t>(result.max_frame_size()), AltsMinFrameSize,
     105           0 :                       MaxFrameSize);
     106           0 :   }
     107           0 :   return AltsMinFrameSize;
     108           0 : }
     109             : 
     110             : absl::StatusOr<std::unique_ptr<AltsHandshakeResult>>
     111             : AltsTsiHandshaker::getHandshakeResult(const grpc::gcp::HandshakerResult& result,
     112             :                                       absl::Span<const uint8_t> received_bytes,
     113           0 :                                       std::size_t bytes_consumed) {
     114             :   // Validate the HandshakerResult message.
     115           0 :   if (!result.has_peer_identity()) {
     116           0 :     return absl::FailedPreconditionError("Handshake result is missing peer identity.");
     117           0 :   }
     118           0 :   if (!result.has_local_identity()) {
     119           0 :     return absl::FailedPreconditionError("Handshake result is missing local identity.");
     120           0 :   }
     121           0 :   if (!result.has_peer_rpc_versions()) {
     122           0 :     return absl::FailedPreconditionError("Handshake result is missing peer rpc versions.");
     123           0 :   }
     124           0 :   if (result.application_protocol().empty()) {
     125           0 :     return absl::FailedPreconditionError("Handshake result has empty application protocol.");
     126           0 :   }
     127           0 :   if (result.record_protocol() != RecordProtocol) {
     128           0 :     return absl::FailedPreconditionError(
     129           0 :         "Handshake result's record protocol is not ALTSRP_GCM_AES128_REKEY.");
     130           0 :   }
     131           0 :   if (result.key_data().size() < AltsAes128GcmRekeyKeyLength) {
     132           0 :     return absl::FailedPreconditionError("Handshake result's key data is too short.");
     133           0 :   }
     134           0 :   if (bytes_consumed > received_bytes.size()) {
     135           0 :     return absl::FailedPreconditionError(
     136           0 :         "Handshaker service consumed more bytes than were received from the "
     137           0 :         "peer.");
     138           0 :   }
     139             : 
     140             :   // Create the frame protector.
     141           0 :   std::size_t max_frame_size = computeMaxFrameSize(result);
     142           0 :   tsi_zero_copy_grpc_protector* protector = nullptr;
     143           0 :   grpc_core::ExecCtx exec_ctx;
     144           0 :   tsi_result ok = alts_zero_copy_grpc_protector_create(
     145           0 :       reinterpret_cast<const uint8_t*>(result.key_data().data()), AltsAes128GcmRekeyKeyLength, true,
     146           0 :       is_client_,
     147           0 :       /*is_integrity_only=*/false, /*enable_extra_copy=*/false, &max_frame_size, &protector);
     148           0 :   if (ok != TSI_OK) {
     149           0 :     return absl::InternalError(absl::StrFormat("Failed to create frame protector: %zu", ok));
     150           0 :   }
     151             : 
     152             :   // Calculate the unused bytes.
     153           0 :   std::size_t unused_bytes_size = received_bytes.size() - bytes_consumed;
     154           0 :   const uint8_t* unused_bytes_ptr = received_bytes.data() + bytes_consumed;
     155           0 :   std::vector<uint8_t> unused_bytes(unused_bytes_ptr, unused_bytes_ptr + unused_bytes_size);
     156             : 
     157             :   // Create and return the AltsHandshakeResult.
     158           0 :   auto handshake_result = std::make_unique<AltsHandshakeResult>();
     159           0 :   handshake_result->frame_protector = std::make_unique<TsiFrameProtector>(protector);
     160           0 :   handshake_result->peer_identity = result.peer_identity().service_account();
     161           0 :   handshake_result->unused_bytes = unused_bytes;
     162           0 :   return handshake_result;
     163           0 : }
     164             : 
     165             : } // namespace Alts
     166             : } // namespace TransportSockets
     167             : } // namespace Extensions
     168             : } // namespace Envoy

Generated by: LCOV version 1.15