Line data Source code
1 : #pragma once 2 : 3 : #include "source/common/network/filter_impl.h" 4 : #include "source/extensions/filters/network/thrift_proxy/config.h" 5 : #include "source/extensions/filters/network/thrift_proxy/decoder.h" 6 : #include "source/extensions/filters/network/thrift_proxy/passthrough_decoder_event_handler.h" 7 : #include "source/extensions/health_checkers/thrift/client.h" 8 : 9 : namespace Envoy { 10 : namespace Extensions { 11 : namespace HealthCheckers { 12 : namespace ThriftHealthChecker { 13 : 14 : using namespace Envoy::Extensions::NetworkFilters; 15 : using namespace Envoy::Extensions::NetworkFilters::ThriftProxy; 16 : 17 : // The simple response decoder decodes the response and informs the health 18 : // check session if it's a success response or not. 19 : class SimpleResponseDecoder : public DecoderCallbacks, 20 : public PassThroughDecoderEventHandler, 21 : protected Logger::Loggable<Logger::Id::hc> { 22 : public: 23 : SimpleResponseDecoder(TransportPtr transport, ProtocolPtr protocol) 24 : : transport_(std::move(transport)), protocol_(std::move(protocol)), 25 0 : decoder_(std::make_unique<Decoder>(*transport_, *protocol_, *this)) {} 26 : 27 : // Return if the response is complete. 28 : bool onData(Buffer::Instance& data); 29 : 30 : // Check if it is a success response or not. 31 : bool responseSuccess(); 32 : 33 : // PassThroughDecoderEventHandler 34 : FilterStatus messageBegin(MessageMetadataSharedPtr metadata) override; 35 : FilterStatus messageEnd() override; 36 : 37 : // DecoderCallbacks 38 0 : DecoderEventHandler& newDecoderEventHandler() override { return *this; } 39 0 : bool passthroughEnabled() const override { return true; } 40 0 : bool isRequest() const override { return false; } 41 0 : bool headerKeysPreserveCase() const override { return false; } 42 : 43 : private: 44 : TransportPtr transport_; 45 : ProtocolPtr protocol_; 46 : DecoderPtr decoder_; 47 : Buffer::OwnedImpl buffer_; 48 : absl::optional<bool> success_; 49 : bool complete_{}; 50 : }; 51 : 52 : using SimpleResponseDecoderPtr = std::unique_ptr<SimpleResponseDecoder>; 53 : 54 : class ClientImpl; 55 : 56 : // Network::ClientConnection takes a shared pointer callback but we need a 57 : // unique DeferredDeletable pointer for connection management. Therefore we 58 : // need an additional wrapper class. 59 : class ThriftSessionCallbacks : public Network::ConnectionCallbacks, 60 : public Network::ReadFilterBaseImpl { 61 : public: 62 0 : ThriftSessionCallbacks(ClientImpl& parent) : parent_(parent) {} 63 : 64 : // Network::ConnectionCallbacks 65 : void onEvent(Network::ConnectionEvent event) override; 66 : void onAboveWriteBufferHighWatermark() override; 67 : void onBelowWriteBufferLowWatermark() override; 68 : 69 : // Network::ReadFilter 70 : Network::FilterStatus onData(Buffer::Instance& data, bool) override; 71 : 72 : private: 73 : ClientImpl& parent_; 74 : }; 75 : 76 : using ThriftSessionCallbacksSharedPtr = std::shared_ptr<ThriftSessionCallbacks>; 77 : 78 : class ClientImpl : public Client, 79 : public Network::ConnectionCallbacks, 80 : protected Logger::Loggable<Logger::Id::hc> { 81 : public: 82 : ClientImpl(ClientCallback& callback, TransportType transport, ProtocolType protocol, 83 : const std::string& method_name, Upstream::HostSharedPtr host, int32_t seq_id, 84 : bool fixed_seq_id) 85 : : parent_(callback), transport_(transport), protocol_(protocol), method_name_(method_name), 86 0 : host_(host), seq_id_(seq_id), fixed_seq_id_(fixed_seq_id) {} 87 : 88 : void onData(Buffer::Instance& data); 89 : 90 : // Client 91 : void start() override; 92 : bool sendRequest() override; 93 : void close() override; 94 : 95 : // Network::ConnectionCallbacks 96 0 : void onEvent(Network::ConnectionEvent event) override { parent_.onEvent(event); } 97 0 : void onAboveWriteBufferHighWatermark() override { parent_.onAboveWriteBufferHighWatermark(); } 98 0 : void onBelowWriteBufferLowWatermark() override { parent_.onBelowWriteBufferLowWatermark(); } 99 : 100 : private: 101 0 : TransportPtr createTransport() { 102 0 : return NamedTransportConfigFactory::getFactory(transport_).createTransport(); 103 0 : } 104 : 105 0 : ProtocolPtr createProtocol() { 106 0 : return NamedProtocolConfigFactory::getFactory(protocol_).createProtocol(); 107 0 : } 108 : 109 0 : int32_t sequenceId() { 110 0 : if (fixed_seq_id_) { 111 0 : return seq_id_; 112 0 : } 113 : 114 0 : if (seq_id_ != std::numeric_limits<int32_t>::max()) { 115 0 : return seq_id_++; 116 0 : } 117 : 118 0 : seq_id_ = 0; 119 0 : return std::numeric_limits<int32_t>::max(); 120 0 : } 121 : 122 : ClientCallback& parent_; 123 : const TransportType transport_; 124 : const ProtocolType protocol_; 125 : const std::string& method_name_; 126 : Upstream::HostSharedPtr host_; 127 : Network::ClientConnectionPtr connection_; 128 : Upstream::HostDescriptionConstSharedPtr host_description_; 129 : 130 : int32_t seq_id_{0}; 131 : bool fixed_seq_id_; 132 : ThriftSessionCallbacksSharedPtr session_callbacks_; 133 : SimpleResponseDecoderPtr response_decoder_; 134 : }; 135 : 136 : class ClientFactoryImpl : public ClientFactory { 137 : public: 138 : // ClientFactory 139 : ClientPtr create(ClientCallback& callbacks, TransportType transport, ProtocolType protocol, 140 : const std::string& method_name, Upstream::HostSharedPtr host, int32_t seq_id, 141 : bool fixed_seq_id) override; 142 : 143 : static ClientFactoryImpl instance_; 144 : }; 145 : 146 : } // namespace ThriftHealthChecker 147 : } // namespace HealthCheckers 148 : } // namespace Extensions 149 : } // namespace Envoy