1
#pragma once
2

            
3
#include "envoy/matcher/matcher.h"
4
#include "envoy/network/filter.h"
5
#include "envoy/server/factory_context.h"
6

            
7
#include "source/common/matcher/matcher.h"
8

            
9
#include "absl/status/status.h"
10
#include "xds/type/matcher/v3/domain.pb.h"
11
#include "xds/type/matcher/v3/domain.pb.validate.h"
12

            
13
namespace Envoy {
14
namespace Extensions {
15
namespace Common {
16
namespace Matcher {
17

            
18
using ::Envoy::Matcher::ActionMatchResult;
19
using ::Envoy::Matcher::DataInputFactoryCb;
20
using ::Envoy::Matcher::DataInputGetResult;
21
using ::Envoy::Matcher::DataInputPtr;
22
using ::Envoy::Matcher::MatchTree;
23
using ::Envoy::Matcher::OnMatch;
24
using ::Envoy::Matcher::OnMatchFactory;
25
using ::Envoy::Matcher::OnMatchFactoryCb;
26
using ::Envoy::Matcher::SkippedMatchCb;
27

            
28
/**
29
 * Configuration for domain matcher that holds all domain mappings and match actions.
30
 */
31
template <class DataType> struct DomainMatcherConfig {
32
  // Exact domain matches (e.g., "api.example.com")
33
  absl::flat_hash_map<std::string, std::shared_ptr<OnMatch<DataType>>> exact_matches_;
34

            
35
  // Wildcard matches stored without "*." prefix for efficient lookups.
36
  // Maps suffix (e.g., "example.com") to match action.
37
  absl::flat_hash_map<std::string, std::shared_ptr<OnMatch<DataType>>> wildcard_matches_;
38

            
39
  // Global wildcard "*" match. They are given lowest priority.
40
  std::shared_ptr<OnMatch<DataType>> global_wildcard_match_;
41
};
42

            
43
/**
44
 * Domain matcher which implements ServerNameMatcher specs. It matches domains
45
 * using exact lookups and wildcard patterns in the following order:
46
 * 1. Exact matches (highest priority)
47
 * 2. Wildcards by longest suffix match (not declaration order)
48
 * 3. Global wildcard "*" (lowest priority).
49
 */
50
template <class DataType> class DomainTrieMatcher : public MatchTree<DataType> {
51
public:
52
  DomainTrieMatcher(DataInputPtr<DataType>&& data_input,
53
                    absl::optional<OnMatch<DataType>> on_no_match,
54
                    std::shared_ptr<DomainMatcherConfig<DataType>> config)
55
33
      : data_input_(std::move(data_input)), on_no_match_(std::move(on_no_match)),
56
33
        config_(std::move(config)) {
57
33
    absl::Status validation_status = validateDataInputType();
58
33
    if (!validation_status.ok()) {
59
1
      throw EnvoyException(std::string(validation_status.message()));
60
1
    }
61
33
  }
62

            
63
  ActionMatchResult match(const DataType& data,
64
32
                          SkippedMatchCb skipped_match_cb = nullptr) override {
65
32
    const auto input = data_input_->get(data);
66
32
    if (input.availability() != Envoy::Matcher::DataAvailability::AllDataAvailable) {
67
1
      return ActionMatchResult::insufficientData();
68
1
    }
69

            
70
31
    absl::optional<absl::string_view> domain = input.stringData();
71
31
    if (!domain) {
72
1
      return MatchTree<DataType>::handleRecursionAndSkips(on_no_match_, data, skipped_match_cb);
73
1
    }
74
30
    if (domain->empty()) {
75
1
      return MatchTree<DataType>::handleRecursionAndSkips(on_no_match_, data, skipped_match_cb);
76
1
    }
77

            
78
    // 1. Try exact match first (highest priority).
79
29
    auto exact_it = config_->exact_matches_.find(*domain);
80
29
    if (exact_it != config_->exact_matches_.end()) {
81
8
      ActionMatchResult result =
82
8
          MatchTree<DataType>::handleRecursionAndSkips(*(exact_it->second), data, skipped_match_cb);
83

            
84
      // If ``keep_matching`` is used, treat as no match and continue to wildcards.
85
8
      if (result.isMatch() || result.isInsufficientData()) {
86
6
        return result;
87
6
      }
88
8
    }
89

            
90
    // 2. Try wildcard matches from longest suffix to shortest.
91
    // For "www.example.com", try "example.com", then "com".
92
23
    auto wildcard_matches = findAllWildcardMatches(*domain);
93
23
    for (const auto& wildcard_match : wildcard_matches) {
94
17
      ActionMatchResult result =
95
17
          MatchTree<DataType>::handleRecursionAndSkips(*wildcard_match, data, skipped_match_cb);
96

            
97
      // If ``keep_matching`` is used, treat as no match and continue to next wildcard.
98
17
      if (result.isMatch() || result.isInsufficientData()) {
99
15
        return result;
100
15
      }
101
17
    }
102

            
103
    // 3. Finally try global wildcard "*" (lowest priority).
104
8
    if (config_->global_wildcard_match_) {
105
5
      ActionMatchResult result = MatchTree<DataType>::handleRecursionAndSkips(
106
5
          *(config_->global_wildcard_match_), data, skipped_match_cb);
107

            
108
5
      if (result.isMatch() || result.isInsufficientData()) {
109
5
        return result;
110
5
      }
111
5
    }
112

            
113
3
    return MatchTree<DataType>::handleRecursionAndSkips(on_no_match_, data, skipped_match_cb);
114
8
  }
115

            
116
private:
117
  // Validate that the data input type is supported.
118
33
  absl::Status validateDataInputType() const {
119
33
    const auto input_type = data_input_->dataInputType();
120
33
    if (input_type != Envoy::Matcher::DefaultMatchingDataType) {
121
1
      return absl::InvalidArgumentError(
122
1
          absl::StrCat("Unsupported data input type: ", input_type,
123
1
                       ", currently only string type is supported in domain matcher"));
124
1
    }
125
32
    return absl::OkStatus();
126
33
  }
127

            
128
  // Find all wildcard matches for the given domain ordered from longest to shortest suffix.
129
  // Returns empty vector if no wildcard matches are found.
130
  std::vector<std::shared_ptr<OnMatch<DataType>>>
131
23
  findAllWildcardMatches(absl::string_view domain) const {
132
23
    std::vector<std::shared_ptr<OnMatch<DataType>>> matches;
133

            
134
23
    size_t dot_pos = domain.find('.');
135
63
    while (dot_pos != absl::string_view::npos) {
136
40
      const auto suffix = domain.substr(dot_pos + 1);
137

            
138
      // Direct lookup without creating temporary strings.
139
40
      auto wildcard_it = config_->wildcard_matches_.find(suffix);
140
40
      if (wildcard_it != config_->wildcard_matches_.end()) {
141
21
        matches.push_back(wildcard_it->second);
142
21
      }
143

            
144
      // Find next "dot" for shorter patterns.
145
40
      dot_pos = domain.find('.', dot_pos + 1);
146
40
    }
147

            
148
23
    return matches;
149
23
  }
150

            
151
  const DataInputPtr<DataType> data_input_;
152
  const absl::optional<OnMatch<DataType>> on_no_match_;
153
  const std::shared_ptr<DomainMatcherConfig<DataType>> config_;
154
};
155

            
156
template <class DataType>
157
class DomainTrieMatcherFactoryBase : public ::Envoy::Matcher::CustomMatcherFactory<DataType> {
158
public:
159
  ::Envoy::Matcher::MatchTreeFactoryCb<DataType>
160
  createCustomMatcherFactoryCb(const Protobuf::Message& config,
161
                               Server::Configuration::ServerFactoryContext& factory_context,
162
                               DataInputFactoryCb<DataType> data_input,
163
                               absl::optional<OnMatchFactoryCb<DataType>> on_no_match,
164
37
                               OnMatchFactory<DataType>& on_match_factory) override {
165
37
    auto typed_config = std::make_shared<xds::type::matcher::v3::ServerNameMatcher>(
166
37
        MessageUtil::downcastAndValidate<const xds::type::matcher::v3::ServerNameMatcher&>(
167
37
            config, factory_context.messageValidationVisitor()));
168

            
169
37
    absl::Status validation_status = validateConfiguration(*typed_config);
170
37
    if (!validation_status.ok()) {
171
4
      throw EnvoyException(std::string(validation_status.message()));
172
4
    }
173

            
174
33
    auto domain_config = buildDomainMatcherConfig(*typed_config, on_match_factory);
175

            
176
33
    return [data_input, domain_config, on_no_match]() {
177
33
      return std::make_unique<DomainTrieMatcher<DataType>>(
178
33
          data_input(), on_no_match ? absl::make_optional(on_no_match.value()()) : absl::nullopt,
179
33
          domain_config);
180
33
    };
181
37
  }
182

            
183
54
  ProtobufTypes::MessagePtr createEmptyConfigProto() override {
184
54
    return std::make_unique<xds::type::matcher::v3::ServerNameMatcher>();
185
54
  }
186

            
187
88
  std::string name() const override { return "envoy.matching.custom_matchers.domain_matcher"; }
188

            
189
private:
190
  absl::Status
191
37
  validateConfiguration(const xds::type::matcher::v3::ServerNameMatcher& config) const {
192
37
    absl::flat_hash_set<std::string> seen_domains;
193
37
    seen_domains.reserve(getTotalDomainCount(config));
194

            
195
75
    for (const auto& domain_matcher : config.domain_matchers()) {
196
84
      for (const auto& domain : domain_matcher.domains()) {
197
84
        if (!seen_domains.insert(domain).second) {
198
1
          return absl::InvalidArgumentError(
199
1
              absl::StrCat("Duplicate domain in ServerNameMatcher: ", domain));
200
1
        }
201

            
202
83
        absl::Status validation_status = validateDomainFormat(domain);
203
83
        if (!validation_status.ok()) {
204
3
          return validation_status;
205
3
        }
206
83
      }
207
75
    }
208

            
209
33
    return absl::OkStatus();
210
37
  }
211

            
212
83
  static absl::Status validateDomainFormat(absl::string_view domain) {
213
83
    if (domain == "*") {
214
11
      return absl::OkStatus(); // Global wildcard is valid.
215
11
    }
216

            
217
72
    if (domain.empty()) {
218
      return absl::InvalidArgumentError("Empty domain in ServerNameMatcher");
219
    }
220

            
221
    // Check for invalid wildcard patterns anywhere in the domain.
222
72
    const size_t wildcard_pos = domain.find('*');
223
72
    if (wildcard_pos != absl::string_view::npos) {
224
      // Only allow "*." at the beginning (prefix wildcard).
225
42
      if (wildcard_pos != 0 || domain.size() < 3 || domain[1] != '.') {
226
2
        return absl::InvalidArgumentError(
227
2
            absl::StrCat("Invalid wildcard domain format: ", domain,
228
2
                         ". Only '*' and '*.domain' patterns are supported"));
229
2
      }
230

            
231
      // Ensure no additional wildcards exist.
232
40
      if (domain.find('*', 1) != absl::string_view::npos) {
233
1
        return absl::InvalidArgumentError(absl::StrCat("Invalid wildcard domain format: ", domain,
234
1
                                                       ". Multiple wildcards are not supported"));
235
1
      }
236
40
    }
237

            
238
69
    return absl::OkStatus();
239
72
  }
240

            
241
37
  static size_t getTotalDomainCount(const xds::type::matcher::v3::ServerNameMatcher& config) {
242
37
    size_t count = 0;
243
75
    for (const auto& domain_matcher : config.domain_matchers()) {
244
75
      count += domain_matcher.domains().size();
245
75
    }
246
37
    return count;
247
37
  }
248

            
249
  std::shared_ptr<DomainMatcherConfig<DataType>>
250
  buildDomainMatcherConfig(const xds::type::matcher::v3::ServerNameMatcher& config,
251
33
                           OnMatchFactory<DataType>& on_match_factory) const {
252
33
    auto domain_config = std::make_shared<DomainMatcherConfig<DataType>>();
253

            
254
71
    for (const auto& domain_matcher : config.domain_matchers()) {
255
71
      auto on_match_factory_cb = *on_match_factory.createOnMatch(domain_matcher.on_match());
256
71
      auto on_match = std::make_shared<OnMatch<DataType>>(on_match_factory_cb());
257

            
258
79
      for (const auto& domain : domain_matcher.domains()) {
259
79
        if (domain == "*") {
260
          // Global wildcard. We use first declaration if multiple exist.
261
11
          if (!domain_config->global_wildcard_match_) {
262
11
            domain_config->global_wildcard_match_ = on_match;
263
11
          }
264
68
        } else if (domain[0] == '*') {
265
          // Wildcard pattern. We strip "*." prefix for efficient lookup.
266
39
          const auto suffix = domain.substr(2); // Remove "*."
267
39
          domain_config->wildcard_matches_.emplace(std::string(suffix), on_match);
268
39
        } else {
269
          // Exact match.
270
29
          domain_config->exact_matches_.emplace(domain, on_match);
271
29
        }
272
79
      }
273
71
    }
274

            
275
33
    return domain_config;
276
33
  }
277
};
278

            
279
class NetworkDomainMatcherFactory : public DomainTrieMatcherFactoryBase<Network::MatchingData> {};
280
class HttpDomainMatcherFactory : public DomainTrieMatcherFactoryBase<Http::HttpMatchingData> {};
281

            
282
} // namespace Matcher
283
} // namespace Common
284
} // namespace Extensions
285
} // namespace Envoy