Coverage Report

Created: 2023-11-12 09:30

/proc/self/cwd/source/extensions/filters/network/thrift_proxy/twitter_protocol_impl.cc
Line
Count
Source (jump to first uncovered line)
1
#include "source/extensions/filters/network/thrift_proxy/twitter_protocol_impl.h"
2
3
#include "envoy/common/exception.h"
4
5
#include "source/common/buffer/buffer_impl.h"
6
#include "source/extensions/filters/network/thrift_proxy/buffer_helper.h"
7
#include "source/extensions/filters/network/thrift_proxy/thrift_object_impl.h"
8
#include "source/extensions/filters/network/thrift_proxy/unframed_transport_impl.h"
9
10
#include "absl/strings/str_replace.h"
11
12
namespace Envoy {
13
namespace Extensions {
14
namespace NetworkFilters {
15
namespace ThriftProxy {
16
namespace {
17
18
struct StructNameValues {
19
  const std::string connectionOptionsStruct = "ConnectionOptions";
20
  const std::string requestHeaderStruct = "RequestHeader";
21
  const std::string clientIdStruct = "ClientId";
22
  const std::string delegationStruct = "Delegation";
23
  const std::string requestContextStruct = "RequestContext";
24
  const std::string responseHeaderStruct = "ResponseHeader";
25
  const std::string spanStruct = "Span";
26
  const std::string annotationStruct = "Annotation";
27
  const std::string binaryAnnotationStruct = "BinaryAnnotation";
28
  const std::string endpointStruct = "Endpoint";
29
  const std::string upgradeReplyStruct = "UpgradeReply";
30
};
31
using StructNames = ConstSingleton<StructNameValues>;
32
33
struct RequestHeaderFieldNameValues {
34
  const std::string traceIdField = "trace_id";
35
  const std::string spanIdField = "span_id";
36
  const std::string parentSpanIdField = "parent_span_id";
37
  const std::string sampledField = "sampled";
38
  const std::string clientIdField = "client_id";
39
  const std::string flagsField = "flags";
40
  const std::string contextsField = "contexts";
41
  const std::string destField = "dest";
42
  const std::string delegationsField = "delegations";
43
  const std::string traceIdHighField = "trace_id_high";
44
};
45
using RequestHeaderFieldNames = ConstSingleton<RequestHeaderFieldNameValues>;
46
47
struct ClientIdFieldNameValues {
48
  const std::string nameField = "name";
49
};
50
using ClientIdFieldNames = ConstSingleton<ClientIdFieldNameValues>;
51
52
struct DelegationFieldNameValues {
53
  const std::string srcField = "src";
54
  const std::string dstField = "dst";
55
};
56
using DelegationFieldNames = ConstSingleton<DelegationFieldNameValues>;
57
58
struct RequestContextFieldNameValues {
59
  const std::string keyField = "key";
60
  const std::string valueField = "value";
61
};
62
using RequestContextFieldNames = ConstSingleton<RequestContextFieldNameValues>;
63
64
struct ResponseHeaderFieldNameValues {
65
  const std::string spansField = "spans";
66
  const std::string contextsField = "contexts";
67
};
68
using ResponseHeaderFieldNames = ConstSingleton<ResponseHeaderFieldNameValues>;
69
70
struct SpanFieldNameValues {
71
  const std::string traceIdField = "trace_id";
72
  const std::string nameField = "name";
73
  const std::string idField = "id";
74
  const std::string parentIdField = "parent_id";
75
  const std::string annotationsField = "annotations";
76
  const std::string binaryAnnotationsField = "binary_annotations";
77
  const std::string debugField = "debug";
78
};
79
using SpanFieldNames = ConstSingleton<SpanFieldNameValues>;
80
81
struct AnnotationFieldNameValues {
82
  const std::string timestampField = "timestamp";
83
  const std::string valueField = "value";
84
  const std::string hostField = "host";
85
};
86
using AnnotationFieldNames = ConstSingleton<AnnotationFieldNameValues>;
87
88
struct BinaryAnnotationFieldNameValues {
89
  const std::string keyField = "key";
90
  const std::string valueField = "value";
91
  const std::string annotationTypeField = "annotation_type";
92
  const std::string hostField = "host";
93
};
94
using BinaryAnnotationFieldNames = ConstSingleton<BinaryAnnotationFieldNameValues>;
95
96
struct EndpointFieldNameValues {
97
  const std::string ipv4Field = "ipv4";
98
  const std::string portField = "port";
99
  const std::string serviceNameField = "service_name";
100
};
101
using EndpointFieldNames = ConstSingleton<EndpointFieldNameValues>;
102
103
0
const std::string& emptyString() { CONSTRUCT_ON_FIRST_USE(std::string, ""); }
104
105
/**
106
 * HeaderObjectProtocol implements BinaryProtocolImpl for the specific purpose of decoding the
107
 * Twitter protocol RequestHeader and ResponseHeader thrift structs. These appear after any
108
 * transport data (e.g. frame size) and before the start of a Thrift message. Decoding them
109
 * via a Protocol implementation allows us to reuse the Decoder and its state machine.
110
 */
111
class HeaderObjectProtocol : public BinaryProtocolImpl {
112
public:
113
0
  bool readMessageBegin(Buffer::Instance&, MessageMetadata&) override { return true; }
114
0
  bool readMessageEnd(Buffer::Instance&) override { return true; }
115
};
116
117
// Not const because the interfaces do not allow it, but these objects do not maintain internal
118
// state and are therefore not modifiable.
119
0
Transport& headerObjectTransport() {
120
0
  static UnframedTransportImpl* transport = new UnframedTransportImpl();
121
0
  return *transport;
122
0
}
123
124
0
Protocol& headerObjectProtocol() {
125
0
  static HeaderObjectProtocol* protocol = new HeaderObjectProtocol();
126
0
  return *protocol;
127
0
}
128
129
/**
130
 * ClientId is a Twitter protocol client identifier.
131
 *
132
 * See https://github.com/twitter/finagle/blob/master/finagle-thrift/src/main/thrift/tracing.thrift
133
 */
134
class ClientId {
135
public:
136
0
  ClientId(const std::string& name) : name_(name) {}
137
0
  ClientId(const ThriftStructValue& value) {
138
0
    for (const auto& field : value.fields()) {
139
      // Unknown field id are ignored, to allow for future additional fields.
140
0
      if (field->fieldId() == NameFieldId) {
141
0
        name_ = field->getValue().getValueTyped<std::string>();
142
0
      }
143
0
    }
144
0
  }
145
146
0
  void write(Buffer::Instance& buffer) {
147
0
    Protocol& protocol = headerObjectProtocol();
148
0
    protocol.writeStructBegin(buffer, StructNames::get().clientIdStruct);
149
150
    // name
151
0
    protocol.writeFieldBegin(buffer, ClientIdFieldNames::get().nameField, FieldType::String,
152
0
                             NameFieldId);
153
0
    protocol.writeString(buffer, name_);
154
0
    protocol.writeFieldEnd(buffer);
155
156
0
    protocol.writeFieldBegin(buffer, emptyString(), FieldType::Stop, 0);
157
0
    protocol.writeStructEnd(buffer);
158
0
  }
159
160
  static constexpr int16_t NameFieldId = 1;
161
162
  std::string name_;
163
};
164
165
/**
166
 * UpgradeReply represents Twitter protocol upgrade responses.
167
 */
168
class UpgradeReply : public DirectResponse, public ThriftObject {
169
public:
170
0
  UpgradeReply() = default;
171
  UpgradeReply(Transport& transport)
172
0
      : thrift_obj_(std::make_unique<ThriftObjectImpl>(transport, protocol_)) {}
173
174
  // DirectResponse
175
  DirectResponse::ResponseType encode(MessageMetadata& metadata, Protocol&,
176
0
                                      Buffer::Instance& buffer) const override {
177
0
    if (!metadata.hasSequenceId()) {
178
0
      metadata.setSequenceId(0);
179
0
    };
180
181
0
    metadata.setMethodName(TwitterProtocolImpl::upgradeMethodName());
182
0
    metadata.setMessageType(MessageType::Reply);
183
184
    // The upgrade response cannot have Twitter protocol headers, so ignore the caller's Protocol.
185
0
    BinaryProtocolImpl protocol;
186
0
    protocol.writeMessageBegin(buffer, metadata);
187
188
    // Per the Thrift standard, this is an invalid reply. We should start a reply struct with a
189
    // single field of id 0 (0x0B 0x00 0x00) to indicate success, followed by an empty UpgradeReply
190
    // struct (0x00), followed by a stop field for the reply struct (0x00). The finagle-twitter
191
    // implementation, however, just emits a single stop field.
192
0
    protocol.writeStructBegin(buffer, StructNames::get().upgradeReplyStruct);
193
0
    protocol.writeFieldBegin(buffer, emptyString(), FieldType::Stop, 0);
194
0
    protocol.writeStructEnd(buffer);
195
196
0
    protocol.writeMessageEnd(buffer);
197
198
0
    return DirectResponse::ResponseType::SuccessReply;
199
0
  }
200
201
  // ThriftObject
202
0
  const ThriftFieldPtrList& fields() const override { return thrift_obj_->fields(); }
203
0
  bool onData(Buffer::Instance& buffer) override { return thrift_obj_->onData(buffer); }
204
205
private:
206
  BinaryProtocolImpl protocol_;
207
  ThriftObjectPtr thrift_obj_;
208
};
209
210
/**
211
 * ConnectionOptions is the Twitter protocol upgrade request. It is an empty struct.
212
 */
213
class ConnectionOptions : public ThriftStructValueImpl {
214
public:
215
0
  ConnectionOptions() : ThriftStructValueImpl(nullptr) {}
216
};
217
218
/**
219
 * RequestContext is a Twitter protocol request context (key/value pair).
220
 *
221
 * See https://github.com/twitter/finagle/blob/master/finagle-thrift/src/main/thrift/tracing.thrift
222
 */
223
class RequestContext {
224
public:
225
0
  RequestContext(const std::string& key, const std::string& value) : key_(key), value_(value) {}
226
0
  RequestContext(const ThriftStructValue& value) {
227
0
    for (const auto& field : value.fields()) {
228
      // Unknown field id are ignored, to allow for future additional fields.
229
0
      switch (field->fieldId()) {
230
0
      case 1:
231
0
        key_ = field->getValue().getValueTyped<std::string>();
232
0
        break;
233
0
      case 2:
234
0
        value_ = field->getValue().getValueTyped<std::string>();
235
0
        break;
236
0
      }
237
0
    }
238
0
  }
239
240
0
  void write(Buffer::Instance& buffer) const {
241
0
    Protocol& protocol = headerObjectProtocol();
242
0
    protocol.writeStructBegin(buffer, StructNames::get().requestContextStruct);
243
244
    // key
245
0
    protocol.writeFieldBegin(buffer, RequestContextFieldNames::get().keyField, FieldType::String,
246
0
                             KeyFieldId);
247
0
    protocol.writeString(buffer, key_);
248
0
    protocol.writeFieldEnd(buffer);
249
250
    // value
251
0
    protocol.writeFieldBegin(buffer, RequestContextFieldNames::get().valueField, FieldType::String,
252
0
                             ValueFieldId);
253
0
    protocol.writeString(buffer, value_);
254
0
    protocol.writeFieldEnd(buffer);
255
256
0
    protocol.writeFieldBegin(buffer, emptyString(), FieldType::Stop, 0);
257
0
    protocol.writeStructEnd(buffer);
258
0
  }
259
260
  static constexpr int16_t KeyFieldId = 1;
261
  static constexpr int16_t ValueFieldId = 2;
262
263
  std::string key_;
264
  std::string value_;
265
};
266
using RequestContextList = std::list<RequestContext>;
267
268
/**
269
 * Delegation is Twitter protocol delegation table entry.
270
 *
271
 * See https://github.com/twitter/finagle/blob/master/finagle-thrift/src/main/thrift/tracing.thrift
272
 */
273
class Delegation {
274
public:
275
0
  Delegation(const std::string& src, const std::string& dst) : src_(src), dst_(dst) {}
276
0
  Delegation(const ThriftStructValue& value) {
277
0
    for (const auto& field : value.fields()) {
278
      // Unknown field id are ignored, to allow for future additional fields.
279
0
      switch (field->fieldId()) {
280
0
      case SrcFieldId:
281
0
        src_ = field->getValue().getValueTyped<std::string>();
282
0
        break;
283
0
      case DstFieldId:
284
0
        dst_ = field->getValue().getValueTyped<std::string>();
285
0
        break;
286
0
      }
287
0
    }
288
0
  }
289
290
0
  void write(Buffer::Instance& buffer) const {
291
0
    Protocol& protocol = headerObjectProtocol();
292
0
    protocol.writeStructBegin(buffer, StructNames::get().delegationStruct);
293
294
    // src
295
0
    protocol.writeFieldBegin(buffer, DelegationFieldNames::get().srcField, FieldType::String,
296
0
                             SrcFieldId);
297
0
    protocol.writeString(buffer, src_);
298
0
    protocol.writeFieldEnd(buffer);
299
300
    // dst
301
0
    protocol.writeFieldBegin(buffer, DelegationFieldNames::get().dstField, FieldType::String,
302
0
                             DstFieldId);
303
0
    protocol.writeString(buffer, dst_);
304
0
    protocol.writeFieldEnd(buffer);
305
306
0
    protocol.writeFieldBegin(buffer, emptyString(), FieldType::Stop, 0);
307
0
    protocol.writeStructEnd(buffer);
308
0
  }
309
310
  static constexpr int16_t SrcFieldId = 1;
311
  static constexpr int16_t DstFieldId = 2;
312
313
  std::string src_;
314
  std::string dst_;
315
};
316
using DelegationList = std::list<Delegation>;
317
318
/**
319
 * RequestHeader is a Twitter protocol request header, inserted between the transport start and
320
 * message begin.
321
 *
322
 * See https://github.com/twitter/finagle/blob/master/finagle-thrift/src/main/thrift/tracing.thrift
323
 */
324
class RequestHeader {
325
public:
326
0
  RequestHeader(const ThriftObject& header) {
327
0
    for (const auto& field : header.fields()) {
328
      // Unknown field id are ignored, to allow for future additional fields.
329
0
      switch (field->fieldId()) {
330
0
      case TraceIdFieldId:
331
0
        trace_id_ = field->getValue().getValueTyped<int64_t>();
332
0
        break;
333
0
      case SpanIdFieldId:
334
0
        span_id_ = field->getValue().getValueTyped<int64_t>();
335
0
        break;
336
0
      case ParentSpanIdFieldId:
337
0
        parent_span_id_ = field->getValue().getValueTyped<int64_t>();
338
0
        break;
339
      // unused: field 4
340
0
      case SampledFieldId:
341
0
        sampled_ = field->getValue().getValueTyped<bool>();
342
0
        break;
343
0
      case ClientIdFieldId:
344
0
        client_id_ = ClientId(field->getValue().getValueTyped<ThriftStructValue>());
345
0
        break;
346
0
      case FlagsFieldId:
347
0
        flags_ = field->getValue().getValueTyped<int64_t>();
348
0
        break;
349
0
      case ContextsFieldId:
350
0
        readContexts(field->getValue().getValueTyped<ThriftListValue>());
351
0
        break;
352
0
      case DestFieldId:
353
0
        dest_ = field->getValue().getValueTyped<std::string>();
354
0
        break;
355
0
      case DelegationsFieldId:
356
0
        readDelegations(field->getValue().getValueTyped<ThriftListValue>());
357
0
        break;
358
0
      case TraceIdHighFieldId:
359
0
        trace_id_high_ = field->getValue().getValueTyped<int64_t>();
360
0
        break;
361
0
      }
362
0
    }
363
0
  }
364
365
0
  RequestHeader(const MessageMetadata& metadata) {
366
0
    if (metadata.traceId()) {
367
0
      trace_id_ = *metadata.traceId();
368
0
    }
369
0
    if (metadata.traceIdHigh()) {
370
0
      trace_id_high_ = *metadata.traceIdHigh();
371
0
    }
372
373
0
    if (metadata.spanId()) {
374
0
      span_id_ = *metadata.spanId();
375
0
    }
376
0
    if (metadata.parentSpanId()) {
377
0
      parent_span_id_ = *metadata.parentSpanId();
378
0
    }
379
380
0
    if (metadata.flags()) {
381
0
      flags_ = *metadata.flags();
382
0
    }
383
384
0
    if (metadata.sampled().has_value()) {
385
0
      sampled_ = metadata.sampled().value();
386
0
    }
387
388
0
    metadata.requestHeaders().iterate(
389
0
        [this](const Http::HeaderEntry& header) -> Http::HeaderMap::Iterate {
390
0
          absl::string_view key = header.key().getStringView();
391
0
          if (key.empty()) {
392
0
            return Http::HeaderMap::Iterate::Continue;
393
0
          }
394
395
0
          if (key == Headers::get().ClientId.get()) {
396
0
            client_id_ = ClientId(std::string(header.value().getStringView()));
397
0
          } else if (key == Headers::get().Dest.get()) {
398
0
            dest_ = std::string(header.value().getStringView());
399
0
          } else if (key.find(":d:") == 0 && key.size() > 3) {
400
0
            delegations_.emplace_back(std::string(key.substr(3)),
401
0
                                      std::string(header.value().getStringView()));
402
0
          } else if (key[0] != ':') {
403
0
            contexts_.emplace_back(std::string(key), std::string(header.value().getStringView()));
404
0
          }
405
0
          return Http::HeaderMap::Iterate::Continue;
406
0
        });
407
0
  }
408
409
0
  void write(Buffer::Instance& buffer) {
410
0
    Protocol& protocol = headerObjectProtocol();
411
0
    protocol.writeStructBegin(buffer, StructNames::get().requestHeaderStruct);
412
413
    // trace_id
414
0
    protocol.writeFieldBegin(buffer, RequestHeaderFieldNames::get().traceIdField, FieldType::I64,
415
0
                             TraceIdFieldId);
416
0
    protocol.writeInt64(buffer, trace_id_);
417
0
    protocol.writeFieldEnd(buffer);
418
419
    // span_id
420
0
    protocol.writeFieldBegin(buffer, RequestHeaderFieldNames::get().spanIdField, FieldType::I64,
421
0
                             SpanIdFieldId);
422
0
    protocol.writeInt64(buffer, span_id_);
423
0
    protocol.writeFieldEnd(buffer);
424
425
    // parent_span_id
426
0
    if (parent_span_id_) {
427
0
      protocol.writeFieldBegin(buffer, RequestHeaderFieldNames::get().parentSpanIdField,
428
0
                               FieldType::I64, ParentSpanIdFieldId);
429
0
      protocol.writeInt64(buffer, *parent_span_id_);
430
0
      protocol.writeFieldEnd(buffer);
431
0
    }
432
433
    // sampled
434
0
    if (sampled_) {
435
0
      protocol.writeFieldBegin(buffer, RequestHeaderFieldNames::get().sampledField, FieldType::Bool,
436
0
                               SampledFieldId);
437
0
      protocol.writeBool(buffer, *sampled_);
438
0
      protocol.writeFieldEnd(buffer);
439
0
    }
440
441
    // client_id
442
0
    if (client_id_) {
443
0
      protocol.writeFieldBegin(buffer, RequestHeaderFieldNames::get().clientIdField,
444
0
                               FieldType::Struct, ClientIdFieldId);
445
0
      client_id_->write(buffer);
446
0
      protocol.writeFieldEnd(buffer);
447
0
    }
448
449
    // flags
450
0
    if (flags_) {
451
0
      protocol.writeFieldBegin(buffer, RequestHeaderFieldNames::get().flagsField, FieldType::I64,
452
0
                               FlagsFieldId);
453
0
      protocol.writeInt64(buffer, *flags_);
454
0
      protocol.writeFieldEnd(buffer);
455
0
    }
456
457
    // contexts
458
0
    if (!contexts_.empty()) {
459
0
      protocol.writeFieldBegin(buffer, RequestHeaderFieldNames::get().contextsField,
460
0
                               FieldType::List, ContextsFieldId);
461
0
      protocol.writeListBegin(buffer, FieldType::Struct, contexts_.size());
462
0
      for (const auto& context : contexts_) {
463
0
        context.write(buffer);
464
0
      }
465
0
      protocol.writeListEnd(buffer);
466
0
      protocol.writeFieldEnd(buffer);
467
0
    }
468
469
    // dest
470
0
    if (dest_) {
471
0
      protocol.writeFieldBegin(buffer, RequestHeaderFieldNames::get().destField, FieldType::String,
472
0
                               DestFieldId);
473
0
      protocol.writeString(buffer, *dest_);
474
0
      protocol.writeFieldEnd(buffer);
475
0
    }
476
477
    // delegations
478
0
    if (!delegations_.empty()) {
479
0
      protocol.writeFieldBegin(buffer, RequestHeaderFieldNames::get().delegationsField,
480
0
                               FieldType::List, DelegationsFieldId);
481
0
      protocol.writeListBegin(buffer, FieldType::Struct, delegations_.size());
482
0
      for (const auto& delegation : delegations_) {
483
0
        delegation.write(buffer);
484
0
      }
485
0
      protocol.writeListEnd(buffer);
486
0
      protocol.writeFieldEnd(buffer);
487
0
    }
488
489
    // trace_id_high
490
0
    if (trace_id_high_) {
491
0
      protocol.writeFieldBegin(buffer, RequestHeaderFieldNames::get().traceIdHighField,
492
0
                               FieldType::I64, TraceIdHighFieldId);
493
0
      protocol.writeInt64(buffer, *trace_id_high_);
494
0
      protocol.writeFieldEnd(buffer);
495
0
    }
496
497
0
    protocol.writeFieldBegin(buffer, emptyString(), FieldType::Stop, 0);
498
0
    protocol.writeStructEnd(buffer);
499
0
  }
500
501
0
  int64_t traceId() const { return trace_id_; }
502
0
  int64_t spanId() const { return span_id_; }
503
0
  absl::optional<int64_t> parentSpanId() const { return parent_span_id_; }
504
0
  absl::optional<bool> sampled() const { return sampled_; }
505
0
  absl::optional<ClientId> clientId() const { return client_id_; }
506
0
  absl::optional<int64_t> flags() const { return flags_; }
507
0
  const RequestContextList& contexts() const { return contexts_; }
508
0
  RequestContextList* contexts() { return &contexts_; }
509
0
  absl::optional<std::string> dest() { return dest_; }
510
0
  const DelegationList& delegations() const { return delegations_; }
511
0
  DelegationList* delegations() { return &delegations_; }
512
0
  absl::optional<int64_t> traceIdHigh() const { return trace_id_high_; }
513
514
private:
515
  static constexpr int16_t TraceIdFieldId = 1;
516
  static constexpr int16_t SpanIdFieldId = 2;
517
  static constexpr int16_t ParentSpanIdFieldId = 3;
518
  static constexpr int16_t SampledFieldId = 5;
519
  static constexpr int16_t ClientIdFieldId = 6;
520
  static constexpr int16_t FlagsFieldId = 7;
521
  static constexpr int16_t ContextsFieldId = 8;
522
  static constexpr int16_t DestFieldId = 9;
523
  static constexpr int16_t DelegationsFieldId = 10;
524
  static constexpr int16_t TraceIdHighFieldId = 11;
525
526
0
  void readContexts(const ThriftListValue& ctxts_list) {
527
0
    contexts_.clear();
528
0
    for (const auto& elem : ctxts_list.elements()) {
529
0
      const ThriftStructValue& ctxt_struct = elem->getValueTyped<ThriftStructValue>();
530
0
      contexts_.emplace_back(ctxt_struct);
531
0
    }
532
0
  }
533
534
0
  void readDelegations(const ThriftListValue& delegations_list) {
535
0
    delegations_.clear();
536
0
    for (const auto& elem : delegations_list.elements()) {
537
0
      const ThriftStructValue& ctxt_struct = elem->getValueTyped<ThriftStructValue>();
538
0
      delegations_.emplace_back(ctxt_struct);
539
0
    }
540
0
  }
541
542
  int64_t trace_id_{0};
543
  int64_t span_id_{0};
544
  absl::optional<int64_t> parent_span_id_;
545
  absl::optional<bool> sampled_;
546
  absl::optional<ClientId> client_id_;
547
  absl::optional<int64_t> flags_;
548
  std::list<RequestContext> contexts_;
549
  absl::optional<std::string> dest_;
550
  DelegationList delegations_;
551
  absl::optional<int64_t> trace_id_high_;
552
};
553
554
/**
555
 * ResponseHeader is a Twitter protocol response header, inserted between the transport start and
556
 * message begin.
557
 *
558
 * See https://github.com/twitter/finagle/blob/master/finagle-thrift/src/main/thrift/tracing.thrift
559
 */
560
class ResponseHeader {
561
public:
562
0
  ResponseHeader(const ThriftObject& header) {
563
0
    for (const auto& field : header.fields()) {
564
      // Unknown field id are ignored, to allow for future additional fields.
565
0
      switch (field->fieldId()) {
566
0
      case SpansFieldId:
567
0
        readSpans(field->getValue().getValueTyped<ThriftListValue>());
568
0
        break;
569
0
      case ContextsFieldId:
570
0
        readContexts(field->getValue().getValueTyped<ThriftListValue>());
571
0
        break;
572
0
      }
573
0
    }
574
0
  }
575
0
  ResponseHeader(const MessageMetadata& metadata) : spans_(metadata.spans()) {
576
0
    metadata.responseHeaders().iterate(
577
0
        [this](const Http::HeaderEntry& header) -> Http::HeaderMap::Iterate {
578
0
          absl::string_view key = header.key().getStringView();
579
0
          if (!key.empty() && key[0] != ':') {
580
0
            contexts_.emplace_back(std::string(key), std::string(header.value().getStringView()));
581
0
          }
582
0
          return Http::HeaderMap::Iterate::Continue;
583
0
        });
584
0
  }
585
586
0
  void write(Buffer::Instance& buffer) {
587
0
    Protocol& protocol = headerObjectProtocol();
588
0
    protocol.writeStructBegin(buffer, StructNames::get().responseHeaderStruct);
589
590
    // spans
591
0
    if (!spans_.empty()) {
592
0
      protocol.writeFieldBegin(buffer, ResponseHeaderFieldNames::get().spansField, FieldType::List,
593
0
                               SpansFieldId);
594
0
      protocol.writeListBegin(buffer, FieldType::Struct, spans_.size());
595
0
      for (const auto& span : spans_) {
596
0
        writeSpan(buffer, span);
597
0
      }
598
0
      protocol.writeListEnd(buffer);
599
0
      protocol.writeFieldEnd(buffer);
600
0
    }
601
602
    // contexts
603
0
    if (!contexts_.empty()) {
604
0
      protocol.writeFieldBegin(buffer, ResponseHeaderFieldNames::get().contextsField,
605
0
                               FieldType::List, ContextsFieldId);
606
0
      protocol.writeListBegin(buffer, FieldType::Struct, contexts_.size());
607
0
      for (const auto& context : contexts_) {
608
0
        context.write(buffer);
609
0
      }
610
0
      protocol.writeListEnd(buffer);
611
0
      protocol.writeFieldEnd(buffer);
612
0
    }
613
614
0
    protocol.writeFieldBegin(buffer, emptyString(), FieldType::Stop, 0);
615
0
    protocol.writeStructEnd(buffer);
616
0
  }
617
618
0
  SpanList& spans() { return spans_; }
619
0
  RequestContextList& contexts() { return contexts_; }
620
621
private:
622
  static constexpr int16_t SpansFieldId = 1;
623
  static constexpr int16_t ContextsFieldId = 2;
624
625
  static constexpr int16_t SpanTraceIdFieldId = 1;
626
  static constexpr int16_t SpanNameFieldId = 3;
627
  static constexpr int16_t SpanIdFieldId = 4;
628
  static constexpr int16_t SpanParentIdFieldId = 5;
629
  static constexpr int16_t SpanAnnotationsFieldId = 6;
630
  static constexpr int16_t SpanBinaryAnnotationsFieldId = 8;
631
  static constexpr int16_t SpanDebugFieldId = 9;
632
633
  static constexpr int16_t AnnotationTimestampFieldId = 1;
634
  static constexpr int16_t AnnotationValueFieldId = 2;
635
  static constexpr int16_t AnnotationHostFieldId = 3;
636
637
  static constexpr int16_t BinaryAnnotationKeyFieldId = 1;
638
  static constexpr int16_t BinaryAnnotationValueFieldId = 2;
639
  static constexpr int16_t BinaryAnnotationAnnotationTypeFieldId = 3;
640
  static constexpr int16_t BinaryAnnotationHostFieldId = 4;
641
642
  static constexpr int16_t EndpointIpv4FieldId = 1;
643
  static constexpr int16_t EndpointPortFieldId = 2;
644
  static constexpr int16_t EndpointServiceNameFieldId = 3;
645
646
0
  void readSpans(const ThriftListValue& spans_list) {
647
0
    spans_.clear();
648
0
    for (const auto& elem : spans_list.elements()) {
649
0
      spans_.emplace_back();
650
0
      readSpan(spans_.back(), elem->getValueTyped<ThriftStructValue>());
651
0
    }
652
0
  }
653
654
0
  void readSpan(Span& span, const ThriftStructValue& thrift_struct) {
655
0
    for (const auto& field : thrift_struct.fields()) {
656
      // Unknown field id are ignored, to allow for future additional fields.
657
0
      switch (field->fieldId()) {
658
0
      case SpanTraceIdFieldId:
659
0
        span.trace_id_ = field->getValue().getValueTyped<int64_t>();
660
0
        break;
661
      // field 2: unused
662
0
      case SpanNameFieldId:
663
0
        span.name_ = field->getValue().getValueTyped<std::string>();
664
0
        break;
665
0
      case SpanIdFieldId:
666
0
        span.span_id_ = field->getValue().getValueTyped<int64_t>();
667
0
        break;
668
0
      case SpanParentIdFieldId:
669
0
        span.parent_span_id_ = field->getValue().getValueTyped<int64_t>();
670
0
        break;
671
0
      case SpanAnnotationsFieldId:
672
0
        readAnnotations(span.annotations_, field->getValue().getValueTyped<ThriftListValue>());
673
0
        break;
674
      // field 7: unused
675
0
      case SpanBinaryAnnotationsFieldId:
676
0
        readBinaryAnnotations(span.binary_annotations_,
677
0
                              field->getValue().getValueTyped<ThriftListValue>());
678
0
        break;
679
0
      case SpanDebugFieldId:
680
0
        span.debug_ = field->getValue().getValueTyped<bool>();
681
0
        break;
682
0
      }
683
0
    }
684
0
  }
685
686
0
  void writeSpan(Buffer::Instance& buffer, const Span& span) {
687
0
    Protocol& protocol = headerObjectProtocol();
688
689
0
    protocol.writeStructBegin(buffer, StructNames::get().spanStruct);
690
    // trace_id
691
0
    protocol.writeFieldBegin(buffer, SpanFieldNames::get().traceIdField, FieldType::I64,
692
0
                             SpanTraceIdFieldId);
693
0
    protocol.writeInt64(buffer, span.trace_id_);
694
0
    protocol.writeFieldEnd(buffer);
695
696
    // name
697
0
    protocol.writeFieldBegin(buffer, SpanFieldNames::get().nameField, FieldType::String,
698
0
                             SpanNameFieldId);
699
0
    protocol.writeString(buffer, span.name_);
700
0
    protocol.writeFieldEnd(buffer);
701
702
    // id
703
0
    protocol.writeFieldBegin(buffer, SpanFieldNames::get().idField, FieldType::I64, SpanIdFieldId);
704
0
    protocol.writeInt64(buffer, span.span_id_);
705
0
    protocol.writeFieldEnd(buffer);
706
707
    // parent_id
708
0
    if (span.parent_span_id_) {
709
0
      protocol.writeFieldBegin(buffer, SpanFieldNames::get().parentIdField, FieldType::I64,
710
0
                               SpanParentIdFieldId);
711
0
      protocol.writeInt64(buffer, *span.parent_span_id_);
712
0
      protocol.writeFieldEnd(buffer);
713
0
    }
714
715
    // annotations
716
0
    protocol.writeFieldBegin(buffer, SpanFieldNames::get().annotationsField, FieldType::List,
717
0
                             SpanAnnotationsFieldId);
718
0
    protocol.writeListBegin(buffer, FieldType::Struct, span.annotations_.size());
719
0
    for (const auto& annotation : span.annotations_) {
720
0
      writeAnnotation(buffer, annotation);
721
0
    }
722
0
    protocol.writeListEnd(buffer);
723
0
    protocol.writeFieldEnd(buffer);
724
725
    // binary_annotations
726
0
    protocol.writeFieldBegin(buffer, SpanFieldNames::get().binaryAnnotationsField, FieldType::List,
727
0
                             SpanBinaryAnnotationsFieldId);
728
0
    protocol.writeListBegin(buffer, FieldType::Struct, span.binary_annotations_.size());
729
0
    for (const auto& annotation : span.binary_annotations_) {
730
0
      writeBinaryAnnotation(buffer, annotation);
731
0
    }
732
0
    protocol.writeListEnd(buffer);
733
0
    protocol.writeFieldEnd(buffer);
734
735
    // debug
736
0
    protocol.writeFieldBegin(buffer, SpanFieldNames::get().debugField, FieldType::Bool,
737
0
                             SpanDebugFieldId);
738
0
    protocol.writeBool(buffer, span.debug_);
739
0
    protocol.writeFieldEnd(buffer);
740
741
0
    protocol.writeFieldBegin(buffer, emptyString(), FieldType::Stop, 0);
742
0
    protocol.writeStructEnd(buffer);
743
0
  }
744
745
0
  void readAnnotations(AnnotationList& annotations, const ThriftListValue& thrift_list) {
746
0
    annotations.clear();
747
0
    for (const auto& elem : thrift_list.elements()) {
748
0
      annotations.emplace_back();
749
0
      readAnnotation(annotations.back(), elem->getValueTyped<ThriftStructValue>());
750
0
    }
751
0
  }
752
753
0
  void readAnnotation(Annotation& annotation, const ThriftStructValue& thrift_struct) {
754
0
    for (const auto& field : thrift_struct.fields()) {
755
      // Unknown field id are ignored, to allow for future additional fields.
756
0
      switch (field->fieldId()) {
757
0
      case AnnotationTimestampFieldId:
758
0
        annotation.timestamp_ = field->getValue().getValueTyped<int64_t>();
759
0
        break;
760
0
      case AnnotationValueFieldId:
761
0
        annotation.value_ = field->getValue().getValueTyped<std::string>();
762
0
        break;
763
0
      case AnnotationHostFieldId:
764
0
        annotation.host_.emplace();
765
0
        readEndpoint(annotation.host_.value(),
766
0
                     field->getValue().getValueTyped<ThriftStructValue>());
767
0
        break;
768
0
      }
769
0
    }
770
0
  }
771
772
0
  void writeAnnotation(Buffer::Instance& buffer, const Annotation& annotation) {
773
0
    Protocol& protocol = headerObjectProtocol();
774
775
0
    protocol.writeStructBegin(buffer, StructNames::get().annotationStruct);
776
777
    // timestamp
778
0
    protocol.writeFieldBegin(buffer, AnnotationFieldNames::get().timestampField, FieldType::I64,
779
0
                             AnnotationTimestampFieldId);
780
0
    protocol.writeInt64(buffer, annotation.timestamp_);
781
0
    protocol.writeFieldEnd(buffer);
782
783
    // value
784
0
    protocol.writeFieldBegin(buffer, AnnotationFieldNames::get().valueField, FieldType::String,
785
0
                             AnnotationValueFieldId);
786
0
    protocol.writeString(buffer, annotation.value_);
787
0
    protocol.writeFieldEnd(buffer);
788
789
    // endpoint
790
0
    if (annotation.host_) {
791
0
      protocol.writeFieldBegin(buffer, AnnotationFieldNames::get().hostField, FieldType::Struct,
792
0
                               AnnotationHostFieldId);
793
0
      writeEndpoint(buffer, *annotation.host_);
794
0
      protocol.writeFieldEnd(buffer);
795
0
    }
796
797
0
    protocol.writeFieldBegin(buffer, emptyString(), FieldType::Stop, 0);
798
0
    protocol.writeStructEnd(buffer);
799
0
  }
800
801
  void readBinaryAnnotations(BinaryAnnotationList& annotations,
802
0
                             const ThriftListValue& thrift_list) {
803
0
    annotations.clear();
804
0
    for (const auto& elem : thrift_list.elements()) {
805
0
      annotations.emplace_back();
806
0
      readBinaryAnnotation(annotations.back(), elem->getValueTyped<ThriftStructValue>());
807
0
    }
808
0
  }
809
810
0
  void readBinaryAnnotation(BinaryAnnotation& annotation, const ThriftStructValue& thrift_struct) {
811
0
    for (const auto& field : thrift_struct.fields()) {
812
      // Unknown field id are ignored, to allow for future additional fields.
813
0
      switch (field->fieldId()) {
814
0
      case BinaryAnnotationKeyFieldId:
815
0
        annotation.key_ = field->getValue().getValueTyped<std::string>();
816
0
        break;
817
0
      case BinaryAnnotationValueFieldId:
818
0
        annotation.value_ = field->getValue().getValueTyped<std::string>();
819
0
        break;
820
0
      case BinaryAnnotationAnnotationTypeFieldId:
821
0
        annotation.annotation_type_ =
822
0
            static_cast<AnnotationType>(field->getValue().getValueTyped<int32_t>());
823
0
        break;
824
0
      case BinaryAnnotationHostFieldId:
825
0
        annotation.host_.emplace();
826
0
        readEndpoint(annotation.host_.value(),
827
0
                     field->getValue().getValueTyped<ThriftStructValue>());
828
0
        break;
829
0
      }
830
0
    }
831
0
  }
832
833
0
  void writeBinaryAnnotation(Buffer::Instance& buffer, const BinaryAnnotation& annotation) {
834
0
    Protocol& protocol = headerObjectProtocol();
835
836
0
    protocol.writeStructBegin(buffer, StructNames::get().binaryAnnotationStruct);
837
838
    // key
839
0
    protocol.writeFieldBegin(buffer, BinaryAnnotationFieldNames::get().keyField, FieldType::String,
840
0
                             BinaryAnnotationKeyFieldId);
841
0
    protocol.writeString(buffer, annotation.key_);
842
0
    protocol.writeFieldEnd(buffer);
843
844
    // value
845
0
    protocol.writeFieldBegin(buffer, BinaryAnnotationFieldNames::get().valueField,
846
0
                             FieldType::String, BinaryAnnotationValueFieldId);
847
0
    protocol.writeString(buffer, annotation.value_);
848
0
    protocol.writeFieldEnd(buffer);
849
850
    // annotation_type
851
0
    protocol.writeFieldBegin(buffer, BinaryAnnotationFieldNames::get().annotationTypeField,
852
0
                             FieldType::I32, BinaryAnnotationAnnotationTypeFieldId);
853
0
    protocol.writeInt32(buffer, static_cast<int32_t>(annotation.annotation_type_));
854
0
    protocol.writeFieldEnd(buffer);
855
856
    // endpoint
857
0
    if (annotation.host_) {
858
0
      protocol.writeFieldBegin(buffer, BinaryAnnotationFieldNames::get().hostField,
859
0
                               FieldType::Struct, BinaryAnnotationHostFieldId);
860
0
      writeEndpoint(buffer, *annotation.host_);
861
0
      protocol.writeFieldEnd(buffer);
862
0
    }
863
864
0
    protocol.writeFieldBegin(buffer, emptyString(), FieldType::Stop, 0);
865
0
    protocol.writeStructEnd(buffer);
866
0
  }
867
868
0
  void readEndpoint(Endpoint& endpoint, const ThriftStructValue& thrift_struct) {
869
0
    for (const auto& field : thrift_struct.fields()) {
870
      // Unknown field id are ignored, to allow for future additional fields.
871
0
      switch (field->fieldId()) {
872
0
      case 1:
873
0
        endpoint.ipv4_ = field->getValue().getValueTyped<int32_t>();
874
0
        break;
875
0
      case 2:
876
0
        endpoint.port_ = field->getValue().getValueTyped<int16_t>();
877
0
        break;
878
0
      case 3:
879
0
        endpoint.service_name_ = field->getValue().getValueTyped<std::string>();
880
0
        break;
881
0
      }
882
0
    }
883
0
  }
884
885
0
  void writeEndpoint(Buffer::Instance& buffer, const Endpoint& endpoint) {
886
0
    Protocol& protocol = headerObjectProtocol();
887
888
0
    protocol.writeStructBegin(buffer, StructNames::get().endpointStruct);
889
890
    // ipv4
891
0
    protocol.writeFieldBegin(buffer, EndpointFieldNames::get().ipv4Field, FieldType::I32,
892
0
                             EndpointIpv4FieldId);
893
0
    protocol.writeInt32(buffer, endpoint.ipv4_);
894
0
    protocol.writeFieldEnd(buffer);
895
896
    // port
897
0
    protocol.writeFieldBegin(buffer, EndpointFieldNames::get().portField, FieldType::I16,
898
0
                             EndpointPortFieldId);
899
0
    protocol.writeInt16(buffer, endpoint.port_);
900
0
    protocol.writeFieldEnd(buffer);
901
902
    // service_name
903
0
    protocol.writeFieldBegin(buffer, EndpointFieldNames::get().serviceNameField, FieldType::String,
904
0
                             EndpointServiceNameFieldId);
905
0
    protocol.writeString(buffer, endpoint.service_name_);
906
0
    protocol.writeFieldEnd(buffer);
907
908
0
    protocol.writeFieldBegin(buffer, emptyString(), FieldType::Stop, 0);
909
0
    protocol.writeStructEnd(buffer);
910
0
  }
911
912
0
  void readContexts(const ThriftListValue& ctxts_list) {
913
0
    contexts_.clear();
914
0
    for (const auto& elem : ctxts_list.elements()) {
915
0
      const ThriftStructValue& ctxt_struct = elem->getValueTyped<ThriftStructValue>();
916
0
      contexts_.emplace_back(ctxt_struct);
917
0
    }
918
0
  }
919
920
  std::list<Span> spans_;
921
  std::list<RequestContext> contexts_;
922
};
923
924
} // namespace
925
926
0
bool TwitterProtocolImpl::readMessageBegin(Buffer::Instance& buffer, MessageMetadata& metadata) {
927
  // If we see a normal binary protocol message with the improbable name on the first request
928
  // or response, we're upgrading to the TTwitter protocol.
929
0
  if (!upgraded_.has_value()) {
930
0
    if (!BinaryProtocolImpl::readMessageBegin(buffer, metadata)) {
931
      // Need more data.
932
0
      return false;
933
0
    }
934
935
0
    ASSERT(metadata.hasMethodName());
936
0
    if (metadata.methodName() == upgradeMethodName()) {
937
0
      metadata.setProtocolUpgradeMessage(true);
938
0
      return true;
939
0
    }
940
941
0
    upgraded_ = false;
942
0
    return true;
943
0
  }
944
945
0
  if (!upgraded_.value()) {
946
    // Fall back to regular binary protocol with no header object.
947
0
    return BinaryProtocolImpl::readMessageBegin(buffer, metadata);
948
0
  }
949
950
  // Upgraded protocol: consume RequestHeader or ResponseHeader.
951
0
  if (!header_complete_) {
952
0
    if (!header_) {
953
0
      header_ = std::make_unique<ThriftObjectImpl>(headerObjectTransport(), headerObjectProtocol());
954
0
    }
955
0
    header_complete_ = header_->onData(buffer);
956
0
    if (!header_complete_) {
957
      // Need more data.
958
0
      return false;
959
0
    }
960
0
  }
961
962
0
  if (!BinaryProtocolImpl::readMessageBegin(buffer, metadata)) {
963
    // Need more data.
964
0
    return false;
965
0
  }
966
967
  // Now that we know whether this is a request or a response, handle the header.
968
0
  ASSERT(metadata.hasMessageType());
969
0
  switch (metadata.messageType()) {
970
0
  case MessageType::Call:
971
0
  case MessageType::Oneway:
972
0
    updateMetadataWithRequestHeader(*header_, metadata);
973
0
    break;
974
0
  case MessageType::Reply:
975
0
  case MessageType::Exception:
976
0
    updateMetadataWithResponseHeader(*header_, metadata);
977
0
    break;
978
0
  }
979
980
0
  header_complete_ = false;
981
0
  header_.reset();
982
0
  return true;
983
0
}
984
985
void TwitterProtocolImpl::writeMessageBegin(Buffer::Instance& buffer,
986
0
                                            const MessageMetadata& metadata) {
987
0
  if (upgraded_.value_or(false)) {
988
0
    switch (metadata.messageType()) {
989
0
    case MessageType::Call:
990
0
    case MessageType::Oneway:
991
0
      writeRequestHeader(buffer, metadata);
992
0
      break;
993
0
    case MessageType::Reply:
994
0
    case MessageType::Exception:
995
0
      writeResponseHeader(buffer, metadata);
996
0
      break;
997
0
    }
998
0
  }
999
1000
0
  BinaryProtocolImpl::writeMessageBegin(buffer, metadata);
1001
0
}
1002
1003
void TwitterProtocolImpl::updateMetadataWithRequestHeader(const ThriftObject& header_object,
1004
0
                                                          MessageMetadata& metadata) {
1005
0
  RequestHeader req_header(header_object);
1006
1007
0
  Http::HeaderMap& headers = metadata.requestHeaders();
1008
1009
0
  metadata.setTraceId(req_header.traceId());
1010
0
  metadata.setSpanId(req_header.spanId());
1011
0
  if (req_header.parentSpanId()) {
1012
0
    metadata.setParentSpanId(*req_header.parentSpanId());
1013
0
  }
1014
0
  if (req_header.sampled()) {
1015
0
    metadata.setSampled(*req_header.sampled());
1016
0
  }
1017
0
  if (req_header.clientId()) {
1018
0
    headers.addReferenceKey(Headers::get().ClientId, req_header.clientId()->name_);
1019
0
  }
1020
0
  if (req_header.flags()) {
1021
0
    metadata.setFlags(*req_header.flags());
1022
0
  }
1023
0
  for (const auto& context : *req_header.contexts()) {
1024
    // LowerCaseString doesn't allow '\0', '\n', and '\r'.
1025
0
    const std::string key =
1026
0
        absl::StrReplaceAll(context.key_, {{std::string(1, '\0'), ""}, {"\n", ""}, {"\r", ""}});
1027
0
    headers.addCopy(Http::LowerCaseString{key}, context.value_);
1028
0
  }
1029
0
  if (req_header.dest()) {
1030
0
    headers.addReferenceKey(Headers::get().Dest, *req_header.dest());
1031
0
  }
1032
  // TODO(zuercher): Delegations are stored as headers for now. Consider passing them as simple
1033
  // objects
1034
0
  for (const auto& delegation : *req_header.delegations()) {
1035
    // LowerCaseString doesn't allow '\0', '\n', and '\r'.
1036
0
    const std::string src =
1037
0
        absl::StrReplaceAll(delegation.src_, {{std::string(1, '\0'), ""}, {"\n", ""}, {"\r", ""}});
1038
0
    const std::string key = fmt::format(":d:{}", src);
1039
0
    headers.addCopy(Http::LowerCaseString{key}, delegation.dst_);
1040
0
  }
1041
0
  if (req_header.traceIdHigh()) {
1042
0
    metadata.setTraceIdHigh(*req_header.traceIdHigh());
1043
0
  }
1044
0
}
1045
1046
void TwitterProtocolImpl::writeRequestHeader(Buffer::Instance& buffer,
1047
0
                                             const MessageMetadata& metadata) {
1048
0
  RequestHeader req_header(metadata);
1049
0
  req_header.write(buffer);
1050
0
}
1051
1052
void TwitterProtocolImpl::updateMetadataWithResponseHeader(const ThriftObject& header_object,
1053
0
                                                           MessageMetadata& metadata) {
1054
0
  ResponseHeader resp_header(header_object);
1055
1056
0
  Http::HeaderMap& headers = metadata.responseHeaders();
1057
0
  for (const auto& context : resp_header.contexts()) {
1058
    // LowerCaseString doesn't allow '\0', '\n', and '\r'.
1059
0
    const std::string key =
1060
0
        absl::StrReplaceAll(context.key_, {{std::string(1, '\0'), ""}, {"\n", ""}, {"\r", ""}});
1061
0
    headers.addCopy(Http::LowerCaseString(key), context.value_);
1062
0
  }
1063
1064
0
  SpanList& spans = resp_header.spans();
1065
0
  std::copy(spans.begin(), spans.end(), std::back_inserter(metadata.mutableSpans()));
1066
0
}
1067
1068
void TwitterProtocolImpl::writeResponseHeader(Buffer::Instance& buffer,
1069
0
                                              const MessageMetadata& metadata) {
1070
0
  ResponseHeader resp_header(metadata);
1071
0
  resp_header.write(buffer);
1072
0
}
1073
1074
0
ThriftObjectPtr TwitterProtocolImpl::newHeader() {
1075
0
  return std::make_unique<ThriftObjectImpl>(headerObjectTransport(), headerObjectProtocol());
1076
0
}
1077
1078
0
DecoderEventHandlerSharedPtr TwitterProtocolImpl::upgradeRequestDecoder() {
1079
0
  return std::make_shared<ConnectionOptions>();
1080
0
}
1081
1082
0
DirectResponsePtr TwitterProtocolImpl::upgradeResponse(const DecoderEventHandler& decoder) {
1083
0
  ASSERT(dynamic_cast<const ConnectionOptions*>(&decoder) != nullptr);
1084
0
  upgraded_ = true;
1085
0
  return std::make_unique<UpgradeReply>();
1086
0
};
1087
1088
ThriftObjectPtr TwitterProtocolImpl::attemptUpgrade(Transport& transport,
1089
                                                    ThriftConnectionState& state,
1090
0
                                                    Buffer::Instance& buffer) {
1091
  // Check if we've already attempted to upgrade this connection.
1092
0
  if (state.upgradeAttempted()) {
1093
0
    upgraded_ = state.isUpgraded();
1094
0
    return nullptr;
1095
0
  }
1096
1097
  // Write upgrade request to buffer and return an object that can decode the response.
1098
0
  MessageMetadata metadata;
1099
0
  metadata.setMethodName(upgradeMethodName());
1100
0
  metadata.setSequenceId(0);
1101
0
  metadata.setMessageType(MessageType::Call);
1102
1103
0
  Buffer::OwnedImpl message;
1104
0
  BinaryProtocolImpl::writeMessageBegin(message, metadata);
1105
0
  writeStructBegin(message, StructNames::get().connectionOptionsStruct);
1106
0
  writeFieldBegin(message, emptyString(), FieldType::Stop, 0);
1107
0
  writeStructEnd(message);
1108
0
  writeMessageEnd(message);
1109
0
  transport.encodeFrame(buffer, metadata, message);
1110
1111
0
  return std::make_unique<UpgradeReply>(transport);
1112
0
}
1113
1114
0
void TwitterProtocolImpl::completeUpgrade(ThriftConnectionState& state, ThriftObject& response) {
1115
0
  UpgradeReply& upgrade_reply = dynamic_cast<UpgradeReply&>(response);
1116
1117
0
  if (upgrade_reply.fields().empty()) {
1118
0
    state.markUpgraded();
1119
0
    upgraded_ = true;
1120
0
  } else {
1121
0
    state.markUpgradeFailed();
1122
0
    upgraded_ = false;
1123
0
  }
1124
0
}
1125
1126
0
bool TwitterProtocolImpl::isUpgradePrefix(Buffer::Instance& buffer) {
1127
  // 12 bytes is the minimum length for the start of a binary protocol message.
1128
0
  ASSERT(buffer.length() >= 12);
1129
1130
  // Must appear to be binary protocol.
1131
0
  if (!isMagic(buffer.peekBEInt<uint16_t>())) {
1132
0
    return false;
1133
0
  }
1134
1135
  // Must have correct length message name length.
1136
0
  if (buffer.peekBEInt<uint32_t>(4) != upgradeMethodName().length()) {
1137
0
    return false;
1138
0
  }
1139
1140
  // Given the fixed 8 bytes of message begin before the name, calculate how many bytes of message
1141
  // name are available in the buffer.
1142
0
  uint32_t available_len = static_cast<uint32_t>(
1143
0
      std::min(static_cast<uint64_t>(upgradeMethodName().length()), buffer.length() - 8));
1144
0
  ASSERT(available_len <= upgradeMethodName().length());
1145
0
  ASSERT(buffer.length() >= available_len + 8);
1146
1147
  // Extract as much of the name as is available.
1148
0
  absl::string_view available_name(
1149
0
      static_cast<const char*>(buffer.linearize(available_len + 8)) + 8, available_len);
1150
1151
0
  absl::string_view full_name(upgradeMethodName());
1152
1153
0
  return full_name.compare(0, available_len, available_name) == 0;
1154
0
}
1155
1156
class TwitterProtocolConfigFactory : public ProtocolFactoryBase<TwitterProtocolImpl> {
1157
public:
1158
4
  TwitterProtocolConfigFactory() : ProtocolFactoryBase(ProtocolNames::get().TWITTER) {}
1159
};
1160
1161
/**
1162
 * Static registration for the Twitter protocol. @see RegisterFactory.
1163
 */
1164
REGISTER_FACTORY(TwitterProtocolConfigFactory, NamedProtocolConfigFactory);
1165
1166
} // namespace ThriftProxy
1167
} // namespace NetworkFilters
1168
} // namespace Extensions
1169
} // namespace Envoy