1
#include "source/extensions/load_balancing_policies/override_host/load_balancer.h"
2

            
3
#include <algorithm>
4
#include <memory>
5
#include <vector>
6

            
7
#include "envoy/common/exception.h"
8
#include "envoy/common/optref.h"
9
#include "envoy/config/core/v3/base.pb.h"
10
#include "envoy/extensions/load_balancing_policies/override_host/v3/override_host.pb.h"
11
#include "envoy/http/header_map.h"
12
#include "envoy/upstream/load_balancer.h"
13
#include "envoy/upstream/upstream.h"
14

            
15
#include "source/common/common/assert.h"
16
#include "source/common/common/logger.h"
17
#include "source/common/common/thread.h"
18
#include "source/common/config/metadata.h"
19
#include "source/common/config/utility.h"
20

            
21
#include "absl/container/inlined_vector.h"
22
#include "absl/log/check.h"
23
#include "absl/status/status.h"
24
#include "absl/strings/match.h"
25
#include "absl/strings/numbers.h"
26
#include "absl/strings/str_cat.h"
27
#include "absl/strings/str_join.h"
28
#include "absl/strings/string_view.h"
29
#include "absl/types/optional.h"
30
#include "load_balancer.h"
31

            
32
namespace Envoy {
33
namespace Extensions {
34
namespace LoadBalancingPolicies {
35
namespace OverrideHost {
36

            
37
using ::envoy::extensions::load_balancing_policies::override_host::v3::OverrideHost;
38
using ::Envoy::Http::HeaderMap;
39
using ::Envoy::Server::Configuration::ServerFactoryContext;
40
using ::Envoy::Upstream::HostConstSharedPtr;
41
using ::Envoy::Upstream::HostMapConstSharedPtr;
42
using ::Envoy::Upstream::LoadBalancerConfig;
43
using ::Envoy::Upstream::LoadBalancerContext;
44
using ::Envoy::Upstream::LoadBalancerParams;
45
using ::Envoy::Upstream::LoadBalancerPtr;
46
using ::Envoy::Upstream::TypedLoadBalancerFactory;
47

            
48
OverrideHostLbConfig::OverrideHostLbConfig(std::vector<OverrideSource>&& override_host_sources,
49
                                           absl::optional<Config::MetadataKey>&& selected_host_key,
50
                                           TypedLoadBalancerFactory* fallback_load_balancer_factory,
51
                                           LoadBalancerConfigPtr&& fallback_load_balancer_config)
52
29
    : fallback_picker_lb_config_{fallback_load_balancer_factory,
53
29
                                 std::move(fallback_load_balancer_config)},
54
29
      override_host_sources_(std::move(override_host_sources)),
55
29
      selected_host_key_(std::move(selected_host_key)) {}
56

            
57
OverrideHostLbConfig::OverrideSource
58
57
OverrideHostLbConfig::OverrideSource::make(const OverrideHost::OverrideHostSource& config) {
59
57
  return OverrideHostLbConfig::OverrideSource{
60
57
      !config.header().empty()
61
57
          ? absl::optional<Http::LowerCaseString>(Http::LowerCaseString(config.header()))
62
57
          : absl::nullopt,
63
57
      config.has_metadata() ? absl::optional<Config::MetadataKey>(config.metadata())
64
57
                            : absl::nullopt};
65
57
}
66

            
67
absl::StatusOr<std::vector<OverrideHostLbConfig::OverrideSource>>
68
OverrideHostLbConfig::makeOverrideSources(
69
32
    const Protobuf::RepeatedPtrField<OverrideHost::OverrideHostSource>& override_sources) {
70
32
  std::vector<OverrideSource> result;
71
57
  for (const OverrideHost::OverrideHostSource& source : override_sources) {
72
57
    result.push_back(OverrideSource::make(source));
73
    // Either header name or metadata key must be present
74
57
    if (!result.back().header_name.has_value() && !result.back().metadata_key.has_value()) {
75
1
      return absl::InvalidArgumentError("Empty override source");
76
1
    }
77
56
    if (result.back().header_name.has_value() && result.back().metadata_key.has_value()) {
78
1
      return absl::InvalidArgumentError("Only one override source must be set");
79
1
    }
80
56
  }
81
30
  return result;
82
32
}
83

            
84
absl::StatusOr<std::unique_ptr<OverrideHostLbConfig>>
85
32
OverrideHostLbConfig::make(const OverrideHost& config, ServerFactoryContext& context) {
86
  // Must be validated before calling this function.
87
32
  absl::StatusOr<std::vector<OverrideSource>> override_host_sources =
88
32
      makeOverrideSources(config.override_host_sources());
89
32
  RETURN_IF_NOT_OK(override_host_sources.status());
90

            
91
30
  absl::optional<Config::MetadataKey> selected_host_key;
92
30
  if (config.has_selected_host_key()) {
93
4
    selected_host_key.emplace(config.selected_host_key());
94
4
  }
95

            
96
30
  ASSERT(config.has_fallback_policy());
97
30
  absl::InlinedVector<absl::string_view, 4> missing_policies;
98
30
  for (const auto& policy : config.fallback_policy().policies()) {
99
30
    TypedLoadBalancerFactory* factory =
100
30
        Envoy::Config::Utility::getAndCheckFactory<TypedLoadBalancerFactory>(
101
30
            policy.typed_extension_config(), /*is_optional=*/true);
102
30
    if (factory != nullptr) {
103
      // Load and validate the configuration.
104
29
      auto proto_message = factory->createEmptyConfigProto();
105
29
      RETURN_IF_NOT_OK(Envoy::Config::Utility::translateOpaqueConfig(
106
29
          policy.typed_extension_config().typed_config(), context.messageValidationVisitor(),
107
29
          *proto_message));
108

            
109
29
      auto fallback_load_balancer_config = factory->loadConfig(context, *proto_message);
110
29
      RETURN_IF_NOT_OK_REF(fallback_load_balancer_config.status());
111
29
      return std::unique_ptr<OverrideHostLbConfig>(new OverrideHostLbConfig(
112
29
          std::move(override_host_sources).value(), std::move(selected_host_key), factory,
113
29
          std::move(fallback_load_balancer_config.value())));
114
29
    }
115
1
    missing_policies.push_back(policy.typed_extension_config().name());
116
1
  }
117
1
  return absl::InvalidArgumentError(
118
1
      absl::StrCat("dynamic forwarding LB: didn't find a registered fallback load balancer factory "
119
1
                   "with names from ",
120
1
                   absl::StrJoin(missing_policies, ", ")));
121
30
}
122

            
123
Upstream::ThreadAwareLoadBalancerPtr OverrideHostLbConfig::create(const ClusterInfo& cluster_info,
124
                                                                  const PrioritySet& priority_set,
125
                                                                  Loader& runtime,
126
                                                                  RandomGenerator& random,
127
27
                                                                  TimeSource& time_source) const {
128
27
  return fallback_picker_lb_config_.load_balancer_factory->create(
129
27
      makeOptRefFromPtr<const LoadBalancerConfig>(
130
27
          fallback_picker_lb_config_.load_balancer_config.get()),
131
27
      cluster_info, priority_set, runtime, random, time_source);
132
27
}
133

            
134
27
absl::Status OverrideHostLoadBalancer::initialize() {
135
27
  ASSERT(fallback_picker_lb_ != nullptr); // Always needs a locality picker LB.
136
27
  return fallback_picker_lb_->initialize();
137
27
}
138

            
139
27
LoadBalancerFactorySharedPtr OverrideHostLoadBalancer::factory() {
140
27
  ASSERT_IS_MAIN_OR_TEST_THREAD();
141
27
  if (!factory_) {
142
27
    factory_ = std::make_shared<LoadBalancerFactoryImpl>(config_, fallback_picker_lb_->factory());
143
27
  }
144
27
  return factory_;
145
27
}
146

            
147
HostConstSharedPtr
148
1
OverrideHostLoadBalancer::LoadBalancerImpl::peekAnotherHost(LoadBalancerContext* context) {
149
  // TODO(yavlasov): Return a host from request metadata if present.
150
1
  return fallback_picker_lb_->peekAnotherHost(context);
151
1
}
152

            
153
HostSelectionResponse
154
59
OverrideHostLoadBalancer::LoadBalancerImpl::chooseHostInternal(LoadBalancerContext* context) {
155
59
  if (!context || !context->requestStreamInfo()) {
156
    // If there is no context or no request stream info, we can't use the
157
    // metadata, so we just return a host from the fallback picker.
158
1
    return fallback_picker_lb_->chooseHost(context);
159
1
  }
160

            
161
58
  OverrideHostFilterState* override_host_state = nullptr;
162
58
  if (override_host_state =
163
58
          context->requestStreamInfo()->filterState()->getDataMutable<OverrideHostFilterState>(
164
58
              OverrideHostFilterState::kFilterStateKey);
165
58
      override_host_state == nullptr) {
166
47
    auto state_ptr = std::make_shared<OverrideHostFilterState>(getSelectedHosts(context));
167
47
    override_host_state = state_ptr.get();
168

            
169
47
    context->requestStreamInfo()->filterState()->setData(
170
47
        OverrideHostFilterState::kFilterStateKey, std::move(state_ptr),
171
47
        StreamInfo::FilterState::StateType::Mutable);
172
47
  }
173

            
174
58
  if (override_host_state->empty()) {
175
16
    ENVOY_LOG(trace, "No overridden hosts were found. Using fallback LB policy.");
176
16
    return fallback_picker_lb_->chooseHost(context);
177
16
  }
178

            
179
42
  if (HostConstSharedPtr host = getEndpoint(*override_host_state); host != nullptr) {
180
34
    return {host};
181
34
  }
182

            
183
  // If some endpoints were found, but none of them are available in
184
  // the cluster endpoint set, or the number of retries in the retry policy
185
  // exceeds the number of fallback endpoints, then use to the fallback LB
186
  // policy.
187
8
  ENVOY_LOG(trace, "Failed to find any endpoints from metadata in the cluster. "
188
8
                   "Using fallback LB policy.");
189
8
  return fallback_picker_lb_->chooseHost(context);
190
42
}
191

            
192
HostSelectionResponse
193
59
OverrideHostLoadBalancer::LoadBalancerImpl::chooseHost(LoadBalancerContext* context) {
194
59
  auto response = chooseHostInternal(context);
195
59
  addSelectedHostKey(context, response);
196
59
  return response;
197
59
}
198

            
199
void OverrideHostLoadBalancer::LoadBalancerImpl::addSelectedHostKey(
200
59
    LoadBalancerContext* context, HostSelectionResponse& response) {
201
59
  if (!config_.selectedHostKey().has_value()) {
202
55
    return;
203
55
  }
204

            
205
4
  if (response.host == nullptr) {
206
    return;
207
  }
208

            
209
4
  const std::string selected_endpoint = response.host->address()->asString();
210
4
  const Config::MetadataKey& metadata_key = config_.selectedHostKey().value();
211
4
  if (metadata_key.path_.size() < 1) {
212
    // Should not be possible based on proto validation, catching anyways.
213
    ENVOY_LOG(trace, "Path was not provided in selected_host_key.");
214
    return;
215
  }
216

            
217
4
  Protobuf::Struct updated_metadata;
218
4
  Protobuf::Struct* updated_metadata_ptr = &updated_metadata;
219

            
220
4
  for (size_t i = 0; i + 1 < metadata_key.path_.size(); i++) {
221
    Protobuf::Value& current_val = (*updated_metadata_ptr->mutable_fields())[metadata_key.path_[i]];
222
    updated_metadata_ptr = current_val.mutable_struct_value();
223
  }
224

            
225
4
  (*updated_metadata_ptr->mutable_fields())[metadata_key.path_.back()].set_string_value(
226
4
      selected_endpoint);
227

            
228
  // Set the value of the metadata key to be the host:port
229
4
  context->requestStreamInfo()->setDynamicMetadata(metadata_key.key_, updated_metadata);
230
4
}
231

            
232
absl::optional<absl::string_view>
233
OverrideHostLoadBalancer::LoadBalancerImpl::getSelectedHostsFromMetadata(
234
52
    const ::envoy::config::core::v3::Metadata& metadata, const Config::MetadataKey& metadata_key) {
235
52
  const Protobuf::Value& metadata_value = Config::Metadata::metadataValue(&metadata, metadata_key);
236
  // TODO(yanavlasov): make it distinguish between not-present and invalid metadata.
237
52
  if (metadata_value.has_string_value()) {
238
30
    return absl::string_view{metadata_value.string_value()};
239
30
  }
240
22
  return absl::nullopt;
241
52
}
242

            
243
absl::optional<absl::string_view>
244
OverrideHostLoadBalancer::LoadBalancerImpl::getSelectedHostsFromHeader(
245
27
    const Envoy::Http::RequestHeaderMap* header_map, const Http::LowerCaseString& header_name) {
246
27
  if (!header_map) {
247
2
    return absl::nullopt;
248
2
  }
249
25
  HeaderMap::GetResult result = header_map->get(header_name);
250
25
  if (result.empty()) {
251
11
    return absl::nullopt;
252
11
  }
253

            
254
  // Use only the first value of the header, if it happens to be have multiple.
255
14
  return result[0]->value().getStringView();
256
25
}
257

            
258
std::vector<std::string>
259
47
OverrideHostLoadBalancer::LoadBalancerImpl::getSelectedHosts(LoadBalancerContext* context) {
260
  // This is checked by config validation.
261
47
  ASSERT(!config_.overrideHostSources().empty());
262

            
263
47
  std::vector<std::string> selected_hosts;
264
47
  selected_hosts.reserve(4);
265

            
266
79
  for (const auto& override_source : config_.overrideHostSources()) {
267
    // This is checked by config validation
268
79
    ASSERT(override_source.header_name.has_value() != override_source.metadata_key.has_value());
269

            
270
79
    absl::optional<absl::string_view> hosts;
271
79
    if (override_source.header_name.has_value()) {
272
27
      hosts = getSelectedHostsFromHeader(context->downstreamHeaders(),
273
27
                                         override_source.header_name.value());
274
52
    } else if (override_source.metadata_key.has_value()) {
275
      // Lookup selected endpoints in the request metadata if the header based
276
      // selection is not enabled.
277
52
      hosts = getSelectedHostsFromMetadata(context->requestStreamInfo()->dynamicMetadata(),
278
52
                                           override_source.metadata_key.value());
279
52
    }
280

            
281
79
    if (!hosts.has_value()) {
282
35
      continue;
283
35
    }
284

            
285
46
    for (absl::string_view host : absl::StrSplit(hosts.value(), ',', absl::SkipWhitespace())) {
286
46
      selected_hosts.push_back(std::string(absl::StripAsciiWhitespace(host)));
287
46
    }
288
44
  }
289

            
290
47
  ENVOY_LOG(trace, "Selected endpoints: {}", absl::StrJoin(selected_hosts, ","));
291
47
  return selected_hosts;
292
47
}
293

            
294
HostConstSharedPtr
295
38
OverrideHostLoadBalancer::LoadBalancerImpl::findHost(absl::string_view endpoint) {
296
38
  HostMapConstSharedPtr hosts = priority_set_.crossPriorityHostMap();
297
38
  if (hosts == nullptr) {
298
1
    return nullptr;
299
1
  }
300

            
301
37
  ENVOY_LOG(trace, "Looking up {} in {}", endpoint,
302
37
            absl::StrJoin(*hosts, ", ",
303
37
                          [](std::string* out, Envoy::Upstream::HostMap::const_reference entry) {
304
37
                            absl::StrAppend(out, entry.first);
305
37
                          }));
306

            
307
37
  if (const auto host_iterator = hosts->find(endpoint); host_iterator != hosts->end()) {
308
    // TODO(yanavlasov): Validate that host health status did not change.
309
34
    return host_iterator->second;
310
34
  }
311
3
  return nullptr;
312
37
}
313

            
314
HostConstSharedPtr OverrideHostLoadBalancer::LoadBalancerImpl::getEndpoint(
315
42
    OverrideHostFilterState& override_host_state) {
316
42
  absl::string_view override_host = override_host_state.consumeNextHost();
317
46
  while (!override_host.empty()) {
318
38
    if (HostConstSharedPtr host = findHost(override_host); host != nullptr) {
319
34
      ENVOY_LOG(trace, "Selected endpoint: {}", override_host);
320
34
      return host;
321
34
    }
322
4
    override_host = override_host_state.consumeNextHost();
323
4
  }
324

            
325
  // If we reach here, it means that all the selected hosts are not available or have used.
326
8
  ENVOY_LOG(trace, "Number of attempts has exceeded the number of override hosts.");
327

            
328
8
  return nullptr;
329
42
}
330

            
331
LoadBalancerPtr
332
39
OverrideHostLoadBalancer::LoadBalancerFactoryImpl::create(LoadBalancerParams params) {
333
39
  LoadBalancerPtr fallback_picker_lb = fallback_picker_lb_factory_->create(params);
334
39
  ASSERT(fallback_picker_lb != nullptr); // Factory can not create null LB.
335
39
  return std::make_unique<LoadBalancerImpl>(config_, std::move(fallback_picker_lb),
336
39
                                            params.priority_set);
337
39
}
338

            
339
} // namespace OverrideHost
340
} // namespace LoadBalancingPolicies
341
} // namespace Extensions
342
} // namespace Envoy