1
#pragma once
2

            
3
#include <algorithm>
4
#include <iosfwd>
5
#include <memory>
6
#include <queue>
7
#include <utility>
8
#include <vector>
9

            
10
#include "envoy/common/random_generator.h"
11
#include "envoy/upstream/scheduler.h"
12

            
13
#include "source/common/common/assert.h"
14
#include "source/common/common/logger.h"
15

            
16
#include "absl/container/flat_hash_map.h"
17

            
18
namespace Envoy {
19
namespace Upstream {
20

            
21
// Weighted Random Selection Queue (WRSQ) Scheduler
22
// ------------------------------------------------
23
// This scheduler keeps a queue for each unique weight among all objects inserted and adds the
24
// objects to their respective queue based on weight. When performing a pick operation, a queue is
25
// selected and an object is pulled. Each queue gets its own selection probability which is weighted
26
// as the sum of all weights of objects contained within. Once a queue is picked, you can simply
27
// pull from the top and honor the expected selection probability of each object.
28
//
29
// Adding an object will cause the scheduler to rebuild internal structures on the first pick that
30
// follows. This first pick operation will be linear on the number of unique weights among objects
31
// inserted. Subsequent picks will be logarithmic with the number of unique weights. Adding objects
32
// is always constant time.
33
//
34
// For the case where all object weights are the same, WRSQ behaves identical to vanilla
35
// round-robin. If all object weights are different, it behaves identical to weighted random
36
// selection.
37
//
38
// NOTE: While the base scheduler interface allows for mutation of object weights with each pick,
39
// this implementation is not meant for circumstances where the object weights change with each pick
40
// (like in the least request LB). This scheduler implementation will perform quite poorly if the
41
// object weights change often.
42
template <class C>
43
class WRSQScheduler : public Scheduler<C>, protected Logger::Loggable<Logger::Id::upstream> {
44
public:
45
11
  WRSQScheduler(Random::RandomGenerator& random) : random_(random) {}
46

            
47
17667
  std::shared_ptr<C> peekAgain(std::function<double(const C&)> calculate_weight) override {
48
17667
    std::shared_ptr<C> picked{pickAndAddInternal(calculate_weight)};
49
17667
    if (picked != nullptr) {
50
17664
      prepick_queue_.emplace(picked);
51
17664
    }
52
17667
    return picked;
53
17667
  }
54

            
55
17803
  std::shared_ptr<C> pickAndAdd(std::function<double(const C&)> calculate_weight) override {
56
    // Burn through the pre-pick queue.
57
17806
    while (!prepick_queue_.empty()) {
58
16663
      std::shared_ptr<C> prepicked_obj = prepick_queue_.front().lock();
59
16663
      prepick_queue_.pop();
60
16663
      if (prepicked_obj != nullptr) {
61
16660
        return prepicked_obj;
62
16660
      }
63
16663
    }
64

            
65
1143
    return pickAndAddInternal(calculate_weight);
66
17803
  }
67

            
68
418
  void add(double weight, std::shared_ptr<C> entry) override {
69
418
    rebuild_cumulative_weights_ = true;
70
418
    queue_map_[weight].emplace(std::move(entry));
71
418
  }
72

            
73
  bool empty() const override { return queue_map_.empty(); }
74

            
75
private:
76
  using ObjQueue = std::queue<std::weak_ptr<C>>;
77

            
78
  // TODO(tonya11en): We can reduce memory utilization by using an absl::flat_hash_map of QueueInfo
79
  // with heterogeneous lookup on the weight. This would allow us to save 8 bytes per unique weight.
80
  using QueueMap = absl::flat_hash_map<double, ObjQueue>;
81

            
82
  // Used to store a queue's weight info necessary to perform the weighted random selection.
83
  struct QueueInfo {
84
    double cumulative_weight;
85
    double weight;
86
    ObjQueue& q;
87
  };
88

            
89
  // If needed, such as after object expiry or addition, rebuild the cumulative weights vector.
90
18814
  void maybeRebuildCumulativeWeights() {
91
18814
    if (!rebuild_cumulative_weights_) {
92
18796
      return;
93
18796
    }
94

            
95
18
    cumulative_weights_.clear();
96
18
    cumulative_weights_.reserve(queue_map_.size());
97

            
98
18
    double weight_sum = 0;
99
44
    for (auto& it : queue_map_) {
100
44
      const auto weight_val = it.first;
101
44
      weight_sum += weight_val * it.second.size();
102
44
      cumulative_weights_.push_back({weight_sum, weight_val, it.second});
103
44
    }
104

            
105
18
    rebuild_cumulative_weights_ = false;
106
18
  }
107

            
108
  // Performs a weighted random selection on the queues containing objects of the same weight.
109
  // Popping off the top of the queue to pick an object will honor the selection probability based
110
  // on the weight provided when the object was added.
111
18814
  QueueInfo& chooseQueue() {
112
18814
    ASSERT(!queue_map_.empty());
113

            
114
18814
    maybeRebuildCumulativeWeights();
115

            
116
18814
    const double weight_sum = cumulative_weights_.back().cumulative_weight;
117
18814
    uint64_t rnum = random_.random() % static_cast<uint32_t>(weight_sum);
118
18814
    auto it = std::upper_bound(cumulative_weights_.begin(), cumulative_weights_.end(), rnum,
119
20502
                               [](auto a, auto b) { return a < b.cumulative_weight; });
120
18814
    ASSERT(it != cumulative_weights_.end());
121
18814
    return *it;
122
18814
  }
123

            
124
  // Remove objects from the queue until it's empty or there is an unexpired object at the front. If
125
  // the queue is purged to empty, it's removed from the queue map and we return true.
126
18814
  bool purgeExpired(QueueInfo& qinfo) {
127
18825
    while (!qinfo.q.empty() && qinfo.q.front().expired()) {
128
11
      qinfo.q.pop();
129
11
      rebuild_cumulative_weights_ = true;
130
11
    }
131

            
132
18814
    if (qinfo.q.empty()) {
133
9
      queue_map_.erase(qinfo.weight);
134
9
      return true;
135
9
    }
136
18805
    return false;
137
18814
  }
138

            
139
18810
  std::shared_ptr<C> pickAndAddInternal(std::function<double(const C&)> calculate_weight) {
140
18819
    while (!queue_map_.empty()) {
141
18814
      QueueInfo& qinfo = chooseQueue();
142
18814
      if (purgeExpired(qinfo)) {
143
        // The chosen queue was purged to empty and removed from the queue map. Try again.
144
9
        continue;
145
9
      }
146

            
147
18805
      auto obj = qinfo.q.front().lock();
148
18805
      qinfo.q.pop();
149
18805
      if (obj == nullptr) {
150
        // The object expired after the purge.
151
        continue;
152
      }
153

            
154
18805
      const double new_weight = calculate_weight ? calculate_weight(*obj) : qinfo.weight;
155
18805
      if (new_weight == qinfo.weight) {
156
18804
        qinfo.q.emplace(obj);
157
18804
      } else {
158
        // The weight has changed for this object, so we must re-add it to the scheduler.
159
1
        ENVOY_LOG_EVERY_POW_2(
160
1
            warn, "WRSQ scheduler is used with a load balancer that mutates host weights with each "
161
1
                  "selection, this will likely result in poor selection performance");
162
1
        add(new_weight, obj);
163
1
      }
164

            
165
18805
      return obj;
166
18805
    }
167

            
168
5
    return nullptr;
169
18810
  }
170

            
171
  Random::RandomGenerator& random_;
172

            
173
  // Objects already picked via peekAgain().
174
  ObjQueue prepick_queue_;
175

            
176
  // A mapping from an object weight to the associated queue.
177
  QueueMap queue_map_;
178

            
179
  // Stores the necessary information to perform a weighted random selection of the different
180
  // queues. A cumulative sum is also kept of the total object weights for a queue, which allows for
181
  // a single random number generation and a binary search to pick a queue.
182
  std::vector<QueueInfo> cumulative_weights_;
183

            
184
  // Keeps state that determines whether the cumulative weights need to be rebuilt. If any objects
185
  // contained in a queue change from addition or expiry, it throws off the cumulative weight
186
  // values. Therefore, they must be recalculated.
187
  bool rebuild_cumulative_weights_{true};
188
};
189

            
190
} // namespace Upstream
191
} // namespace Envoy