1
#pragma once
2

            
3
#include "envoy/extensions/transport_sockets/starttls/v3/starttls.pb.h"
4
#include "envoy/extensions/transport_sockets/starttls/v3/starttls.pb.validate.h"
5
#include "envoy/network/connection.h"
6
#include "envoy/network/transport_socket.h"
7
#include "envoy/stats/scope.h"
8
#include "envoy/stats/stats_macros.h"
9

            
10
#include "source/common/buffer/buffer_impl.h"
11
#include "source/common/common/logger.h"
12
#include "source/common/network/transport_socket_options_impl.h"
13

            
14
namespace Envoy {
15
namespace Extensions {
16
namespace TransportSockets {
17
namespace StartTls {
18

            
19
class StartTlsSocket : public Network::TransportSocket, Logger::Loggable<Logger::Id::filter> {
20
public:
21
  StartTlsSocket(Network::TransportSocketPtr raw_socket, // RawBufferSocket
22
                 Network::TransportSocketPtr tls_socket, // TlsSocket
23
                 const Network::TransportSocketOptionsConstSharedPtr&)
24
6
      : active_socket_(std::move(raw_socket)), tls_socket_(std::move(tls_socket)) {}
25

            
26
6
  void setTransportSocketCallbacks(Network::TransportSocketCallbacks& callbacks) override {
27
6
    callbacks_ = &callbacks;
28
6
    active_socket_->setTransportSocketCallbacks(callbacks_);
29
6
  }
30

            
31
2
  std::string protocol() const override { return "starttls"; }
32

            
33
6
  absl::string_view failureReason() const override { return active_socket_->failureReason(); }
34

            
35
6
  void onConnected() override { active_socket_->onConnected(); }
36
2
  bool canFlushClose() override { return active_socket_->canFlushClose(); }
37
6
  Ssl::ConnectionInfoConstSharedPtr ssl() const override { return active_socket_->ssl(); }
38

            
39
6
  void closeSocket(Network::ConnectionEvent event) override {
40
6
    return active_socket_->closeSocket(event);
41
6
  }
42

            
43
20
  Network::IoResult doRead(Buffer::Instance& buffer) override {
44
20
    return active_socket_->doRead(buffer);
45
20
  }
46

            
47
33
  Network::IoResult doWrite(Buffer::Instance& buffer, bool end_stream) override {
48
33
    return active_socket_->doWrite(buffer, end_stream);
49
33
  }
50

            
51
  // Method to enable TLS.
52
  bool startSecureTransport() override;
53

            
54
  void configureInitialCongestionWindow(uint64_t bandwidth_bits_per_sec,
55
2
                                        std::chrono::microseconds rtt) override {
56
2
    return active_socket_->configureInitialCongestionWindow(bandwidth_bits_per_sec, rtt);
57
2
  }
58

            
59
private:
60
  // This is a proxy for wrapping the transport callback object passed from the consumer.
61
  // Its primary purpose is to filter Connected events to ensure they only happen once per open.
62
  // connection open.
63
  class CallbackProxy : public Network::TransportSocketCallbacks {
64
  public:
65
12
    CallbackProxy(Network::TransportSocketCallbacks* callbacks) : parent_(callbacks) {}
66

            
67
15
    Network::IoHandle& ioHandle() override { return parent_->ioHandle(); }
68
1
    const Network::IoHandle& ioHandle() const override {
69
1
      return static_cast<const Network::TransportSocketCallbacks*>(parent_)->ioHandle();
70
1
    }
71
28
    Network::Connection& connection() override { return parent_->connection(); }
72
10
    bool shouldDrainReadBuffer() override { return parent_->shouldDrainReadBuffer(); }
73
1
    void setTransportSocketIsReadable() override { return parent_->setTransportSocketIsReadable(); }
74
14
    void raiseEvent(Network::ConnectionEvent event) override {
75
14
      if (event == Network::ConnectionEvent::Connected) {
76
        // Don't send the connected event if we're already open
77
12
        if (connected_) {
78
6
          parent_->flushWriteBuffer();
79
6
          return;
80
6
        }
81
6
        connected_ = true;
82
6
      } else {
83
2
        connected_ = false;
84
2
      }
85

            
86
8
      parent_->raiseEvent(event);
87
8
    }
88
1
    void flushWriteBuffer() override { parent_->flushWriteBuffer(); }
89

            
90
  private:
91
    Network::TransportSocketCallbacks* parent_;
92
    bool connected_{false};
93
  };
94

            
95
  // Socket used in all transport socket operations.
96
  // initially it is set to use raw buffer socket but
97
  // can be converted to use tls.
98
  Network::TransportSocketPtr active_socket_;
99
  // Secure transport socket. It will replace raw buffer socket
100
  //  when startSecureTransport is called.
101
  Network::TransportSocketPtr tls_socket_;
102

            
103
  CallbackProxy callbacks_{nullptr};
104

            
105
  bool using_tls_{false};
106
};
107

            
108
class StartTlsSocketFactory : public Network::CommonUpstreamTransportSocketFactory,
109
                              Logger::Loggable<Logger::Id::config> {
110
public:
111
3
  ~StartTlsSocketFactory() override = default;
112

            
113
  StartTlsSocketFactory(Network::UpstreamTransportSocketFactoryPtr raw_socket_factory,
114
                        Network::UpstreamTransportSocketFactoryPtr tls_socket_factory)
115
3
      : raw_socket_factory_(std::move(raw_socket_factory)),
116
3
        tls_socket_factory_(std::move(tls_socket_factory)) {}
117

            
118
  Network::TransportSocketPtr
119
  createTransportSocket(Network::TransportSocketOptionsConstSharedPtr options,
120
                        Upstream::HostDescriptionConstSharedPtr host) const override;
121
1
  bool implementsSecureTransport() const override { return false; }
122
  absl::string_view defaultServerNameIndication() const override { return ""; }
123
1
  Envoy::Ssl::ClientContextSharedPtr sslCtx() override { return tls_socket_factory_->sslCtx(); }
124
1
  OptRef<const Ssl::ClientContextConfig> clientContextConfig() const override {
125
1
    return tls_socket_factory_->clientContextConfig();
126
1
  }
127

            
128
private:
129
  Network::UpstreamTransportSocketFactoryPtr raw_socket_factory_;
130
  Network::UpstreamTransportSocketFactoryPtr tls_socket_factory_;
131
};
132

            
133
class StartTlsDownstreamSocketFactory : public Network::DownstreamTransportSocketFactory,
134
                                        Logger::Loggable<Logger::Id::config> {
135
public:
136
2
  ~StartTlsDownstreamSocketFactory() override = default;
137

            
138
  StartTlsDownstreamSocketFactory(Network::DownstreamTransportSocketFactoryPtr raw_socket_factory,
139
                                  Network::DownstreamTransportSocketFactoryPtr tls_socket_factory)
140
2
      : raw_socket_factory_(std::move(raw_socket_factory)),
141
2
        tls_socket_factory_(std::move(tls_socket_factory)) {}
142

            
143
  Network::TransportSocketPtr createDownstreamTransportSocket() const override;
144
  bool implementsSecureTransport() const override { return false; }
145

            
146
private:
147
  Network::DownstreamTransportSocketFactoryPtr raw_socket_factory_;
148
  Network::DownstreamTransportSocketFactoryPtr tls_socket_factory_;
149
};
150

            
151
} // namespace StartTls
152
} // namespace TransportSockets
153
} // namespace Extensions
154
} // namespace Envoy