Coverage Report

Created: 2023-11-12 09:30

/proc/self/cwd/source/extensions/filters/network/thrift_proxy/binary_protocol_impl.cc
Line
Count
Source (jump to first uncovered line)
1
#include "source/extensions/filters/network/thrift_proxy/binary_protocol_impl.h"
2
3
#include <limits>
4
5
#include "envoy/common/exception.h"
6
7
#include "source/common/common/assert.h"
8
#include "source/common/common/fmt.h"
9
#include "source/common/common/macros.h"
10
#include "source/common/runtime/runtime_features.h"
11
#include "source/extensions/filters/network/thrift_proxy/buffer_helper.h"
12
13
namespace Envoy {
14
namespace Extensions {
15
namespace NetworkFilters {
16
namespace ThriftProxy {
17
18
const uint16_t BinaryProtocolImpl::Magic = 0x8001;
19
20
0
bool BinaryProtocolImpl::readMessageBegin(Buffer::Instance& buffer, MessageMetadata& metadata) {
21
0
  if (buffer.length() < MinMessageBeginLength) {
22
0
    return false;
23
0
  }
24
25
0
  uint16_t version = buffer.peekBEInt<uint16_t>();
26
0
  if (version != Magic) {
27
0
    throw EnvoyException(
28
0
        fmt::format("invalid binary protocol version 0x{:04x} != 0x{:04x}", version, Magic));
29
0
  }
30
31
  // The byte at offset 2 is unused and ignored.
32
33
0
  MessageType type = static_cast<MessageType>(buffer.peekInt<int8_t>(3));
34
0
  if (type < MessageType::Call || type > MessageType::LastMessageType) {
35
0
    throw EnvoyException(
36
0
        fmt::format("invalid binary protocol message type {}", static_cast<int8_t>(type)));
37
0
  }
38
39
0
  uint32_t name_len = buffer.peekBEInt<uint32_t>(4);
40
0
  if (buffer.length() < name_len + MinMessageBeginLength) {
41
0
    return false;
42
0
  }
43
44
0
  buffer.drain(8);
45
46
0
  if (name_len > 0) {
47
0
    metadata.setMethodName(
48
0
        std::string(static_cast<const char*>(buffer.linearize(name_len)), name_len));
49
0
    buffer.drain(name_len);
50
0
  } else {
51
0
    metadata.setMethodName("");
52
0
  }
53
0
  metadata.setMessageType(type);
54
0
  metadata.setSequenceId(buffer.drainBEInt<int32_t>());
55
56
0
  return true;
57
0
}
58
59
0
bool BinaryProtocolImpl::readMessageEnd(Buffer::Instance& buffer) {
60
0
  UNREFERENCED_PARAMETER(buffer);
61
0
  return true;
62
0
}
63
64
0
bool BinaryProtocolImpl::peekReplyPayload(Buffer::Instance& buffer, ReplyType& reply_type) {
65
  // binary protocol does not transmit struct names so go straight to peek at field begin
66
  // FieldType::Stop is encoded as 1 byte.
67
0
  if (buffer.length() < 1) {
68
0
    return false;
69
0
  }
70
71
0
  FieldType type = static_cast<FieldType>(buffer.peekInt<int8_t>());
72
0
  if (type == FieldType::Stop) {
73
    // If the first field is stop then response is void success
74
0
    reply_type = ReplyType::Success;
75
0
    return true;
76
0
  }
77
78
0
  if (buffer.length() < 3) {
79
0
    return false;
80
0
  }
81
82
0
  int16_t id = buffer.peekBEInt<int16_t>(1);
83
0
  validateFieldId(id);
84
  // successful response struct in field id 0, error (IDL exception) in field id greater than 0
85
0
  reply_type = id == 0 ? ReplyType::Success : ReplyType::Error;
86
0
  return true;
87
0
}
88
89
0
void BinaryProtocolImpl::validateFieldId(int16_t id) {
90
0
  if (id >= 0) {
91
0
    return;
92
0
  }
93
94
0
  if (Runtime::runtimeFeatureEnabled("envoy.reloadable_features.thrift_allow_negative_field_ids")) {
95
0
    return;
96
0
  }
97
98
0
  throw EnvoyException(absl::StrCat("invalid binary protocol field id ", id));
99
0
}
100
101
0
bool BinaryProtocolImpl::readStructBegin(Buffer::Instance& buffer, std::string& name) {
102
0
  UNREFERENCED_PARAMETER(buffer);
103
0
  name.clear(); // binary protocol does not transmit struct names
104
0
  return true;
105
0
}
106
107
0
bool BinaryProtocolImpl::readStructEnd(Buffer::Instance& buffer) {
108
0
  UNREFERENCED_PARAMETER(buffer);
109
0
  return true;
110
0
}
111
112
bool BinaryProtocolImpl::readFieldBegin(Buffer::Instance& buffer, std::string& name,
113
0
                                        FieldType& field_type, int16_t& field_id) {
114
  // FieldType::Stop is encoded as 1 byte.
115
0
  if (buffer.length() < 1) {
116
0
    return false;
117
0
  }
118
119
0
  FieldType type = static_cast<FieldType>(buffer.peekInt<int8_t>());
120
0
  if (type == FieldType::Stop) {
121
0
    field_id = 0;
122
0
    buffer.drain(1);
123
0
  } else {
124
    // FieldType followed by 2 bytes of field id
125
0
    if (buffer.length() < 3) {
126
0
      return false;
127
0
    }
128
0
    int16_t id = buffer.peekBEInt<int16_t>(1);
129
0
    validateFieldId(id);
130
0
    field_id = id;
131
0
    buffer.drain(3);
132
0
  }
133
134
0
  name.clear(); // binary protocol does not transmit field names
135
0
  field_type = type;
136
137
0
  return true;
138
0
}
139
140
0
bool BinaryProtocolImpl::readFieldEnd(Buffer::Instance& buffer) {
141
0
  UNREFERENCED_PARAMETER(buffer);
142
0
  return true;
143
0
}
144
145
bool BinaryProtocolImpl::readMapBegin(Buffer::Instance& buffer, FieldType& key_type,
146
0
                                      FieldType& value_type, uint32_t& size) {
147
  // Minimum length:
148
  //   key type: 1 byte +
149
  //   value type: 1 byte +
150
  //   map size: 4 bytes
151
0
  if (buffer.length() < 6) {
152
0
    return false;
153
0
  }
154
155
0
  FieldType ktype = static_cast<FieldType>(buffer.peekInt<int8_t>(0));
156
0
  FieldType vtype = static_cast<FieldType>(buffer.peekInt<int8_t>(1));
157
0
  int32_t s = buffer.peekBEInt<int32_t>(2);
158
0
  if (s < 0) {
159
0
    throw EnvoyException(absl::StrCat("negative binary protocol map size ", s));
160
0
  }
161
162
0
  buffer.drain(6);
163
164
0
  key_type = ktype;
165
0
  value_type = vtype;
166
0
  size = static_cast<uint32_t>(s);
167
168
0
  return true;
169
0
}
170
171
0
bool BinaryProtocolImpl::readMapEnd(Buffer::Instance& buffer) {
172
0
  UNREFERENCED_PARAMETER(buffer);
173
0
  return true;
174
0
}
175
176
bool BinaryProtocolImpl::readListBegin(Buffer::Instance& buffer, FieldType& elem_type,
177
0
                                       uint32_t& size) {
178
  // Minimum length:
179
  //   elem type: 1 byte +
180
  //   map size: 4 bytes
181
0
  if (buffer.length() < 5) {
182
0
    return false;
183
0
  }
184
185
0
  FieldType type = static_cast<FieldType>(buffer.peekInt<int8_t>());
186
0
  int32_t s = buffer.peekBEInt<int32_t>(1);
187
0
  if (s < 0) {
188
0
    throw EnvoyException(fmt::format("negative binary protocol list/set size {}", s));
189
0
  }
190
0
  buffer.drain(5);
191
192
0
  elem_type = type;
193
0
  size = static_cast<uint32_t>(s);
194
195
0
  return true;
196
0
}
197
198
0
bool BinaryProtocolImpl::readListEnd(Buffer::Instance& buffer) {
199
0
  UNREFERENCED_PARAMETER(buffer);
200
0
  return true;
201
0
}
202
203
bool BinaryProtocolImpl::readSetBegin(Buffer::Instance& buffer, FieldType& elem_type,
204
0
                                      uint32_t& size) {
205
0
  return readListBegin(buffer, elem_type, size);
206
0
}
207
208
0
bool BinaryProtocolImpl::readSetEnd(Buffer::Instance& buffer) { return readListEnd(buffer); }
209
210
0
bool BinaryProtocolImpl::readBool(Buffer::Instance& buffer, bool& value) {
211
0
  if (buffer.length() < 1) {
212
0
    return false;
213
0
  }
214
215
0
  value = buffer.drainInt<int8_t>() != 0;
216
0
  return true;
217
0
}
218
219
0
bool BinaryProtocolImpl::readByte(Buffer::Instance& buffer, uint8_t& value) {
220
0
  if (buffer.length() < 1) {
221
0
    return false;
222
0
  }
223
0
  value = buffer.drainInt<int8_t>();
224
0
  return true;
225
0
}
226
227
0
bool BinaryProtocolImpl::readInt16(Buffer::Instance& buffer, int16_t& value) {
228
0
  if (buffer.length() < 2) {
229
0
    return false;
230
0
  }
231
0
  value = buffer.drainBEInt<int16_t>();
232
0
  return true;
233
0
}
234
235
0
bool BinaryProtocolImpl::readInt32(Buffer::Instance& buffer, int32_t& value) {
236
0
  if (buffer.length() < 4) {
237
0
    return false;
238
0
  }
239
0
  value = buffer.drainBEInt<int32_t>();
240
0
  return true;
241
0
}
242
243
0
bool BinaryProtocolImpl::readInt64(Buffer::Instance& buffer, int64_t& value) {
244
0
  if (buffer.length() < 8) {
245
0
    return false;
246
0
  }
247
0
  value = buffer.drainBEInt<int64_t>();
248
0
  return true;
249
0
}
250
251
0
bool BinaryProtocolImpl::readDouble(Buffer::Instance& buffer, double& value) {
252
0
  static_assert(sizeof(double) == sizeof(uint64_t), "sizeof(double) != size(uint64_t)");
253
254
0
  if (buffer.length() < 8) {
255
0
    return false;
256
0
  }
257
258
0
  value = BufferHelper::drainBEDouble(buffer);
259
0
  return true;
260
0
}
261
262
0
bool BinaryProtocolImpl::readString(Buffer::Instance& buffer, std::string& value) {
263
  // Encoded as size (4 bytes) followed by string (0+ bytes).
264
0
  if (buffer.length() < 4) {
265
0
    return false;
266
0
  }
267
268
0
  int32_t str_len = buffer.peekBEInt<int32_t>();
269
0
  if (str_len < 0) {
270
0
    throw EnvoyException(fmt::format("negative binary protocol string/binary length {}", str_len));
271
0
  }
272
273
0
  if (str_len == 0) {
274
0
    buffer.drain(4);
275
0
    value.clear();
276
0
    return true;
277
0
  }
278
279
0
  if (buffer.length() < static_cast<uint64_t>(str_len) + 4) {
280
0
    return false;
281
0
  }
282
283
0
  buffer.drain(4);
284
0
  value.assign(static_cast<const char*>(buffer.linearize(str_len)), str_len);
285
0
  buffer.drain(str_len);
286
0
  return true;
287
0
}
288
289
0
bool BinaryProtocolImpl::readBinary(Buffer::Instance& buffer, std::string& value) {
290
0
  return readString(buffer, value);
291
0
}
292
293
void BinaryProtocolImpl::writeMessageBegin(Buffer::Instance& buffer,
294
0
                                           const MessageMetadata& metadata) {
295
0
  buffer.writeBEInt<uint16_t>(Magic);
296
0
  buffer.writeBEInt<uint16_t>(static_cast<uint16_t>(metadata.messageType()));
297
0
  writeString(buffer, metadata.methodName());
298
0
  buffer.writeBEInt<int32_t>(metadata.sequenceId());
299
0
}
300
301
0
void BinaryProtocolImpl::writeMessageEnd(Buffer::Instance& buffer) {
302
0
  UNREFERENCED_PARAMETER(buffer);
303
0
}
304
305
0
void BinaryProtocolImpl::writeStructBegin(Buffer::Instance& buffer, const std::string& name) {
306
0
  UNREFERENCED_PARAMETER(buffer);
307
0
  UNREFERENCED_PARAMETER(name);
308
0
}
309
310
0
void BinaryProtocolImpl::writeStructEnd(Buffer::Instance& buffer) {
311
0
  UNREFERENCED_PARAMETER(buffer);
312
0
}
313
314
void BinaryProtocolImpl::writeFieldBegin(Buffer::Instance& buffer, const std::string& name,
315
0
                                         FieldType field_type, int16_t field_id) {
316
0
  UNREFERENCED_PARAMETER(name);
317
318
0
  buffer.writeByte(static_cast<uint8_t>(field_type));
319
0
  if (field_type == FieldType::Stop) {
320
0
    return;
321
0
  }
322
323
0
  buffer.writeBEInt<int16_t>(field_id);
324
0
}
325
326
0
void BinaryProtocolImpl::writeFieldEnd(Buffer::Instance& buffer) { UNREFERENCED_PARAMETER(buffer); }
327
328
void BinaryProtocolImpl::writeMapBegin(Buffer::Instance& buffer, FieldType key_type,
329
0
                                       FieldType value_type, uint32_t size) {
330
0
  if (size > static_cast<uint32_t>(std::numeric_limits<int32_t>::max())) {
331
0
    throw EnvoyException(absl::StrCat("illegal binary protocol map size ", size));
332
0
  }
333
334
0
  buffer.writeByte(static_cast<int8_t>(key_type));
335
0
  buffer.writeByte(static_cast<int8_t>(value_type));
336
0
  buffer.writeBEInt<int32_t>(static_cast<int32_t>(size));
337
0
}
338
339
0
void BinaryProtocolImpl::writeMapEnd(Buffer::Instance& buffer) { UNREFERENCED_PARAMETER(buffer); }
340
341
void BinaryProtocolImpl::writeListBegin(Buffer::Instance& buffer, FieldType elem_type,
342
0
                                        uint32_t size) {
343
0
  if (size > static_cast<uint32_t>(std::numeric_limits<int32_t>::max())) {
344
0
    throw EnvoyException(fmt::format("illegal binary protocol list/set size {}", size));
345
0
  }
346
347
0
  buffer.writeByte(static_cast<int8_t>(elem_type));
348
0
  buffer.writeBEInt<int32_t>(static_cast<int32_t>(size));
349
0
}
350
351
0
void BinaryProtocolImpl::writeListEnd(Buffer::Instance& buffer) { UNREFERENCED_PARAMETER(buffer); }
352
353
void BinaryProtocolImpl::writeSetBegin(Buffer::Instance& buffer, FieldType elem_type,
354
0
                                       uint32_t size) {
355
0
  writeListBegin(buffer, elem_type, size);
356
0
}
357
358
0
void BinaryProtocolImpl::writeSetEnd(Buffer::Instance& buffer) { writeListEnd(buffer); }
359
360
0
void BinaryProtocolImpl::writeBool(Buffer::Instance& buffer, bool value) {
361
0
  buffer.writeByte(value ? 1 : 0);
362
0
}
363
364
0
void BinaryProtocolImpl::writeByte(Buffer::Instance& buffer, uint8_t value) {
365
0
  buffer.writeByte(value);
366
0
}
367
368
0
void BinaryProtocolImpl::writeInt16(Buffer::Instance& buffer, int16_t value) {
369
0
  buffer.writeBEInt<int16_t>(value);
370
0
}
371
372
0
void BinaryProtocolImpl::writeInt32(Buffer::Instance& buffer, int32_t value) {
373
0
  buffer.writeBEInt<int32_t>(value);
374
0
}
375
376
0
void BinaryProtocolImpl::writeInt64(Buffer::Instance& buffer, int64_t value) {
377
0
  buffer.writeBEInt<int64_t>(value);
378
0
}
379
380
0
void BinaryProtocolImpl::writeDouble(Buffer::Instance& buffer, double value) {
381
0
  BufferHelper::writeBEDouble(buffer, value);
382
0
}
383
384
0
void BinaryProtocolImpl::writeString(Buffer::Instance& buffer, const std::string& value) {
385
0
  buffer.writeBEInt<uint32_t>(value.length());
386
0
  buffer.add(value);
387
0
}
388
389
0
void BinaryProtocolImpl::writeBinary(Buffer::Instance& buffer, const std::string& value) {
390
0
  writeString(buffer, value);
391
0
}
392
393
0
bool LaxBinaryProtocolImpl::readMessageBegin(Buffer::Instance& buffer, MessageMetadata& metadata) {
394
  // Minimum message length:
395
  //   name len: 4 bytes +
396
  //   name: 0 bytes +
397
  //   msg type: 1 byte +
398
  //   seq id: 4 bytes
399
0
  if (buffer.length() < 9) {
400
0
    return false;
401
0
  }
402
403
0
  uint32_t name_len = buffer.peekBEInt<uint32_t>();
404
405
0
  if (buffer.length() < 9 + name_len) {
406
0
    return false;
407
0
  }
408
409
0
  MessageType type = static_cast<MessageType>(buffer.peekInt<int8_t>(name_len + 4));
410
0
  if (type < MessageType::Call || type > MessageType::LastMessageType) {
411
0
    throw EnvoyException(
412
0
        fmt::format("invalid (lax) binary protocol message type {}", static_cast<int8_t>(type)));
413
0
  }
414
415
0
  buffer.drain(4);
416
0
  if (name_len > 0) {
417
0
    metadata.setMethodName(
418
0
        std::string(static_cast<const char*>(buffer.linearize(name_len)), name_len));
419
0
    buffer.drain(name_len);
420
0
  } else {
421
0
    metadata.setMethodName("");
422
0
  }
423
424
0
  metadata.setMessageType(type);
425
0
  metadata.setSequenceId(buffer.peekBEInt<int32_t>(1));
426
0
  buffer.drain(5);
427
428
0
  return true;
429
0
}
430
431
void LaxBinaryProtocolImpl::writeMessageBegin(Buffer::Instance& buffer,
432
0
                                              const MessageMetadata& metadata) {
433
0
  writeString(buffer, metadata.methodName());
434
0
  buffer.writeByte(static_cast<int8_t>(metadata.messageType()));
435
0
  buffer.writeBEInt<int32_t>(metadata.sequenceId());
436
0
}
437
438
class BinaryProtocolConfigFactory : public ProtocolFactoryBase<BinaryProtocolImpl> {
439
public:
440
4
  BinaryProtocolConfigFactory() : ProtocolFactoryBase(ProtocolNames::get().BINARY) {}
441
};
442
443
/**
444
 * Static registration for the binary protocol. @see RegisterFactory.
445
 */
446
REGISTER_FACTORY(BinaryProtocolConfigFactory, NamedProtocolConfigFactory);
447
448
class LaxBinaryProtocolConfigFactory : public ProtocolFactoryBase<LaxBinaryProtocolImpl> {
449
public:
450
4
  LaxBinaryProtocolConfigFactory() : ProtocolFactoryBase(ProtocolNames::get().LAX_BINARY) {}
451
};
452
453
/**
454
 * Static registration for the auto protocol. @see RegisterFactory.
455
 */
456
REGISTER_FACTORY(LaxBinaryProtocolConfigFactory, NamedProtocolConfigFactory);
457
458
} // namespace ThriftProxy
459
} // namespace NetworkFilters
460
} // namespace Extensions
461
} // namespace Envoy