LCOV - code coverage report
Current view: top level - source/extensions/transport_sockets/alts - tsi_socket.h (source / functions) Hit Total Coverage
Test: coverage.dat Lines: 0 13 0.0 %
Date: 2024-01-05 06:35:25 Functions: 0 13 0.0 %

          Line data    Source code
       1             : #pragma once
       2             : 
       3             : #include "envoy/network/transport_socket.h"
       4             : 
       5             : #include "source/common/buffer/buffer_impl.h"
       6             : #include "source/common/buffer/watermark_buffer.h"
       7             : #include "source/common/network/raw_buffer_socket.h"
       8             : #include "source/common/network/transport_socket_options_impl.h"
       9             : #include "source/extensions/transport_sockets/alts/noop_transport_socket_callbacks.h"
      10             : #include "source/extensions/transport_sockets/alts/tsi_frame_protector.h"
      11             : #include "source/extensions/transport_sockets/alts/tsi_handshaker.h"
      12             : 
      13             : namespace Envoy {
      14             : namespace Extensions {
      15             : namespace TransportSockets {
      16             : namespace Alts {
      17             : 
      18             : struct TsiInfo {
      19             :   std::string peer_identity_;
      20             : };
      21             : 
      22             : /**
      23             :  * A factory function to create TsiHandshaker
      24             :  * @param dispatcher the dispatcher for the thread where the socket is running on.
      25             :  * @param local_address the local address of the connection.
      26             :  * @param remote_address the remote address of the connection.
      27             :  */
      28             : using HandshakerFactory = std::function<TsiHandshakerPtr(
      29             :     Event::Dispatcher& dispatcher, const Network::Address::InstanceConstSharedPtr& local_address,
      30             :     const Network::Address::InstanceConstSharedPtr& remote_address)>;
      31             : 
      32             : /**
      33             :  * A function to validate the peer of the connection.
      34             :  * @param err an error message to indicate why the peer is invalid. This is an
      35             :  * output param that should be populated by the function implementation.
      36             :  * @return true if the peer is valid or false if the peer is invalid.
      37             :  */
      38             : using HandshakeValidator = std::function<bool(TsiInfo& tsi_info, std::string& err)>;
      39             : 
      40             : /* Forward declaration */
      41             : class TsiTransportSocketCallbacks;
      42             : 
      43             : /**
      44             :  * A implementation of Network::TransportSocket based on gRPC TSI
      45             :  */
      46             : class TsiSocket : public Network::TransportSocket,
      47             :                   public TsiHandshakerCallbacks,
      48             :                   public Logger::Loggable<Logger::Id::connection> {
      49             : public:
      50             :   // For Test
      51             :   TsiSocket(HandshakerFactory handshaker_factory, HandshakeValidator handshake_validator,
      52             :             Network::TransportSocketPtr&& raw_socket_ptr, bool downstream);
      53             : 
      54             :   /**
      55             :    * @param handshaker_factory a function to initiate a TsiHandshaker
      56             :    * @param handshake_validator a function to validate the peer. Called right
      57             :    * after the handshake completed with peer data to do the peer validation.
      58             :    * The connection will be closed immediately if it returns false.
      59             :    * @param downstream is true for downstream transport socket.
      60             :    */
      61             :   TsiSocket(HandshakerFactory handshaker_factory, HandshakeValidator handshake_validator,
      62             :             bool downstream);
      63             :   ~TsiSocket() override;
      64             : 
      65             :   // Network::TransportSocket
      66             :   void setTransportSocketCallbacks(Envoy::Network::TransportSocketCallbacks& callbacks) override;
      67             :   std::string protocol() const override;
      68             :   absl::string_view failureReason() const override;
      69           0 :   bool canFlushClose() override { return handshake_complete_; }
      70           0 :   Envoy::Ssl::ConnectionInfoConstSharedPtr ssl() const override { return nullptr; }
      71           0 :   bool startSecureTransport() override { return false; }
      72           0 :   void configureInitialCongestionWindow(uint64_t, std::chrono::microseconds) override {}
      73             :   Network::IoResult doWrite(Buffer::Instance& buffer, bool end_stream) override;
      74             :   void closeSocket(Network::ConnectionEvent event) override;
      75             :   Network::IoResult doRead(Buffer::Instance& buffer) override;
      76             :   void onConnected() override;
      77             : 
      78             :   // TsiHandshakerCallbacks
      79             :   void onNextDone(NextResultPtr&& result) override;
      80             : 
      81             :   // This API should be called only after ALTS handshake finishes successfully.
      82           0 :   size_t actualFrameSizeToUse() { return actual_frame_size_to_use_; }
      83             :   // Set actual_frame_size_to_use_. Exposed for testing purpose.
      84           0 :   void setActualFrameSizeToUse(size_t frame_size) { actual_frame_size_to_use_ = frame_size; }
      85             :   // Set frame_overhead_size_. Exposed for testing purpose.
      86           0 :   void setFrameOverheadSize(size_t overhead_size) { frame_overhead_size_ = overhead_size; }
      87             : 
      88             : private:
      89             :   Network::PostIoAction doHandshake();
      90             :   Network::PostIoAction doHandshakeNext();
      91             :   Network::PostIoAction doHandshakeNextDone(NextResultPtr&& next_result);
      92             : 
      93             :   // Helper function to perform repeated read and unprotect operations.
      94             :   Network::IoResult repeatReadAndUnprotect(Buffer::Instance& buffer, Network::IoResult prev_result);
      95             :   // Helper function to perform repeated protect and write operations.
      96             :   Network::IoResult repeatProtectAndWrite(Buffer::Instance& buffer, bool end_stream);
      97             :   // Helper function to read from a raw socket and update status.
      98             :   Network::IoResult readFromRawSocket();
      99             : 
     100             :   HandshakerFactory handshaker_factory_;
     101             :   HandshakeValidator handshake_validator_;
     102             :   TsiHandshakerPtr handshaker_{};
     103             :   bool handshaker_next_calling_{};
     104             : 
     105             :   TsiFrameProtectorPtr frame_protector_;
     106             :   // default_max_frame_size_ is the maximum frame size supported by
     107             :   // TsiSocket.
     108             :   size_t default_max_frame_size_{16384};
     109             :   // actual_frame_size_to_use_ is the actual frame size used by
     110             :   // frame protector, which is the result of frame size negotiation.
     111             :   size_t actual_frame_size_to_use_{0};
     112             :   // frame_overhead_size_ includes 4 bytes frame message type and 16 bytes tag length.
     113             :   // It is consistent with gRPC ALTS zero copy frame protector implementation.
     114             :   // The maximum size of data that can be protected for each frame is equal to
     115             :   // actual_frame_size_to_use_ - frame_overhead_size_.
     116             :   size_t frame_overhead_size_{20};
     117             : 
     118             :   Envoy::Network::TransportSocketCallbacks* callbacks_{};
     119             :   std::unique_ptr<TsiTransportSocketCallbacks> tsi_callbacks_;
     120             :   Network::TransportSocketPtr raw_buffer_socket_;
     121             :   const bool downstream_;
     122             : 
     123           0 :   Buffer::WatermarkBuffer raw_read_buffer_{[]() {}, []() {}, []() {}};
     124             :   Envoy::Buffer::OwnedImpl raw_write_buffer_;
     125             :   bool handshake_complete_{};
     126             :   bool end_stream_read_{};
     127             :   bool read_error_{};
     128             :   uint64_t prev_bytes_to_drain_{};
     129             : };
     130             : 
     131             : /**
     132             :  * An implementation of Network::UpstreamTransportSocketFactory for TsiSocket
     133             :  */
     134             : class TsiSocketFactory : public Network::DownstreamTransportSocketFactory,
     135             :                          public Network::CommonUpstreamTransportSocketFactory {
     136             : public:
     137             :   TsiSocketFactory(HandshakerFactory handshaker_factory, HandshakeValidator handshake_validator);
     138             : 
     139             :   bool implementsSecureTransport() const override;
     140           0 :   absl::string_view defaultServerNameIndication() const override { return ""; }
     141             : 
     142             :   Network::TransportSocketPtr
     143             :   createTransportSocket(Network::TransportSocketOptionsConstSharedPtr options,
     144             :                         Upstream::HostDescriptionConstSharedPtr) const override;
     145             : 
     146             :   Network::TransportSocketPtr createDownstreamTransportSocket() const override;
     147             : 
     148             : private:
     149             :   HandshakerFactory handshaker_factory_;
     150             :   HandshakeValidator handshake_validator_;
     151             : };
     152             : 
     153             : /**
     154             :  * An implementation of Network::TransportSocketCallbacks for TsiSocket
     155             :  */
     156             : class TsiTransportSocketCallbacks : public NoOpTransportSocketCallbacks {
     157             : public:
     158             :   TsiTransportSocketCallbacks(Network::TransportSocketCallbacks& parent,
     159             :                               const Buffer::WatermarkBuffer& read_buffer)
     160           0 :       : NoOpTransportSocketCallbacks(parent), raw_read_buffer_(read_buffer) {}
     161           0 :   bool shouldDrainReadBuffer() override {
     162           0 :     return raw_read_buffer_.length() >= raw_read_buffer_.highWatermark();
     163           0 :   }
     164             : 
     165             : private:
     166             :   const Buffer::WatermarkBuffer& raw_read_buffer_;
     167             : };
     168             : 
     169             : } // namespace Alts
     170             : } // namespace TransportSockets
     171             : } // namespace Extensions
     172             : } // namespace Envoy

Generated by: LCOV version 1.15