1
#pragma once
2

            
3
#include "source/common/network/filter_impl.h"
4
#include "source/extensions/filters/network/thrift_proxy/config.h"
5
#include "source/extensions/filters/network/thrift_proxy/decoder.h"
6
#include "source/extensions/filters/network/thrift_proxy/passthrough_decoder_event_handler.h"
7
#include "source/extensions/health_checkers/thrift/client.h"
8

            
9
namespace Envoy {
10
namespace Extensions {
11
namespace HealthCheckers {
12
namespace ThriftHealthChecker {
13

            
14
using namespace Envoy::Extensions::NetworkFilters;
15
using namespace Envoy::Extensions::NetworkFilters::ThriftProxy;
16

            
17
// The simple response decoder decodes the response and informs the health
18
// check session if it's a success response or not.
19
class SimpleResponseDecoder : public DecoderCallbacks,
20
                              public PassThroughDecoderEventHandler,
21
                              protected Logger::Loggable<Logger::Id::hc> {
22
public:
23
  SimpleResponseDecoder(TransportPtr transport, ProtocolPtr protocol)
24
6
      : transport_(std::move(transport)), protocol_(std::move(protocol)),
25
6
        decoder_(std::make_unique<Decoder>(*transport_, *protocol_, *this)) {}
26

            
27
  // Return if the response is complete.
28
  bool onData(Buffer::Instance& data);
29

            
30
  // Check if it is a success response or not.
31
  bool responseSuccess();
32

            
33
  // PassThroughDecoderEventHandler
34
  FilterStatus messageBegin(MessageMetadataSharedPtr metadata) override;
35
  FilterStatus messageEnd() override;
36

            
37
  // DecoderCallbacks
38
8
  DecoderEventHandler& newDecoderEventHandler() override { return *this; }
39
8
  bool passthroughEnabled() const override { return true; }
40
8
  bool isRequest() const override { return false; }
41
8
  bool headerKeysPreserveCase() const override { return false; }
42

            
43
private:
44
  TransportPtr transport_;
45
  ProtocolPtr protocol_;
46
  DecoderPtr decoder_;
47
  Buffer::OwnedImpl buffer_;
48
  absl::optional<bool> success_;
49
  bool complete_{};
50
};
51

            
52
using SimpleResponseDecoderPtr = std::unique_ptr<SimpleResponseDecoder>;
53

            
54
class ClientImpl;
55

            
56
// Network::ClientConnection takes a shared pointer callback but we need a
57
// unique DeferredDeletable pointer for connection management. Therefore we
58
// need an additional wrapper class.
59
class ThriftSessionCallbacks : public Network::ConnectionCallbacks,
60
                               public Network::ReadFilterBaseImpl {
61
public:
62
6
  ThriftSessionCallbacks(ClientImpl& parent) : parent_(parent) {}
63

            
64
  // Network::ConnectionCallbacks
65
  void onEvent(Network::ConnectionEvent event) override;
66
  void onAboveWriteBufferHighWatermark() override;
67
  void onBelowWriteBufferLowWatermark() override;
68

            
69
  // Network::ReadFilter
70
  Network::FilterStatus onData(Buffer::Instance& data, bool) override;
71

            
72
private:
73
  ClientImpl& parent_;
74
};
75

            
76
using ThriftSessionCallbacksSharedPtr = std::shared_ptr<ThriftSessionCallbacks>;
77

            
78
class ClientImpl : public Client,
79
                   public Network::ConnectionCallbacks,
80
                   protected Logger::Loggable<Logger::Id::hc> {
81
public:
82
  ClientImpl(ClientCallback& callback, TransportType transport, ProtocolType protocol,
83
             const std::string& method_name, Upstream::HostSharedPtr host, int32_t seq_id,
84
             bool fixed_seq_id)
85
6
      : parent_(callback), transport_(transport), protocol_(protocol), method_name_(method_name),
86
6
        host_(host), seq_id_(seq_id), fixed_seq_id_(fixed_seq_id) {}
87

            
88
  void onData(Buffer::Instance& data);
89

            
90
  // Client
91
  void start() override;
92
  bool sendRequest() override;
93
  void close() override;
94

            
95
  // Network::ConnectionCallbacks
96
6
  void onEvent(Network::ConnectionEvent event) override { parent_.onEvent(event); }
97
6
  void onAboveWriteBufferHighWatermark() override { parent_.onAboveWriteBufferHighWatermark(); }
98
6
  void onBelowWriteBufferLowWatermark() override { parent_.onBelowWriteBufferLowWatermark(); }
99

            
100
private:
101
14
  TransportPtr createTransport() {
102
14
    return NamedTransportConfigFactory::getFactory(transport_).createTransport();
103
14
  }
104

            
105
14
  ProtocolPtr createProtocol() {
106
14
    return NamedProtocolConfigFactory::getFactory(protocol_).createProtocol();
107
14
  }
108

            
109
8
  int32_t sequenceId() {
110
8
    if (fixed_seq_id_) {
111
2
      return seq_id_;
112
2
    }
113

            
114
6
    if (seq_id_ != std::numeric_limits<int32_t>::max()) {
115
5
      return seq_id_++;
116
5
    }
117

            
118
1
    seq_id_ = 0;
119
1
    return std::numeric_limits<int32_t>::max();
120
6
  }
121

            
122
  ClientCallback& parent_;
123
  const TransportType transport_;
124
  const ProtocolType protocol_;
125
  const std::string& method_name_;
126
  Upstream::HostSharedPtr host_;
127
  Network::ClientConnectionPtr connection_;
128
  Upstream::HostDescriptionConstSharedPtr host_description_;
129

            
130
  int32_t seq_id_{0};
131
  bool fixed_seq_id_;
132
  ThriftSessionCallbacksSharedPtr session_callbacks_;
133
  SimpleResponseDecoderPtr response_decoder_;
134
};
135

            
136
class ClientFactoryImpl : public ClientFactory {
137
public:
138
  // ClientFactory
139
  ClientPtr create(ClientCallback& callbacks, TransportType transport, ProtocolType protocol,
140
                   const std::string& method_name, Upstream::HostSharedPtr host, int32_t seq_id,
141
                   bool fixed_seq_id) override;
142

            
143
  static ClientFactoryImpl instance_;
144
};
145

            
146
} // namespace ThriftHealthChecker
147
} // namespace HealthCheckers
148
} // namespace Extensions
149
} // namespace Envoy