1
#include "source/common/router/weighted_cluster_specifier.h"
2

            
3
#include "source/common/config/well_known_names.h"
4
#include "source/common/router/config_utility.h"
5

            
6
namespace Envoy {
7
namespace Router {
8

            
9
108
absl::Status validateWeightedClusterSpecifier(const ClusterWeightProto& cluster) {
10
  // If one and only one of name or cluster_header is specified. The empty() of name
11
  // and cluster_header will be different values.
12
108
  if (cluster.name().empty() != cluster.cluster_header().empty()) {
13
106
    return absl::OkStatus();
14
106
  }
15
2
  const auto error = cluster.name().empty()
16
2
                         ? "At least one of name or cluster_header need to be specified"
17
2
                         : "Only one of name or cluster_header can be specified";
18
2
  return absl::InvalidArgumentError(error);
19
108
}
20

            
21
absl::StatusOr<std::shared_ptr<WeightedClustersConfigEntry>> WeightedClustersConfigEntry::create(
22
    const ClusterWeightProto& cluster, const MetadataMatchCriteria* parent_metadata_match,
23
108
    std::string&& runtime_key, Server::Configuration::ServerFactoryContext& context) {
24
108
  RETURN_IF_NOT_OK(validateWeightedClusterSpecifier(cluster));
25
106
  return std::unique_ptr<WeightedClustersConfigEntry>(new WeightedClustersConfigEntry(
26
106
      cluster, parent_metadata_match, std::move(runtime_key), context));
27
108
}
28

            
29
WeightedClustersConfigEntry::WeightedClustersConfigEntry(
30
    const envoy::config::route::v3::WeightedCluster::ClusterWeight& cluster,
31
    const MetadataMatchCriteria* parent_metadata_match, std::string&& runtime_key,
32
    Server::Configuration::ServerFactoryContext& context)
33
106
    : runtime_key_(std::move(runtime_key)),
34
106
      cluster_weight_(PROTOBUF_GET_WRAPPED_REQUIRED(cluster, weight)),
35
      per_filter_configs_(
36
106
          THROW_OR_RETURN_VALUE(PerFilterConfigs::create(cluster.typed_per_filter_config(), context,
37
                                                         context.messageValidationVisitor()),
38
                                std::unique_ptr<PerFilterConfigs>)),
39
106
      host_rewrite_(cluster.host_rewrite_literal()), cluster_name_(cluster.name()),
40
106
      cluster_header_name_(cluster.cluster_header()) {
41
106
  if (!cluster.request_headers_to_add().empty() || !cluster.request_headers_to_remove().empty()) {
42
55
    request_headers_parser_ =
43
55
        THROW_OR_RETURN_VALUE(HeaderParser::configure(cluster.request_headers_to_add(),
44
55
                                                      cluster.request_headers_to_remove()),
45
55
                              Router::HeaderParserPtr);
46
55
  }
47
106
  if (!cluster.response_headers_to_add().empty() || !cluster.response_headers_to_remove().empty()) {
48
54
    response_headers_parser_ =
49
54
        THROW_OR_RETURN_VALUE(HeaderParser::configure(cluster.response_headers_to_add(),
50
54
                                                      cluster.response_headers_to_remove()),
51
54
                              Router::HeaderParserPtr);
52
54
  }
53

            
54
106
  if (cluster.has_metadata_match()) {
55
2
    const auto filter_it = cluster.metadata_match().filter_metadata().find(
56
2
        Envoy::Config::MetadataFilters::get().ENVOY_LB);
57
2
    if (filter_it != cluster.metadata_match().filter_metadata().end()) {
58
2
      if (parent_metadata_match != nullptr) {
59
1
        cluster_metadata_match_criteria_ =
60
1
            parent_metadata_match->mergeMatchCriteria(filter_it->second);
61
1
      } else {
62
1
        cluster_metadata_match_criteria_ =
63
1
            std::make_unique<MetadataMatchCriteriaImpl>(filter_it->second);
64
1
      }
65
2
    }
66
2
  }
67
106
}
68

            
69
WeightedClusterSpecifierPlugin::WeightedClusterSpecifierPlugin(
70
    const WeightedClusterProto& weighted_clusters,
71
    const MetadataMatchCriteria* parent_metadata_match, absl::string_view route_name,
72
    Server::Configuration::ServerFactoryContext& context, absl::Status& creation_status)
73
79
    : loader_(context.runtime()), random_value_header_(weighted_clusters.header_name()),
74
79
      runtime_key_prefix_(weighted_clusters.runtime_key_prefix()),
75
79
      use_hash_policy_(weighted_clusters.random_value_specifier_case() ==
76
79
                               WeightedClusterProto::kUseHashPolicy
77
79
                           ? weighted_clusters.use_hash_policy().value()
78
79
                           : false) {
79

            
80
79
  absl::string_view runtime_key_prefix = weighted_clusters.runtime_key_prefix();
81

            
82
79
  weighted_clusters_.reserve(weighted_clusters.clusters().size());
83

            
84
108
  for (const ClusterWeightProto& cluster : weighted_clusters.clusters()) {
85
108
    auto cluster_entry =
86
108
        THROW_OR_RETURN_VALUE(WeightedClustersConfigEntry::create(
87
108
                                  cluster, parent_metadata_match,
88
108
                                  absl::StrCat(runtime_key_prefix, ".", cluster.name()), context),
89
108
                              std::shared_ptr<WeightedClustersConfigEntry>);
90
108
    weighted_clusters_.emplace_back(std::move(cluster_entry));
91
108
    total_cluster_weight_ += weighted_clusters_.back()->clusterWeight(loader_);
92
108
    if (total_cluster_weight_ > std::numeric_limits<uint32_t>::max()) {
93
1
      creation_status = absl::InvalidArgumentError(
94
1
          fmt::format("The sum of weights of all weighted clusters of route {} exceeds {}",
95
1
                      route_name, std::numeric_limits<uint32_t>::max()));
96
1
      return;
97
1
    }
98
108
  }
99

            
100
  // Reject the config if the total_weight of all clusters is 0.
101
78
  if (total_cluster_weight_ == 0) {
102
1
    creation_status = absl::InvalidArgumentError(
103
1
        "Sum of weights in the weighted_cluster must be greater than 0.");
104
1
    return;
105
1
  }
106
78
}
107

            
108
/**
109
 * Route entry implementation for weighted clusters. The RouteEntryImplBase object holds
110
 * one or more weighted cluster objects, where each object has a back pointer to the parent
111
 * RouteEntryImplBase object. Almost all functions in this class forward calls back to the
112
 * parent, with the exception of clusterName, routeEntry, and metadataMatchCriteria.
113
 */
114
class WeightedClusterEntry : public DynamicRouteEntry {
115
public:
116
  WeightedClusterEntry(RouteConstSharedPtr route, std::string&& cluster_name,
117
                       WeightedClustersConfigEntryConstSharedPtr config)
118
209
      : DynamicRouteEntry(route, std::move(cluster_name)), config_(std::move(config)) {
119
209
    ASSERT(config_ != nullptr);
120
209
  }
121

            
122
339
  const std::string& clusterName() const override {
123
339
    if (!config_->cluster_name_.empty()) {
124
293
      return config_->cluster_name_;
125
293
    }
126
46
    return DynamicRouteEntry::clusterName();
127
339
  }
128

            
129
3
  const MetadataMatchCriteria* metadataMatchCriteria() const override {
130
3
    if (config_->cluster_metadata_match_criteria_ != nullptr) {
131
2
      return config_->cluster_metadata_match_criteria_.get();
132
2
    }
133
1
    return DynamicRouteEntry::metadataMatchCriteria();
134
3
  }
135

            
136
  void finalizeRequestHeaders(Http::RequestHeaderMap& headers, const Formatter::Context& context,
137
                              const StreamInfo::StreamInfo& stream_info,
138
144
                              bool insert_envoy_original_path) const override {
139
144
    requestHeaderParser().evaluateHeaders(headers, context, stream_info);
140
144
    if (!config_->host_rewrite_.empty()) {
141
2
      headers.setHost(config_->host_rewrite_);
142
2
    }
143
144
    DynamicRouteEntry::finalizeRequestHeaders(headers, context, stream_info,
144
144
                                              insert_envoy_original_path);
145
144
  }
146
  Http::HeaderTransforms requestHeaderTransforms(const StreamInfo::StreamInfo& stream_info,
147
1
                                                 bool do_formatting) const override {
148
1
    auto transforms = requestHeaderParser().getHeaderTransforms(stream_info, do_formatting);
149
1
    mergeTransforms(transforms,
150
1
                    DynamicRouteEntry::requestHeaderTransforms(stream_info, do_formatting));
151
1
    return transforms;
152
1
  }
153
  void finalizeResponseHeaders(Http::ResponseHeaderMap& headers, const Formatter::Context& context,
154
145
                               const StreamInfo::StreamInfo& stream_info) const override {
155
145
    responseHeaderParser().evaluateHeaders(headers, context, stream_info);
156
145
    DynamicRouteEntry::finalizeResponseHeaders(headers, context, stream_info);
157
145
  }
158
  Http::HeaderTransforms responseHeaderTransforms(const StreamInfo::StreamInfo& stream_info,
159
1
                                                  bool do_formatting) const override {
160
1
    auto transforms = responseHeaderParser().getHeaderTransforms(stream_info, do_formatting);
161
1
    mergeTransforms(transforms,
162
1
                    DynamicRouteEntry::responseHeaderTransforms(stream_info, do_formatting));
163
1
    return transforms;
164
1
  }
165

            
166
165
  absl::optional<bool> filterDisabled(absl::string_view config_name) const override {
167
165
    absl::optional<bool> result = config_->per_filter_configs_->disabled(config_name);
168
165
    if (result.has_value()) {
169
1
      return result.value();
170
1
    }
171
164
    return DynamicRouteEntry::filterDisabled(config_name);
172
165
  }
173
  const RouteSpecificFilterConfig*
174
3
  mostSpecificPerFilterConfig(absl::string_view name) const override {
175
3
    const auto* config = config_->per_filter_configs_->get(name);
176
3
    return config ? config : DynamicRouteEntry::mostSpecificPerFilterConfig(name);
177
3
  }
178
4
  RouteSpecificFilterConfigs perFilterConfigs(absl::string_view filter_name) const override {
179
4
    auto result = DynamicRouteEntry::perFilterConfigs(filter_name);
180
4
    const auto* cfg = config_->per_filter_configs_->get(filter_name);
181
4
    if (cfg != nullptr) {
182
2
      result.push_back(cfg);
183
2
    }
184
4
    return result;
185
4
  }
186

            
187
private:
188
145
  const HeaderParser& requestHeaderParser() const {
189
145
    if (config_->request_headers_parser_ != nullptr) {
190
13
      return *config_->request_headers_parser_;
191
13
    }
192
132
    return HeaderParser::defaultParser();
193
145
  }
194
146
  const HeaderParser& responseHeaderParser() const {
195
146
    if (config_->response_headers_parser_ != nullptr) {
196
12
      return *config_->response_headers_parser_;
197
12
    }
198
134
    return HeaderParser::defaultParser();
199
146
  }
200

            
201
  WeightedClustersConfigEntryConstSharedPtr config_;
202
};
203

            
204
// Selects a cluster depending on weight parameters from configuration or from headers.
205
// This function takes into account the weights set through configuration or through
206
// runtime parameters.
207
// Returns selected cluster, or nullptr if weighted configuration is invalid.
208
RouteConstSharedPtr WeightedClusterSpecifierPlugin::pickWeightedCluster(
209
    RouteEntryAndRouteConstSharedPtr parent, const Http::RequestHeaderMap& headers,
210
210
    const StreamInfo::StreamInfo& stream_info, const uint64_t random_value) const {
211
210
  absl::optional<uint64_t> hash_value;
212

            
213
  // Only use hash policy if explicitly enabled via use_hash_policy field
214
210
  if (use_hash_policy_) {
215
110
    const auto* route_hash_policy = parent->hashPolicy();
216
110
    if (route_hash_policy != nullptr) {
217
110
      hash_value = route_hash_policy->generateHash(
218
110
          OptRef<const Http::RequestHeaderMap>(headers),
219
110
          OptRef<const StreamInfo::StreamInfo>(stream_info), nullptr);
220
110
    }
221
110
  }
222

            
223
210
  const uint64_t selection_value = hash_value.has_value() ? hash_value.value() : random_value;
224

            
225
210
  absl::optional<uint64_t> random_value_from_header;
226
  // Retrieve the random value from the header if corresponding header name is specified.
227
  // weighted_clusters_config_ is known not to be nullptr here. If it were, pickWeightedCluster
228
  // would not be called.
229
210
  if (!random_value_header_.get().empty()) {
230
2
    const auto header_value = headers.get(random_value_header_);
231
2
    if (header_value.size() == 1) {
232
      // We expect single-valued header here, otherwise it will potentially cause inconsistent
233
      // weighted cluster picking throughout the process because different values are used to
234
      // compute the selected value. So, we treat multi-valued header as invalid input and fall back
235
      // to use internally generated random number.
236
1
      uint64_t random_value = 0;
237
1
      if (absl::SimpleAtoi(header_value[0]->value().getStringView(), &random_value)) {
238
1
        random_value_from_header = random_value;
239
1
      }
240
1
    }
241

            
242
2
    if (!random_value_from_header.has_value()) {
243
      // Random value should be found here. But if it is not set due to some errors, log the
244
      // information and fallback to the random value that is set by stream id.
245
1
      ENVOY_LOG(debug, "The random value can not be found from the header and it will fall back to "
246
1
                       "the value that is set by stream id");
247
1
    }
248
2
  }
249

            
250
210
  const bool runtime_key_prefix_configured = !runtime_key_prefix_.empty();
251
210
  uint32_t total_cluster_weight = total_cluster_weight_;
252
210
  absl::InlinedVector<uint32_t, 4> cluster_weights;
253

            
254
  // if runtime config is used, we need to recompute total_weight.
255
210
  if (runtime_key_prefix_configured) {
256
    // Temporary storage to hold consistent cluster weights. Since cluster weight
257
    // can be changed with runtime keys, we need a way to gather all the weight
258
    // and aggregate the total without a change in between.
259
    // The InlinedVector will be able to handle at least 4 cluster weights
260
    // without allocation. For cases when more clusters are needed, it is
261
    // reserved to ensure at most a single allocation.
262
16
    cluster_weights.reserve(weighted_clusters_.size());
263

            
264
16
    total_cluster_weight = 0;
265
48
    for (const auto& cluster : weighted_clusters_) {
266
48
      auto cluster_weight = cluster->clusterWeight(loader_);
267
48
      cluster_weights.push_back(cluster_weight);
268
48
      if (cluster_weight > std::numeric_limits<uint32_t>::max() - total_cluster_weight) {
269
        IS_ENVOY_BUG("Sum of weight cannot overflow 2^32");
270
        return nullptr;
271
      }
272
48
      total_cluster_weight += cluster_weight;
273
48
    }
274
16
  }
275

            
276
210
  if (total_cluster_weight == 0) {
277
1
    IS_ENVOY_BUG("Sum of weight cannot be zero");
278
1
    return nullptr;
279
1
  }
280
209
  const uint64_t selected_value =
281
209
      (random_value_from_header.has_value() ? random_value_from_header.value() : selection_value) %
282
209
      total_cluster_weight;
283
209
  uint64_t begin = 0;
284
209
  uint64_t end = 0;
285
209
  auto cluster_weight = cluster_weights.begin();
286

            
287
  // Find the right cluster to route to based on the interval in which
288
  // the selected value falls. The intervals are determined as
289
  // [0, cluster1_weight), [cluster1_weight, cluster1_weight+cluster2_weight),..
290
306
  for (const auto& cluster : weighted_clusters_) {
291

            
292
306
    if (runtime_key_prefix_configured) {
293
31
      end = begin + *cluster_weight++;
294
275
    } else {
295
275
      end = begin + cluster->clusterWeight(loader_);
296
275
    }
297

            
298
306
    if (selected_value >= begin && selected_value < end) {
299
209
      if (!cluster->cluster_name_.empty()) {
300
178
        return std::make_shared<WeightedClusterEntry>(std::move(parent), "", cluster);
301
178
      }
302
31
      ASSERT(!cluster->cluster_header_name_.get().empty());
303

            
304
31
      const auto entries = headers.get(cluster->cluster_header_name_);
305
31
      absl::string_view cluster_name =
306
31
          entries.empty() ? absl::string_view{} : entries[0]->value().getStringView();
307
31
      return std::make_shared<WeightedClusterEntry>(std::move(parent), std::string(cluster_name),
308
31
                                                    cluster);
309
209
    }
310
97
    begin = end;
311
97
  }
312

            
313
  IS_ENVOY_BUG("unexpected");
314
  return nullptr;
315
209
}
316

            
317
RouteConstSharedPtr WeightedClusterSpecifierPlugin::route(RouteEntryAndRouteConstSharedPtr parent,
318
                                                          const Http::RequestHeaderMap& headers,
319
                                                          const StreamInfo::StreamInfo& stream_info,
320
210
                                                          uint64_t random) const {
321
210
  return pickWeightedCluster(std::move(parent), headers, stream_info, random);
322
210
}
323

            
324
absl::Status
325
73
WeightedClusterSpecifierPlugin::validateClusters(const Upstream::ClusterManager& cm) const {
326
96
  for (const auto& cluster : weighted_clusters_) {
327
96
    if (cluster->cluster_name_.empty() || cm.hasCluster(cluster->cluster_name_)) {
328
95
      continue; // Only check the explicit cluster name and ignore the cluster header name.
329
95
    }
330
1
    return absl::InvalidArgumentError(
331
1
        fmt::format("route: unknown weighted cluster '{}'", cluster->cluster_name_));
332
96
  }
333
72
  return absl::OkStatus();
334
73
}
335

            
336
} // namespace Router
337
} // namespace Envoy