1
#include "cilium/websocket_codec.h"
2

            
3
#include <endian.h>
4
#include <fmt/format.h>
5
#include <http_parser.h>
6
#include <sys/types.h>
7

            
8
#include <algorithm>
9
#include <chrono>
10
#include <cstddef>
11
#include <cstdint>
12
#include <cstring>
13
#include <string>
14
#include <utility>
15

            
16
#include "envoy/buffer/buffer.h"
17
#include "envoy/common/pure.h"
18
#include "envoy/http/codes.h"
19
#include "envoy/http/header_map.h"
20
#include "envoy/network/address.h"
21
#include "envoy/network/connection.h"
22

            
23
#include "source/common/buffer/buffer_impl.h"
24
#include "source/common/common/assert.h"
25
#include "source/common/common/enum_to_int.h"
26
#include "source/common/common/hex.h"
27
#include "source/common/common/logger.h"
28
#include "source/common/common/utility.h"
29
#include "source/common/http/codes.h"
30
#include "source/common/http/header_map_impl.h"
31
#include "source/common/http/header_utility.h"
32
#include "source/common/http/headers.h"
33
#include "source/common/http/utility.h"
34
#include "source/common/network/utility.h"
35

            
36
#include "absl/strings/ascii.h"
37
#include "absl/strings/string_view.h"
38
#include "cilium/websocket_config.h"
39
#include "cilium/websocket_protocol.h"
40

            
41
namespace Envoy {
42
namespace Cilium {
43
namespace WebSocket {
44

            
45
namespace {
46

            
47
Http::RegisterCustomInlineHeader<Http::CustomInlineHeaderRegistry::Type::RequestHeaders>
48
    origin_handle(Http::CustomHeaders::get().Origin);
49
Http::RegisterCustomInlineHeader<Http::CustomInlineHeaderRegistry::Type::RequestHeaders>
50
    original_dst_host_handle(Http::Headers::get().EnvoyOriginalDstHost);
51
Http::RegisterCustomInlineHeader<Http::CustomInlineHeaderRegistry::Type::RequestHeaders>
52
    sec_websocket_key_handle(Http::LowerCaseString{"sec-websocket-key"});
53
Http::RegisterCustomInlineHeader<Http::CustomInlineHeaderRegistry::Type::RequestHeaders>
54
    sec_websocket_version_handle(Http::LowerCaseString{"sec-websocket-version"});
55
Http::RegisterCustomInlineHeader<Http::CustomInlineHeaderRegistry::Type::RequestHeaders>
56
    sec_websocket_protocol_handle(Http::LowerCaseString{"sec-websocket-protocol"});
57
Http::RegisterCustomInlineHeader<Http::CustomInlineHeaderRegistry::Type::RequestHeaders>
58
    sec_websocket_extensions_handle(Http::LowerCaseString{"sec-websocket-extensions"});
59

            
60
Http::RegisterCustomInlineHeader<Http::CustomInlineHeaderRegistry::Type::ResponseHeaders>
61
    sec_websocket_accept_handle(Http::LowerCaseString{"sec-websocket-accept"});
62

            
63
class HttpParser : public Logger::Loggable<Logger::Id::filter> {
64
public:
65
24
  virtual ~HttpParser() = default;
66
24
  HttpParser(http_parser_type type) : type_(type) {}
67

            
68
24
  bool parse(absl::string_view msg) {
69
24
    http_parser_init(&parser_, type_);
70
24
    parser_.data = this;
71
24
    http_parser_settings settings = {
72
24
        nullptr, /* on_message_begin */
73
24
        nullptr, /* on_URL */
74
24
        nullptr, /* on_status */
75
179
        [](http_parser* parser, const char* at, size_t length) -> int {
76
179
          return static_cast<HttpParser*>(parser->data)->onHeaderField(at, length);
77
179
        }, /* on_header_field */
78
179
        [](http_parser* parser, const char* at, size_t length) -> int {
79
179
          return static_cast<HttpParser*>(parser->data)->onHeaderValue(at, length);
80
179
        }, /* on_header_value */
81
24
        [](http_parser* parser) -> int {
82
24
          return static_cast<HttpParser*>(parser->data)->onHeadersComplete();
83
24
        },       /* on_headers_complete */
84
24
        nullptr, /* on_body */
85
24
        [](http_parser* parser) -> int {
86
24
          static_cast<HttpParser*>(parser->data)->message_complete_ = true;
87
24
          return 0;
88
24
        },       /* on_message_complete */
89
24
        nullptr, /* chunk header, chunk length in parser->content_length */
90
24
        nullptr, /* chunk complete */
91
24
    };
92

            
93
24
    ssize_t rc = http_parser_execute(&parser_, &settings, msg.data(), msg.length());
94
24
    ENVOY_LOG(trace, "websocket: http_parser parsed {} chars, error code: {}", rc,
95
24
              static_cast<int>(HTTP_PARSER_ERRNO(&parser_)));
96

            
97
    // Errors in parsing HTTP.
98
24
    if (HTTP_PARSER_ERRNO(&parser_) != HPE_OK) {
99
      return false;
100
    }
101

            
102
24
    return message_complete_;
103
24
  }
104

            
105
24
  bool versionIsHttp11() {
106
24
    ENVOY_LOG(trace, "websocket: http_parser got version major: {} minor: {}", parser_.http_major,
107
24
              parser_.http_minor);
108
24
    return parser_.http_major == 1 && parser_.http_minor == 1;
109
24
  }
110

            
111
  uint32_t size() { return parser_.nread; }
112

            
113
protected:
114
179
  int completeLastHeader() {
115
179
    if (Http::HeaderUtility::headerNameContainsUnderscore(current_header_field_.getStringView())) {
116
      ENVOY_LOG(debug, "websocket: Rejecting invalid header: key={} value={}",
117
                current_header_field_.getStringView(), current_header_value_.getStringView());
118
      return -1;
119
    }
120
179
    ENVOY_LOG(trace, "websocket: completed header: key={} value={}",
121
179
              current_header_field_.getStringView(), current_header_value_.getStringView());
122

            
123
179
    if (!current_header_field_.empty()) {
124
      // Strip trailing whitespace of the current header value if any. Leading whitespace was
125
      // trimmed in onHeaderValue. http_parser does not strip leading or trailing whitespace as the
126
      // spec requires: https://tools.ietf.org/html/rfc7230#section-3.2.4
127
179
      current_header_value_.rtrim();
128

            
129
2343
      current_header_field_.inlineTransform([](char c) { return absl::ascii_tolower(c); });
130

            
131
179
      addHeader(std::move(current_header_field_), std::move(current_header_value_));
132
179
    }
133
179
    return 0;
134
179
  }
135

            
136
179
  int onHeaderField(const char* data, size_t length) {
137
179
    if (parsing_value_) {
138
155
      auto code = completeLastHeader();
139
155
      if (code != 0) {
140
        return code;
141
      }
142
155
    }
143
179
    parsing_value_ = false;
144
179
    current_header_field_.append(data, length);
145
179
    return 0;
146
179
  }
147

            
148
179
  int onHeaderValue(const char* data, size_t length) {
149
179
    parsing_value_ = true;
150
179
    absl::string_view header_value{data, length};
151
179
    if (!Http::HeaderUtility::headerValueIsValid(header_value)) {
152
      ENVOY_LOG(debug, "websocket: invalid header value: {}", header_value);
153
      return -1;
154
    }
155

            
156
179
    if (current_header_value_.empty()) {
157
      // Strip leading whitespace if the current header value input contains the first bytes of the
158
      // encoded header value. Trailing whitespace is stripped once the full header value is known
159
      // in completeLastHeader. http_parser does not strip leading or trailing
160
      // whitespace as the spec requires: https://tools.ietf.org/html/rfc7230#section-3.2.4 .
161
179
      header_value = StringUtil::ltrim(header_value);
162
179
    }
163
179
    current_header_value_.append(header_value.data(), header_value.length());
164
179
    return 0;
165
179
  }
166

            
167
24
  virtual int onHeadersComplete() { return completeLastHeader(); }
168

            
169
  virtual void addHeader(Http::HeaderString&& key, Http::HeaderString&& value) PURE;
170

            
171
  http_parser_type type_;
172
  http_parser parser_;
173

            
174
  Http::HeaderString current_header_field_;
175
  Http::HeaderString current_header_value_;
176
  bool parsing_value_{false};
177
  bool message_complete_{false};
178
};
179

            
180
class RequestParser : public HttpParser {
181
public:
182
13
  RequestParser() : HttpParser(HTTP_REQUEST), headers_(Http::RequestHeaderMapImpl::create()) {}
183

            
184
13
  const Http::RequestHeaderMap& headers() { return *(headers_.get()); }
185

            
186
protected:
187
13
  int onHeadersComplete() override {
188
13
    headers_->setMethod(http_method_str(static_cast<http_method>(parser_.method)));
189
13
    return HttpParser::onHeadersComplete();
190
13
  }
191

            
192
120
  void addHeader(Http::HeaderString&& key, Http::HeaderString&& value) override {
193
120
    headers_->addViaMove(std::move(key), std::move(value));
194
120
  }
195

            
196
private:
197
  Http::RequestHeaderMapPtr headers_;
198
};
199

            
200
class ResponseParser : public HttpParser {
201
public:
202
11
  ResponseParser() : HttpParser(HTTP_RESPONSE), headers_(Http::ResponseHeaderMapImpl::create()) {}
203

            
204
11
  const Http::ResponseHeaderMap& headers() { return *(headers_.get()); }
205

            
206
11
  unsigned int status() {
207
11
    ENVOY_LOG(trace, "websocket: http_parser got status: {}",
208
11
              static_cast<unsigned int>(parser_.status_code));
209
11
    return parser_.status_code;
210
11
  }
211

            
212
protected:
213
11
  int onHeadersComplete() override {
214
11
    headers_->setStatus(parser_.status_code);
215
11
    return HttpParser::onHeadersComplete();
216
11
  }
217

            
218
59
  void addHeader(Http::HeaderString&& key, Http::HeaderString&& value) override {
219
59
    headers_->addViaMove(std::move(key), std::move(value));
220
59
  }
221

            
222
private:
223
  Http::ResponseHeaderMapPtr headers_;
224
};
225

            
226
197
#define CRLF "\r\n"
227
static const char REQUEST_POSTFIX[] = " HTTP/1.1" CRLF;
228
static const std::string request_prefix = "GET ";
229
static const std::string response_prefix = "HTTP/1.1 ";
230
static const absl::string_view header_separator = {CRLF CRLF, sizeof(CRLF CRLF) - 1};
231

            
232
160
void encodeHeader(Buffer::Instance& buffer, absl::string_view key, absl::string_view value) {
233
160
  buffer.add(key);
234
160
  buffer.add(": ", 2);
235
160
  buffer.add(value);
236
160
  buffer.add(CRLF, 2);
237
160
}
238

            
239
24
void encodeHeaders(Buffer::Instance& buffer, Http::RequestOrResponseHeaderMap& headers) {
240
24
  const Http::HeaderValues& header_values = Http::Headers::get();
241
24
  headers.iterate(
242
171
      [&buffer, &header_values](const Http::HeaderEntry& header) -> Http::HeaderMap::Iterate {
243
171
        absl::string_view key = header.key().getStringView();
244
171
        if (key[0] == ':') {
245
          // Translate :authority -> host so that upper layers do not need to deal with this.
246
46
          if (key.size() > 1 && key[1] == 'a') {
247
11
            key = absl::string_view(header_values.HostLegacy.get());
248
35
          } else {
249
            // Skip all headers starting with ':' that make it here.
250
35
            return Http::HeaderMap::Iterate::Continue;
251
35
          }
252
46
        }
253
136
        encodeHeader(buffer, key, header.value().getStringView());
254
136
        return Http::HeaderMap::Iterate::Continue;
255
171
      });
256
24
  encodeHeader(buffer, header_values.ContentLength.get(), "0");
257
24
}
258

            
259
11
void encodeRequest(Buffer::Instance& buffer, Http::RequestHeaderMap& headers) {
260
11
  const Http::HeaderEntry* method = headers.Method();
261
11
  const Http::HeaderEntry* path = headers.Path();
262

            
263
11
  buffer.add(method->value().getStringView());
264
11
  buffer.add(" ", 1);
265
11
  buffer.add(path->value().getStringView());
266
11
  buffer.add(REQUEST_POSTFIX, sizeof(REQUEST_POSTFIX) - 1);
267

            
268
11
  encodeHeaders(buffer, headers);
269

            
270
11
  buffer.add(CRLF, 2);
271
11
}
272

            
273
13
void encodeResponse(Buffer::Instance& buffer, Http::ResponseHeaderMap& headers) {
274
13
  const Http::HeaderEntry* status = headers.Status();
275
13
  uint64_t numeric_status = Http::Utility::getResponseStatus(headers);
276
13
  const char* status_string = Http::CodeUtility::toString(static_cast<Http::Code>(numeric_status));
277

            
278
13
  buffer.add(response_prefix);
279
13
  buffer.add(status->value().getStringView());
280
13
  buffer.add(" ", 1);
281
13
  buffer.add(status_string, strlen(status_string));
282
13
  buffer.add(CRLF, 2);
283

            
284
13
  encodeHeaders(buffer, headers);
285

            
286
13
  buffer.add(CRLF, 2);
287
13
}
288

            
289
} // namespace
290

            
291
//
292
// Codec
293
//
294

            
295
Codec::Codec(CodecCallbacks* parent, Network::Connection& conn)
296
24
    : parent_(parent), connection_(conn), encoder_(*this), decoder_(*this) {
297
24
  ENVOY_CONN_LOG(trace, "Enabling websocket handshake timeout at {} ms", connection_,
298
24
                 parent_->config()->handshake_timeout_.count());
299
24
  handshake_timer_ = connection_.dispatcher().createTimer([this]() {
300
    parent_->config()->stats_.handshake_timeout_.inc();
301
    closeOnError("websocket handshake timed out");
302
  });
303
24
  handshake_timer_->enableTimer(parent_->config()->handshake_timeout_);
304
24
}
305

            
306
namespace {
307

            
308
225
size_t maskData(uint8_t* buf, size_t n_bytes, uint8_t mask[4], size_t payload_offset = 0) {
309
209781265
  for (size_t i = 0; i < n_bytes; i++) {
310
209781040
    buf[i] ^= mask[payload_offset % 4];
311
209781040
    payload_offset++;
312
209781040
  }
313
225
  return payload_offset;
314
225
}
315

            
316
} // namespace
317

            
318
1
void Codec::closeOnError(const char* msg) {
319
1
  if (msg) {
320
1
    ENVOY_LOG(debug, "websocket: Closing connection: {}", msg);
321
1
  }
322
  // Close downstream, this should result also in the upstream getting closed (if any).
323
1
  connection_.close(Network::ConnectionCloseType::NoFlush, fmt::format("websocket error: {}", msg));
324
1
}
325

            
326
1
void Codec::closeOnError(Buffer::Instance& data, const char* msg) {
327
1
  closeOnError(msg);
328
  // Test infra insists on data being drained
329
1
  data.drain(data.length());
330
1
}
331

            
332
11
void Codec::handshake() {
333
11
  ENVOY_LOG(debug, "websocket: handshake");
334

            
335
11
  auto& config = parent_->config();
336

            
337
11
  if (!config->client_) {
338
    ENVOY_LOG(warn, "websocket: skipping handshake on a server");
339
    return;
340
  }
341

            
342
11
  const Network::Address::InstanceConstSharedPtr& dst_address =
343
11
      connection_.connectionInfoProvider().localAddress();
344

            
345
  // Create WebSocket Handshake
346
11
  const Http::HeaderValues& header_values = Http::Headers::get();
347
11
  Envoy::Http::RequestHeaderMapPtr headers = Http::RequestHeaderMapImpl::create();
348
11
  headers->setReferenceMethod(header_values.MethodValues.Get);
349
11
  headers->setReferencePath(config->path_);
350
11
  headers->setReferenceHost(config->host_);
351
11
  headers->setReferenceUpgrade(header_values.UpgradeValues.WebSocket);
352
11
  headers->setReferenceConnection(header_values.ConnectionValues.Upgrade);
353
11
  headers->setReferenceInline(sec_websocket_key_handle.handle(), config->key_);
354
11
  headers->setReferenceInline(sec_websocket_version_handle.handle(), config->version_);
355
11
  if (!config->origin_.empty()) {
356
11
    headers->setReferenceInline(origin_handle.handle(), config->origin_);
357
11
  }
358
  // Set original destination address header
359
11
  headers->setReferenceInline(original_dst_host_handle.handle(), dst_address->asStringView());
360
  // Set 'x-request-id' header
361
11
  config->request_id_extension_->set(*headers, true, false);
362

            
363
11
  parent_->onHandshakeCreated(*headers);
364

            
365
11
  Buffer::OwnedImpl handshake_buffer{};
366
11
  encodeRequest(handshake_buffer, *headers);
367
11
  parent_->injectEncoded(handshake_buffer, false);
368
  // Check that the buffer was drained
369
11
  ASSERT(handshake_buffer.length() == 0, "Handshake buffer not drained");
370
11
  parent_->onHandshakeSent();
371
11
}
372

            
373
58
void Codec::encode(Buffer::Instance& data, bool end_stream) {
374
58
  ENVOY_LOG(debug, "websocket: encode {} bytes, end_stream: {}", data.length(), end_stream);
375

            
376
58
  encoder_.encode(data, end_stream, OPCODE_BIN);
377

            
378
  // Only forward data if handshake has completed
379
58
  if (accepted_) {
380
    // Reset idle timer on data
381
55
    if (encoder_.hasData()) {
382
55
      resetPingTimer();
383
55
    }
384
55
    parent_->injectEncoded(encoder_.data(), encoder_.endStream());
385
55
  }
386
58
}
387

            
388
void Codec::encodeHandshakeResponse(Http::ResponseHeaderMap& headers, uint32_t status,
389
                                    absl::string_view hash,
390
13
                                    const Http::RequestHeaderMap* request_headers) {
391
13
  if (status == 200) {
392
12
    ENVOY_LOG(debug, "websocket: Using hash {}", hash);
393
12
    headers.setStatus(enumToInt(Http::Code::SwitchingProtocols)); // 101
394
12
    headers.setReferenceConnection(Http::Headers::get().ConnectionValues.Upgrade);
395
12
    headers.setReferenceUpgrade(Http::Headers::get().UpgradeValues.WebSocket);
396
12
    headers.addCopy(Envoy::Http::LowerCaseString("sec-websocket-accept"), hash);
397
12
  } else {
398
1
    headers.setStatus(enumToInt(Http::Code::Forbidden)); // 403
399
1
  }
400
13
  if (request_headers != nullptr && request_headers->RequestId()) {
401
12
    headers.setRequestId(request_headers->getRequestIdValue());
402
12
  }
403
13
}
404

            
405
Network::Address::InstanceConstSharedPtr
406
Codec::decodeHandshakeRequest(const ConfigSharedPtr& config,
407
13
                              const Http::RequestHeaderMap& headers) {
408

            
409
13
  auto method = headers.getMethodValue();
410
13
  auto path = absl::AsciiStrToLower(headers.getPathValue());
411
13
  auto host = absl::AsciiStrToLower(headers.getHostValue());
412
13
  auto connection = absl::AsciiStrToLower(headers.getConnectionValue());
413
13
  auto upgrade = absl::AsciiStrToLower(headers.getUpgradeValue());
414
13
  auto version =
415
13
      absl::AsciiStrToLower(headers.getInlineValue(sec_websocket_version_handle.handle()));
416
13
  auto origin = headers.getInline(origin_handle.handle());
417
13
  auto protocol = headers.getInline(sec_websocket_protocol_handle.handle());
418
13
  auto extensions = headers.getInline(sec_websocket_extensions_handle.handle());
419
13
  auto override_header = headers.getInline(original_dst_host_handle.handle());
420
13
  auto key = headers.getInlineValue(sec_websocket_key_handle.handle());
421

            
422
13
  Network::Address::InstanceConstSharedPtr orig_dst{nullptr};
423
13
  if (override_header != nullptr && !override_header->value().empty()) {
424
12
    const std::string request_override_host(override_header->value().getStringView());
425
12
    orig_dst = Network::Utility::parseInternetAddressAndPortNoThrow(request_override_host, false);
426
12
  }
427
13
  bool valid =
428
13
      (method == Http::Headers::get().MethodValues.Get &&
429
13
       connection == Http::Headers::get().ConnectionValues.Upgrade &&
430
13
       upgrade == Http::Headers::get().UpgradeValues.WebSocket &&
431
       // path must be present with non-empty value, and must match expected if configured
432
13
       ((config->path_.empty() && path.length() > 0) || (path == config->path_)) &&
433
       // host must be present with non-empty value, and must match expected if configured
434
13
       ((config->host_.empty() && host.length() > 0) || (host == config->host_)) &&
435
       // key must be present with non-empty value, and must match expected if configured
436
13
       ((config->key_.empty() && key.length() > 0) || (key == config->key_)) &&
437
       // version must be present with non-empty value, and must match expected if configured
438
13
       ((config->version_.empty() && version.length() > 0) || (version == config->version_)) &&
439
       // origin must be present with non-empty value and must match expected if configured,
440
       // origin may not be present if not configured
441
13
       (config->origin_.empty()
442
12
            ? origin == nullptr
443
12
            : (origin != nullptr &&
444
12
               absl::AsciiStrToLower(origin->value().getStringView()) == config->origin_)) &&
445
       // protocol and extensions are not allowed for now
446
13
       protocol == nullptr && extensions == nullptr &&
447
       // override header must be present and have a valid value
448
13
       orig_dst != nullptr);
449
13
  ENVOY_LOG(debug,
450
13
            "websocket: valid = {}, method: {}/{}, path: \"{}\"/\"{}\", host: {}/{}, connection: "
451
13
            "{}/{}, upgrade: {}/{}, key: {}/{}, version: {}/{}, origin: {}/{}, protocol: {}, "
452
13
            "extensions: {}, override: {}",
453
13
            valid, method, Http::Headers::get().MethodValues.Get, path, config->path_, host,
454
13
            config->host_, connection, Http::Headers::get().ConnectionValues.Upgrade, upgrade,
455
13
            Http::Headers::get().UpgradeValues.WebSocket, key, config->key_, version,
456
13
            config->version_, origin ? origin->value().getStringView() : "<NONE>", config->origin_,
457
13
            protocol ? protocol->value().getStringView() : "<NONE>",
458
13
            extensions ? extensions->value().getStringView() : "<NONE>",
459
13
            override_header ? override_header->value().getStringView() : "<NONE>");
460

            
461
13
  return valid ? orig_dst : nullptr;
462
13
}
463

            
464
23
void Codec::startPingTimer() {
465
23
  auto& config = parent_->config();
466

            
467
  // Start ping timer if enabled
468
23
  if (config->ping_interval_.count()) {
469
11
    ENVOY_CONN_LOG(trace, "Enabling websocket PING timer at {} ms", connection_,
470
11
                   config->ping_interval_.count());
471
97
    ping_timer_ = connection_.dispatcher().createTimer([this]() {
472
97
      auto& config = parent_->config();
473
97
      char count_buffer[StringUtil::MIN_ITOA_OUT_LEN];
474
97
      const uint32_t count_len =
475
97
          StringUtil::itoa(count_buffer, StringUtil::MIN_ITOA_OUT_LEN, ++ping_count_);
476
97
      if (ping(count_buffer, count_len)) {
477
89
        ENVOY_CONN_LOG(trace, "Injected websocket PING {}", connection_, ping_count_);
478
        // Randomize ping interval with jitter when idle
479
89
        if (ping_timer_ != nullptr) {
480
89
          uint64_t interval_ms = config->ping_interval_.count();
481
89
          const uint64_t jitter_percent_mod = ping_interval_jitter_percent_ * interval_ms / 100;
482
89
          if (jitter_percent_mod > 0) {
483
            interval_ms += config->random_.random() % jitter_percent_mod;
484
          }
485
89
          ping_timer_->enableTimer(std::chrono::milliseconds(interval_ms));
486
89
        }
487
89
      }
488
97
    });
489
11
    ping_timer_->enableTimer(config->ping_interval_);
490
11
  }
491
23
}
492

            
493
24
bool Codec::checkPrefix(Buffer::Instance& data, const std::string& prefix) {
494
  // Sanity check the first chars to catch non-HTTP messages
495
24
  auto cmp_len = std::min(data.length(), prefix.length());
496
24
  const char* cmp_data = reinterpret_cast<char*>(data.linearize(cmp_len));
497
24
  return absl::string_view(cmp_data, cmp_len) == absl::string_view(prefix.data(), cmp_len);
498
24
}
499

            
500
208
void Codec::decode(Buffer::Instance& data, bool end_stream) {
501
208
  ENVOY_LOG(trace, "websocket: decode {} bytes, end_stream: {}", data.length(), end_stream);
502

            
503
208
  auto& config = parent_->config();
504

            
505
208
  if (!accepted_) {
506
    // Buffer incoming data in case it arrives in parts
507
24
    handshake_buffer_.move(data);
508

            
509
24
    if (handshake_buffer_.length() > WEBSOCKET_HANDSHAKE_MAX_SIZE) {
510
      config->stats_.handshake_too_large_.inc();
511
      return closeOnError(handshake_buffer_, "handshake message too long.");
512
    }
513
    // Client needs to wait for a valid handshake response before accepting any data
514
24
    if (config->client_) {
515
      // Sanity check the first chars to catch non HTTP responses
516
11
      if (!checkPrefix(handshake_buffer_, response_prefix)) {
517
        config->stats_.handshake_not_http_.inc();
518
        return closeOnError(handshake_buffer_, "response not http.");
519
      }
520
13
    } else {
521
      // Server needs to see the handshake request as the first message.
522
      // Sanity check the first chars to catch non HTTP requests
523
13
      if (!checkPrefix(handshake_buffer_, request_prefix)) {
524
        config->stats_.handshake_not_http_.inc();
525
        return closeOnError(handshake_buffer_, "request not http.");
526
      }
527
13
    }
528

            
529
    // Find the header separator that marks the end of the handshake request/response.
530
24
    ssize_t pos =
531
24
        handshake_buffer_.search(header_separator.data(), header_separator.length(), 0, 0);
532
24
    if (pos == -1) {
533
      if (end_stream) {
534
        config->stats_.protocol_error_.inc();
535
        return closeOnError(handshake_buffer_, "no request/response.");
536
      }
537
      return; // Header separator not found, Need more data
538
    }
539

            
540
    // Got the request/response, can disable the handshake timer.
541
24
    handshake_timer_->disableTimer();
542

            
543
    // Include the header separator in message size
544
24
    size_t msg_size = pos + header_separator.length();
545
24
    absl::string_view message = {reinterpret_cast<char*>(handshake_buffer_.linearize(msg_size)),
546
24
                                 msg_size};
547

            
548
24
    if (config->client_) {
549
11
      ResponseParser parser;
550
11
      bool ok = parser.parse(message);
551
11
      if (!ok) {
552
        config->stats_.handshake_parse_error_.inc();
553
        return closeOnError(handshake_buffer_, "response parse failed.");
554
      }
555
11
      handshake_buffer_.drain(msg_size);
556

            
557
11
      const Http::ResponseHeaderMap& headers = parser.headers();
558
11
      parent_->onHandshakeResponse(headers);
559

            
560
11
      if (!parser.versionIsHttp11()) {
561
        config->stats_.handshake_invalid_http_version_.inc();
562
        return closeOnError(handshake_buffer_, "unsupported HTTP protocol");
563
      }
564

            
565
11
      if (parser.status() != 101) {
566
        config->stats_.handshake_invalid_http_status_.inc();
567
        return closeOnError(handshake_buffer_, "Invalid HTTP status code for websocket");
568
      }
569

            
570
      // Validate response headers
571
11
      auto connection = absl::AsciiStrToLower(headers.getConnectionValue());
572
11
      auto upgrade = absl::AsciiStrToLower(headers.getUpgradeValue());
573
11
      auto key_accept = headers.getInlineValue(sec_websocket_accept_handle.handle());
574

            
575
11
      auto key_response = config->keyResponse(config->key_);
576
11
      accepted_ = connection == Http::Headers::get().ConnectionValues.Upgrade &&
577
11
                  upgrade == Http::Headers::get().UpgradeValues.WebSocket &&
578
11
                  key_accept == key_response;
579

            
580
11
      ENVOY_LOG(debug,
581
11
                "websocket: accepted_ = {}, connection: {}, upgrade: {}, accept: {} (expected {})",
582
11
                accepted_, connection, upgrade, key_accept, key_response);
583

            
584
11
      if (!accepted_) {
585
        config->stats_.handshake_invalid_websocket_response_.inc();
586
        return closeOnError(handshake_buffer_, "Invalid WebSocket response");
587
      }
588

            
589
      // Kick write on the other direction
590
11
      parent_->injectEncoded(encoder_.data(), encoder_.endStream());
591

            
592
13
    } else {
593
      // Server needs to wait for a valid handshake request before accepting any data
594
13
      RequestParser parser;
595
13
      bool ok = parser.parse(message);
596
13
      if (!ok) {
597
        // Consider issuing HTTP response instead?
598
        config->stats_.handshake_parse_error_.inc();
599
        return closeOnError(handshake_buffer_, "request parse failed.");
600
      }
601
13
      handshake_buffer_.drain(msg_size);
602

            
603
13
      const Http::RequestHeaderMap& headers = parser.headers();
604
13
      parent_->onHandshakeRequest(headers);
605

            
606
13
      if (!parser.versionIsHttp11()) {
607
        config->stats_.handshake_invalid_http_version_.inc();
608
        return closeOnError(handshake_buffer_, "unsupported HTTP protocol");
609
      }
610

            
611
      // Validate request headers
612
13
      auto response_headers = Http::ResponseHeaderMapImpl::create();
613
13
      Buffer::OwnedImpl response_buffer{};
614
13
      auto orig_dst = decodeHandshakeRequest(config, headers);
615
13
      accepted_ = (orig_dst != nullptr);
616
13
      if (!accepted_) {
617
1
        config->stats_.handshake_invalid_websocket_request_.inc();
618

            
619
        // Create handshake error response
620
1
        encodeHandshakeResponse(*response_headers, 403, "", &headers);
621
1
        encodeResponse(response_buffer, *response_headers);
622
1
        parent_->injectEncoded(response_buffer, true);
623
        // Check if the buffer was not drained
624
1
        if (response_buffer.length() > 0) {
625
          config->stats_.handshake_write_error_.inc();
626
1
        } else {
627
1
          parent_->onHandshakeResponseSent(*response_headers);
628
1
        }
629
1
        return closeOnError(handshake_buffer_, "Invalid WebSocket request");
630
1
      }
631

            
632
      // Create handshake response
633
12
      auto hash = Config::keyResponse(headers.getInlineValue(sec_websocket_key_handle.handle()));
634
12
      encodeHandshakeResponse(*response_headers, 200, hash, &headers);
635
12
      encodeResponse(response_buffer, *response_headers);
636
12
      parent_->injectEncoded(response_buffer, false);
637
      // Check if the buffer was not drained
638
12
      if (response_buffer.length() > 0) {
639
        config->stats_.handshake_write_error_.inc();
640
        return closeOnError(handshake_buffer_, "error writing handshake response");
641
      }
642
      // Set destination address for the original destination filter.
643
12
      parent_->setOriginalDestinationAddress(orig_dst);
644

            
645
12
      parent_->onHandshakeResponseSent(*response_headers);
646
12
    }
647

            
648
23
    startPingTimer();
649

            
650
    // Move any remaining data back to 'data'
651
23
    data.move(handshake_buffer_);
652
23
  }
653

            
654
  // Handshake done, process data.
655
207
  decoder_.decode(data, end_stream);
656

            
657
  // Reset idle timer on data
658
207
  if (decoder_.hasData()) {
659
43
    resetPingTimer();
660
43
  }
661

            
662
207
  parent_->injectDecoded(decoder_.data(), decoder_.endStream());
663
207
}
664

            
665
97
bool Codec::ping(const void* payload, size_t len) {
666
97
  if (encoder_.endStream()) {
667
8
    return false;
668
8
  }
669
89
  Buffer::OwnedImpl buf(payload, len);
670
89
  encoder_.encode(buf, false, OPCODE_PING);
671
89
  parent_->config()->stats_.ping_sent_count_.inc();
672
89
  parent_->injectEncoded(encoder_.data(), encoder_.endStream());
673
89
  return true;
674
97
}
675

            
676
81
bool Codec::pong(const void* payload, size_t len) {
677
81
  if (encoder_.endStream()) {
678
15
    return false;
679
15
  }
680
66
  Buffer::OwnedImpl buf(payload, len);
681
66
  encoder_.encode(buf, false, OPCODE_PONG);
682
66
  parent_->injectEncoded(encoder_.data(), encoder_.endStream());
683
66
  return true;
684
81
}
685

            
686
// Encoder
687

            
688
// Encode 'data' and 'end_stream' as websocket frames into 'encoded_'. Uses 'opcode' as the
689
// websocket frame type for the data frames.
690
213
void Codec::Encoder::encode(Buffer::Instance& data, bool end_stream, uint8_t opcode) {
691
213
  auto hex_len = std::min(data.length(), 20UL);
692
213
  const uint8_t* hex_data = reinterpret_cast<uint8_t*>(data.linearize(hex_len));
693
213
  ENVOY_LOG(debug, "websocket encoder: {} bytes: 0x{}, end_stream: {}, opcode: {}", data.length(),
694
213
            Hex::encode(hex_data, hex_len), end_stream, opcode);
695

            
696
213
  auto& config = parent_.config();
697
  //
698
  // Encode data as a single WebSocket frame
699
  //
700
213
  if (data.length() > 0) {
701
196
    uint8_t frame_header[14];
702
196
    size_t frame_header_length = 2;
703
196
    size_t payload_len = data.length();
704

            
705
196
    frame_header[0] = FIN_MASK | opcode;
706
196
    if (payload_len < 126) {
707
177
      frame_header[1] = payload_len;
708
177
    } else if (payload_len < 65536) {
709
7
      uint16_t len16;
710

            
711
7
      frame_header[1] = 126;
712
7
      len16 = htobe16(payload_len);
713
7
      memcpy(frame_header + frame_header_length, &len16, 2); // NOLINT(safe-memcpy)
714
7
      frame_header_length += 2;
715
14
    } else {
716
12
      uint64_t len64;
717

            
718
12
      frame_header[1] = 127;
719
12
      len64 = htobe64(payload_len);
720
12
      memcpy(frame_header + frame_header_length, &len64, 8); // NOLINT(safe-memcpy)
721
12
      frame_header_length += 8;
722
12
    }
723

            
724
    // Client must mask the payload
725
196
    if (config->client_) {
726
110
      frame_header[1] |= MASK_MASK;
727

            
728
110
      union {
729
110
        uint8_t bytes[4];
730
110
        uint32_t word;
731
110
      } mask;
732

            
733
110
      mask.word = config->random_.random();
734
110
      memcpy(frame_header + frame_header_length, &mask, 4); // NOLINT(safe-memcpy)
735
110
      frame_header_length += 4;
736
110
      uint8_t* buf = reinterpret_cast<uint8_t*>(data.linearize(payload_len));
737
110
      maskData(buf, payload_len, mask.bytes);
738
110
    }
739

            
740
    // Add frame header and (masked) data
741
196
    encoded_.add(absl::string_view{reinterpret_cast<char*>(frame_header), frame_header_length});
742
196
    encoded_.move(data, payload_len);
743
196
  }
744

            
745
  //
746
  // Append closing frame if 'end_stream'
747
  //
748
213
  if (end_stream) {
749
22
    uint8_t frame_header[14];
750
22
    size_t frame_header_length = 2;
751
22
    size_t payload_len = 0;
752

            
753
22
    frame_header[0] = FIN_MASK | OPCODE_CLOSE;
754
22
    frame_header[1] = payload_len;
755
    // Client must mask the payload
756
22
    if (config->client_) {
757
11
      frame_header[1] |= MASK_MASK;
758

            
759
11
      uint32_t mask = config->random_.random();
760
11
      memcpy(frame_header + frame_header_length, &mask, 4); // NOLINT(safe-memcpy)
761
11
      frame_header_length += 4;
762
      // No data to mask
763
11
    }
764
22
    encoded_.add(reinterpret_cast<void*>(frame_header), frame_header_length);
765
22
    end_stream_ = true;
766

            
767
22
    ENVOY_LOG(debug, "websocket encoder: sent WebSocket CLOSE message, end_stream: {}", end_stream);
768
22
  }
769
213
}
770

            
771
// Decoder
772

            
773
/*
774
 * TRY_READ_NETWORK reads sizeof(*(DATA)) bytes from 'buffer_' if available.
775
 * Does not drain anything from the buffer,
776
 * draining has to be done separately.
777
 */
778
#define TRY_READ_NETWORK(DATA)                                                                     \
779
340
  {                                                                                                \
780
340
    if (buffer_.length() < frame_offset + sizeof(*(DATA))) {                                       \
781
      /* Try again when we have more data */                                                       \
782
      return;                                                                                      \
783
    }                                                                                              \
784
340
    ENVOY_LOG(trace, "websocket: copyOut {} bytes at offset {}", sizeof(*(DATA)), frame_offset);   \
785
340
    buffer_.copyOut(frame_offset, sizeof(*(DATA)), (DATA));                                        \
786
340
    frame_offset += sizeof(*(DATA));                                                               \
787
340
  }
788

            
789
// Decode 'data' into 'decoded_'.
790
207
void Codec::Decoder::decode(Buffer::Instance& data, bool end_stream) {
791
207
  ENVOY_LOG(trace, "websocket decoder: {} bytes, end_stream: {}", data.length(), end_stream);
792

            
793
207
  buffer_.move(data);
794

            
795
207
  if (end_stream_ && buffer_.length() > 0) {
796
    ENVOY_LOG(debug, "websocket decoder: data received after CLOSE: {} bytes", buffer_.length());
797
    buffer_.drain(buffer_.length());
798
    return;
799
  }
800

            
801
207
  if (end_stream) {
802
21
    end_stream_ = true;
803
21
  }
804

            
805
413
  while (buffer_.length() > 0) {
806
    // Try finish any frame in progress
807
3450
    while (payload_remaining_ > 0) {
808
3244
      auto slice = buffer_.frontSlice();
809
3244
      size_t n_bytes = std::min(slice.len_, payload_remaining_);
810

            
811
      // Unmask data in place
812
3244
      uint8_t* buf = static_cast<uint8_t*>(slice.mem_);
813
3244
      auto hex_len = std::min(n_bytes, 20UL);
814
3244
      if (unmasking_) {
815
25
        ENVOY_LOG(
816
25
            trace,
817
25
            "websocket decoder: unmasking payload remaining: {}, offset: {}, processing: {}: 0x{}",
818
25
            payload_remaining_, payload_offset_, n_bytes, Hex::encode(buf, hex_len));
819
25
        payload_offset_ = maskData(buf, n_bytes, mask_, payload_offset_);
820
25
      }
821
3244
      ENVOY_LOG(trace, "websocket decoder: payload remaining: {}, offset: {}, processing: {}: 0x{}",
822
3244
                payload_remaining_, payload_offset_, n_bytes, Hex::encode(buf, hex_len));
823

            
824
3244
      decoded_.move(buffer_, n_bytes);
825
3244
      payload_remaining_ -= n_bytes;
826

            
827
3244
      if (buffer_.length() == 0) {
828
36
        return;
829
36
      }
830
3244
    }
831
    //
832
    // Now at a frame boundary, reset state for a new frame.
833
    //
834
206
    unmasking_ = false;
835
206
    payload_offset_ = 0;
836
206
    RELEASE_ASSERT(payload_remaining_ == 0, "internal websocket framing error");
837

            
838
206
    uint8_t frame_header[2];
839
206
    size_t frame_offset = 0;
840
206
    uint8_t opcode;
841
206
    uint64_t payload_len;
842

            
843
206
    ENVOY_LOG(trace, "websocket decoder: remaining buffer: {} bytes", buffer_.length());
844

            
845
206
    TRY_READ_NETWORK(&frame_header);
846
206
    opcode = frame_header[0] & OPCODE_MASK;
847
206
    payload_len = frame_header[1] & PAYLOAD_LEN_MASK;
848

            
849
206
    if (payload_len == 126) {
850
8
      uint16_t len16;
851

            
852
8
      TRY_READ_NETWORK(&len16);
853
8
      payload_len = be16toh(len16);
854
198
    } else if (payload_len == 127) {
855
13
      uint64_t len64;
856

            
857
13
      TRY_READ_NETWORK(&len64);
858
13
      payload_len = be64toh(len64);
859
13
    }
860
206
    if (frame_header[1] & MASK_MASK) {
861
113
      TRY_READ_NETWORK(&mask_);
862
113
      unmasking_ = true;
863
113
    }
864

            
865
    //
866
    // Whole header received and decoded
867
    //
868

            
869
    // Terminate and respond to any control frames
870
206
    if (opcode >= OPCODE_CLOSE) {
871
      // Protect against too large control frames that could happen if the decoder ever loses
872
      // sync with the data stream.
873
162
      if (payload_len > WEBSOCKET_CONTROL_FRAME_MAX_SIZE) {
874
        ENVOY_LOG(debug, "websocket decoder: too large control frame: {} bytes", payload_len);
875
        buffer_.drain(buffer_.length());
876
        end_stream_ = true;
877
        return;
878
      }
879

            
880
      // Buffer until whole control frame has been received
881
162
      if (buffer_.length() < frame_offset + payload_len) {
882
        return;
883
      }
884

            
885
      // Drain control frame header, get the payload
886
162
      buffer_.drain(frame_offset);
887
162
      uint8_t* payload = reinterpret_cast<uint8_t*>(buffer_.linearize(payload_len));
888

            
889
      // Unmask the control frame payload
890
162
      if (unmasking_) {
891
90
        maskData(payload, payload_len, mask_);
892
90
      }
893

            
894
162
      switch (opcode) {
895
18
      case OPCODE_CLOSE:
896
18
        ENVOY_LOG(trace, "websocket decoder: CLOSE received");
897
18
        end_stream_ = true;
898
18
        break;
899
81
      case OPCODE_PING: {
900
81
        ENVOY_LOG(trace, "websocket decoder: PING received");
901
        // Reply with a PONG with the same payload
902
81
        parent_.pong(payload, payload_len);
903
81
        break;
904
      }
905
63
      case OPCODE_PONG:
906
63
        ENVOY_LOG(trace, "websocket decoder: PONG received");
907
63
        break;
908
162
      }
909
      // Drain control plane payload
910
162
      buffer_.drain(payload_len);
911
173
    } else {
912
      // Unframe and forward all non-control frames
913
44
      ENVOY_LOG(trace, "websocket decoder: received websocket data: header {} bytes, data {} bytes",
914
44
                frame_offset, payload_len);
915

            
916
44
      buffer_.drain(frame_offset);
917
44
      payload_remaining_ = payload_len;
918
44
    }
919
206
  }
920
207
}
921

            
922
} // namespace WebSocket
923
} // namespace Cilium
924
} // namespace Envoy