Line data Source code
1 : #include "source/extensions/transport_sockets/alts/alts_tsi_handshaker.h" 2 : 3 : #include <algorithm> 4 : #include <cstdint> 5 : #include <cstring> 6 : #include <memory> 7 : #include <string> 8 : #include <utility> 9 : #include <vector> 10 : 11 : #include "absl/memory/memory.h" 12 : #include "absl/status/status.h" 13 : #include "absl/strings/str_format.h" 14 : #include "absl/types/span.h" 15 : #include "src/core/lib/iomgr/exec_ctx.h" 16 : #include "src/core/tsi/alts/zero_copy_frame_protector/alts_zero_copy_grpc_protector.h" 17 : 18 : namespace Envoy { 19 : namespace Extensions { 20 : namespace TransportSockets { 21 : namespace Alts { 22 : 23 : using ::grpc::gcp::HandshakerResp; 24 : 25 : constexpr std::size_t AltsAes128GcmRekeyKeyLength = 44; 26 : 27 : std::unique_ptr<AltsTsiHandshaker> 28 0 : AltsTsiHandshaker::createForClient(std::shared_ptr<grpc::Channel> handshaker_service_channel) { 29 0 : return absl::WrapUnique(new AltsTsiHandshaker(/*is_client=*/true, handshaker_service_channel)); 30 0 : } 31 : 32 : std::unique_ptr<AltsTsiHandshaker> 33 0 : AltsTsiHandshaker::createForServer(std::shared_ptr<grpc::Channel> handshaker_service_channel) { 34 0 : return absl::WrapUnique(new AltsTsiHandshaker(/*is_client=*/false, handshaker_service_channel)); 35 0 : } 36 : 37 : AltsTsiHandshaker::AltsTsiHandshaker(bool is_client, 38 : std::shared_ptr<grpc::Channel> handshaker_service_channel) 39 0 : : is_client_(is_client), handshaker_service_channel_(handshaker_service_channel) {} 40 : 41 : absl::Status AltsTsiHandshaker::next(void* handshaker, const unsigned char* received_bytes, 42 0 : size_t received_bytes_size, OnNextDone on_next_done) { 43 : // Argument and state checks. 44 0 : if (handshaker == nullptr || (received_bytes == nullptr && received_bytes_size > 0) || 45 0 : on_next_done == nullptr) { 46 0 : return absl::InvalidArgumentError("Invalid nullptr argument to AltsTsiHandshaker::Next."); 47 0 : } 48 0 : if (is_handshake_complete_) { 49 0 : return absl::InternalError("Handshake is already complete."); 50 0 : } 51 : 52 : // Get a handshake message from the handshaker service. 53 0 : absl::Span<const uint8_t> in_bytes = absl::MakeConstSpan(received_bytes, received_bytes_size); 54 0 : HandshakerResp response; 55 0 : if (!has_sent_initial_handshake_message_) { 56 0 : has_sent_initial_handshake_message_ = true; 57 0 : auto alts_proxy = AltsProxy::create(handshaker_service_channel_); 58 0 : if (!alts_proxy.ok()) { 59 0 : return alts_proxy.status(); 60 0 : } 61 0 : alts_proxy_ = *std::move(alts_proxy); 62 0 : if (is_client_) { 63 0 : auto client_start = alts_proxy_->sendStartClientHandshakeReq(); 64 0 : if (!client_start.ok()) { 65 0 : return client_start.status(); 66 0 : } 67 0 : response = *std::move(client_start); 68 0 : } else { 69 0 : auto server_start = alts_proxy_->sendStartServerHandshakeReq(in_bytes); 70 0 : if (!server_start.ok()) { 71 0 : return server_start.status(); 72 0 : } 73 0 : response = *std::move(server_start); 74 0 : } 75 0 : } else { 76 0 : auto next = alts_proxy_->sendNextHandshakeReq(in_bytes); 77 0 : if (!next.ok()) { 78 0 : return next.status(); 79 0 : } 80 0 : response = *std::move(next); 81 0 : } 82 : 83 : // Maybe prepare the handshake result. 84 0 : std::unique_ptr<AltsHandshakeResult> handshake_result = nullptr; 85 0 : if (response.has_result()) { 86 0 : is_handshake_complete_ = true; 87 0 : auto result = getHandshakeResult(response.result(), in_bytes, response.bytes_consumed()); 88 0 : if (!result.ok()) { 89 0 : return result.status(); 90 0 : } 91 0 : handshake_result = *std::move(result); 92 0 : } 93 : 94 : // Write the out bytes. 95 0 : const std::string& out_bytes = response.out_frames(); 96 0 : on_next_done(absl::OkStatus(), handshaker, 97 0 : reinterpret_cast<const unsigned char*>(out_bytes.c_str()), out_bytes.size(), 98 0 : std::move(handshake_result)); 99 0 : return absl::OkStatus(); 100 0 : } 101 : 102 0 : std::size_t AltsTsiHandshaker::computeMaxFrameSize(const grpc::gcp::HandshakerResult& result) { 103 0 : if (result.max_frame_size() > 0) { 104 0 : return std::clamp(static_cast<std::size_t>(result.max_frame_size()), AltsMinFrameSize, 105 0 : MaxFrameSize); 106 0 : } 107 0 : return AltsMinFrameSize; 108 0 : } 109 : 110 : absl::StatusOr<std::unique_ptr<AltsHandshakeResult>> 111 : AltsTsiHandshaker::getHandshakeResult(const grpc::gcp::HandshakerResult& result, 112 : absl::Span<const uint8_t> received_bytes, 113 0 : std::size_t bytes_consumed) { 114 : // Validate the HandshakerResult message. 115 0 : if (!result.has_peer_identity()) { 116 0 : return absl::FailedPreconditionError("Handshake result is missing peer identity."); 117 0 : } 118 0 : if (!result.has_local_identity()) { 119 0 : return absl::FailedPreconditionError("Handshake result is missing local identity."); 120 0 : } 121 0 : if (!result.has_peer_rpc_versions()) { 122 0 : return absl::FailedPreconditionError("Handshake result is missing peer rpc versions."); 123 0 : } 124 0 : if (result.application_protocol().empty()) { 125 0 : return absl::FailedPreconditionError("Handshake result has empty application protocol."); 126 0 : } 127 0 : if (result.record_protocol() != RecordProtocol) { 128 0 : return absl::FailedPreconditionError( 129 0 : "Handshake result's record protocol is not ALTSRP_GCM_AES128_REKEY."); 130 0 : } 131 0 : if (result.key_data().size() < AltsAes128GcmRekeyKeyLength) { 132 0 : return absl::FailedPreconditionError("Handshake result's key data is too short."); 133 0 : } 134 0 : if (bytes_consumed > received_bytes.size()) { 135 0 : return absl::FailedPreconditionError( 136 0 : "Handshaker service consumed more bytes than were received from the " 137 0 : "peer."); 138 0 : } 139 : 140 : // Create the frame protector. 141 0 : std::size_t max_frame_size = computeMaxFrameSize(result); 142 0 : tsi_zero_copy_grpc_protector* protector = nullptr; 143 0 : grpc_core::ExecCtx exec_ctx; 144 0 : tsi_result ok = alts_zero_copy_grpc_protector_create( 145 0 : reinterpret_cast<const uint8_t*>(result.key_data().data()), AltsAes128GcmRekeyKeyLength, true, 146 0 : is_client_, 147 0 : /*is_integrity_only=*/false, /*enable_extra_copy=*/false, &max_frame_size, &protector); 148 0 : if (ok != TSI_OK) { 149 0 : return absl::InternalError(absl::StrFormat("Failed to create frame protector: %zu", ok)); 150 0 : } 151 : 152 : // Calculate the unused bytes. 153 0 : std::size_t unused_bytes_size = received_bytes.size() - bytes_consumed; 154 0 : const uint8_t* unused_bytes_ptr = received_bytes.data() + bytes_consumed; 155 0 : std::vector<uint8_t> unused_bytes(unused_bytes_ptr, unused_bytes_ptr + unused_bytes_size); 156 : 157 : // Create and return the AltsHandshakeResult. 158 0 : auto handshake_result = std::make_unique<AltsHandshakeResult>(); 159 0 : handshake_result->frame_protector = std::make_unique<TsiFrameProtector>(protector); 160 0 : handshake_result->peer_identity = result.peer_identity().service_account(); 161 0 : handshake_result->unused_bytes = unused_bytes; 162 0 : return handshake_result; 163 0 : } 164 : 165 : } // namespace Alts 166 : } // namespace TransportSockets 167 : } // namespace Extensions 168 : } // namespace Envoy