1
#include "cilium/tls_wrapper.h"
2

            
3
#include <chrono>
4
#include <cstdint>
5
#include <memory>
6
#include <string>
7
#include <utility>
8
#include <vector>
9

            
10
#include "envoy/buffer/buffer.h"
11
#include "envoy/network/address.h"
12
#include "envoy/network/post_io_action.h"
13
#include "envoy/network/transport_socket.h"
14
#include "envoy/registry/registry.h"
15
#include "envoy/server/transport_socket_config.h"
16
#include "envoy/ssl/connection.h"
17
#include "envoy/ssl/context.h"
18
#include "envoy/ssl/context_config.h"
19

            
20
#include "source/common/common/empty_string.h"
21
#include "source/common/common/logger.h"
22
#include "source/common/network/raw_buffer_socket.h"
23
#include "source/common/network/transport_socket_options_impl.h"
24
#include "source/common/protobuf/protobuf.h"
25
#include "source/common/tls/ssl_socket.h"
26

            
27
#include "absl/status/statusor.h"
28
#include "absl/strings/string_view.h"
29
#include "cilium/api/tls_wrapper.pb.h"
30
#include "cilium/filter_state_cilium_policy.h"
31
#include "cilium/network_policy.h"
32
#include "filter_state_cilium_policy.h"
33

            
34
namespace Envoy {
35
namespace Cilium {
36

            
37
namespace {
38

            
39
constexpr absl::string_view NotReadyReason{"Socket is not ready"};
40

            
41
// This SslSocketWrapper wraps a real SslSocket and hooks it up with
42
// TLS configuration derived from Cilium Network Policy.
43
class SslSocketWrapper : public Network::TransportSocket, Logger::Loggable<Logger::Id::connection> {
44
public:
45
  SslSocketWrapper(Extensions::TransportSockets::Tls::InitialState state,
46
                   const Network::TransportSocketOptionsConstSharedPtr& transport_socket_options)
47
27
      : state_(state), transport_socket_options_(transport_socket_options) {}
48

            
49
  // Network::TransportSocket
50
27
  void setTransportSocketCallbacks(Network::TransportSocketCallbacks& callbacks) override {
51
27
    callbacks_ = &callbacks;
52
27
  }
53

            
54
8
  std::string protocol() const override { return socket_ ? socket_->protocol() : EMPTY_STRING; }
55

            
56
42
  absl::string_view failureReason() const override {
57
42
    return socket_ ? socket_->failureReason() : NotReadyReason;
58
42
  }
59

            
60
2
  bool canFlushClose() override { return socket_ ? socket_->canFlushClose() : true; }
61

            
62
  // Override if need to intercept client socket connect() call.
63
  // Api::SysCallIntResult connect(Network::ConnectionSocket& socket) override
64

            
65
27
  void closeSocket(Network::ConnectionEvent type) override {
66
27
    if (socket_) {
67
27
      socket_->closeSocket(type);
68
27
    }
69
27
  }
70

            
71
96
  Network::IoResult doRead(Buffer::Instance& buffer) override {
72
96
    if (socket_) {
73
96
      return socket_->doRead(buffer);
74
96
    }
75
    return {Network::PostIoAction::Close, 0, false};
76
96
  }
77

            
78
170
  Network::IoResult doWrite(Buffer::Instance& buffer, bool end_stream) override {
79
170
    if (socket_) {
80
170
      return socket_->doWrite(buffer, end_stream);
81
170
    }
82
    return {Network::PostIoAction::Close, 0, false};
83
170
  }
84

            
85
27
  void onConnected() override {
86
    // Get the Cilium policy filter state from the callbacks in order to get the TLS
87
    // configuration.
88
    // Cilium socket option is only created if the (initial) policy for the local pod exists.
89
    // If the policy requires TLS then a TLS socket is used, but if the policy does not require
90
    // TLS a raw socket is used instead.
91
27
    auto& conn = callbacks_->connection();
92

            
93
27
    ENVOY_CONN_LOG(trace, "retrieving policy filter state", conn);
94
27
    auto policy_fs =
95
27
        conn.streamInfo().filterState()->getDataReadOnly<Cilium::CiliumPolicyFilterState>(
96
27
            Cilium::CiliumPolicyFilterState::key());
97

            
98
27
    if (policy_fs) {
99
27
      const auto& policy = policy_fs->getPolicy();
100

            
101
      // Resolve the destination security ID and port
102
27
      uint32_t destination_identity = 0;
103
27
      uint32_t destination_port = policy_fs->port_;
104
27
      const Network::Address::Ip* dip = nullptr;
105
27
      bool is_client = state_ == Extensions::TransportSockets::Tls::InitialState::Client;
106

            
107
27
      if (!policy_fs->ingress_) {
108
        Network::Address::InstanceConstSharedPtr dst_address =
109
            is_client ? callbacks_->connection().connectionInfoProvider().remoteAddress()
110
                      : callbacks_->connection().connectionInfoProvider().localAddress();
111
        if (dst_address) {
112
          dip = dst_address->ip();
113
          if (dip) {
114
            destination_port = dip->port();
115
            destination_identity = policy_fs->resolvePolicyId(dip);
116
          } else {
117
            ENVOY_CONN_LOG(warn, "cilium.tls_wrapper: Non-IP destination address: {}", conn,
118
                           dst_address->asString());
119
          }
120
        } else {
121
          ENVOY_CONN_LOG(warn, "cilium.tls_wrapper: No destination address", conn);
122
        }
123
      }
124

            
125
      // get the requested server name from the connection, if any
126
27
      const auto& sni = policy_fs->sni_;
127

            
128
27
      auto remote_id = policy_fs->ingress_ ? policy_fs->source_identity_ : destination_identity;
129
27
      auto port_policy = policy.findPortPolicy(policy_fs->ingress_, destination_port);
130
27
      const Envoy::Ssl::ContextConfig* config = nullptr;
131
27
      bool raw_socket_allowed = false;
132
27
      auto proxy_id = policy_fs->proxy_id_;
133
27
      Envoy::Ssl::ContextSharedPtr ctx =
134
27
          is_client ? port_policy.getClientTlsContext(proxy_id, remote_id, sni, &config,
135
15
                                                      raw_socket_allowed)
136
27
                    : port_policy.getServerTlsContext(proxy_id, remote_id, sni, &config,
137
12
                                                      raw_socket_allowed);
138
27
      if (ctx) {
139
        // create the underlying SslSocket
140
27
        auto status_or_socket = Extensions::TransportSockets::Tls::SslSocket::create(
141
27
            std::move(ctx), state_, transport_socket_options_, config->createHandshaker());
142
27
        if (status_or_socket.ok()) {
143
27
          socket_ = std::move(status_or_socket.value());
144
          // Set the callbacks
145
27
          socket_->setTransportSocketCallbacks(*callbacks_);
146
          // explicitly configure ssl connection with the latest configuration from the SSL socket.
147
27
          callbacks_->connection().connectionInfoSetter().setSslConnection(socket_->ssl());
148
27
        } else {
149
          ENVOY_CONN_LOG(error, "Unable to create ssl socket {}", conn,
150
                         status_or_socket.status().message());
151
        }
152
27
      } else if (config == nullptr && raw_socket_allowed) {
153
        // Use RawBufferSocket when policy allows without TLS.
154
        // If policy has TLS context config then a raw socket must NOT be used.
155
        socket_ = std::make_unique<Network::RawBufferSocket>();
156
        // Set the callbacks
157
        socket_->setTransportSocketCallbacks(*callbacks_);
158
      } else {
159
        policy.tlsWrapperMissingPolicyInc();
160

            
161
        std::string ip_str("<none>");
162
        if (policy_fs->ingress_) {
163
          Network::Address::InstanceConstSharedPtr src_address =
164
              is_client ? callbacks_->connection().connectionInfoProvider().localAddress()
165
                        : callbacks_->connection().connectionInfoProvider().remoteAddress();
166
          if (src_address) {
167
            const auto sip = src_address->ip();
168
            if (sip) {
169
              ip_str = sip->addressAsString();
170
            }
171
          }
172
        } else {
173
          if (dip) {
174
            ip_str = dip->addressAsString();
175
          }
176
        }
177
        ENVOY_CONN_LOG(
178
            warn,
179
            "cilium.tls_wrapper: Could not get {} TLS context for pod {} on {} IP {} (id {}) port "
180
            "{} sni \"{}\" and raw socket is not allowed",
181
            conn, is_client ? "client" : "server", policy_fs->pod_ip_,
182
            policy_fs->ingress_ ? "source" : "destination", ip_str, remote_id, destination_port,
183
            sni);
184
      }
185
27
    } else {
186
      ENVOY_CONN_LOG(warn,
187
                     "cilium.tls_wrapper: Can not correlate connection with Cilium Network "
188
                     "Policy (Cilium socket option not found)",
189
                     conn);
190
    }
191

            
192
27
    if (socket_) {
193
27
      socket_->onConnected();
194
27
    }
195
27
  }
196

            
197
27
  Ssl::ConnectionInfoConstSharedPtr ssl() const override {
198
27
    return socket_ ? socket_->ssl() : nullptr;
199
27
  }
200

            
201
  bool startSecureTransport() override { return socket_ ? socket_->startSecureTransport() : false; }
202

            
203
  void configureInitialCongestionWindow(uint64_t bandwidth_bits_per_sec,
204
                                        std::chrono::microseconds rtt) override {
205
    if (socket_) {
206
      socket_->configureInitialCongestionWindow(bandwidth_bits_per_sec, rtt);
207
    }
208
  }
209

            
210
private:
211
  Extensions::TransportSockets::Tls::InitialState state_;
212
  const Network::TransportSocketOptionsConstSharedPtr transport_socket_options_;
213
  Network::TransportSocketPtr socket_{nullptr};
214
  Network::TransportSocketCallbacks* callbacks_{};
215
};
216

            
217
class ClientSslSocketFactory : public Network::CommonUpstreamTransportSocketFactory {
218
public:
219
  Network::TransportSocketPtr
220
  createTransportSocket(Network::TransportSocketOptionsConstSharedPtr options,
221
15
                        std::shared_ptr<const Upstream::HostDescription>) const override {
222
15
    return std::make_unique<SslSocketWrapper>(
223
15
        Extensions::TransportSockets::Tls::InitialState::Client, options);
224
15
  }
225

            
226
4
  absl::string_view defaultServerNameIndication() const override { return EMPTY_STRING; }
227

            
228
  bool implementsSecureTransport() const override { return true; }
229
};
230

            
231
class ServerSslSocketFactory : public Network::DownstreamTransportSocketFactory {
232
public:
233
12
  Network::TransportSocketPtr createDownstreamTransportSocket() const override {
234
12
    return std::make_unique<SslSocketWrapper>(
235
12
        Extensions::TransportSockets::Tls::InitialState::Server, nullptr);
236
12
  }
237

            
238
  bool implementsSecureTransport() const override { return true; }
239
};
240

            
241
} // namespace
242

            
243
absl::StatusOr<Network::UpstreamTransportSocketFactoryPtr>
244
UpstreamTlsWrapperFactory::createTransportSocketFactory(
245
19
    const Protobuf::Message&, Server::Configuration::TransportSocketFactoryContext&) {
246
19
  return std::make_unique<ClientSslSocketFactory>();
247
19
}
248

            
249
21
ProtobufTypes::MessagePtr UpstreamTlsWrapperFactory::createEmptyConfigProto() {
250
21
  return std::make_unique<::cilium::UpstreamTlsWrapperContext>();
251
21
}
252

            
253
REGISTER_FACTORY(UpstreamTlsWrapperFactory,
254
                 Server::Configuration::UpstreamTransportSocketConfigFactory);
255

            
256
absl::StatusOr<Network::DownstreamTransportSocketFactoryPtr>
257
DownstreamTlsWrapperFactory::createTransportSocketFactory(
258
    const Protobuf::Message&, Server::Configuration::TransportSocketFactoryContext&,
259
12
    const std::vector<std::string>&) {
260
12
  return std::make_unique<ServerSslSocketFactory>();
261
12
}
262

            
263
14
ProtobufTypes::MessagePtr DownstreamTlsWrapperFactory::createEmptyConfigProto() {
264
14
  return std::make_unique<::cilium::DownstreamTlsWrapperContext>();
265
14
}
266

            
267
REGISTER_FACTORY(DownstreamTlsWrapperFactory,
268
                 Server::Configuration::DownstreamTransportSocketConfigFactory);
269
} // namespace Cilium
270
} // namespace Envoy