1
#pragma once
2

            
3
#include <chrono>
4
#include <cstdint>
5

            
6
#include "source/common/grpc/buffered_message_ttl_manager.h"
7
#include "source/common/grpc/typed_async_client.h"
8
#include "source/common/protobuf/utility.h"
9

            
10
#include "absl/container/btree_map.h"
11

            
12
namespace Envoy {
13
namespace Grpc {
14

            
15
enum class BufferState { Buffered, PendingFlush };
16

            
17
// This class wraps bidirectional gRPC and provides message arrival guarantee.
18
// It stores messages to be sent or in the process of being sent in a buffer,
19
// and can track the status of the message based on the ID assigned to each message.
20
// If a message fails to be sent, it can be re-buffered to guarantee its arrival.
21
template <class RequestType, class ResponseType> class BufferedAsyncClient {
22
public:
23
  BufferedAsyncClient(uint32_t max_buffer_bytes, const Protobuf::MethodDescriptor& service_method,
24
                      Grpc::AsyncStreamCallbacks<ResponseType>& callbacks,
25
                      const Grpc::AsyncClient<RequestType, ResponseType>& client,
26
                      Event::Dispatcher& dispatcher, std::chrono::milliseconds message_timeout_msec)
27
3
      : max_buffer_bytes_(max_buffer_bytes), service_method_(service_method), callbacks_(callbacks),
28
3
        client_(client),
29
3
        ttl_manager_(dispatcher, [this](uint64_t id) { onError(id); }, message_timeout_msec) {}
30

            
31
3
  ~BufferedAsyncClient() {
32
3
    if (active_stream_ != nullptr) {
33
3
      active_stream_ = nullptr;
34
3
    }
35
3
  }
36

            
37
  // It push message into internal message buffer.
38
  // If the buffer is full, it will return absl::nullopt.
39
4
  absl::optional<uint64_t> bufferMessage(RequestType& message) {
40
4
    const auto buffer_size = message.ByteSizeLong();
41
4
    if (current_buffer_bytes_ + buffer_size > max_buffer_bytes_) {
42
1
      return absl::nullopt;
43
1
    }
44

            
45
3
    auto id = publishId();
46
3
    message_buffer_[id] = std::make_pair(BufferState::Buffered, message);
47
3
    current_buffer_bytes_ += buffer_size;
48
3
    return id;
49
4
  }
50

            
51
6
  absl::flat_hash_set<uint64_t> sendBufferedMessages() {
52
6
    if (active_stream_ == nullptr) {
53
3
      active_stream_ =
54
3
          client_.start(service_method_, callbacks_, Http::AsyncClient::StreamOptions());
55
3
    }
56

            
57
6
    if (active_stream_->isAboveWriteBufferHighWatermark()) {
58
1
      return {};
59
1
    }
60

            
61
5
    absl::flat_hash_set<uint64_t> inflight_message_ids;
62

            
63
5
    for (auto&& it : message_buffer_) {
64
4
      const auto id = it.first;
65
4
      auto& state = it.second.first;
66
4
      auto& message = it.second.second;
67

            
68
4
      if (state == BufferState::PendingFlush) {
69
1
        continue;
70
1
      }
71

            
72
3
      state = BufferState::PendingFlush;
73
3
      inflight_message_ids.emplace(id);
74
3
      active_stream_->sendMessage(message, false);
75
3
    }
76

            
77
5
    ttl_manager_.addDeadlineEntry(inflight_message_ids);
78
5
    return inflight_message_ids;
79
6
  }
80

            
81
3
  void onSuccess(uint64_t message_id) { erasePendingMessage(message_id); }
82

            
83
3
  void onError(uint64_t message_id) {
84
3
    const auto& message_it = message_buffer_.find(message_id);
85

            
86
3
    if (message_it == message_buffer_.end() ||
87
3
        message_it->second.first != Grpc::BufferState::PendingFlush) {
88
2
      return;
89
2
    }
90

            
91
1
    message_buffer_.at(message_id).first = BufferState::Buffered;
92
1
  }
93

            
94
3
  bool hasActiveStream() { return active_stream_ != nullptr; }
95

            
96
10
  const absl::btree_map<uint64_t, std::pair<BufferState, RequestType>>& messageBuffer() {
97
10
    return message_buffer_;
98
10
  }
99

            
100
private:
101
3
  void erasePendingMessage(uint64_t message_id) {
102
    // This case will be considered if `onSuccess` had called with unknown message id that is not
103
    // received by envoy as response.
104
3
    if (message_buffer_.find(message_id) == message_buffer_.end()) {
105
      return;
106
    }
107
3
    auto& buffer = message_buffer_.at(message_id);
108

            
109
    // There may be cases where the buffer status is not PendingFlush when
110
    // this function is called. For example, a message_buffer that was
111
    // PendingFlush may become Buffered due to an external state change
112
    // (e.g. re-buffering due to timeout).
113
3
    if (buffer.first == BufferState::PendingFlush) {
114
2
      const auto buffer_size = buffer.second.ByteSizeLong();
115
2
      current_buffer_bytes_ -= buffer_size;
116
2
      message_buffer_.erase(message_id);
117
2
    }
118
3
  }
119

            
120
3
  uint64_t publishId() { return next_message_id_++; }
121

            
122
  const uint32_t max_buffer_bytes_ = 0;
123
  const Protobuf::MethodDescriptor& service_method_;
124
  Grpc::AsyncStreamCallbacks<ResponseType>& callbacks_;
125
  Grpc::AsyncClient<RequestType, ResponseType> client_;
126
  Grpc::AsyncStream<RequestType> active_stream_;
127
  absl::btree_map<uint64_t, std::pair<BufferState, RequestType>> message_buffer_;
128
  uint32_t current_buffer_bytes_ = 0;
129
  uint64_t next_message_id_ = 0;
130
  BufferedMessageTtlManager ttl_manager_;
131
};
132

            
133
template <class RequestType, class ResponseType>
134
using BufferedAsyncClientPtr = std::unique_ptr<BufferedAsyncClient<RequestType, ResponseType>>;
135

            
136
} // namespace Grpc
137
} // namespace Envoy