LCOV - code coverage report
Current view: top level - source/extensions/transport_sockets/starttls - starttls_socket.h (source / functions) Hit Total Coverage
Test: coverage.dat Lines: 0 54 0.0 %
Date: 2024-01-05 06:35:25 Functions: 0 30 0.0 %

          Line data    Source code
       1             : #pragma once
       2             : 
       3             : #include "envoy/extensions/transport_sockets/starttls/v3/starttls.pb.h"
       4             : #include "envoy/extensions/transport_sockets/starttls/v3/starttls.pb.validate.h"
       5             : #include "envoy/network/connection.h"
       6             : #include "envoy/network/transport_socket.h"
       7             : #include "envoy/stats/scope.h"
       8             : #include "envoy/stats/stats_macros.h"
       9             : 
      10             : #include "source/common/buffer/buffer_impl.h"
      11             : #include "source/common/common/logger.h"
      12             : #include "source/common/network/transport_socket_options_impl.h"
      13             : 
      14             : namespace Envoy {
      15             : namespace Extensions {
      16             : namespace TransportSockets {
      17             : namespace StartTls {
      18             : 
      19             : class StartTlsSocket : public Network::TransportSocket, Logger::Loggable<Logger::Id::filter> {
      20             : public:
      21             :   StartTlsSocket(Network::TransportSocketPtr raw_socket, // RawBufferSocket
      22             :                  Network::TransportSocketPtr tls_socket, // TlsSocket
      23             :                  const Network::TransportSocketOptionsConstSharedPtr&)
      24           0 :       : active_socket_(std::move(raw_socket)), tls_socket_(std::move(tls_socket)) {}
      25             : 
      26           0 :   void setTransportSocketCallbacks(Network::TransportSocketCallbacks& callbacks) override {
      27           0 :     callbacks_ = &callbacks;
      28           0 :     active_socket_->setTransportSocketCallbacks(callbacks_);
      29           0 :   }
      30             : 
      31           0 :   std::string protocol() const override { return "starttls"; }
      32             : 
      33           0 :   absl::string_view failureReason() const override { return active_socket_->failureReason(); }
      34             : 
      35           0 :   void onConnected() override { active_socket_->onConnected(); }
      36           0 :   bool canFlushClose() override { return active_socket_->canFlushClose(); }
      37           0 :   Ssl::ConnectionInfoConstSharedPtr ssl() const override { return active_socket_->ssl(); }
      38             : 
      39           0 :   void closeSocket(Network::ConnectionEvent event) override {
      40           0 :     return active_socket_->closeSocket(event);
      41           0 :   }
      42             : 
      43           0 :   Network::IoResult doRead(Buffer::Instance& buffer) override {
      44           0 :     return active_socket_->doRead(buffer);
      45           0 :   }
      46             : 
      47           0 :   Network::IoResult doWrite(Buffer::Instance& buffer, bool end_stream) override {
      48           0 :     return active_socket_->doWrite(buffer, end_stream);
      49           0 :   }
      50             : 
      51             :   // Method to enable TLS.
      52             :   bool startSecureTransport() override;
      53             : 
      54             :   void configureInitialCongestionWindow(uint64_t bandwidth_bits_per_sec,
      55           0 :                                         std::chrono::microseconds rtt) override {
      56           0 :     return active_socket_->configureInitialCongestionWindow(bandwidth_bits_per_sec, rtt);
      57           0 :   }
      58             : 
      59             : private:
      60             :   // This is a proxy for wrapping the transport callback object passed from the consumer.
      61             :   // Its primary purpose is to filter Connected events to ensure they only happen once per open.
      62             :   // connection open.
      63             :   class CallbackProxy : public Network::TransportSocketCallbacks {
      64             :   public:
      65           0 :     CallbackProxy(Network::TransportSocketCallbacks* callbacks) : parent_(callbacks) {}
      66             : 
      67           0 :     Network::IoHandle& ioHandle() override { return parent_->ioHandle(); }
      68           0 :     const Network::IoHandle& ioHandle() const override {
      69           0 :       return static_cast<const Network::TransportSocketCallbacks*>(parent_)->ioHandle();
      70           0 :     }
      71           0 :     Network::Connection& connection() override { return parent_->connection(); }
      72           0 :     bool shouldDrainReadBuffer() override { return parent_->shouldDrainReadBuffer(); }
      73           0 :     void setTransportSocketIsReadable() override { return parent_->setTransportSocketIsReadable(); }
      74           0 :     void raiseEvent(Network::ConnectionEvent event) override {
      75           0 :       if (event == Network::ConnectionEvent::Connected) {
      76             :         // Don't send the connected event if we're already open
      77           0 :         if (connected_) {
      78           0 :           parent_->flushWriteBuffer();
      79           0 :           return;
      80           0 :         }
      81           0 :         connected_ = true;
      82           0 :       } else {
      83           0 :         connected_ = false;
      84           0 :       }
      85             : 
      86           0 :       parent_->raiseEvent(event);
      87           0 :     }
      88           0 :     void flushWriteBuffer() override { parent_->flushWriteBuffer(); }
      89             : 
      90             :   private:
      91             :     Network::TransportSocketCallbacks* parent_;
      92             :     bool connected_{false};
      93             :   };
      94             : 
      95             :   // Socket used in all transport socket operations.
      96             :   // initially it is set to use raw buffer socket but
      97             :   // can be converted to use tls.
      98             :   Network::TransportSocketPtr active_socket_;
      99             :   // Secure transport socket. It will replace raw buffer socket
     100             :   //  when startSecureTransport is called.
     101             :   Network::TransportSocketPtr tls_socket_;
     102             : 
     103             :   CallbackProxy callbacks_{nullptr};
     104             : 
     105             :   bool using_tls_{false};
     106             : };
     107             : 
     108             : class StartTlsSocketFactory : public Network::CommonUpstreamTransportSocketFactory,
     109             :                               Logger::Loggable<Logger::Id::config> {
     110             : public:
     111           0 :   ~StartTlsSocketFactory() override = default;
     112             : 
     113             :   StartTlsSocketFactory(Network::UpstreamTransportSocketFactoryPtr raw_socket_factory,
     114             :                         Network::UpstreamTransportSocketFactoryPtr tls_socket_factory)
     115             :       : raw_socket_factory_(std::move(raw_socket_factory)),
     116           0 :         tls_socket_factory_(std::move(tls_socket_factory)) {}
     117             : 
     118             :   Network::TransportSocketPtr
     119             :   createTransportSocket(Network::TransportSocketOptionsConstSharedPtr options,
     120             :                         Upstream::HostDescriptionConstSharedPtr host) const override;
     121           0 :   bool implementsSecureTransport() const override { return false; }
     122           0 :   absl::string_view defaultServerNameIndication() const override { return ""; }
     123           0 :   Envoy::Ssl::ClientContextSharedPtr sslCtx() override { return tls_socket_factory_->sslCtx(); }
     124           0 :   OptRef<const Ssl::ClientContextConfig> clientContextConfig() const override {
     125           0 :     return tls_socket_factory_->clientContextConfig();
     126           0 :   }
     127             : 
     128             : private:
     129             :   Network::UpstreamTransportSocketFactoryPtr raw_socket_factory_;
     130             :   Network::UpstreamTransportSocketFactoryPtr tls_socket_factory_;
     131             : };
     132             : 
     133             : class StartTlsDownstreamSocketFactory : public Network::DownstreamTransportSocketFactory,
     134             :                                         Logger::Loggable<Logger::Id::config> {
     135             : public:
     136           0 :   ~StartTlsDownstreamSocketFactory() override = default;
     137             : 
     138             :   StartTlsDownstreamSocketFactory(Network::DownstreamTransportSocketFactoryPtr raw_socket_factory,
     139             :                                   Network::DownstreamTransportSocketFactoryPtr tls_socket_factory)
     140             :       : raw_socket_factory_(std::move(raw_socket_factory)),
     141           0 :         tls_socket_factory_(std::move(tls_socket_factory)) {}
     142             : 
     143             :   Network::TransportSocketPtr createDownstreamTransportSocket() const override;
     144           0 :   bool implementsSecureTransport() const override { return false; }
     145             : 
     146             : private:
     147             :   Network::DownstreamTransportSocketFactoryPtr raw_socket_factory_;
     148             :   Network::DownstreamTransportSocketFactoryPtr tls_socket_factory_;
     149             : };
     150             : 
     151             : } // namespace StartTls
     152             : } // namespace TransportSockets
     153             : } // namespace Extensions
     154             : } // namespace Envoy

Generated by: LCOV version 1.15