1
#include "source/extensions/load_balancing_policies/common/thread_aware_lb_impl.h"
2

            
3
#include <memory>
4
#include <random>
5

            
6
#include "source/common/common/hex.h"
7
#include "source/common/http/headers.h"
8
#include "source/common/http/utility.h"
9
#include "source/common/runtime/runtime_features.h"
10

            
11
namespace Envoy {
12
namespace Upstream {
13

            
14
// TODO(mergeconflict): Adjust locality weights for partial availability, as is done in
15
//                      HostSetImpl::effectiveLocalityWeight.
16
namespace {
17

            
18
void normalizeHostWeights(const HostVector& hosts, double normalized_locality_weight,
19
                          NormalizedHostWeightVector& normalized_host_weights,
20
680
                          double& min_normalized_weight, double& max_normalized_weight) {
21
  // sum should be at most uint32_t max value, so we can validate it by accumulating into unit64_t
22
  // and making sure there was no overflow
23
680
  uint64_t sum = 0;
24
5012
  for (const auto& host : hosts) {
25
4990
    sum += host->weight();
26
4990
    if (sum > std::numeric_limits<uint32_t>::max()) {
27
      IS_ENVOY_BUG("weights should have been previously validated in validateEndpoints()");
28
      return;
29
    }
30
4990
  }
31

            
32
5012
  for (const auto& host : hosts) {
33
4990
    const double weight = host->weight() * normalized_locality_weight / sum;
34
4990
    normalized_host_weights.push_back({host, weight});
35
4990
    min_normalized_weight = std::min(min_normalized_weight, weight);
36
4990
    max_normalized_weight = std::max(max_normalized_weight, weight);
37
4990
  }
38
680
}
39

            
40
void normalizeLocalityWeights(const HostsPerLocality& hosts_per_locality,
41
                              const LocalityWeights& locality_weights,
42
                              NormalizedHostWeightVector& normalized_host_weights,
43
22
                              double& min_normalized_weight, double& max_normalized_weight) {
44
22
  ASSERT(locality_weights.size() == hosts_per_locality.get().size());
45

            
46
  // sum should be at most uint32_t max value, so we can validate it by accumulating into unit64_t
47
  // and making sure there was no overflow
48
22
  uint64_t sum = 0;
49
52
  for (const auto weight : locality_weights) {
50
52
    sum += weight;
51
52
    if (sum > std::numeric_limits<uint32_t>::max()) {
52
      IS_ENVOY_BUG("locality weights should have been validated in validateEndpoints");
53
    }
54
52
  }
55

            
56
  // Locality weights (unlike host weights) may be 0. If _all_ locality weights were 0, bail out.
57
22
  if (sum == 0) {
58
4
    return;
59
4
  }
60

            
61
  // Compute normalized weights for all hosts in each locality. If a locality was assigned zero
62
  // weight, all hosts in that locality will be skipped.
63
64
  for (LocalityWeights::size_type i = 0; i < locality_weights.size(); ++i) {
64
46
    if (locality_weights[i] != 0) {
65
40
      const HostVector& hosts = hosts_per_locality.get()[i];
66
40
      const double normalized_locality_weight = static_cast<double>(locality_weights[i]) / sum;
67
40
      normalizeHostWeights(hosts, normalized_locality_weight, normalized_host_weights,
68
40
                           min_normalized_weight, max_normalized_weight);
69
40
    }
70
46
  }
71
18
}
72

            
73
void normalizeWeights(const HostSet& host_set, bool in_panic,
74
                      NormalizedHostWeightVector& normalized_host_weights,
75
                      double& min_normalized_weight, double& max_normalized_weight,
76
662
                      bool locality_weighted_balancing) {
77
662
  if (!locality_weighted_balancing || host_set.localityWeights() == nullptr ||
78
662
      host_set.localityWeights()->empty()) {
79
    // If we're not dealing with locality weights, just normalize weights for the flat set of hosts.
80
640
    const auto& hosts = in_panic ? host_set.hosts() : host_set.healthyHosts();
81
640
    normalizeHostWeights(hosts, 1.0, normalized_host_weights, min_normalized_weight,
82
640
                         max_normalized_weight);
83
640
  } else {
84
    // Otherwise, normalize weights across all localities.
85
22
    const auto& hosts_per_locality =
86
22
        in_panic ? host_set.hostsPerLocality() : host_set.healthyHostsPerLocality();
87
22
    normalizeLocalityWeights(hosts_per_locality, *(host_set.localityWeights()),
88
22
                             normalized_host_weights, min_normalized_weight, max_normalized_weight);
89
22
  }
90
662
}
91

            
92
std::string generateCookie(LoadBalancerContext* context, absl::string_view name,
93
                           absl::string_view path, std::chrono::seconds ttl,
94
24
                           absl::Span<const Http::CookieAttribute> attributes) {
95
24
  ASSERT(context != nullptr);
96
24
  const StreamInfo::StreamInfo* stream_info = context->requestStreamInfo();
97
24
  if (stream_info == nullptr) {
98
8
    return {};
99
8
  }
100

            
101
16
  const auto& conn = stream_info->downstreamAddressProvider();
102
16
  const auto& remote_address = conn.remoteAddress();
103
16
  const auto& local_address = conn.localAddress();
104
16
  if (remote_address == nullptr || local_address == nullptr) {
105
8
    return {};
106
8
  }
107

            
108
8
  const std::string value = remote_address->asString() + local_address->asString();
109
8
  std::string cookie_value = Hex::uint64ToHex(HashUtil::xxHash64(value));
110

            
111
8
  std::string cookie_header_value =
112
8
      Http::Utility::makeSetCookieValue(name, cookie_value, path, ttl, true, attributes);
113
8
  context->setHeadersModifier(
114
8
      [h = std::move(cookie_header_value)](Http::ResponseHeaderMap& headers) {
115
8
        headers.addReferenceKey(Http::Headers::get().SetCookie, h);
116
8
      });
117

            
118
8
  return cookie_value;
119
16
}
120

            
121
} // namespace
122

            
123
545
absl::Status ThreadAwareLoadBalancerBase::initialize() {
124
  // TODO(mattklein123): In the future, once initialized and the initial LB is built, it would be
125
  // better to use a background thread for computing LB updates. This has the substantial benefit
126
  // that if the LB computation thread falls behind, host set updates can be trivially collapsed.
127
  // I will look into doing this in a follow up. Doing everything using a background thread heavily
128
  // complicated initialization as the load balancer would need its own initialized callback. I
129
  // think the synchronous/asynchronous split is probably the best option.
130
545
  if (Runtime::runtimeFeatureEnabled(
131
545
          "envoy.reloadable_features.coalesce_lb_rebuilds_on_batch_update")) {
132
544
    member_update_cb_ =
133
552
        priority_set_.addMemberUpdateCb([this](const HostVector&, const HostVector&) {
134
58
          processDirtyPriorities();
135
58
          refresh();
136
58
        });
137

            
138
    // PriorityUpdateCb can fire before initialize() during batch host updates, while MemberUpdateCb
139
    // (which flushes dirty priorities) is deferred until the batch completes. If initialize() is
140
    // invoked mid-batch, process any queued priorities now so per_priority_panic_ is sized for all
141
    // current priorities before refresh() indexes into it.
142
544
    processDirtyPriorities();
143
544
  } else {
144
1
    priority_update_cb_ = priority_set_.addPriorityUpdateCb(
145
1
        [this](uint32_t, const HostVector&, const HostVector&) { refresh(); });
146
1
  }
147

            
148
545
  refresh();
149
545
  return absl::OkStatus();
150
545
}
151

            
152
604
void ThreadAwareLoadBalancerBase::refresh() {
153
604
  auto per_priority_state_vector = std::make_shared<std::vector<PerPriorityStatePtr>>(
154
604
      priority_set_.hostSetsPerPriority().size());
155
604
  auto healthy_per_priority_load =
156
604
      std::make_shared<HealthyLoad>(per_priority_load_.healthy_priority_load_);
157
604
  auto degraded_per_priority_load =
158
604
      std::make_shared<DegradedLoad>(per_priority_load_.degraded_priority_load_);
159

            
160
662
  for (const auto& host_set : priority_set_.hostSetsPerPriority()) {
161
662
    const uint32_t priority = host_set->priority();
162
662
    (*per_priority_state_vector)[priority] = std::make_unique<PerPriorityState>();
163
662
    const auto& per_priority_state = (*per_priority_state_vector)[priority];
164
    // Copy panic flag from LoadBalancerBase. It is calculated when there is a change
165
    // in hosts set or hosts' health.
166
662
    ASSERT(priority < per_priority_panic_.size());
167
662
    per_priority_state->global_panic_ = per_priority_panic_[priority];
168

            
169
    // Normalize host and locality weights such that the sum of all normalized weights is 1.
170
662
    NormalizedHostWeightVector normalized_host_weights;
171
662
    double min_normalized_weight = 1.0;
172
662
    double max_normalized_weight = 0.0;
173
662
    normalizeWeights(*host_set, per_priority_state->global_panic_, normalized_host_weights,
174
662
                     min_normalized_weight, max_normalized_weight, locality_weighted_balancing_);
175
662
    per_priority_state->current_lb_ = createLoadBalancer(
176
662
        std::move(normalized_host_weights), min_normalized_weight, max_normalized_weight);
177
662
  }
178

            
179
604
  {
180
604
    absl::WriterMutexLock lock(&factory_->mutex_);
181
604
    factory_->healthy_per_priority_load_ = healthy_per_priority_load;
182
604
    factory_->degraded_per_priority_load_ = degraded_per_priority_load;
183
604
    factory_->per_priority_state_ = per_priority_state_vector;
184
604
  }
185
604
}
186

            
187
HostSelectionResponse
188
179379
ThreadAwareLoadBalancerBase::LoadBalancerImpl::chooseHost(LoadBalancerContext* context) {
189
  // Make sure we correctly return nullptr for any early chooseHost() calls.
190
179379
  if (per_priority_state_ == nullptr) {
191
2
    return {nullptr};
192
2
  }
193

            
194
179377
  HostConstSharedPtr host;
195

            
196
  // If there is no hash in the context, just choose a random value (this effectively becomes
197
  // the random LB but it won't crash if someone configures it this way).
198
  // computeHashKey() may be computed on demand, so get it only once.
199
179377
  absl::optional<uint64_t> hash;
200
179377
  if (context) {
201
    // If there is a hash policy, use the hash policy in the load balancer first.
202
179346
    if (hash_policy_ != nullptr) {
203
48
      hash = hash_policy_->generateHash(
204
48
          makeOptRefFromPtr(context->downstreamHeaders()),
205
48
          makeOptRefFromPtr(context->requestStreamInfo()),
206
48
          [context](absl::string_view name, absl::string_view path, std::chrono::seconds ttl,
207
48
                    absl::Span<const Http::CookieAttribute> attributes) -> std::string {
208
24
            return generateCookie(context, name, path, ttl, attributes);
209
24
          });
210
179298
    } else {
211
179298
      hash = context->computeHashKey();
212
179298
    }
213
179346
  }
214

            
215
179377
  const uint64_t h = hash ? hash.value() : random_.random();
216

            
217
179377
  const uint32_t priority =
218
179377
      LoadBalancerBase::choosePriority(h, *healthy_per_priority_load_, *degraded_per_priority_load_)
219
179377
          .first;
220
179377
  const auto& per_priority_state = (*per_priority_state_)[priority];
221
179377
  if (per_priority_state->global_panic_) {
222
131190
    stats_.lb_healthy_panic_.inc();
223
131190
  }
224

            
225
179377
  const uint32_t max_attempts = context ? context->hostSelectionRetryCount() + 1 : 1;
226
179421
  for (uint32_t i = 0; i < max_attempts; ++i) {
227
179415
    host = LoadBalancer::onlyAllowSynchronousHostSelection(
228
179415
        per_priority_state->current_lb_->chooseHost(h, i));
229

            
230
    // If host selection failed or the host is accepted by the filter, return.
231
    // Otherwise, try again.
232
179415
    if (!host || !context || !context->shouldSelectAnotherHost(*host)) {
233
179371
      return host;
234
179371
    }
235
179415
  }
236
6
  return host;
237
179377
}
238

            
239
1507
LoadBalancerPtr ThreadAwareLoadBalancerBase::LoadBalancerFactoryImpl::create(LoadBalancerParams) {
240
1507
  auto lb = std::make_unique<LoadBalancerImpl>(stats_, random_, hash_policy_);
241

            
242
  // We must protect current_lb_ via a RW lock since it is accessed and written to by multiple
243
  // threads. All complex processing has already been precalculated however.
244
1507
  absl::ReaderMutexLock lock(mutex_);
245
1507
  lb->healthy_per_priority_load_ = healthy_per_priority_load_;
246
1507
  lb->degraded_per_priority_load_ = degraded_per_priority_load_;
247
1507
  lb->per_priority_state_ = per_priority_state_;
248
1507
  return lb;
249
1507
}
250

            
251
double ThreadAwareLoadBalancerBase::BoundedLoadHashingLoadBalancer::hostOverloadFactor(
252
57
    const Host& host, double weight) const {
253
  // TODO(scheler): This will not work if rq_active cluster stat is disabled, need to detect
254
  // and alert the user if that's the case.
255

            
256
57
  const uint32_t overall_active = host.cluster().trafficStats()->upstream_rq_active_.value();
257
57
  const uint32_t host_active = host.stats().rq_active_.value();
258

            
259
57
  const uint32_t total_slots = ((overall_active + 1) * hash_balance_factor_ + 99) / 100;
260
57
  const uint32_t slots =
261
57
      std::max(static_cast<uint32_t>(std::ceil(total_slots * weight)), static_cast<uint32_t>(1));
262

            
263
57
  if (host.stats().rq_active_.value() > slots) {
264
1
    ENVOY_LOG_MISC(
265
1
        debug,
266
1
        "ThreadAwareLoadBalancerBase::BoundedLoadHashingLoadBalancer::chooseHost: "
267
1
        "host {} overloaded; overall_active {}, host_weight {}, host_active {} > slots {}",
268
1
        host.address()->asString(), overall_active, weight, host_active, slots);
269
1
  }
270
57
  return static_cast<double>(host.stats().rq_active_.value()) / slots;
271
57
}
272

            
273
HostSelectionResponse
274
ThreadAwareLoadBalancerBase::BoundedLoadHashingLoadBalancer::chooseHost(uint64_t hash,
275
67
                                                                        uint32_t attempt) const {
276

            
277
  // This is implemented based on the method described in the paper
278
  // https://arxiv.org/abs/1608.01350. For the specified `hash_balance_factor`, requests to any
279
  // upstream host are capped at `hash_balance_factor/100` times the average number of requests
280
  // across the cluster. When a request arrives for an upstream host that is currently serving at
281
  // its max capacity, linear probing is used to identify an eligible host. Further, the linear
282
  // probe is implemented using a random jump on hosts ring/table to identify the eligible host
283
  // (this technique is as described in the paper https://arxiv.org/abs/1908.08762 - the random jump
284
  // avoids the cascading overflow effect when choosing the next host on the ring/table).
285
  //
286
  // If weights are specified on the hosts, they are respected.
287
  //
288
  // This is an O(N) algorithm, unlike other load balancers. Using a lower `hash_balance_factor`
289
  // results in more hosts being probed, so use a higher value if you require better performance.
290

            
291
67
  if (normalized_host_weights_.empty()) {
292
1
    return {nullptr};
293
1
  }
294

            
295
66
  HostConstSharedPtr host =
296
66
      LoadBalancer::onlyAllowSynchronousHostSelection(hashing_lb_ptr_->chooseHost(hash, attempt));
297
66
  if (host == nullptr) {
298
    return {nullptr};
299
  }
300
66
  const double weight = normalized_host_weights_map_.at(host);
301
66
  double overload_factor = hostOverloadFactor(*host, weight);
302
66
  if (overload_factor <= 1.0) {
303
60
    ENVOY_LOG_MISC(debug,
304
60
                   "ThreadAwareLoadBalancerBase::BoundedLoadHashingLoadBalancer::chooseHost: "
305
60
                   "selected host #{} (attempt:1)",
306
60
                   host->address()->asString());
307
60
    return host;
308
60
  }
309

            
310
  // When a host is overloaded, we choose the next host in a random manner rather than picking the
311
  // next one in the ring. The random sequence is seeded by the hash, so the same input gets the
312
  // same sequence of hosts all the time.
313
6
  const uint32_t num_hosts = normalized_host_weights_.size();
314
6
  auto host_index = std::vector<uint32_t>(num_hosts);
315
34
  for (uint32_t i = 0; i < num_hosts; i++) {
316
28
    host_index[i] = i;
317
28
  }
318

            
319
  // Not using Random::RandomGenerator as it does not take a seed. Seeded RNG is a requirement
320
  // here as we need the same shuffle sequence for the same hash every time.
321
  // Further, not using std::default_random_engine and std::uniform_int_distribution as they
322
  // are not consistent across Linux and Windows platforms.
323
6
  const uint64_t seed = hash;
324
6
  std::mt19937 random(seed);
325

            
326
  // generates a random number in the range [0,k) uniformly.
327
13
  auto uniform_int = [](std::mt19937& random, uint32_t k) -> uint32_t {
328
13
    uint32_t x = k;
329
26
    while (x >= k) {
330
13
      x = random() / ((static_cast<uint64_t>(random.max()) + 1u) / k);
331
13
    }
332
13
    return x;
333
13
  };
334

            
335
6
  HostConstSharedPtr alt_host, least_overloaded_host = host;
336
6
  double least_overload_factor = overload_factor;
337
14
  for (uint32_t i = 0; i < num_hosts; i++) {
338
    // The random shuffle algorithm
339
13
    const uint32_t j = uniform_int(random, num_hosts - i);
340
13
    std::swap(host_index[i], host_index[i + j]);
341

            
342
13
    const uint32_t k = host_index[i];
343
13
    alt_host = normalized_host_weights_[k].first;
344
13
    if (alt_host == host) {
345
4
      continue;
346
4
    }
347

            
348
9
    const double alt_host_weight = normalized_host_weights_[k].second;
349
9
    overload_factor = hostOverloadFactor(*alt_host, alt_host_weight);
350

            
351
9
    if (overload_factor <= 1.0) {
352
5
      ENVOY_LOG_MISC(debug,
353
5
                     "ThreadAwareLoadBalancerBase::BoundedLoadHashingLoadBalancer::chooseHost: "
354
5
                     "selected host #{}:{} (attempt:{})",
355
5
                     k, alt_host->address()->asString(), i + 2);
356
5
      return alt_host;
357
5
    }
358

            
359
4
    if (least_overload_factor > overload_factor) {
360
3
      least_overloaded_host = alt_host;
361
3
      least_overload_factor = overload_factor;
362
3
    }
363
4
  }
364

            
365
1
  return least_overloaded_host;
366
6
}
367

            
368
TypedHashLbConfigBase::TypedHashLbConfigBase(absl::Span<const HashPolicyProto* const> hash_policy,
369
                                             Regex::Engine& regex_engine,
370
99
                                             absl::Status& creation_status) {
371
99
  if (hash_policy.empty()) {
372
92
    return;
373
92
  }
374
7
  auto hash_policy_or = Http::HashPolicyImpl::create(hash_policy, regex_engine);
375
7
  SET_AND_RETURN_IF_NOT_OK(hash_policy_or.status(), creation_status);
376
7
  hash_policy_ = std::move(hash_policy_or).value();
377
7
}
378

            
379
465
absl::Status TypedHashLbConfigBase::validateEndpoints(const PriorityState& priorities) const {
380

            
381
465
  for (const auto& [hosts, locality_weights_map] : priorities) {
382
    // Sum should be at most uint32_t max value, so we can validate it by accumulating into uint64_t
383
    // and making sure there was no overflow.
384
465
    uint64_t host_sum = 0;
385
570
    for (const auto& host : *hosts) {
386
570
      host_sum += host->weight();
387
570
      if (host_sum > std::numeric_limits<uint32_t>::max()) {
388
5
        return absl::InvalidArgumentError(
389
5
            fmt::format("The sum of weights of all upstream hosts in a locality exceeds {}",
390
5
                        std::numeric_limits<uint32_t>::max()));
391
5
      }
392
570
    }
393

            
394
460
    uint64_t locality_sum = 0;
395
461
    for (const auto& [_, weight] : locality_weights_map) {
396
2
      locality_sum += weight;
397
2
      if (locality_sum > std::numeric_limits<uint32_t>::max()) {
398
1
        return absl::InvalidArgumentError(
399
1
            fmt::format("The sum of weights of all localities at the same priority exceeds {}",
400
1
                        std::numeric_limits<uint32_t>::max()));
401
1
      }
402
2
    }
403
460
  }
404

            
405
459
  return absl::OkStatus();
406
465
}
407

            
408
} // namespace Upstream
409
} // namespace Envoy