1
#pragma once
2

            
3
#include "envoy/matcher/matcher.h"
4
#include "envoy/network/filter.h"
5
#include "envoy/server/factory_context.h"
6

            
7
#include "source/common/matcher/matcher.h"
8
#include "source/common/network/lc_trie.h"
9
#include "source/common/network/utility.h"
10

            
11
#include "xds/type/matcher/v3/ip.pb.h"
12
#include "xds/type/matcher/v3/ip.pb.validate.h"
13

            
14
namespace Envoy {
15
namespace Extensions {
16
namespace Common {
17
namespace Matcher {
18

            
19
using ::Envoy::Matcher::ActionMatchResult;
20
using ::Envoy::Matcher::DataInputFactoryCb;
21
using ::Envoy::Matcher::DataInputGetResult;
22
using ::Envoy::Matcher::DataInputPtr;
23
using ::Envoy::Matcher::MatchTree;
24
using ::Envoy::Matcher::OnMatch;
25
using ::Envoy::Matcher::OnMatchFactory;
26
using ::Envoy::Matcher::OnMatchFactoryCb;
27
using ::Envoy::Matcher::SkippedMatchCb;
28

            
29
template <class DataType> struct IpRangeNode {
30
  size_t index_;
31
  uint32_t prefix_len_;
32
  bool exclusive_;
33
  std::shared_ptr<OnMatch<DataType>> on_match_;
34

            
35
32
  friend bool operator==(const IpRangeNode<DataType>& lhs, const IpRangeNode<DataType>& rhs) {
36
32
    return lhs.index_ == rhs.index_ && lhs.prefix_len_ == rhs.prefix_len_ &&
37
32
           lhs.exclusive_ == rhs.exclusive_ && lhs.on_match_ == rhs.on_match_;
38
32
  }
39
  template <typename H>
40
  friend H AbslHashValue(H h, // NOLINT(readability-identifier-naming)
41
832
                         const IpRangeNode<DataType>& node) {
42
832
    return H::combine(std::move(h), node.index_, node.prefix_len_, node.exclusive_, node.on_match_);
43
832
  }
44
};
45

            
46
template <class DataType> struct IpRangeNodeComparator {
47
26
  inline bool operator()(const IpRangeNode<DataType>& lhs, const IpRangeNode<DataType>& rhs) const {
48
26
    if (lhs.prefix_len_ > rhs.prefix_len_) {
49
9
      return true;
50
9
    }
51
17
    if (lhs.prefix_len_ == rhs.prefix_len_ && lhs.index_ < rhs.index_) {
52
2
      return true;
53
2
    }
54
15
    return false;
55
17
  }
56
};
57

            
58
/**
59
 * Implementation of a `sublinear` LC-trie matcher for IP ranges.
60
 */
61
template <class DataType> class IpRangeMatcher : public MatchTree<DataType> {
62
public:
63
  IpRangeMatcher(DataInputPtr<DataType>&& data_input, absl::optional<OnMatch<DataType>> on_no_match,
64
                 const std::shared_ptr<Network::LcTrie::LcTrie<IpRangeNode<DataType>>>& trie)
65
56
      : data_input_(std::move(data_input)), on_no_match_(std::move(on_no_match)), trie_(trie) {
66
56
    auto input_type = data_input_->dataInputType();
67
56
    if (input_type != Envoy::Matcher::DefaultMatchingDataType) {
68
1
      throw EnvoyException(
69
1
          absl::StrCat("Unsupported data input type: ", input_type,
70
1
                       ", currently only string type is supported in IP range matcher"));
71
1
    }
72
56
  }
73

            
74
  ActionMatchResult match(const DataType& data,
75
62
                          SkippedMatchCb skipped_match_cb = nullptr) override {
76
62
    const auto input = data_input_->get(data);
77
62
    if (input.availability() != Envoy::Matcher::DataAvailability::AllDataAvailable) {
78
2
      return ActionMatchResult::insufficientData();
79
2
    }
80
60
    auto string_data = input.stringData();
81
60
    if (!string_data) {
82
9
      return MatchTree<DataType>::handleRecursionAndSkips(on_no_match_, data, skipped_match_cb);
83
9
    }
84
51
    const Network::Address::InstanceConstSharedPtr addr =
85
51
        Network::Utility::parseInternetAddressNoThrow(std::string(*string_data));
86
51
    if (!addr) {
87
2
      return MatchTree<DataType>::handleRecursionAndSkips(on_no_match_, data, skipped_match_cb);
88
2
    }
89
49
    auto values = trie_->getData(addr);
90
    // The candidates returned by the LC trie are not in any specific order, so we
91
    // sort them by the prefix length first (longest first), and the order of declaration second.
92
49
    std::sort(values.begin(), values.end(), IpRangeNodeComparator<DataType>());
93
49
    bool first = true;
94
51
    for (const auto& node : values) {
95
42
      if (!first && node.exclusive_) {
96
1
        continue;
97
1
      }
98
      // handleRecursionAndSkips should only return match-failure, no-match, or an action cb.
99
41
      ActionMatchResult processed_match =
100
41
          MatchTree<DataType>::handleRecursionAndSkips(*node.on_match_, data, skipped_match_cb);
101

            
102
41
      if (processed_match.isMatch() || processed_match.isInsufficientData()) {
103
36
        return processed_match;
104
36
      }
105
      // No-match isn't definitive, so continue checking nodes.
106
5
      if (first) {
107
4
        first = false;
108
4
      }
109
5
    }
110
13
    return MatchTree<DataType>::handleRecursionAndSkips(on_no_match_, data, skipped_match_cb);
111
49
  }
112

            
113
private:
114
  const DataInputPtr<DataType> data_input_;
115
  const absl::optional<OnMatch<DataType>> on_no_match_;
116
  std::shared_ptr<Network::LcTrie::LcTrie<IpRangeNode<DataType>>> trie_;
117
};
118

            
119
template <class DataType>
120
class IpRangeMatcherFactoryBase : public ::Envoy::Matcher::CustomMatcherFactory<DataType> {
121
public:
122
  ::Envoy::Matcher::MatchTreeFactoryCb<DataType>
123
  createCustomMatcherFactoryCb(const Protobuf::Message& config,
124
                               Server::Configuration::ServerFactoryContext& factory_context,
125
                               DataInputFactoryCb<DataType> data_input,
126
                               absl::optional<OnMatchFactoryCb<DataType>> on_no_match,
127
57
                               OnMatchFactory<DataType>& on_match_factory) override {
128
57
    const auto& typed_config =
129
57
        MessageUtil::downcastAndValidate<const xds::type::matcher::v3::IPMatcher&>(
130
57
            config, factory_context.messageValidationVisitor());
131
57
    std::vector<OnMatchFactoryCb<DataType>> match_children;
132
57
    match_children.reserve(typed_config.range_matchers().size());
133
78
    for (const auto& range_matcher : typed_config.range_matchers()) {
134
78
      match_children.push_back(*on_match_factory.createOnMatch(range_matcher.on_match()));
135
78
    }
136
57
    std::vector<std::pair<IpRangeNode<DataType>, std::vector<Network::Address::CidrRange>>> data;
137
57
    data.reserve(match_children.size());
138
57
    size_t i = 0;
139
    // Ranges might have variable prefix length so we cannot combine them into one node because
140
    // then the matched prefix length cannot be determined.
141
78
    for (const auto& range_matcher : typed_config.range_matchers()) {
142
78
      auto on_match = std::make_shared<OnMatch<DataType>>(match_children[i++]());
143
90
      for (const auto& range : range_matcher.ranges()) {
144
90
        IpRangeNode<DataType> node = {i, range.prefix_len().value(), range_matcher.exclusive(),
145
90
                                      on_match};
146
90
        data.push_back({node,
147
90
                        {THROW_OR_RETURN_VALUE(Network::Address::CidrRange::create(range),
148
90
                                               Network::Address::CidrRange)}});
149
90
      }
150
78
    }
151
57
    auto lc_trie = std::make_shared<Network::LcTrie::LcTrie<IpRangeNode<DataType>>>(data);
152
57
    return [data_input, lc_trie, on_no_match]() {
153
56
      return std::make_unique<IpRangeMatcher<DataType>>(
154
56
          data_input(), on_no_match ? absl::make_optional(on_no_match.value()()) : absl::nullopt,
155
56
          lc_trie);
156
56
    };
157
57
  };
158
84
  ProtobufTypes::MessagePtr createEmptyConfigProto() override {
159
84
    return std::make_unique<xds::type::matcher::v3::IPMatcher>();
160
84
  }
161
546
  std::string name() const override { return "envoy.matching.custom_matchers.ip_range_matcher"; }
162
};
163

            
164
class NetworkIpRangeMatcherFactory : public IpRangeMatcherFactoryBase<Network::MatchingData> {};
165
class UdpNetworkIpRangeMatcherFactory : public IpRangeMatcherFactoryBase<Network::UdpMatchingData> {
166
};
167
class HttpIpRangeMatcherFactory : public IpRangeMatcherFactoryBase<Http::HttpMatchingData> {};
168

            
169
} // namespace Matcher
170
} // namespace Common
171
} // namespace Extensions
172
} // namespace Envoy