1
#pragma once
2

            
3
#include <cstddef>
4
#include <cstdint>
5
#include <memory>
6
#include <string>
7

            
8
#include "envoy/buffer/buffer.h"
9
#include "envoy/common/pure.h"
10
#include "envoy/event/timer.h"
11
#include "envoy/http/header_map.h"
12
#include "envoy/network/address.h"
13
#include "envoy/network/connection.h"
14

            
15
#include "source/common/buffer/buffer_impl.h"
16
#include "source/common/common/logger.h"
17

            
18
#include "absl/strings/string_view.h"
19
#include "cilium/websocket_config.h"
20

            
21
namespace Envoy {
22
namespace Cilium {
23
namespace WebSocket {
24

            
25
class CodecCallbacks {
26
public:
27
24
  virtual ~CodecCallbacks() = default;
28

            
29
  virtual const ConfigSharedPtr& config() PURE;
30

            
31
  virtual void injectEncoded(Buffer::Instance& data, bool end_stream) PURE;
32
  virtual void injectDecoded(Buffer::Instance& data, bool end_stream) PURE;
33

            
34
  virtual void
35
  setOriginalDestinationAddress(const Network::Address::InstanceConstSharedPtr& orig_dst) PURE;
36

            
37
  virtual void onHandshakeCreated(const Http::RequestHeaderMap&) PURE;
38
  virtual void onHandshakeSent() PURE;
39
  virtual void onHandshakeRequest(const Http::RequestHeaderMap& headers) PURE;
40
  virtual void onHandshakeResponse(const Http::ResponseHeaderMap& headers) PURE;
41
  virtual void onHandshakeResponseSent(const Http::ResponseHeaderMap& headers) PURE;
42
};
43

            
44
class Codec : Logger::Loggable<Logger::Id::filter> {
45
public:
46
  Codec(CodecCallbacks* parent, Network::Connection& conn);
47

            
48
  void handshake();
49
  void encode(Buffer::Instance&, bool end_stream);
50
  void decode(Buffer::Instance&, bool end_stream);
51

            
52
private:
53
  class Encoder : Logger::Loggable<Logger::Id::filter> {
54
  public:
55
24
    Encoder(Codec& parent) : parent_(parent) {}
56

            
57
    void encode(Buffer::Instance&, bool end_stream, uint8_t opcode);
58

            
59
55
    size_t hasData() { return encoded_.length() > 0; }
60
221
    Buffer::Instance& data() { return encoded_; }
61
399
    bool endStream() { return end_stream_; }
62
    void drain() { encoded_.drain(encoded_.length()); }
63

            
64
    Codec& parent_;
65
    bool end_stream_{false};
66
    Buffer::OwnedImpl encoded_{}; // Buffer for encoded websocket frames
67
  };
68

            
69
  class Decoder : Logger::Loggable<Logger::Id::filter> {
70
  public:
71
24
    Decoder(Codec& parent) : parent_(parent) {}
72

            
73
    void decode(Buffer::Instance& data, bool end_stream);
74

            
75
207
    size_t hasData() { return decoded_.length() > 0; }
76
207
    Buffer::Instance& data() { return decoded_; }
77
207
    bool endStream() { return end_stream_; }
78
    void drain() { decoded_.drain(decoded_.length()); }
79

            
80
    Codec& parent_;
81
    bool end_stream_{false};
82
    Buffer::OwnedImpl buffer_{};  // Buffer for partial websocket frames
83
    Buffer::OwnedImpl decoded_{}; // Buffer for decoded websocket frames
84

            
85
    bool unmasking_{false};
86
    uint8_t mask_[4];
87
    size_t payload_offset_{0};
88
    size_t payload_remaining_{0};
89
  };
90

            
91
  void startPingTimer();
92
98
  void resetPingTimer() {
93
98
    if (ping_timer_ != nullptr) {
94
38
      auto config = parent_->config();
95
38
      if (config->ping_when_idle_) {
96
38
        ping_timer_->enableTimer(config->ping_interval_);
97
38
      }
98
38
    }
99
98
  }
100

            
101
  bool ping(const void* payload, size_t len);
102
  bool pong(const void* payload, size_t len);
103

            
104
  static Network::Address::InstanceConstSharedPtr
105
  decodeHandshakeRequest(const ConfigSharedPtr& config, const Http::RequestHeaderMap& headers);
106
  static void encodeHandshakeResponse(Http::ResponseHeaderMap& headers, uint32_t status,
107
                                      absl::string_view hash,
108
                                      const Http::RequestHeaderMap* request_headers);
109

            
110
213
  const ConfigSharedPtr& config() { return parent_->config(); };
111

            
112
  static bool checkPrefix(Buffer::Instance& data, const std::string& prefix);
113

            
114
  void closeOnError(const char* msg);
115
  void closeOnError(Buffer::Instance& data, const char* msg);
116

            
117
  CodecCallbacks* parent_;
118
  Network::Connection& connection_;
119
  Encoder encoder_;
120
  Decoder decoder_;
121

            
122
  Event::TimerPtr ping_timer_{nullptr};
123
  uint32_t ping_interval_jitter_percent_{15};
124
  uint64_t ping_count_{0};
125

            
126
  Event::TimerPtr handshake_timer_{nullptr};
127
  Buffer::OwnedImpl handshake_buffer_{};
128
  bool accepted_{false};
129
};
130
using CodecPtr = std::unique_ptr<Codec>;
131

            
132
} // namespace WebSocket
133
} // namespace Cilium
134
} // namespace Envoy