Coverage Report

Created: 2023-11-12 09:30

/proc/self/cwd/source/extensions/filters/network/thrift_proxy/header_transport_impl.cc
Line
Count
Source (jump to first uncovered line)
1
#include "source/extensions/filters/network/thrift_proxy/header_transport_impl.h"
2
3
#include <limits>
4
5
#include "envoy/common/exception.h"
6
#include "envoy/http/header_formatter.h"
7
8
#include "source/common/buffer/buffer_impl.h"
9
#include "source/extensions/filters/network/thrift_proxy/buffer_helper.h"
10
11
#include "absl/strings/str_replace.h"
12
13
namespace Envoy {
14
namespace Extensions {
15
namespace NetworkFilters {
16
namespace ThriftProxy {
17
namespace {
18
19
// c.f.
20
// https://github.com/apache/thrift/blob/master/lib/cpp/src/thrift/protocol/TProtocolTypes.h#L27
21
enum class HeaderProtocolType {
22
  Binary = 0,
23
  JSON = 1,
24
  Compact = 2,
25
26
  FirstHeaderProtocolType = Binary,
27
  LastHeaderProtocolType = Compact,
28
};
29
30
// Fixed portion of frame header:
31
//   Header magic: 2 bytes +
32
//   Flags: 2 bytes +
33
//   Sequence number: 4 bytes
34
//   Header data size: 2 bytes
35
constexpr uint64_t MinFrameStartSizeNoHeaders = 10;
36
37
// Minimum frame size: fixed portion of frame header + 4 bytes of header data (the minimum)
38
constexpr int32_t MinFrameStartSize = MinFrameStartSizeNoHeaders + 4;
39
40
// Minimum to start decoding: 4 bytes of frame size + the fixed portion of the frame header
41
constexpr uint64_t MinDecodeBytes = MinFrameStartSizeNoHeaders + 4;
42
43
// Maximum size for header data.
44
constexpr int32_t MaxHeadersSize = 65536;
45
46
} // namespace
47
48
0
bool HeaderTransportImpl::decodeFrameStart(Buffer::Instance& buffer, MessageMetadata& metadata) {
49
0
  if (buffer.length() < MinDecodeBytes) {
50
0
    return false;
51
0
  }
52
53
  // Size of frame, not including the length bytes.
54
0
  const int32_t frame_size = buffer.peekBEInt<int32_t>();
55
56
  // Minimum header frame size is 18 bytes (4 bytes of frame size + 10 bytes of fixed header +
57
  // minimum 4 bytes of variable header data), so frame_size must be at least 14.
58
0
  if (frame_size < MinFrameStartSize || frame_size > MaxFrameSize) {
59
0
    throw EnvoyException(absl::StrCat("invalid thrift header transport frame size ", frame_size));
60
0
  }
61
62
0
  int16_t magic = buffer.peekBEInt<uint16_t>(4);
63
0
  if (!isMagic(magic)) {
64
0
    throw EnvoyException(fmt::format("invalid thrift header transport magic {:04x}", magic));
65
0
  }
66
67
  // offset 6: 16 bit flags field
68
0
  int16_t header_flags = buffer.peekBEInt<int16_t>(6);
69
70
  // offset 8: 32 bit sequence number field
71
0
  int32_t seq_id = buffer.peekBEInt<int32_t>(8);
72
73
  // offset 12: 16 bit (remaining) header size / 4 (spec erroneously claims / 32).
74
0
  int16_t raw_header_size = buffer.peekBEInt<int16_t>(12);
75
0
  int32_t header_size = static_cast<int32_t>(raw_header_size) * 4;
76
0
  if (header_size < 0 || header_size > MaxHeadersSize) {
77
0
    throw EnvoyException(fmt::format("invalid thrift header transport header size {} ({:04x})",
78
0
                                     header_size, static_cast<uint16_t>(raw_header_size)));
79
0
  }
80
81
0
  if (header_size == 0) {
82
0
    throw EnvoyException("no header data");
83
0
  }
84
85
0
  if (buffer.length() < static_cast<uint64_t>(header_size) + MinDecodeBytes) {
86
    // Need more header data.
87
0
    return false;
88
0
  }
89
90
  // Header data starts at offset 14 (4 bytes of frame size followed by 10 bytes of fixed header).
91
0
  buffer.drain(MinDecodeBytes);
92
93
  // Remaining frame size is the original frame size (which does not count itself), less the 10
94
  // fixed bytes of the header (magic, flags, etc), less the size of the variable header data
95
  // (header_size).
96
0
  metadata.setFrameSize(
97
0
      static_cast<uint32_t>(frame_size - header_size - MinFrameStartSizeNoHeaders));
98
0
  metadata.setHeaderFlags(header_flags);
99
0
  metadata.setSequenceId(seq_id);
100
101
0
  ProtocolType proto = ProtocolType::Auto;
102
0
  HeaderProtocolType header_proto =
103
0
      static_cast<HeaderProtocolType>(drainVarIntI16(buffer, header_size, "protocol id"));
104
0
  switch (header_proto) {
105
0
  case HeaderProtocolType::Binary:
106
0
    proto = ProtocolType::Binary;
107
0
    break;
108
0
  case HeaderProtocolType::Compact:
109
0
    proto = ProtocolType::Compact;
110
0
    break;
111
0
  default:
112
0
    throw EnvoyException(fmt::format("Unknown protocol {}", static_cast<int>(header_proto)));
113
0
  }
114
0
  metadata.setProtocol(proto);
115
116
0
  int16_t num_xforms = drainVarIntI16(buffer, header_size, "transform count");
117
0
  if (num_xforms < 0) {
118
0
    throw EnvoyException(absl::StrCat("invalid header transport transform count ", num_xforms));
119
0
  }
120
121
0
  while (num_xforms-- > 0) {
122
0
    int32_t xform_id = drainVarIntI32(buffer, header_size, "transform id");
123
124
    // To date, no transforms have a data field. In the future, some transform IDs may require
125
    // consuming another varint 32 at this point. The known transform IDs are:
126
    // 1: zlib compression
127
    // 2: hmac (appended to end of packet)
128
    // 3: snappy compression
129
0
    buffer.drain(header_size);
130
0
    metadata.setAppException(AppExceptionType::MissingResult,
131
0
                             absl::StrCat("Unknown transform ", xform_id));
132
0
    return true;
133
0
  }
134
135
0
  const bool is_request = metadata.isRequest();
136
0
  auto formatter =
137
0
      is_request ? metadata.requestHeaders().formatter() : metadata.responseHeaders().formatter();
138
139
0
  while (header_size > 0) {
140
    // Attempt to read info blocks
141
0
    int32_t info_id = drainVarIntI32(buffer, header_size, "info id");
142
0
    if (info_id != 1) {
143
      // 0 indicates a padding byte, and the end of the info block.
144
      // 1 indicates an info id header/value pair.
145
      // Any other value is an unknown info id block, which we ignore.
146
0
      break;
147
0
    }
148
149
0
    int32_t num_headers = drainVarIntI32(buffer, header_size, "header count");
150
0
    if (num_headers < 0) {
151
0
      throw EnvoyException(absl::StrCat("invalid header transport header count ", num_headers));
152
0
    }
153
154
0
    while (num_headers-- > 0) {
155
0
      std::string key_string = drainVarString(buffer, header_size, "header key");
156
0
      if (formatter) {
157
0
        formatter->processKey(key_string);
158
0
      }
159
      // LowerCaseString doesn't allow '\0', '\n', and '\r'.
160
0
      key_string =
161
0
          absl::StrReplaceAll(key_string, {{std::string(1, '\0'), ""}, {"\n", ""}, {"\r", ""}});
162
163
0
      const Http::LowerCaseString key = Http::LowerCaseString(key_string);
164
0
      const std::string value = drainVarString(buffer, header_size, "header value");
165
166
0
      if (is_request) {
167
0
        metadata.requestHeaders().addCopy(key, value);
168
0
      } else {
169
0
        metadata.responseHeaders().addCopy(key, value);
170
0
      }
171
0
    }
172
0
  }
173
174
  // Remaining bytes are padding or ignored info blocks.
175
0
  if (header_size > 0) {
176
0
    buffer.drain(header_size);
177
0
  }
178
179
0
  return true;
180
0
}
181
182
0
bool HeaderTransportImpl::decodeFrameEnd(Buffer::Instance&) {
183
0
  exception_.reset();
184
0
  exception_reason_.clear();
185
186
0
  return true;
187
0
}
188
189
void HeaderTransportImpl::encodeFrame(Buffer::Instance& buffer, const MessageMetadata& metadata,
190
0
                                      Buffer::Instance& message) {
191
0
  uint64_t msg_size = message.length();
192
0
  if (msg_size == 0) {
193
0
    throw EnvoyException(absl::StrCat("invalid thrift header transport message size ", msg_size));
194
0
  }
195
196
0
  const uint32_t headers_size =
197
0
      metadata.isRequest() ? metadata.requestHeaders().size() : metadata.responseHeaders().size();
198
0
  if (headers_size > MaxHeadersSize / 2) {
199
    // Each header takes a minimum of 2 bytes, yielding this limit.
200
0
    throw EnvoyException(
201
0
        absl::StrCat("invalid thrift header transport too many headers ", headers_size));
202
0
  }
203
204
0
  Buffer::OwnedImpl header_buffer;
205
206
0
  if (!metadata.hasProtocol()) {
207
0
    throw EnvoyException("missing header transport protocol");
208
0
  }
209
210
0
  switch (metadata.protocol()) {
211
0
  case ProtocolType::Binary:
212
0
    BufferHelper::writeVarIntI32(header_buffer, static_cast<int32_t>(HeaderProtocolType::Binary));
213
0
    break;
214
0
  case ProtocolType::Compact:
215
0
    BufferHelper::writeVarIntI32(header_buffer, static_cast<int32_t>(HeaderProtocolType::Compact));
216
0
    break;
217
0
  default:
218
0
    throw EnvoyException(fmt::format("invalid header transport protocol {}",
219
0
                                     ProtocolNames::get().fromType(metadata.protocol())));
220
0
  }
221
222
0
  BufferHelper::writeVarIntI32(header_buffer, 0); // num transforms
223
224
0
  if (headers_size > 0) {
225
    // Info ID 1
226
0
    header_buffer.writeByte(1);
227
228
    // Num headers
229
0
    BufferHelper::writeVarIntI32(header_buffer, static_cast<int32_t>(headers_size));
230
231
0
    auto formatter = metadata.isRequest() ? metadata.requestHeaders().formatter()
232
0
                                          : metadata.responseHeaders().formatter();
233
234
0
    auto header_writer = [&header_buffer,
235
0
                          formatter](const Http::HeaderEntry& header) -> Http::HeaderMap::Iterate {
236
0
      const auto header_key = header.key().getStringView();
237
238
0
      writeVarString(header_buffer, formatter ? formatter->format(header_key) : header_key);
239
0
      writeVarString(header_buffer, header.value().getStringView());
240
0
      return Http::HeaderMap::Iterate::Continue;
241
0
    };
242
243
0
    if (metadata.isRequest()) {
244
0
      metadata.requestHeaders().iterate(header_writer);
245
0
    } else {
246
0
      metadata.responseHeaders().iterate(header_writer);
247
0
    }
248
0
  }
249
250
0
  uint64_t header_size = header_buffer.length();
251
252
  // Always pad (as the Apache implementation does).
253
0
  const int padding = 4 - (header_size % 4);
254
0
  header_buffer.add("\0\0\0\0", padding);
255
0
  header_size += padding;
256
257
0
  if (header_size > MaxHeadersSize) {
258
0
    throw EnvoyException(absl::StrCat("invalid thrift header transport header size ", header_size));
259
0
  }
260
261
  // Frame size does not include the frame length itself.
262
0
  uint64_t size = header_size + msg_size + MinFrameStartSizeNoHeaders;
263
0
  if (size > MaxFrameSize) {
264
0
    throw EnvoyException(absl::StrCat("invalid thrift header transport frame size ", size));
265
0
  }
266
267
0
  int32_t seq_id = 0;
268
0
  if (metadata.hasSequenceId()) {
269
0
    seq_id = metadata.sequenceId();
270
0
  }
271
0
  int16_t header_flags = 0;
272
0
  if (metadata.hasHeaderFlags()) {
273
0
    header_flags = metadata.headerFlags();
274
0
  }
275
276
0
  buffer.writeBEInt<uint32_t>(static_cast<uint32_t>(size));
277
0
  buffer.writeBEInt<uint16_t>(Magic);
278
0
  buffer.writeBEInt<uint16_t>(header_flags); // flags
279
0
  buffer.writeBEInt<int32_t>(seq_id);
280
0
  buffer.writeBEInt<uint16_t>(static_cast<uint16_t>(header_size / 4));
281
282
0
  buffer.move(header_buffer);
283
0
  buffer.move(message);
284
0
}
285
286
int16_t HeaderTransportImpl::drainVarIntI16(Buffer::Instance& buffer, int32_t& header_size,
287
0
                                            const char* desc) {
288
0
  int32_t value = drainVarIntI32(buffer, header_size, desc);
289
0
  if (value > static_cast<int32_t>(std::numeric_limits<int16_t>::max())) {
290
0
    throw EnvoyException(fmt::format("header transport {}: value {} exceeds max i16 ({})", desc,
291
0
                                     value, std::numeric_limits<int16_t>::max()));
292
0
  }
293
0
  return static_cast<int16_t>(value);
294
0
}
295
296
int32_t HeaderTransportImpl::drainVarIntI32(Buffer::Instance& buffer, int32_t& header_size,
297
0
                                            const char* desc) {
298
0
  if (header_size <= 0) {
299
0
    throw EnvoyException(fmt::format("unable to read header transport {}: header too small", desc));
300
0
  }
301
302
0
  int size;
303
0
  int32_t value = BufferHelper::peekVarIntI32(buffer, 0, size);
304
0
  if (size < 0 || (header_size - size) < 0) {
305
0
    throw EnvoyException(fmt::format("unable to read header transport {}: header too small", desc));
306
0
  }
307
0
  buffer.drain(size);
308
0
  header_size -= size;
309
0
  return value;
310
0
}
311
312
std::string HeaderTransportImpl::drainVarString(Buffer::Instance& buffer, int32_t& header_size,
313
0
                                                const char* desc) {
314
0
  const int16_t str_len = drainVarIntI16(buffer, header_size, desc);
315
0
  if (str_len == 0) {
316
0
    return "";
317
0
  }
318
319
0
  if (header_size < static_cast<int32_t>(str_len)) {
320
0
    throw EnvoyException(fmt::format("unable to read header transport {}: header too small", desc));
321
0
  }
322
323
0
  const std::string value(static_cast<char*>(buffer.linearize(str_len)), str_len);
324
0
  buffer.drain(str_len);
325
0
  header_size -= str_len;
326
0
  return value;
327
0
}
328
329
0
void HeaderTransportImpl::writeVarString(Buffer::Instance& buffer, const absl::string_view str) {
330
0
  const std::string::size_type len = str.length();
331
0
  if (len > static_cast<uint32_t>(std::numeric_limits<int16_t>::max())) {
332
0
    throw EnvoyException(absl::StrCat("header string too long: ", len));
333
0
  }
334
335
0
  BufferHelper::writeVarIntI32(buffer, static_cast<int32_t>(len));
336
0
  if (len == 0) {
337
0
    return;
338
0
  }
339
0
  buffer.add(str.data(), len);
340
0
}
341
342
class HeaderTransportConfigFactory : public TransportFactoryBase<HeaderTransportImpl> {
343
public:
344
4
  HeaderTransportConfigFactory() : TransportFactoryBase(TransportNames::get().HEADER) {}
345
};
346
347
/**
348
 * Static registration for the header transport. @see RegisterFactory.
349
 */
350
REGISTER_FACTORY(HeaderTransportConfigFactory, NamedTransportConfigFactory);
351
352
} // namespace ThriftProxy
353
} // namespace NetworkFilters
354
} // namespace Extensions
355
} // namespace Envoy