Line data Source code
1 : #pragma once 2 : 3 : #include "envoy/extensions/transport_sockets/starttls/v3/starttls.pb.h" 4 : #include "envoy/extensions/transport_sockets/starttls/v3/starttls.pb.validate.h" 5 : #include "envoy/network/connection.h" 6 : #include "envoy/network/transport_socket.h" 7 : #include "envoy/stats/scope.h" 8 : #include "envoy/stats/stats_macros.h" 9 : 10 : #include "source/common/buffer/buffer_impl.h" 11 : #include "source/common/common/logger.h" 12 : #include "source/common/network/transport_socket_options_impl.h" 13 : 14 : namespace Envoy { 15 : namespace Extensions { 16 : namespace TransportSockets { 17 : namespace StartTls { 18 : 19 : class StartTlsSocket : public Network::TransportSocket, Logger::Loggable<Logger::Id::filter> { 20 : public: 21 : StartTlsSocket(Network::TransportSocketPtr raw_socket, // RawBufferSocket 22 : Network::TransportSocketPtr tls_socket, // TlsSocket 23 : const Network::TransportSocketOptionsConstSharedPtr&) 24 0 : : active_socket_(std::move(raw_socket)), tls_socket_(std::move(tls_socket)) {} 25 : 26 0 : void setTransportSocketCallbacks(Network::TransportSocketCallbacks& callbacks) override { 27 0 : callbacks_ = &callbacks; 28 0 : active_socket_->setTransportSocketCallbacks(callbacks_); 29 0 : } 30 : 31 0 : std::string protocol() const override { return "starttls"; } 32 : 33 0 : absl::string_view failureReason() const override { return active_socket_->failureReason(); } 34 : 35 0 : void onConnected() override { active_socket_->onConnected(); } 36 0 : bool canFlushClose() override { return active_socket_->canFlushClose(); } 37 0 : Ssl::ConnectionInfoConstSharedPtr ssl() const override { return active_socket_->ssl(); } 38 : 39 0 : void closeSocket(Network::ConnectionEvent event) override { 40 0 : return active_socket_->closeSocket(event); 41 0 : } 42 : 43 0 : Network::IoResult doRead(Buffer::Instance& buffer) override { 44 0 : return active_socket_->doRead(buffer); 45 0 : } 46 : 47 0 : Network::IoResult doWrite(Buffer::Instance& buffer, bool end_stream) override { 48 0 : return active_socket_->doWrite(buffer, end_stream); 49 0 : } 50 : 51 : // Method to enable TLS. 52 : bool startSecureTransport() override; 53 : 54 : void configureInitialCongestionWindow(uint64_t bandwidth_bits_per_sec, 55 0 : std::chrono::microseconds rtt) override { 56 0 : return active_socket_->configureInitialCongestionWindow(bandwidth_bits_per_sec, rtt); 57 0 : } 58 : 59 : private: 60 : // This is a proxy for wrapping the transport callback object passed from the consumer. 61 : // Its primary purpose is to filter Connected events to ensure they only happen once per open. 62 : // connection open. 63 : class CallbackProxy : public Network::TransportSocketCallbacks { 64 : public: 65 0 : CallbackProxy(Network::TransportSocketCallbacks* callbacks) : parent_(callbacks) {} 66 : 67 0 : Network::IoHandle& ioHandle() override { return parent_->ioHandle(); } 68 0 : const Network::IoHandle& ioHandle() const override { 69 0 : return static_cast<const Network::TransportSocketCallbacks*>(parent_)->ioHandle(); 70 0 : } 71 0 : Network::Connection& connection() override { return parent_->connection(); } 72 0 : bool shouldDrainReadBuffer() override { return parent_->shouldDrainReadBuffer(); } 73 0 : void setTransportSocketIsReadable() override { return parent_->setTransportSocketIsReadable(); } 74 0 : void raiseEvent(Network::ConnectionEvent event) override { 75 0 : if (event == Network::ConnectionEvent::Connected) { 76 : // Don't send the connected event if we're already open 77 0 : if (connected_) { 78 0 : parent_->flushWriteBuffer(); 79 0 : return; 80 0 : } 81 0 : connected_ = true; 82 0 : } else { 83 0 : connected_ = false; 84 0 : } 85 : 86 0 : parent_->raiseEvent(event); 87 0 : } 88 0 : void flushWriteBuffer() override { parent_->flushWriteBuffer(); } 89 : 90 : private: 91 : Network::TransportSocketCallbacks* parent_; 92 : bool connected_{false}; 93 : }; 94 : 95 : // Socket used in all transport socket operations. 96 : // initially it is set to use raw buffer socket but 97 : // can be converted to use tls. 98 : Network::TransportSocketPtr active_socket_; 99 : // Secure transport socket. It will replace raw buffer socket 100 : // when startSecureTransport is called. 101 : Network::TransportSocketPtr tls_socket_; 102 : 103 : CallbackProxy callbacks_{nullptr}; 104 : 105 : bool using_tls_{false}; 106 : }; 107 : 108 : class StartTlsSocketFactory : public Network::CommonUpstreamTransportSocketFactory, 109 : Logger::Loggable<Logger::Id::config> { 110 : public: 111 0 : ~StartTlsSocketFactory() override = default; 112 : 113 : StartTlsSocketFactory(Network::UpstreamTransportSocketFactoryPtr raw_socket_factory, 114 : Network::UpstreamTransportSocketFactoryPtr tls_socket_factory) 115 : : raw_socket_factory_(std::move(raw_socket_factory)), 116 0 : tls_socket_factory_(std::move(tls_socket_factory)) {} 117 : 118 : Network::TransportSocketPtr 119 : createTransportSocket(Network::TransportSocketOptionsConstSharedPtr options, 120 : Upstream::HostDescriptionConstSharedPtr host) const override; 121 0 : bool implementsSecureTransport() const override { return false; } 122 0 : absl::string_view defaultServerNameIndication() const override { return ""; } 123 0 : Envoy::Ssl::ClientContextSharedPtr sslCtx() override { return tls_socket_factory_->sslCtx(); } 124 0 : OptRef<const Ssl::ClientContextConfig> clientContextConfig() const override { 125 0 : return tls_socket_factory_->clientContextConfig(); 126 0 : } 127 : 128 : private: 129 : Network::UpstreamTransportSocketFactoryPtr raw_socket_factory_; 130 : Network::UpstreamTransportSocketFactoryPtr tls_socket_factory_; 131 : }; 132 : 133 : class StartTlsDownstreamSocketFactory : public Network::DownstreamTransportSocketFactory, 134 : Logger::Loggable<Logger::Id::config> { 135 : public: 136 0 : ~StartTlsDownstreamSocketFactory() override = default; 137 : 138 : StartTlsDownstreamSocketFactory(Network::DownstreamTransportSocketFactoryPtr raw_socket_factory, 139 : Network::DownstreamTransportSocketFactoryPtr tls_socket_factory) 140 : : raw_socket_factory_(std::move(raw_socket_factory)), 141 0 : tls_socket_factory_(std::move(tls_socket_factory)) {} 142 : 143 : Network::TransportSocketPtr createDownstreamTransportSocket() const override; 144 0 : bool implementsSecureTransport() const override { return false; } 145 : 146 : private: 147 : Network::DownstreamTransportSocketFactoryPtr raw_socket_factory_; 148 : Network::DownstreamTransportSocketFactoryPtr tls_socket_factory_; 149 : }; 150 : 151 : } // namespace StartTls 152 : } // namespace TransportSockets 153 : } // namespace Extensions 154 : } // namespace Envoy