1
#include "cilium/websocket.h"
2

            
3
#include <cstdint>
4
#include <memory>
5
#include <string>
6

            
7
#include "envoy/buffer/buffer.h"
8
#include "envoy/http/header_map.h"
9
#include "envoy/network/address.h"
10
#include "envoy/network/filter.h"
11
#include "envoy/registry/registry.h"
12
#include "envoy/server/factory_context.h"
13
#include "envoy/server/filter_config.h"
14
#include "envoy/stream_info/filter_state.h"
15

            
16
#include "source/common/common/logger.h"
17
#include "source/common/http/headers.h"
18
#include "source/common/network/utility.h"
19
#include "source/common/protobuf/protobuf.h"
20
#include "source/common/protobuf/utility.h"
21
#include "source/common/stream_info/bool_accessor_impl.h"
22
#include "source/common/tcp_proxy/tcp_proxy.h"
23

            
24
#include "absl/status/statusor.h"
25
#include "cilium/api/websocket.pb.h"
26
#include "cilium/api/websocket.pb.validate.h" // IWYU pragma: keep
27
#include "cilium/filter_state_cilium_policy.h"
28
#include "cilium/websocket_codec.h"
29
#include "cilium/websocket_config.h"
30

            
31
namespace Envoy {
32
namespace Cilium {
33
namespace WebSocket {
34

            
35
namespace {
36

            
37
Http::RegisterCustomInlineHeader<Http::CustomInlineHeaderRegistry::Type::RequestHeaders>
38
    origin_handle(Http::CustomHeaders::get().Origin);
39
Http::RegisterCustomInlineHeader<Http::CustomInlineHeaderRegistry::Type::RequestHeaders>
40
    original_dst_host_handle(Http::Headers::get().EnvoyOriginalDstHost);
41
Http::RegisterCustomInlineHeader<Http::CustomInlineHeaderRegistry::Type::RequestHeaders>
42
    sec_websocket_key_handle(Http::LowerCaseString{"sec-websocket-key"});
43
Http::RegisterCustomInlineHeader<Http::CustomInlineHeaderRegistry::Type::RequestHeaders>
44
    sec_websocket_version_handle(Http::LowerCaseString{"sec-websocket-version"});
45
Http::RegisterCustomInlineHeader<Http::CustomInlineHeaderRegistry::Type::RequestHeaders>
46
    sec_websocket_protocol_handle(Http::LowerCaseString{"sec-websocket-protocol"});
47
Http::RegisterCustomInlineHeader<Http::CustomInlineHeaderRegistry::Type::RequestHeaders>
48
    sec_websocket_extensions_handle(Http::LowerCaseString{"sec-websocket-extensions"});
49

            
50
Http::RegisterCustomInlineHeader<Http::CustomInlineHeaderRegistry::Type::ResponseHeaders>
51
    sec_websocket_accept_handle(Http::LowerCaseString{"sec-websocket-accept"});
52

            
53
} // namespace
54

            
55
/**
56
 * Config registration for the WebSocket server filter. @see
57
 * NamedNetworkFilterConfigFactory.
58
 */
59
class CiliumWebSocketServerConfigFactory
60
    : public Server::Configuration::NamedNetworkFilterConfigFactory {
61
public:
62
  // NamedNetworkFilterConfigFactory
63
  absl::StatusOr<Network::FilterFactoryCb>
64
  createFilterFactoryFromProto(const Protobuf::Message& proto_config,
65
13
                               Server::Configuration::FactoryContext& context) override {
66
13
    auto config = std::make_shared<Cilium::WebSocket::Config>(
67
13
        MessageUtil::downcastAndValidate<const ::cilium::WebSocketServer&>(
68
13
            proto_config, context.messageValidationVisitor()),
69
13
        context);
70
13
    return [config](Network::FilterManager& filter_manager) mutable -> void {
71
13
      filter_manager.addFilter(std::make_shared<Cilium::WebSocket::Instance>(config));
72
13
    };
73
13
  }
74

            
75
16
  ProtobufTypes::MessagePtr createEmptyConfigProto() override {
76
16
    return std::make_unique<::cilium::WebSocketServer>();
77
16
  }
78

            
79
29
  std::string name() const override { return "cilium.network.websocket.server"; }
80
};
81

            
82
/**
83
 * Static registration for the websocket server network filter. @see RegisterFactory.
84
 */
85
REGISTER_FACTORY(CiliumWebSocketServerConfigFactory,
86
                 Server::Configuration::NamedNetworkFilterConfigFactory);
87

            
88
/**
89
 * Config registration for the WebSocket client filter. @see
90
 * NamedNetworkFilterConfigFactory.
91
 */
92
class CiliumWebSocketClientConfigFactory
93
    : public Server::Configuration::NamedNetworkFilterConfigFactory {
94
public:
95
  // NamedNetworkFilterConfigFactory
96
  absl::StatusOr<Network::FilterFactoryCb>
97
  createFilterFactoryFromProto(const Protobuf::Message& proto_config,
98
11
                               Server::Configuration::FactoryContext& context) override {
99
11
    auto config = std::make_shared<Cilium::WebSocket::Config>(
100
11
        MessageUtil::downcastAndValidate<const ::cilium::WebSocketClient&>(
101
11
            proto_config, context.messageValidationVisitor()),
102
11
        context);
103
11
    return [config](Network::FilterManager& filter_manager) mutable -> void {
104
11
      filter_manager.addFilter(std::make_shared<Cilium::WebSocket::Instance>(config));
105
11
    };
106
11
  }
107

            
108
14
  ProtobufTypes::MessagePtr createEmptyConfigProto() override {
109
14
    return std::make_unique<::cilium::WebSocketClient>();
110
14
  }
111

            
112
27
  std::string name() const override { return "cilium.network.websocket.client"; }
113
};
114

            
115
/**
116
 * Static registration for the websocket client network filter. @see RegisterFactory.
117
 */
118
REGISTER_FACTORY(CiliumWebSocketClientConfigFactory,
119
                 Server::Configuration::NamedNetworkFilterConfigFactory);
120

            
121
24
void Instance::initializeReadFilterCallbacks(Network::ReadFilterCallbacks& callbacks) {
122
24
  callbacks_ = &callbacks;
123

            
124
  // Tell TcpProxy to not disable read so that we do WebSocket handshake before upstream
125
  // connection is established.
126
  // Use Mutable StateType so that tests can have both client and server filters in the same
127
  // filter chain.
128
24
  callbacks_->connection().streamInfo().filterState()->setData(
129
24
      TcpProxy::ReceiveBeforeConnectKey, std::make_unique<StreamInfo::BoolAccessorImpl>(true),
130
24
      StreamInfo::FilterState::StateType::Mutable, StreamInfo::FilterState::LifeSpan::Connection);
131
24
}
132

            
133
24
Network::FilterStatus Instance::onNewConnection() {
134
24
  std::string pod_ip;
135
24
  bool is_ingress;
136
24
  uint32_t identity, destination_identity;
137
24
  uint32_t proxy_id;
138

            
139
24
  auto& conn = callbacks_->connection();
140

            
141
24
  ENVOY_CONN_LOG(debug, "cilium.network.websocket: onNewConnection", conn);
142

            
143
  // Enable half close if not already enabled
144
24
  if (!conn.isHalfCloseEnabled()) {
145
    conn.enableHalfClose(true);
146
  }
147

            
148
24
  const Network::Address::InstanceConstSharedPtr& dst_address =
149
24
      conn.connectionInfoProvider().localAddress();
150
24
  const Network::Address::Ip* dip = dst_address ? dst_address->ip() : nullptr;
151

            
152
24
  const auto policy_fs =
153
24
      conn.streamInfo().filterState()->getDataReadOnly<Cilium::CiliumPolicyFilterState>(
154
24
          Cilium::CiliumPolicyFilterState::key());
155

            
156
24
  if (policy_fs) {
157
16
    proxy_id = policy_fs->proxy_id_;
158
16
    pod_ip = policy_fs->pod_ip_;
159
16
    is_ingress = policy_fs->ingress_;
160
16
    identity = policy_fs->source_identity_;
161
16
    destination_identity = dip ? policy_fs->resolvePolicyId(dip) : 0;
162
24
  } else {
163
    // Default to ingress to destination address, but no security identities.
164
8
    proxy_id = 0;
165
8
    pod_ip = dip ? dip->addressAsString() : "";
166
8
    is_ingress = true;
167
8
    identity = 0;
168
8
    destination_identity = 0;
169
8
  }
170
  // Initialize the log entry
171
24
  log_entry_.initFromConnection(pod_ip, proxy_id, is_ingress, identity,
172
24
                                callbacks_->connection().connectionInfoProvider().remoteAddress(),
173
24
                                destination_identity, dst_address, &config_->time_source_);
174

            
175
24
  codec_ = std::make_unique<Codec>(this, conn);
176

            
177
24
  if (!config_->client_) {
178
    // Server allows upstream processing only after the handshake has been received
179
13
    return Network::FilterStatus::StopIteration;
180
13
  }
181

            
182
  // Handshake cannot be injected while in this (onNewConnection()) callbask, schedule it to be run
183
  // afterwards, but during the current dispatcher iteration.
184
11
  client_handshake_cb_ =
185
11
      conn.dispatcher().createSchedulableCallback([this]() { codec_->handshake(); });
186
11
  client_handshake_cb_->scheduleCallbackCurrentIteration();
187

            
188
11
  return Network::FilterStatus::Continue;
189
24
}
190

            
191
149
Network::FilterStatus Instance::onData(Buffer::Instance& data, bool end_stream) {
192
149
  auto& conn = callbacks_->connection();
193
149
  ENVOY_CONN_LOG(debug, "cilium.network.websocket: onNewConnection", conn);
194
149
  if (codec_) {
195
149
    if (config_->client_) {
196
28
      codec_->encode(data, end_stream);
197
121
    } else {
198
121
      codec_->decode(data, end_stream);
199
121
    }
200
149
  }
201
  // codec passes the data on via injectEncoded()/injectDecoded(), data is now empty
202
149
  return Network::FilterStatus::StopIteration;
203
149
}
204

            
205
117
Network::FilterStatus Instance::onWrite(Buffer::Instance& data, bool end_stream) {
206
117
  auto& conn = callbacks_->connection();
207
117
  ENVOY_CONN_LOG(trace, "cilium.network.websocket: onWrite {} bytes, end_stream: {}", conn,
208
117
                 data.length(), end_stream);
209
117
  if (codec_) {
210
117
    if (config_->client_) {
211
87
      codec_->decode(data, end_stream);
212
96
    } else {
213
30
      codec_->encode(data, end_stream);
214
30
    }
215
117
  }
216
  // codec passes the data on via injectEncoded()/injectDecoded(), data is now empty
217
117
  return Network::FilterStatus::StopIteration;
218
117
}
219

            
220
245
void Instance::injectEncoded(Buffer::Instance& data, bool end_stream) {
221
245
  if (config_->client_) {
222
136
    callbacks_->injectReadDataToFilterChain(data, end_stream);
223
147
  } else {
224
109
    write_callbacks_->injectWriteDataToFilterChain(data, end_stream);
225
109
  }
226
245
}
227

            
228
207
void Instance::injectDecoded(Buffer::Instance& data, bool end_stream) {
229
207
  if (config_->client_) {
230
87
    write_callbacks_->injectWriteDataToFilterChain(data, end_stream);
231
120
  } else {
232
120
    callbacks_->injectReadDataToFilterChain(data, end_stream);
233
120
  }
234
207
}
235

            
236
13
void Instance::onHandshakeRequest(const Http::RequestHeaderMap& headers) {
237
13
  Network::Address::InstanceConstSharedPtr orig_dst_address{nullptr};
238
13
  uint32_t destination_identity = 0;
239
13
  const auto& conn = callbacks_->connection();
240

            
241
13
  const auto policy_fs =
242
13
      conn.streamInfo().filterState().getDataReadOnly<Cilium::CiliumPolicyFilterState>(
243
13
          Cilium::CiliumPolicyFilterState::key());
244

            
245
13
  if (policy_fs) {
246
    // resolve the original destination from 'x-envoy-original-dst-host' header to be used in the
247
    // access log message
248
9
    auto override_header = headers.getInline(original_dst_host_handle.handle());
249
9
    if (override_header != nullptr && !override_header->value().empty()) {
250
8
      const std::string request_override_host(override_header->value().getStringView());
251
8
      orig_dst_address =
252
8
          Network::Utility::parseInternetAddressAndPortNoThrow(request_override_host, false);
253
8
      const Network::Address::Ip* dip = orig_dst_address ? orig_dst_address->ip() : nullptr;
254
8
      if (dip) {
255
8
        destination_identity = policy_fs->resolvePolicyId(dip);
256
8
      }
257
8
    }
258
9
  }
259

            
260
  // Initialize the log entry
261
13
  log_entry_.updateFromRequest(destination_identity, orig_dst_address, headers);
262
13
}
263

            
264
} // namespace WebSocket
265
} // namespace Cilium
266
} // namespace Envoy