Line data Source code
1 : #pragma once 2 : 3 : #include "envoy/network/transport_socket.h" 4 : 5 : #include "source/common/buffer/buffer_impl.h" 6 : #include "source/common/buffer/watermark_buffer.h" 7 : #include "source/common/network/raw_buffer_socket.h" 8 : #include "source/common/network/transport_socket_options_impl.h" 9 : #include "source/extensions/transport_sockets/alts/noop_transport_socket_callbacks.h" 10 : #include "source/extensions/transport_sockets/alts/tsi_frame_protector.h" 11 : #include "source/extensions/transport_sockets/alts/tsi_handshaker.h" 12 : 13 : namespace Envoy { 14 : namespace Extensions { 15 : namespace TransportSockets { 16 : namespace Alts { 17 : 18 : struct TsiInfo { 19 : std::string peer_identity_; 20 : }; 21 : 22 : /** 23 : * A factory function to create TsiHandshaker 24 : * @param dispatcher the dispatcher for the thread where the socket is running on. 25 : * @param local_address the local address of the connection. 26 : * @param remote_address the remote address of the connection. 27 : */ 28 : using HandshakerFactory = std::function<TsiHandshakerPtr( 29 : Event::Dispatcher& dispatcher, const Network::Address::InstanceConstSharedPtr& local_address, 30 : const Network::Address::InstanceConstSharedPtr& remote_address)>; 31 : 32 : /** 33 : * A function to validate the peer of the connection. 34 : * @param err an error message to indicate why the peer is invalid. This is an 35 : * output param that should be populated by the function implementation. 36 : * @return true if the peer is valid or false if the peer is invalid. 37 : */ 38 : using HandshakeValidator = std::function<bool(TsiInfo& tsi_info, std::string& err)>; 39 : 40 : /* Forward declaration */ 41 : class TsiTransportSocketCallbacks; 42 : 43 : /** 44 : * A implementation of Network::TransportSocket based on gRPC TSI 45 : */ 46 : class TsiSocket : public Network::TransportSocket, 47 : public TsiHandshakerCallbacks, 48 : public Logger::Loggable<Logger::Id::connection> { 49 : public: 50 : // For Test 51 : TsiSocket(HandshakerFactory handshaker_factory, HandshakeValidator handshake_validator, 52 : Network::TransportSocketPtr&& raw_socket_ptr, bool downstream); 53 : 54 : /** 55 : * @param handshaker_factory a function to initiate a TsiHandshaker 56 : * @param handshake_validator a function to validate the peer. Called right 57 : * after the handshake completed with peer data to do the peer validation. 58 : * The connection will be closed immediately if it returns false. 59 : * @param downstream is true for downstream transport socket. 60 : */ 61 : TsiSocket(HandshakerFactory handshaker_factory, HandshakeValidator handshake_validator, 62 : bool downstream); 63 : ~TsiSocket() override; 64 : 65 : // Network::TransportSocket 66 : void setTransportSocketCallbacks(Envoy::Network::TransportSocketCallbacks& callbacks) override; 67 : std::string protocol() const override; 68 : absl::string_view failureReason() const override; 69 0 : bool canFlushClose() override { return handshake_complete_; } 70 0 : Envoy::Ssl::ConnectionInfoConstSharedPtr ssl() const override { return nullptr; } 71 0 : bool startSecureTransport() override { return false; } 72 0 : void configureInitialCongestionWindow(uint64_t, std::chrono::microseconds) override {} 73 : Network::IoResult doWrite(Buffer::Instance& buffer, bool end_stream) override; 74 : void closeSocket(Network::ConnectionEvent event) override; 75 : Network::IoResult doRead(Buffer::Instance& buffer) override; 76 : void onConnected() override; 77 : 78 : // TsiHandshakerCallbacks 79 : void onNextDone(NextResultPtr&& result) override; 80 : 81 : // This API should be called only after ALTS handshake finishes successfully. 82 0 : size_t actualFrameSizeToUse() { return actual_frame_size_to_use_; } 83 : // Set actual_frame_size_to_use_. Exposed for testing purpose. 84 0 : void setActualFrameSizeToUse(size_t frame_size) { actual_frame_size_to_use_ = frame_size; } 85 : // Set frame_overhead_size_. Exposed for testing purpose. 86 0 : void setFrameOverheadSize(size_t overhead_size) { frame_overhead_size_ = overhead_size; } 87 : 88 : private: 89 : Network::PostIoAction doHandshake(); 90 : Network::PostIoAction doHandshakeNext(); 91 : Network::PostIoAction doHandshakeNextDone(NextResultPtr&& next_result); 92 : 93 : // Helper function to perform repeated read and unprotect operations. 94 : Network::IoResult repeatReadAndUnprotect(Buffer::Instance& buffer, Network::IoResult prev_result); 95 : // Helper function to perform repeated protect and write operations. 96 : Network::IoResult repeatProtectAndWrite(Buffer::Instance& buffer, bool end_stream); 97 : // Helper function to read from a raw socket and update status. 98 : Network::IoResult readFromRawSocket(); 99 : 100 : HandshakerFactory handshaker_factory_; 101 : HandshakeValidator handshake_validator_; 102 : TsiHandshakerPtr handshaker_{}; 103 : bool handshaker_next_calling_{}; 104 : 105 : TsiFrameProtectorPtr frame_protector_; 106 : // default_max_frame_size_ is the maximum frame size supported by 107 : // TsiSocket. 108 : size_t default_max_frame_size_{16384}; 109 : // actual_frame_size_to_use_ is the actual frame size used by 110 : // frame protector, which is the result of frame size negotiation. 111 : size_t actual_frame_size_to_use_{0}; 112 : // frame_overhead_size_ includes 4 bytes frame message type and 16 bytes tag length. 113 : // It is consistent with gRPC ALTS zero copy frame protector implementation. 114 : // The maximum size of data that can be protected for each frame is equal to 115 : // actual_frame_size_to_use_ - frame_overhead_size_. 116 : size_t frame_overhead_size_{20}; 117 : 118 : Envoy::Network::TransportSocketCallbacks* callbacks_{}; 119 : std::unique_ptr<TsiTransportSocketCallbacks> tsi_callbacks_; 120 : Network::TransportSocketPtr raw_buffer_socket_; 121 : const bool downstream_; 122 : 123 0 : Buffer::WatermarkBuffer raw_read_buffer_{[]() {}, []() {}, []() {}}; 124 : Envoy::Buffer::OwnedImpl raw_write_buffer_; 125 : bool handshake_complete_{}; 126 : bool end_stream_read_{}; 127 : bool read_error_{}; 128 : uint64_t prev_bytes_to_drain_{}; 129 : }; 130 : 131 : /** 132 : * An implementation of Network::UpstreamTransportSocketFactory for TsiSocket 133 : */ 134 : class TsiSocketFactory : public Network::DownstreamTransportSocketFactory, 135 : public Network::CommonUpstreamTransportSocketFactory { 136 : public: 137 : TsiSocketFactory(HandshakerFactory handshaker_factory, HandshakeValidator handshake_validator); 138 : 139 : bool implementsSecureTransport() const override; 140 0 : absl::string_view defaultServerNameIndication() const override { return ""; } 141 : 142 : Network::TransportSocketPtr 143 : createTransportSocket(Network::TransportSocketOptionsConstSharedPtr options, 144 : Upstream::HostDescriptionConstSharedPtr) const override; 145 : 146 : Network::TransportSocketPtr createDownstreamTransportSocket() const override; 147 : 148 : private: 149 : HandshakerFactory handshaker_factory_; 150 : HandshakeValidator handshake_validator_; 151 : }; 152 : 153 : /** 154 : * An implementation of Network::TransportSocketCallbacks for TsiSocket 155 : */ 156 : class TsiTransportSocketCallbacks : public NoOpTransportSocketCallbacks { 157 : public: 158 : TsiTransportSocketCallbacks(Network::TransportSocketCallbacks& parent, 159 : const Buffer::WatermarkBuffer& read_buffer) 160 0 : : NoOpTransportSocketCallbacks(parent), raw_read_buffer_(read_buffer) {} 161 0 : bool shouldDrainReadBuffer() override { 162 0 : return raw_read_buffer_.length() >= raw_read_buffer_.highWatermark(); 163 0 : } 164 : 165 : private: 166 : const Buffer::WatermarkBuffer& raw_read_buffer_; 167 : }; 168 : 169 : } // namespace Alts 170 : } // namespace TransportSockets 171 : } // namespace Extensions 172 : } // namespace Envoy