1
#include "source/common/websocket/codec.h"
2

            
3
#include <algorithm>
4
#include <array>
5
#include <cstdint>
6
#include <memory>
7
#include <vector>
8

            
9
#include "source/common/buffer/buffer_impl.h"
10
#include "source/common/common/scalar_to_byte_vector.h"
11

            
12
namespace Envoy {
13
namespace WebSocket {
14

            
15
16
absl::optional<std::vector<uint8_t>> Encoder::encodeFrameHeader(const Frame& frame) {
16
16
  if (std::find(kFrameOpcodes.begin(), kFrameOpcodes.end(), frame.opcode_) == kFrameOpcodes.end()) {
17
1
    ENVOY_LOG(debug, "Failed to encode websocket frame with invalid opcode: {}", frame.opcode_);
18
1
    return absl::nullopt;
19
1
  }
20
15
  std::vector<uint8_t> output;
21
  // Set flags and opcode
22
15
  pushScalarToByteVector(
23
15
      static_cast<uint8_t>(frame.final_fragment_ ? (0x80 | frame.opcode_) : frame.opcode_), output);
24

            
25
  // Set payload length
26
15
  if (frame.payload_length_ <= 125) {
27
    // Set mask bit and 7-bit length
28
11
    pushScalarToByteVector(frame.masking_key_.has_value()
29
11
                               ? static_cast<uint8_t>(frame.payload_length_ | 0x80)
30
11
                               : static_cast<uint8_t>(frame.payload_length_),
31
11
                           output);
32
11
  } else if (frame.payload_length_ <= 65535) {
33
    // Set mask bit and 16-bit length indicator
34
2
    pushScalarToByteVector(static_cast<uint8_t>(frame.masking_key_.has_value() ? 0xfe : 0x7e),
35
2
                           output);
36
    // Set 16-bit length
37
2
    pushScalarToByteVector(htobe16(frame.payload_length_), output);
38
2
  } else {
39
    // Set mask bit and 64-bit length indicator
40
2
    pushScalarToByteVector(static_cast<uint8_t>(frame.masking_key_.has_value() ? 0xff : 0x7f),
41
2
                           output);
42
    // Set 64-bit length
43
2
    pushScalarToByteVector(htobe64(frame.payload_length_), output);
44
2
  }
45
  // Set masking key
46
15
  if (frame.masking_key_.has_value()) {
47
8
    pushScalarToByteVector(htobe32(frame.masking_key_.value()), output);
48
8
  }
49
15
  return output;
50
16
}
51

            
52
46
void Decoder::frameDataStart() {
53
46
  frame_.payload_length_ = length_;
54
46
  if (length_ == 0) {
55
7
    state_ = State::FrameFinished;
56
39
  } else {
57
39
    if (max_payload_buffer_length_ > 0) {
58
33
      frame_.payload_ = std::make_unique<Buffer::OwnedImpl>();
59
33
    }
60
39
    state_ = State::FramePayload;
61
39
  }
62
46
}
63

            
64
824
void Decoder::frameData(const uint8_t* mem, uint64_t length) {
65
824
  if (max_payload_buffer_length_ > 0) {
66
562
    uint64_t allowed_length = max_payload_buffer_length_ - frame_.payload_->length();
67
562
    frame_.payload_->add(mem, length <= allowed_length ? length : allowed_length);
68
562
  }
69
824
}
70

            
71
43
void Decoder::frameDataEnd(std::vector<Frame>& output) {
72
43
  output.push_back(std::move(frame_));
73
43
  resetDecoder();
74
43
}
75

            
76
43
void Decoder::resetDecoder() {
77
43
  frame_ = {false, 0, absl::nullopt, 0, nullptr};
78
43
  state_ = State::FrameHeaderFlagsAndOpcode;
79
43
  length_ = 0;
80
43
  num_remaining_extended_length_bytes_ = 0;
81
43
  num_remaining_masking_key_bytes_ = 0;
82
43
}
83

            
84
49
uint8_t Decoder::doDecodeFlagsAndOpcode(absl::Span<const uint8_t>& data) {
85
  // Validate opcode (last 4 bits)
86
49
  uint8_t opcode = data.front() & 0x0f;
87
49
  if (std::find(kFrameOpcodes.begin(), kFrameOpcodes.end(), opcode) == kFrameOpcodes.end()) {
88
1
    ENVOY_LOG(debug, "Failed to decode websocket frame with invalid opcode: {}", opcode);
89
1
    return 0;
90
1
  }
91
48
  frame_.opcode_ = opcode;
92
48
  frame_.final_fragment_ = data.front() & 0x80;
93
48
  state_ = State::FrameHeaderMaskFlagAndLength;
94
48
  return 1;
95
49
}
96

            
97
48
uint8_t Decoder::doDecodeMaskFlagAndLength(absl::Span<const uint8_t>& data) {
98
48
  num_remaining_masking_key_bytes_ = data.front() & 0x80 ? kMaskingKeyLength : 0;
99
48
  uint8_t length_indicator = data.front() & 0x7f;
100
48
  if (length_indicator == 0x7e) {
101
8
    num_remaining_extended_length_bytes_ = kPayloadLength16Bit;
102
8
    state_ = State::FrameHeaderExtendedLength16Bit;
103
40
  } else if (length_indicator == 0x7f) {
104
7
    num_remaining_extended_length_bytes_ = kPayloadLength64Bit;
105
7
    state_ = State::FrameHeaderExtendedLength64Bit;
106
33
  } else if (num_remaining_masking_key_bytes_ > 0) {
107
10
    length_ = length_indicator;
108
10
    state_ = State::FrameHeaderMaskingKey;
109
23
  } else {
110
23
    length_ = length_indicator;
111
23
    frameDataStart();
112
23
  }
113
48
  return 1;
114
48
}
115

            
116
15
uint8_t Decoder::doDecodeExtendedLength(absl::Span<const uint8_t>& data) {
117
15
  uint64_t bytes_to_decode = data.length() <= num_remaining_extended_length_bytes_
118
15
                                 ? data.length()
119
15
                                 : num_remaining_extended_length_bytes_;
120
15
  uint8_t size_of_extended_length =
121
15
      state_ == State::FrameHeaderExtendedLength16Bit ? kPayloadLength16Bit : kPayloadLength64Bit;
122
15
  uint8_t shift_of_bytes = size_of_extended_length - num_remaining_extended_length_bytes_;
123
15
  uint8_t* destination = reinterpret_cast<uint8_t*>(&length_) + shift_of_bytes;
124

            
125
15
  ASSERT(shift_of_bytes >= 0);
126
15
  ASSERT(shift_of_bytes < size_of_extended_length);
127
15
  memcpy(destination, data.data(), bytes_to_decode); // NOLINT(safe-memcpy)
128
15
  num_remaining_extended_length_bytes_ -= bytes_to_decode;
129

            
130
15
  if (num_remaining_extended_length_bytes_ == 0) {
131
#if ABSL_IS_BIG_ENDIAN
132
    length_ = state_ == State::FrameHeaderExtendedLength16Bit ? htole16(le64toh(length_)) : length_;
133
#else
134
14
    length_ = state_ == State::FrameHeaderExtendedLength16Bit ? htobe16(length_) : htobe64(length_);
135
14
#endif
136
14
    if (num_remaining_masking_key_bytes_ > 0) {
137
10
      state_ = State::FrameHeaderMaskingKey;
138
10
    } else {
139
4
      frameDataStart();
140
4
    }
141
14
  }
142
15
  return bytes_to_decode;
143
15
}
144

            
145
20
uint8_t Decoder::doDecodeMaskingKey(absl::Span<const uint8_t>& data) {
146
20
  if (!frame_.masking_key_.has_value()) {
147
20
    frame_.masking_key_ = 0;
148
20
  }
149
20
  uint64_t bytes_to_decode = data.length() <= num_remaining_masking_key_bytes_
150
20
                                 ? data.length()
151
20
                                 : num_remaining_masking_key_bytes_;
152
20
  uint8_t shift_of_bytes = kMaskingKeyLength - num_remaining_masking_key_bytes_;
153
20
  uint8_t* destination =
154
20
      reinterpret_cast<uint8_t*>(&(frame_.masking_key_.value())) + shift_of_bytes;
155
20
  ASSERT(shift_of_bytes >= 0);
156
20
  ASSERT(shift_of_bytes < kMaskingKeyLength);
157
20
  memcpy(destination, data.data(), bytes_to_decode); // NOLINT(safe-memcpy)
158
20
  num_remaining_masking_key_bytes_ -= bytes_to_decode;
159

            
160
20
  if (num_remaining_masking_key_bytes_ == 0) {
161
19
    frame_.masking_key_ = htobe32(frame_.masking_key_.value());
162
19
    frameDataStart();
163
19
  }
164
20
  return bytes_to_decode;
165
20
}
166

            
167
824
uint64_t Decoder::doDecodePayload(absl::Span<const uint8_t>& data) {
168
824
  uint64_t remain_in_buffer = data.length();
169
824
  uint64_t bytes_decoded = 0;
170
824
  if (remain_in_buffer <= length_) {
171
812
    frameData(data.data(), remain_in_buffer);
172
812
    bytes_decoded += remain_in_buffer;
173
812
    length_ -= remain_in_buffer;
174
812
  } else {
175
12
    frameData(data.data(), length_);
176
12
    bytes_decoded += length_;
177
12
    length_ = 0;
178
12
  }
179
824
  if (length_ == 0) {
180
36
    state_ = State::FrameFinished;
181
36
  }
182
824
  return bytes_decoded;
183
824
}
184

            
185
42
absl::optional<std::vector<Frame>> Decoder::decode(const Buffer::Instance& input) {
186
42
  absl::optional<std::vector<Frame>> output = std::vector<Frame>();
187
826
  for (const Buffer::RawSlice& slice : input.getRawSlices()) {
188
826
    absl::Span<const uint8_t> data(reinterpret_cast<uint8_t*>(slice.mem_), slice.len_);
189
1824
    while (!data.empty() || state_ == State::FrameFinished) {
190
999
      uint64_t bytes_decoded = 0;
191
999
      switch (state_) {
192
49
      case State::FrameHeaderFlagsAndOpcode:
193
49
        bytes_decoded = doDecodeFlagsAndOpcode(data);
194
49
        if (bytes_decoded == 0) {
195
1
          return absl::nullopt;
196
1
        }
197
48
        break;
198
48
      case State::FrameHeaderMaskFlagAndLength:
199
48
        bytes_decoded = doDecodeMaskFlagAndLength(data);
200
48
        break;
201
8
      case State::FrameHeaderExtendedLength16Bit:
202
15
      case State::FrameHeaderExtendedLength64Bit:
203
15
        bytes_decoded = doDecodeExtendedLength(data);
204
15
        break;
205
20
      case State::FrameHeaderMaskingKey:
206
20
        bytes_decoded = doDecodeMaskingKey(data);
207
20
        break;
208
824
      case State::FramePayload:
209
824
        bytes_decoded = doDecodePayload(data);
210
824
        break;
211
43
      case State::FrameFinished:
212
43
        frameDataEnd(output.value());
213
43
        break;
214
999
      }
215
998
      data.remove_prefix(bytes_decoded);
216
998
    }
217
826
  }
218
41
  return !output->empty() ? std::move(output) : absl::nullopt;
219
42
}
220

            
221
} // namespace WebSocket
222
} // namespace Envoy