1
#include "source/extensions/clusters/redis/redis_cluster_lb.h"
2

            
3
#include <string>
4

            
5
namespace Envoy {
6
namespace Extensions {
7
namespace Clusters {
8
namespace Redis {
9

            
10
3
bool ClusterSlot::operator==(const Envoy::Extensions::Clusters::Redis::ClusterSlot& rhs) const {
11
3
  if (start_ != rhs.start_ || end_ != rhs.end_ || *primary_ != *rhs.primary_ ||
12
3
      replicas_.size() != rhs.replicas_.size()) {
13
    return false;
14
  }
15
  // The value type is shared_ptr, and the shared_ptr is not same one even for same ip:port.
16
  // so, just compare the key here.
17
3
  return std::equal(replicas_.begin(), replicas_.end(), rhs.replicas_.begin(), rhs.replicas_.end(),
18
3
                    [](const auto& it1, const auto& it2) { return it1.first == it2.first; });
19
3
}
20

            
21
// RedisClusterLoadBalancerFactory
22
bool RedisClusterLoadBalancerFactory::onClusterSlotUpdate(ClusterSlotsSharedPtr&& slots,
23
18
                                                          Envoy::Upstream::HostMap& all_hosts) {
24
  // The slots is sorted, allowing for a quick comparison to make sure we need to update the slot
25
  // array sort based on start and end to enable efficient comparison
26
18
  std::sort(
27
24
      slots->begin(), slots->end(), [](const ClusterSlot& lhs, const ClusterSlot& rhs) -> bool {
28
18
        return lhs.start() < rhs.start() || (!(lhs.start() < rhs.start()) && lhs.end() < rhs.end());
29
18
      });
30

            
31
18
  if (current_cluster_slot_ && *current_cluster_slot_ == *slots) {
32
1
    return false;
33
1
  }
34

            
35
17
  auto updated_slots = std::make_shared<SlotArray>();
36
17
  auto shard_vector = std::make_shared<std::vector<RedisShardSharedPtr>>();
37
17
  absl::flat_hash_map<std::string, uint64_t> shards;
38

            
39
33
  for (const ClusterSlot& slot : *slots) {
40
    // look in the updated map
41
33
    const std::string primary_address = slot.primary()->asString();
42

            
43
33
    auto result = shards.try_emplace(primary_address, shard_vector->size());
44
33
    if (result.second) {
45
32
      auto primary_host = all_hosts.find(primary_address);
46
32
      ASSERT(primary_host != all_hosts.end(),
47
32
             "we expect all address to be found in the updated_hosts");
48

            
49
32
      Upstream::HostVectorSharedPtr primary_and_replicas = std::make_shared<Upstream::HostVector>();
50
32
      Upstream::HostVectorSharedPtr replicas = std::make_shared<Upstream::HostVector>();
51
32
      primary_and_replicas->push_back(primary_host->second);
52

            
53
32
      for (auto const& replica : slot.replicas()) {
54
14
        auto replica_host = all_hosts.find(replica.first);
55
14
        ASSERT(replica_host != all_hosts.end(),
56
14
               "we expect all address to be found in the updated_hosts");
57
14
        replicas->push_back(replica_host->second);
58
14
        primary_and_replicas->push_back(replica_host->second);
59
14
      }
60

            
61
32
      shard_vector->emplace_back(std::make_shared<RedisShard>(primary_host->second, replicas,
62
32
                                                              primary_and_replicas, random_));
63
32
    }
64

            
65
278562
    for (auto i = slot.start(); i <= slot.end(); ++i) {
66
278529
      updated_slots->at(i) = result.first->second;
67
278529
    }
68
33
  }
69

            
70
17
  {
71
17
    absl::WriterMutexLock lock(mutex_);
72
17
    current_cluster_slot_ = std::move(slots);
73
17
    slot_array_ = std::move(updated_slots);
74
17
    shard_vector_ = std::move(shard_vector);
75
17
  }
76
17
  return true;
77
18
}
78

            
79
12
void RedisClusterLoadBalancerFactory::onHostHealthUpdate() {
80
12
  ShardVectorSharedPtr current_shard_vector;
81
12
  {
82
12
    absl::ReaderMutexLock lock(mutex_);
83
12
    current_shard_vector = shard_vector_;
84
12
  }
85

            
86
  // This can get called by cluster initialization before the Redis Cluster topology is resolved.
87
12
  if (!current_shard_vector) {
88
10
    return;
89
10
  }
90

            
91
2
  auto shard_vector = std::make_shared<std::vector<RedisShardSharedPtr>>();
92

            
93
4
  for (auto const& shard : *current_shard_vector) {
94
4
    shard_vector->emplace_back(std::make_shared<RedisShard>(
95
4
        shard->primary(), shard->replicas().hostsPtr(), shard->allHosts().hostsPtr(), random_));
96
4
  }
97

            
98
2
  {
99
2
    absl::WriterMutexLock lock(mutex_);
100
2
    shard_vector_ = std::move(shard_vector);
101
2
  }
102
2
}
103

            
104
59
Upstream::LoadBalancerPtr RedisClusterLoadBalancerFactory::create(Upstream::LoadBalancerParams) {
105
59
  absl::ReaderMutexLock lock(mutex_);
106
59
  return std::make_unique<RedisClusterLoadBalancer>(slot_array_, shard_vector_, random_);
107
59
}
108

            
109
namespace {
110
Upstream::HostConstSharedPtr chooseRandomHost(const Upstream::HostSetImpl& host_set,
111
122
                                              Random::RandomGenerator& random) {
112
122
  auto hosts = host_set.healthyHosts();
113
122
  if (hosts.empty()) {
114
9
    hosts = host_set.degradedHosts();
115
9
  }
116

            
117
122
  if (hosts.empty()) {
118
9
    hosts = host_set.hosts();
119
9
  }
120

            
121
122
  if (!hosts.empty()) {
122
121
    return hosts[random.random() % hosts.size()];
123
121
  } else {
124
1
    return nullptr;
125
1
  }
126
122
}
127
} // namespace
128

            
129
Upstream::HostSelectionResponse
130
RedisClusterLoadBalancerFactory::RedisClusterLoadBalancer::chooseHost(
131
218
    Envoy::Upstream::LoadBalancerContext* context) {
132
218
  if (!slot_array_) {
133
1
    return {nullptr};
134
1
  }
135
217
  absl::optional<uint64_t> hash;
136
217
  if (context) {
137
217
    hash = context->computeHashKey();
138
217
  }
139

            
140
217
  if (!hash) {
141
1
    return {nullptr};
142
1
  }
143

            
144
216
  RedisShardSharedPtr shard;
145
216
  if (dynamic_cast<const RedisSpecifyShardContextImpl*>(context)) {
146
5
    if (hash.value() < shard_vector_->size()) {
147
3
      shard = shard_vector_->at(hash.value());
148
3
    } else {
149
2
      return {nullptr};
150
2
    }
151
211
  } else {
152
211
    shard = shard_vector_->at(
153
211
        slot_array_->at(hash.value() % Envoy::Extensions::Clusters::Redis::MaxSlot));
154
211
  }
155

            
156
214
  auto redis_context = dynamic_cast<RedisLoadBalancerContext*>(context);
157
214
  if (redis_context && redis_context->isReadCommand()) {
158
189
    switch (redis_context->readPolicy()) {
159
43
    case NetworkFilters::Common::Redis::Client::ReadPolicy::Primary:
160
43
      return shard->primary();
161
32
    case NetworkFilters::Common::Redis::Client::ReadPolicy::PreferPrimary:
162
32
      if (shard->primary()->coarseHealth() == Upstream::Host::Health::Healthy) {
163
24
        return shard->primary();
164
24
      } else {
165
8
        return chooseRandomHost(shard->allHosts(), random_);
166
8
      }
167
26
    case NetworkFilters::Common::Redis::Client::ReadPolicy::Replica:
168
26
      return chooseRandomHost(shard->replicas(), random_);
169
32
    case NetworkFilters::Common::Redis::Client::ReadPolicy::PreferReplica:
170
32
      if (!shard->replicas().healthyHosts().empty()) {
171
16
        return chooseRandomHost(shard->replicas(), random_);
172
16
      } else {
173
16
        return chooseRandomHost(shard->allHosts(), random_);
174
16
      }
175
56
    case NetworkFilters::Common::Redis::Client::ReadPolicy::Any:
176
56
      return chooseRandomHost(shard->allHosts(), random_);
177
189
    }
178
189
  }
179
25
  return shard->primary();
180
214
}
181

            
182
bool RedisLoadBalancerContextImpl::isReadRequest(
183
37980
    const NetworkFilters::Common::Redis::RespValue& request) {
184
37980
  const NetworkFilters::Common::Redis::RespValue* command = nullptr;
185
37980
  if (request.type() == NetworkFilters::Common::Redis::RespType::Array) {
186
103
    command = &(request.asArray()[0]);
187
37895
  } else if (request.type() == NetworkFilters::Common::Redis::RespType::CompositeArray) {
188
4
    command = request.asCompositeArray().command();
189
4
  }
190
37980
  if (!command) {
191
37873
    return false;
192
37873
  }
193
107
  if (command->type() != NetworkFilters::Common::Redis::RespType::SimpleString &&
194
107
      command->type() != NetworkFilters::Common::Redis::RespType::BulkString) {
195
1
    return false;
196
1
  }
197
106
  std::string to_lower_string = absl::AsciiStrToLower(command->asString());
198
106
  return NetworkFilters::Common::Redis::SupportedCommands::isReadCommand(to_lower_string);
199
107
}
200

            
201
RedisLoadBalancerContextImpl::RedisLoadBalancerContextImpl(
202
    const std::string& key, bool enabled_hashtagging, bool is_redis_cluster,
203
    const NetworkFilters::Common::Redis::RespValue& request,
204
    NetworkFilters::Common::Redis::Client::ReadPolicy read_policy)
205
37980
    : hash_key_(is_redis_cluster ? Crc16::crc16(hashtag(key, true))
206
37980
                                 : MurmurHash::murmurHash2(hashtag(key, enabled_hashtagging))),
207
37980
      is_read_(isReadRequest(request)), read_policy_(read_policy) {}
208

            
209
// Inspired by the redis-cluster hashtagging algorithm
210
// https://redis.io/topics/cluster-spec#keys-hash-tags
211
37980
absl::string_view RedisLoadBalancerContextImpl::hashtag(absl::string_view v, bool enabled) {
212
37980
  if (!enabled) {
213
86
    return v;
214
86
  }
215

            
216
37894
  auto start = v.find('{');
217
37894
  if (start == std::string::npos) {
218
37887
    return v;
219
37887
  }
220

            
221
7
  auto end = v.find('}', start);
222
7
  if (end == std::string::npos || end == start + 1) {
223
1
    return v;
224
1
  }
225

            
226
6
  return v.substr(start + 1, end - start - 1);
227
7
}
228
RedisSpecifyShardContextImpl::RedisSpecifyShardContextImpl(
229
    uint64_t shard_index, const NetworkFilters::Common::Redis::RespValue& request,
230
    NetworkFilters::Common::Redis::Client::ReadPolicy read_policy)
231
37831
    : RedisLoadBalancerContextImpl(std::to_string(shard_index), true, true, request, read_policy),
232
37831
      shard_index_(shard_index) {}
233

            
234
} // namespace Redis
235
} // namespace Clusters
236
} // namespace Extensions
237
} // namespace Envoy