LCOV - code coverage report
Current view: top level - source/extensions/transport_sockets/alts - tsi_socket.cc (source / functions) Hit Total Coverage
Test: coverage.dat Lines: 0 278 0.0 %
Date: 2024-01-05 06:35:25 Functions: 0 22 0.0 %

          Line data    Source code
       1             : #include "source/extensions/transport_sockets/alts/tsi_socket.h"
       2             : 
       3             : #include <algorithm>
       4             : #include <memory>
       5             : #include <string>
       6             : #include <utility>
       7             : 
       8             : #include "source/common/common/assert.h"
       9             : #include "source/common/common/cleanup.h"
      10             : #include "source/common/common/empty_string.h"
      11             : #include "source/common/common/enum_to_int.h"
      12             : #include "source/common/network/raw_buffer_socket.h"
      13             : 
      14             : namespace Envoy {
      15             : namespace Extensions {
      16             : namespace TransportSockets {
      17             : namespace Alts {
      18             : 
      19             : TsiSocket::TsiSocket(HandshakerFactory handshaker_factory, HandshakeValidator handshake_validator,
      20             :                      Network::TransportSocketPtr&& raw_socket, bool downstream)
      21             :     : handshaker_factory_(handshaker_factory), handshake_validator_(handshake_validator),
      22           0 :       raw_buffer_socket_(std::move(raw_socket)), downstream_(downstream) {}
      23             : 
      24             : TsiSocket::TsiSocket(HandshakerFactory handshaker_factory, HandshakeValidator handshake_validator,
      25             :                      bool downstream)
      26             :     : TsiSocket(handshaker_factory, handshake_validator,
      27           0 :                 std::make_unique<Network::RawBufferSocket>(), downstream) {
      28           0 :   raw_read_buffer_.setWatermarks(default_max_frame_size_);
      29           0 : }
      30             : 
      31           0 : TsiSocket::~TsiSocket() { ASSERT(!handshaker_); }
      32             : 
      33           0 : void TsiSocket::setTransportSocketCallbacks(Envoy::Network::TransportSocketCallbacks& callbacks) {
      34           0 :   ASSERT(!callbacks_);
      35           0 :   callbacks_ = &callbacks;
      36             : 
      37           0 :   tsi_callbacks_ = std::make_unique<TsiTransportSocketCallbacks>(callbacks, raw_read_buffer_);
      38           0 :   raw_buffer_socket_->setTransportSocketCallbacks(*tsi_callbacks_);
      39           0 : }
      40             : 
      41           0 : std::string TsiSocket::protocol() const {
      42             :   // TSI doesn't have a generic way to indicate application layer protocol.
      43             :   // TODO(lizan): support application layer protocol from TSI for known TSIs.
      44           0 :   return EMPTY_STRING;
      45           0 : }
      46             : 
      47           0 : absl::string_view TsiSocket::failureReason() const {
      48             :   // TODO(htuch): Implement error reason for TSI.
      49           0 :   return EMPTY_STRING;
      50           0 : }
      51             : 
      52           0 : Network::PostIoAction TsiSocket::doHandshake() {
      53           0 :   ASSERT(!handshake_complete_);
      54           0 :   ENVOY_CONN_LOG(debug, "TSI: doHandshake", callbacks_->connection());
      55           0 :   if (!handshaker_next_calling_ && raw_read_buffer_.length() > 0) {
      56           0 :     return doHandshakeNext();
      57           0 :   }
      58           0 :   return Network::PostIoAction::KeepOpen;
      59           0 : }
      60             : 
      61           0 : Network::PostIoAction TsiSocket::doHandshakeNext() {
      62           0 :   ENVOY_CONN_LOG(debug, "TSI: doHandshake next: received: {}", callbacks_->connection(),
      63           0 :                  raw_read_buffer_.length());
      64             : 
      65           0 :   if (!handshaker_) {
      66           0 :     handshaker_ =
      67           0 :         handshaker_factory_(callbacks_->connection().dispatcher(),
      68           0 :                             callbacks_->connection().connectionInfoProvider().localAddress(),
      69           0 :                             callbacks_->connection().connectionInfoProvider().remoteAddress());
      70           0 :     if (!handshaker_) {
      71           0 :       ENVOY_CONN_LOG(warn, "TSI: failed to create handshaker", callbacks_->connection());
      72           0 :       callbacks_->connection().close(Network::ConnectionCloseType::NoFlush,
      73           0 :                                      "failed_creating_handshaker");
      74           0 :       return Network::PostIoAction::Close;
      75           0 :     }
      76             : 
      77           0 :     handshaker_->setHandshakerCallbacks(*this);
      78           0 :   }
      79             : 
      80           0 :   handshaker_next_calling_ = true;
      81           0 :   Buffer::OwnedImpl handshaker_buffer;
      82           0 :   handshaker_buffer.move(raw_read_buffer_);
      83           0 :   absl::Status status = handshaker_->next(handshaker_buffer);
      84           0 :   if (!status.ok()) {
      85           0 :     ENVOY_CONN_LOG(debug, "TSI: Handshake failed: status: {}", callbacks_->connection(), status);
      86           0 :     return Network::PostIoAction::Close;
      87           0 :   }
      88           0 :   return Network::PostIoAction::KeepOpen;
      89           0 : }
      90             : 
      91           0 : Network::PostIoAction TsiSocket::doHandshakeNextDone(NextResultPtr&& next_result) {
      92           0 :   ASSERT(next_result);
      93             : 
      94           0 :   ENVOY_CONN_LOG(debug, "TSI: doHandshake next done: status: {} to_send: {}",
      95           0 :                  callbacks_->connection(), next_result->status_, next_result->to_send_->length());
      96             : 
      97           0 :   absl::Status status = next_result->status_;
      98           0 :   AltsHandshakeResult* handshake_result = next_result->result_.get();
      99           0 :   if (!status.ok()) {
     100           0 :     ENVOY_CONN_LOG(debug, "TSI: Handshake failed: status: {}", callbacks_->connection(), status);
     101           0 :     return Network::PostIoAction::Close;
     102           0 :   }
     103             : 
     104           0 :   if (next_result->to_send_->length() > 0) {
     105           0 :     raw_write_buffer_.move(*next_result->to_send_);
     106           0 :   }
     107             : 
     108           0 :   if (status.ok() && handshake_result != nullptr) {
     109           0 :     if (handshake_validator_) {
     110           0 :       std::string err;
     111           0 :       TsiInfo tsi_info;
     112           0 :       tsi_info.peer_identity_ = handshake_result->peer_identity;
     113           0 :       const bool peer_validated = handshake_validator_(tsi_info, err);
     114           0 :       if (peer_validated) {
     115           0 :         ENVOY_CONN_LOG(debug, "TSI: Handshake validation succeeded.", callbacks_->connection());
     116           0 :       } else {
     117           0 :         ENVOY_CONN_LOG(debug, "TSI: Handshake validation failed: {}", callbacks_->connection(),
     118           0 :                        err);
     119           0 :         return Network::PostIoAction::Close;
     120           0 :       }
     121           0 :       ProtobufWkt::Struct dynamic_metadata;
     122           0 :       ProtobufWkt::Value val;
     123           0 :       val.set_string_value(tsi_info.peer_identity_);
     124           0 :       dynamic_metadata.mutable_fields()->insert({std::string("peer_identity"), val});
     125           0 :       callbacks_->connection().streamInfo().setDynamicMetadata(
     126           0 :           "envoy.transport_sockets.peer_information", dynamic_metadata);
     127           0 :       ENVOY_CONN_LOG(debug, "TSI handshake with peer: {}", callbacks_->connection(),
     128           0 :                      tsi_info.peer_identity_);
     129           0 :     } else {
     130           0 :       ENVOY_CONN_LOG(debug, "TSI: Handshake validation skipped.", callbacks_->connection());
     131           0 :     }
     132             : 
     133           0 :     if (!handshake_result->unused_bytes.empty()) {
     134             :       // All handshake data is consumed.
     135           0 :       ASSERT(raw_read_buffer_.length() == 0);
     136           0 :       absl::string_view unused_bytes(
     137           0 :           reinterpret_cast<const char*>(handshake_result->unused_bytes.data()),
     138           0 :           handshake_result->unused_bytes.size());
     139           0 :       raw_read_buffer_.prepend(unused_bytes);
     140           0 :     }
     141           0 :     ENVOY_CONN_LOG(debug, "TSI: Handshake successful: unused_bytes: {}", callbacks_->connection(),
     142           0 :                    handshake_result->unused_bytes.size());
     143             :     // Reset the watermarks with actual negotiated max frame size.
     144           0 :     raw_read_buffer_.setWatermarks(
     145           0 :         std::max<size_t>(actual_frame_size_to_use_, callbacks_->connection().bufferLimit()));
     146           0 :     frame_protector_ = std::move(handshake_result->frame_protector);
     147             : 
     148           0 :     handshake_complete_ = true;
     149           0 :     if (raw_write_buffer_.length() == 0) {
     150           0 :       callbacks_->raiseEvent(Network::ConnectionEvent::Connected);
     151           0 :     }
     152           0 :   }
     153             : 
     154           0 :   if (read_error_ || (!handshake_complete_ && end_stream_read_)) {
     155           0 :     ENVOY_CONN_LOG(debug, "TSI: Handshake failed: end of stream without enough data",
     156           0 :                    callbacks_->connection());
     157           0 :     return Network::PostIoAction::Close;
     158           0 :   }
     159             : 
     160           0 :   if (raw_read_buffer_.length() > 0) {
     161           0 :     callbacks_->setTransportSocketIsReadable();
     162           0 :   }
     163             : 
     164             :   // Try to write raw buffer when next call is done, even this is not in do[Read|Write] stack.
     165           0 :   if (raw_write_buffer_.length() > 0) {
     166           0 :     Network::IoResult result = raw_buffer_socket_->doWrite(raw_write_buffer_, false);
     167           0 :     if (handshake_complete_ && result.action_ != Network::PostIoAction::Close) {
     168           0 :       callbacks_->raiseEvent(Network::ConnectionEvent::Connected);
     169           0 :     }
     170           0 :     return result.action_;
     171           0 :   }
     172             : 
     173           0 :   return Network::PostIoAction::KeepOpen;
     174           0 : }
     175             : 
     176             : Network::IoResult TsiSocket::repeatReadAndUnprotect(Buffer::Instance& buffer,
     177           0 :                                                     Network::IoResult prev_result) {
     178           0 :   Network::IoResult result = prev_result;
     179           0 :   uint64_t total_bytes_processed = 0;
     180             : 
     181           0 :   while (true) {
     182             :     // Do unprotect.
     183           0 :     if (raw_read_buffer_.length() > 0) {
     184           0 :       uint64_t prev_size = buffer.length();
     185           0 :       ENVOY_CONN_LOG(debug, "TSI: unprotecting buffer size: {}", callbacks_->connection(),
     186           0 :                      raw_read_buffer_.length());
     187           0 :       tsi_result status = frame_protector_->unprotect(raw_read_buffer_, buffer);
     188           0 :       if (status != TSI_OK) {
     189           0 :         ENVOY_CONN_LOG(debug, "TSI: unprotect failed: status: {}", callbacks_->connection(),
     190           0 :                        status);
     191           0 :         result.action_ = Network::PostIoAction::Close;
     192           0 :         break;
     193           0 :       }
     194           0 :       ASSERT(raw_read_buffer_.length() == 0);
     195           0 :       ENVOY_CONN_LOG(debug, "TSI: unprotected buffer left: {} result: {}", callbacks_->connection(),
     196           0 :                      raw_read_buffer_.length(), tsi_result_to_string(status));
     197           0 :       total_bytes_processed += buffer.length() - prev_size;
     198             : 
     199             :       // Check if buffer needs to be drained.
     200           0 :       if (callbacks_->shouldDrainReadBuffer()) {
     201           0 :         callbacks_->setTransportSocketIsReadable();
     202           0 :         break;
     203           0 :       }
     204           0 :     }
     205             : 
     206           0 :     if (result.action_ == Network::PostIoAction::Close) {
     207           0 :       break;
     208           0 :     }
     209             : 
     210             :     // End of stream is reached in the previous read.
     211           0 :     if (end_stream_read_) {
     212           0 :       result.end_stream_read_ = true;
     213           0 :       break;
     214           0 :     }
     215             :     // Do another read.
     216           0 :     result = readFromRawSocket();
     217             :     // No data is read.
     218           0 :     if (result.bytes_processed_ == 0) {
     219           0 :       break;
     220           0 :     }
     221           0 :   };
     222           0 :   result.bytes_processed_ = total_bytes_processed;
     223           0 :   ENVOY_CONN_LOG(debug, "TSI: do read result action {} bytes {} end_stream {}",
     224           0 :                  callbacks_->connection(), enumToInt(result.action_), result.bytes_processed_,
     225           0 :                  result.end_stream_read_);
     226           0 :   return result;
     227           0 : }
     228             : 
     229           0 : Network::IoResult TsiSocket::readFromRawSocket() {
     230           0 :   Network::IoResult result = raw_buffer_socket_->doRead(raw_read_buffer_);
     231           0 :   end_stream_read_ = result.end_stream_read_;
     232           0 :   read_error_ = result.action_ == Network::PostIoAction::Close;
     233           0 :   return result;
     234           0 : }
     235             : 
     236           0 : Network::IoResult TsiSocket::doRead(Buffer::Instance& buffer) {
     237           0 :   Network::IoResult result = {Network::PostIoAction::KeepOpen, 0, false};
     238           0 :   if (!handshake_complete_) {
     239           0 :     if (!end_stream_read_ && !read_error_) {
     240           0 :       result = readFromRawSocket();
     241           0 :       ENVOY_CONN_LOG(debug, "TSI: raw read result action {} bytes {} end_stream {}",
     242           0 :                      callbacks_->connection(), enumToInt(result.action_), result.bytes_processed_,
     243           0 :                      result.end_stream_read_);
     244           0 :       if (result.action_ == Network::PostIoAction::Close && result.bytes_processed_ == 0) {
     245           0 :         return result;
     246           0 :       }
     247             : 
     248           0 :       if (result.end_stream_read_ && result.bytes_processed_ == 0) {
     249           0 :         return {Network::PostIoAction::Close, result.bytes_processed_, result.end_stream_read_};
     250           0 :       }
     251           0 :     }
     252           0 :     Network::PostIoAction action = doHandshake();
     253           0 :     if (action == Network::PostIoAction::Close || !handshake_complete_) {
     254           0 :       return {action, 0, false};
     255           0 :     }
     256           0 :   }
     257             :   // Handshake finishes.
     258           0 :   ASSERT(handshake_complete_);
     259           0 :   ASSERT(frame_protector_);
     260           0 :   return repeatReadAndUnprotect(buffer, result);
     261           0 : }
     262             : 
     263           0 : Network::IoResult TsiSocket::repeatProtectAndWrite(Buffer::Instance& buffer, bool end_stream) {
     264           0 :   uint64_t total_bytes_written = 0;
     265           0 :   Network::IoResult result = {Network::PostIoAction::KeepOpen, 0, false};
     266             :   // There should be no handshake bytes in raw_write_buffer_.
     267           0 :   ASSERT(!(raw_write_buffer_.length() > 0 && prev_bytes_to_drain_ == 0));
     268           0 :   while (true) {
     269           0 :     uint64_t bytes_to_drain_this_iteration =
     270           0 :         prev_bytes_to_drain_ > 0
     271           0 :             ? prev_bytes_to_drain_
     272           0 :             : std::min<uint64_t>(buffer.length(), actual_frame_size_to_use_ - frame_overhead_size_);
     273             :     // Consumed all data. Exit.
     274           0 :     if (bytes_to_drain_this_iteration == 0) {
     275           0 :       break;
     276           0 :     }
     277             :     // Short write did not occur previously.
     278           0 :     if (raw_write_buffer_.length() == 0) {
     279           0 :       ASSERT(frame_protector_);
     280           0 :       ASSERT(prev_bytes_to_drain_ == 0);
     281             : 
     282             :       // Do protect.
     283           0 :       ENVOY_CONN_LOG(debug, "TSI: protecting buffer size: {}", callbacks_->connection(),
     284           0 :                      bytes_to_drain_this_iteration);
     285           0 :       tsi_result status = frame_protector_->protect(
     286           0 :           grpc_slice_from_static_buffer(buffer.linearize(bytes_to_drain_this_iteration),
     287           0 :                                         bytes_to_drain_this_iteration),
     288           0 :           raw_write_buffer_);
     289           0 :       ENVOY_CONN_LOG(debug, "TSI: protected buffer left: {} result: {}", callbacks_->connection(),
     290           0 :                      bytes_to_drain_this_iteration, tsi_result_to_string(status));
     291           0 :     }
     292             : 
     293             :     // Write raw_write_buffer_ to network.
     294           0 :     ENVOY_CONN_LOG(debug, "TSI: raw_write length {} end_stream {}", callbacks_->connection(),
     295           0 :                    raw_write_buffer_.length(), end_stream);
     296           0 :     result = raw_buffer_socket_->doWrite(raw_write_buffer_, end_stream && (buffer.length() == 0));
     297             : 
     298             :     // Short write. Exit.
     299           0 :     if (raw_write_buffer_.length() > 0) {
     300           0 :       prev_bytes_to_drain_ = bytes_to_drain_this_iteration;
     301           0 :       break;
     302           0 :     } else {
     303           0 :       buffer.drain(bytes_to_drain_this_iteration);
     304           0 :       prev_bytes_to_drain_ = 0;
     305           0 :       total_bytes_written += bytes_to_drain_this_iteration;
     306           0 :     }
     307           0 :   }
     308             : 
     309           0 :   return {result.action_, total_bytes_written, false};
     310           0 : }
     311             : 
     312           0 : Network::IoResult TsiSocket::doWrite(Buffer::Instance& buffer, bool end_stream) {
     313           0 :   if (!handshake_complete_) {
     314           0 :     Network::PostIoAction action = doHandshake();
     315           0 :     ASSERT(!handshake_complete_);
     316           0 :     return {action, 0, false};
     317           0 :   } else {
     318           0 :     ASSERT(frame_protector_);
     319             :     // Check if we need to flush outstanding handshake bytes.
     320           0 :     if (raw_write_buffer_.length() > 0 && prev_bytes_to_drain_ == 0) {
     321           0 :       ENVOY_CONN_LOG(debug, "TSI: raw_write length {} end_stream {}", callbacks_->connection(),
     322           0 :                      raw_write_buffer_.length(), end_stream);
     323           0 :       Network::IoResult result =
     324           0 :           raw_buffer_socket_->doWrite(raw_write_buffer_, end_stream && (buffer.length() == 0));
     325             :       // Check if short write occurred.
     326           0 :       if (raw_write_buffer_.length() > 0) {
     327           0 :         return {result.action_, 0, false};
     328           0 :       }
     329           0 :     }
     330           0 :     return repeatProtectAndWrite(buffer, end_stream);
     331           0 :   }
     332           0 : }
     333             : 
     334           0 : void TsiSocket::closeSocket(Network::ConnectionEvent) {
     335           0 :   ENVOY_CONN_LOG(debug, "TSI: closing socket", callbacks_->connection());
     336           0 :   if (handshaker_) {
     337           0 :     handshaker_.release()->deferredDelete();
     338           0 :   }
     339           0 : }
     340             : 
     341           0 : void TsiSocket::onConnected() {
     342           0 :   ASSERT(!handshake_complete_);
     343             :   // Client initiates the handshake, so ignore onConnect call on the downstream.
     344           0 :   if (!downstream_) {
     345           0 :     doHandshakeNext();
     346           0 :   }
     347           0 : }
     348             : 
     349           0 : void TsiSocket::onNextDone(NextResultPtr&& result) {
     350           0 :   handshaker_next_calling_ = false;
     351             : 
     352           0 :   Network::PostIoAction action = doHandshakeNextDone(std::move(result));
     353           0 :   if (action == Network::PostIoAction::Close) {
     354           0 :     callbacks_->connection().close(Network::ConnectionCloseType::NoFlush, "tsi_handshake_failed");
     355           0 :   }
     356           0 : }
     357             : 
     358             : TsiSocketFactory::TsiSocketFactory(HandshakerFactory handshaker_factory,
     359             :                                    HandshakeValidator handshake_validator)
     360             :     : handshaker_factory_(std::move(handshaker_factory)),
     361           0 :       handshake_validator_(std::move(handshake_validator)) {}
     362             : 
     363           0 : bool TsiSocketFactory::implementsSecureTransport() const { return true; }
     364             : 
     365             : Network::TransportSocketPtr
     366             : TsiSocketFactory::createTransportSocket(Network::TransportSocketOptionsConstSharedPtr,
     367           0 :                                         Upstream::HostDescriptionConstSharedPtr) const {
     368           0 :   return std::make_unique<TsiSocket>(handshaker_factory_, handshake_validator_, false);
     369           0 : }
     370             : 
     371           0 : Network::TransportSocketPtr TsiSocketFactory::createDownstreamTransportSocket() const {
     372           0 :   return std::make_unique<TsiSocket>(handshaker_factory_, handshake_validator_, true);
     373           0 : }
     374             : 
     375             : } // namespace Alts
     376             : } // namespace TransportSockets
     377             : } // namespace Extensions
     378             : } // namespace Envoy

Generated by: LCOV version 1.15