1
#pragma once
2

            
3
#include <algorithm>
4
#include <climits>
5
#include <vector>
6

            
7
#include "envoy/common/exception.h"
8
#include "envoy/common/platform.h"
9
#include "envoy/network/address.h"
10

            
11
#include "source/common/common/assert.h"
12
#include "source/common/common/utility.h"
13
#include "source/common/network/address_impl.h"
14
#include "source/common/network/cidr_range.h"
15
#include "source/common/network/utility.h"
16

            
17
#include "absl/container/node_hash_set.h"
18
#include "absl/numeric/int128.h"
19
#include "fmt/format.h"
20

            
21
namespace Envoy {
22
namespace Network {
23
namespace LcTrie {
24

            
25
/**
26
 * Maximum number of nodes an LC trie can hold.
27
 * @note If the size of LcTrieInternal::LcNode::address_ ever changes, this constant
28
 *       should be changed to match.
29
 */
30
constexpr size_t MaxLcTrieNodes = (1 << 20);
31

            
32
/**
33
 * Level Compressed Trie for associating data with CIDR ranges. Both IPv4 and IPv6 addresses are
34
 * supported within this class with no calling pattern changes.
35
 *
36
 * The algorithm to build the LC-Trie is described in the paper 'IP-address lookup using LC-tries'
37
 * by 'S. Nilsson' and 'G. Karlsson'. The paper and reference C implementation can be found here:
38
 * https://www.nada.kth.se/~snilsson/publications/IP-address-lookup-using-LC-tries/
39
 *
40
 * Refer to LcTrieInternal for implementation and algorithm details.
41
 */
42
template <class T> class LcTrie {
43
public:
44
  /**
45
   * @param data supplies a vector of data and CIDR ranges.
46
   * @param exclusive if true then only data for the most specific subnet will be returned
47
                      (i.e. data isn't inherited from wider ranges).
48
   * @param fill_factor supplies the fraction of completeness to use when calculating the branch
49
   *                    value for a sub-trie.
50
   * @param root_branching_factor supplies the branching factor at the root.
51
   *
52
   * TODO(ccaraman): Investigate if a non-zero root branching factor should be the default. The
53
   * paper suggests for large LC-Tries to use the value '16'. It reduces the depth of the trie.
54
   * However, there is no suggested values for smaller LC-Tries. With perf tests, it is possible to
55
   * get this data for smaller LC-Tries. Another option is to expose this in the configuration and
56
   * let consumers decide.
57
   */
58
  LcTrie(const std::vector<std::pair<T, std::vector<Address::CidrRange>>>& data,
59
58204
         bool exclusive = false, double fill_factor = 0.5, uint32_t root_branching_factor = 0) {
60

            
61
    // The LcTrie implementation uses 20-bit "pointers" in its compact internal representation,
62
    // so it cannot hold more than 2^20 nodes. But the number of nodes can be greater than the
63
    // number of supported prefixes. Given N prefixes in the data input list, step 2 below can
64
    // produce a new list of up to 2*N prefixes to insert in the LC trie. And the LC trie can
65
    // use up to 2*N/fill_factor nodes.
66
58204
    size_t num_prefixes = 0;
67
58301
    for (const auto& pair_data : data) {
68
35068
      num_prefixes += pair_data.second.size();
69
35068
    }
70
58204
    const size_t max_prefixes = MaxLcTrieNodes * fill_factor / 2;
71
58204
    if (num_prefixes > max_prefixes) {
72
2
      ExceptionUtil::throwEnvoyException(
73
2
          fmt::format("The input vector has '{0}' CIDR range entries. LC-Trie "
74
2
                      "can only support '{1}' CIDR ranges with the specified "
75
2
                      "fill factor.",
76
2
                      num_prefixes, max_prefixes));
77
2
    }
78

            
79
    // Step 1: separate the provided prefixes by protocol (IPv4 vs IPv6),
80
    // and build a Binary Trie per protocol.
81
    //
82
    // For example, if the input prefixes are
83
    //   A: 0.0.0.0/0
84
    //   B: 128.0.0.0/2  (10000000.0.0.0/2 in binary)
85
    //   C: 192.0.0.0/2  (11000000.0.0.0/2)
86
    // the Binary Trie for IPv4 will look like this at the end of step 1:
87
    //          +---+
88
    //          | A |
89
    //          +---+
90
    //               \ 1
91
    //              +---+
92
    //              |   |
93
    //              +---+
94
    //            0/     \1
95
    //          +---+   +---+
96
    //          | B |   | C |
97
    //          +---+   +---+
98
    //
99
    // Note that the prefixes in this example are nested: any IPv4 address
100
    // that matches B or C will also match A. Unfortunately, the classic LC Trie
101
    // algorithm does not support nested prefixes. The next step will solve that
102
    // problem.
103

            
104
58204
    BinaryTrie<Ipv4> ipv4_temp(exclusive);
105
58204
    BinaryTrie<Ipv6> ipv6_temp(exclusive);
106
58299
    for (const auto& pair_data : data) {
107
69879
      for (const auto& cidr_range : pair_data.second) {
108
69879
        if (cidr_range.ip()->version() == Address::IpVersion::v4) {
109
35032
          IpPrefix<Ipv4> ip_prefix(ntohl(cidr_range.ip()->ipv4()->address()), cidr_range.length(),
110
35032
                                   pair_data.first);
111
35032
          ipv4_temp.insert(ip_prefix);
112
35032
        } else {
113
34847
          IpPrefix<Ipv6> ip_prefix(Utility::Ip6ntohl(cidr_range.ip()->ipv6()->address()),
114
34847
                                   cidr_range.length(), pair_data.first);
115
34847
          ipv6_temp.insert(ip_prefix);
116
34847
        }
117
69879
      }
118
35066
    }
119

            
120
    // Step 2: push each Binary Trie's prefixes to its leaves.
121
    //
122
    // Continuing the previous example, the Binary Trie will look like this
123
    // at the end of step 2:
124
    //          +---+
125
    //          |   |
126
    //          +---+
127
    //        0/     \ 1
128
    //      +---+   +---+
129
    //      | A |   |   |
130
    //      +---+   +---+
131
    //            0/     \1
132
    //          +---+   +---+
133
    //          |A,B|   |A,C|
134
    //          +---+   +---+
135
    //
136
    // This trie yields the same match results as the original trie from
137
    // step 1. But it has a useful new property: now that all the prefixes
138
    // are at the leaves, they are disjoint: no prefix is nested under another.
139

            
140
58204
    std::vector<IpPrefix<Ipv4>> ipv4_prefixes = ipv4_temp.pushLeaves();
141
58204
    std::vector<IpPrefix<Ipv6>> ipv6_prefixes = ipv6_temp.pushLeaves();
142

            
143
    // Step 3: take the disjoint prefixes from the leaves of each Binary Trie
144
    // and use them to construct an LC Trie.
145
    //
146
    // Example inputs (from the leaves of the Binary Trie at the end of step 2)
147
    //   A:   0.0.0.0/1
148
    //   A,B: 128.0.0.0/2
149
    //   A,C: 192.0.0.0/2
150
    //
151
    // The LC Trie generated from these inputs with fill_factor=0.5 and root_branching_factor=0
152
    // will be:
153
    //
154
    //       +---------------------------+
155
    //       | branch_factor=2, skip = 0 |
156
    //       +---------------------------+
157
    //    00/       01|         |10       \11
158
    //   +---+      +---+     +---+      +---+
159
    //   | A |      | A |     |A,B|      |A,C|
160
    //   +---+      +---+     +---+      +---+
161
    //
162
    // Or, in the internal vector form that the LcTrie class uses for memory-efficiency,
163
    //    # | branch | skip | first_child | data | note
164
    //   ---+--------+------+-------------+------+--------------------------------------------------
165
    //    0 |      2 |    0 |           1 |  -   | (1 << branch) == 4 children, starting at offset 1
166
    //    1 |      - |    0 |           - |  A   | 1st child of node 0, reached if next bits are 00
167
    //    2 |      - |    0 |           - |  A   |   .
168
    //    3 |      - |    0 |           - |  A,B |   .
169
    //    4 |      - |    0 |           - |  A,C | 4th child of node 0, reached if next bits are 11
170
    //
171
    // The Nilsson and Karlsson paper linked in lc_trie.h has a more thorough example.
172

            
173
58204
    ipv4_trie_.reset(new LcTrieInternal<Ipv4>(ipv4_prefixes, fill_factor, root_branching_factor));
174
58204
    ipv6_trie_.reset(new LcTrieInternal<Ipv6>(ipv6_prefixes, fill_factor, root_branching_factor));
175
58204
  }
176

            
177
  /**
178
   * Retrieve data associated with the CIDR range that contains `ip_address`. Both IPv4 and IPv6
179
   * addresses are supported.
180
   * @param  ip_address supplies the IP address.
181
   * @return a vector of data from the CIDR ranges and IP addresses that contains 'ip_address'. An
182
   * empty vector is returned if no prefix contains 'ip_address' or there is no data for the IP
183
   * version of the ip_address.
184
   */
185
82663
  std::vector<T> getData(const Network::Address::InstanceConstSharedPtr& ip_address) const {
186
82663
    if (ip_address->ip()->version() == Address::IpVersion::v4) {
187
82351
      Ipv4 ip = ntohl(ip_address->ip()->ipv4()->address());
188
82351
      return ipv4_trie_->getData(ip);
189
82609
    } else {
190
312
      Ipv6 ip = Utility::Ip6ntohl(ip_address->ip()->ipv6()->address());
191
312
      return ipv6_trie_->getData(ip);
192
312
    }
193
82663
  }
194

            
195
private:
196
  /**
197
   * Extract n bits from input starting at position p.
198
   * @param p supplies the position.
199
   * @param n supplies the number of bits to extract.
200
   * @param input supplies the IP address to extract bits from. The IP address is stored in host
201
   *              byte order.
202
   * @return extracted bits in the format of IpType.
203
   */
204
  template <class IpType, uint32_t address_size = CHAR_BIT * sizeof(IpType)>
205
297280
  static IpType extractBits(uint32_t p, uint32_t n, IpType input) {
206
    // The IP's are stored in host byte order.
207
    // By shifting the value to the left by p bits(and back), the bits between 0 and p-1 are
208
    // zero'd out. Then to get the n bits, shift the IP back by the address_size minus the number
209
    // of desired bits.
210
297280
    if (n == 0) {
211
164912
      return IpType(0);
212
164912
    }
213
132368
    return input << p >> (address_size - n);
214
297280
  }
215

            
216
  /**
217
   * Removes n bits from input starting at 0.
218
   * @param n supplies the number of bits to remove.
219
   * @param input supplies the IP address to remove bits from. The IP address is stored in host
220
   *              byte order.
221
   * @return input with 0 through n-1 bits cleared.
222
   */
223
  template <class IpType, uint32_t address_size = CHAR_BIT * sizeof(IpType)>
224
1174
  static IpType removeBits(uint32_t n, IpType input) {
225
    // The IP's are stored in host byte order.
226
    // By shifting the value to the left by n bits and back, the bits between 0 and n-1
227
    // (inclusively) are zero'd out.
228
1174
    return input << n >> n;
229
1174
  }
230

            
231
  // IP addresses are stored in host byte order to simplify
232
  using Ipv4 = uint32_t;
233
  using Ipv6 = absl::uint128;
234

            
235
  using DataSet = absl::node_hash_set<T>;
236
  using DataSetSharedPtr = std::shared_ptr<DataSet>;
237

            
238
  /**
239
   * Structure to hold a CIDR range and the data associated with it.
240
   */
241
  template <class IpType, uint32_t address_size = CHAR_BIT * sizeof(IpType)> struct IpPrefix {
242

            
243
    IpPrefix() = default;
244

            
245
69879
    IpPrefix(const IpType& ip, uint32_t length, const T& data) : ip_(ip), length_(length) {
246
69879
      data_.insert(data);
247
69879
    }
248

            
249
    IpPrefix(const IpType& ip, int length, const DataSet& data)
250
71422
        : ip_(ip), length_(length), data_(data) {}
251

            
252
    /**
253
     * @return -1 if the current object is less than other. 0 if they are the same. 1
254
     * if other is smaller than the current object.
255
     */
256
3255
    int compare(const IpPrefix& other) const {
257
3255
      {
258
3255
        if (ip_ < other.ip_) {
259
783
          return -1;
260
2472
        } else if (ip_ > other.ip_) {
261
2472
          return 1;
262
2472
        } else if (length_ < other.length_) {
263
          return -1;
264
        } else if (length_ > other.length_) {
265
          return 1;
266
        } else {
267
          return 0;
268
        }
269
3255
      }
270
3255
    }
271

            
272
3255
    bool operator<(const IpPrefix& other) const { return (this->compare(other) == -1); }
273

            
274
    bool operator!=(const IpPrefix& other) const { return (this->compare(other) != 0); }
275

            
276
    /**
277
     * @return true if other is a prefix of this.
278
     */
279
    bool isPrefix(const IpPrefix& other) {
280
      return (length_ == 0 || (length_ <= other.length_ && contains(other.ip_)));
281
    }
282

            
283
    /**
284
     * @param address supplies an IP address to check against this prefix.
285
     * @return bool true if this prefix contains the address.
286
     */
287
82657
    bool contains(const IpType& address) const {
288
82657
      return (extractBits<IpType, address_size>(0, length_, ip_) ==
289
82657
              extractBits<IpType, address_size>(0, length_, address));
290
82657
    }
291

            
292
    std::string asString() { return fmt::format("{}/{}", toString(ip_), length_); }
293

            
294
    // The address represented either in Ipv4(uint32_t) or Ipv6(absl::uint128).
295
    IpType ip_{0};
296
    // Length of the cidr range.
297
    uint32_t length_{0};
298
    // Data for this entry.
299
    DataSet data_;
300
  };
301

            
302
  /**
303
   * Binary trie used to simplify the construction of Level Compressed Tries.
304
   * This data type supports two operations:
305
   *   1. Add a prefix to the trie.
306
   *   2. Push the prefixes to the leaves of the trie.
307
   * That second operation produces a new set of prefixes that yield the same
308
   * match results as the original set of prefixes from which the BinaryTrie
309
   * was constructed, but with an important difference: the new prefixes are
310
   * guaranteed not to be nested within each other. That allows the use of the
311
   * classic LC Trie construction algorithm, which is fast and (relatively)
312
   * simple but does not work properly with nested prefixes.
313
   */
314
  template <class IpType, uint32_t address_size = CHAR_BIT * sizeof(IpType)> class BinaryTrie {
315
  public:
316
116404
    BinaryTrie(bool exclusive) : root_(std::make_unique<Node>()), exclusive_(exclusive) {}
317

            
318
    /**
319
     * Add a CIDR prefix and associated data to the binary trie. If an entry already
320
     * exists for the prefix, merge the data into the existing entry.
321
     */
322
69879
    void insert(const IpPrefix<IpType>& prefix) {
323
69879
      Node* node = root_.get();
324
76542
      for (uint32_t i = 0; i < prefix.length_; i++) {
325
6663
        auto bit = static_cast<uint32_t>(extractBits(i, 1, prefix.ip_));
326
6663
        NodePtr& next_node = node->children[bit];
327
6663
        if (next_node == nullptr) {
328
5194
          next_node = std::make_unique<Node>();
329
5194
        }
330
6663
        node = next_node.get();
331
6663
      }
332
69879
      if (node->data == nullptr) {
333
69869
        node->data = std::make_shared<DataSet>();
334
69869
      }
335
69879
      node->data->insert(prefix.data_.begin(), prefix.data_.end());
336
69879
    }
337

            
338
    /**
339
     * Update each node in the trie to inherit/override its ancestors' data,
340
     * and then push the prefixes in the binary trie to the leaves so that:
341
     *  1) each leaf contains a prefix, and
342
     *  2) given the set of prefixes now located at the leaves, a useful
343
     *     new property applies: no prefix in that set is nested under any
344
     *     other prefix in the set (since, by definition, no leaf of the
345
     *     trie can be nested under another leaf)
346
     * @return the prefixes associated with the leaf nodes.
347
     */
348
116404
    std::vector<IpPrefix<IpType>> pushLeaves() {
349
116404
      std::vector<IpPrefix<IpType>> prefixes;
350
116404
      std::function<void(Node*, DataSetSharedPtr, unsigned, IpType)> visit =
351
126648
          [&](Node* node, DataSetSharedPtr data, unsigned depth, IpType prefix) {
352
            // Inherit any data set by ancestor nodes.
353
126648
            if (data != nullptr) {
354
3230
              if (node->data == nullptr) {
355
3168
                node->data = data;
356
3168
              } else if (!exclusive_) {
357
39
                node->data->insert(data->begin(), data->end());
358
39
              }
359
3230
            }
360
            // If a node has exactly one child, create a second child node
361
            // that inherits the union of all data set by any ancestor nodes.
362
            // This gives the trie an important new property: all the configured
363
            // prefixes end up at the leaves of the trie. As no leaf is nested
364
            // under another leaf (or one of them would not be a leaf!), the
365
            // leaves of the trie upon completion of this leaf-push operation
366
            // will form a set of disjoint prefixes (no nesting) that can be
367
            // used to build an LC trie.
368
126648
            if (node->children[0] != nullptr && node->children[1] == nullptr) {
369
3997
              node->children[1] = std::make_unique<Node>();
370
122651
            } else if (node->children[0] == nullptr && node->children[1] != nullptr) {
371
1053
              node->children[0] = std::make_unique<Node>();
372
1053
            }
373
126648
            if (node->children[0] != nullptr) {
374
5122
              visit(node->children[0].get(), node->data, depth + 1, (prefix << 1) + IpType(0));
375
5122
              visit(node->children[1].get(), node->data, depth + 1, (prefix << 1) + IpType(1));
376
121526
            } else {
377
121526
              if (node->data != nullptr) {
378
                // Compute the CIDR prefix from the path we've taken to get to this point in the
379
                // tree.
380
71422
                IpType ip = prefix;
381
71422
                if (depth != 0) {
382
1802
                  ip <<= (address_size - depth);
383
1802
                }
384
71422
                prefixes.emplace_back(IpPrefix<IpType>(ip, depth, *node->data));
385
71422
              }
386
121526
            }
387
126648
          };
388
116404
      visit(root_.get(), nullptr, 0, IpType(0));
389
116404
      return prefixes;
390
116404
    }
391

            
392
  private:
393
    struct Node {
394
      std::unique_ptr<Node> children[2];
395
      DataSetSharedPtr data;
396
    };
397
    using NodePtr = std::unique_ptr<Node>;
398
    NodePtr root_;
399
    bool exclusive_;
400
  };
401

            
402
  /**
403
   * Level Compressed Trie (LC-Trie) that contains CIDR ranges and its corresponding data.
404
   *
405
   * The following is an implementation of the algorithm described in the paper
406
   * 'IP-address lookup using LC-tries' by'S. Nilsson' and 'G. Karlsson'.
407
   *
408
   * 'https://github.com/beevek/libkrb/blob/master/krb/lc_trie.hpp' and
409
   * 'http://www.csc.kth.se/~snilsson/software/router/C/' were used as reference during
410
   * implementation.
411
   *
412
   * Note: The trie can only support up 524288(2^19) prefixes with a fill_factor of 1 and
413
   * root_branching_factor not set. Refer to LcTrieInternal::build() method for more details.
414
   */
415
  template <class IpType, uint32_t address_size = CHAR_BIT * sizeof(IpType)> class LcTrieInternal {
416
  public:
417
    /**
418
     * Construct a LC-Trie for IpType.
419
     * @param data supplies a vector of data and CIDR ranges (in IpPrefix format).
420
     * @param fill_factor supplies the fraction of completeness to use when calculating the branch
421
     *                    value for a sub-trie.
422
     * @param root_branching_factor supplies the branching factor at the root. The paper suggests
423
     *                              for large LC-Tries to use the value '16' for the root
424
     *                              branching factor. It reduces the depth of the trie.
425
     */
426
    LcTrieInternal(std::vector<IpPrefix<IpType>>& data, double fill_factor,
427
                   uint32_t root_branching_factor);
428

            
429
    /**
430
     * Retrieve the data associated with the CIDR range that contains `ip_address`.
431
     * @param  ip_address supplies the IP address in host byte order.
432
     * @return a vector of data from the CIDR ranges and IP addresses that encompasses the input.
433
     * An empty vector is returned if the LC Trie is empty.
434
     */
435
    std::vector<T> getData(const IpType& ip_address) const;
436

            
437
  private:
438
    /**
439
     * Builds the Level Compressed Trie, by first sorting the data, removing duplicated
440
     * prefixes and invoking buildRecursive() to build the trie.
441
     */
442
116404
    void build(std::vector<IpPrefix<IpType>>& data) {
443
116404
      if (data.empty()) {
444
46656
        return;
445
46656
      }
446

            
447
69748
      ip_prefixes_ = data;
448
69748
      std::sort(ip_prefixes_.begin(), ip_prefixes_.end());
449

            
450
      // Build the trie_.
451
69748
      trie_.reserve(static_cast<size_t>(ip_prefixes_.size() / fill_factor_));
452
69748
      uint32_t next_free_index = 1;
453
69748
      buildRecursive(0u, 0u, ip_prefixes_.size(), 0u, next_free_index);
454

            
455
      // The value of next_free_index is the final size of the trie_.
456
69748
      ASSERT(next_free_index <= trie_.size());
457
69748
      trie_.resize(next_free_index);
458
69748
      trie_.shrink_to_fit();
459
69748
    }
460

            
461
    // Thin wrapper around computeBranch output to facilitate code readability.
462
    struct ComputePair {
463
587
      ComputePair(int branch, int prefix) : branch_(branch), prefix_(prefix) {}
464

            
465
      uint32_t branch_;
466
      // The total number of bits that have the same prefix for subset of ip_prefixes_.
467
      uint32_t prefix_;
468
    };
469

            
470
    /**
471
     * Compute the branch and skip values for the trie starting at position 'first' through
472
     * 'first+n-1' while disregarding the prefix.
473
     * @param prefix supplies the common prefix in the ip_prefixes_ array.
474
     * @param first supplies the index where computing the branch should begin with.
475
     * @param n supplies the number of nodes to use while computing the branch.
476
     * @return pair of integers for the branching factor and the skip.
477
     */
478
587
    ComputePair computeBranchAndSkip(uint32_t prefix, uint32_t first, uint32_t n) const {
479
587
      ComputePair compute(0, 0);
480

            
481
      // Compute the new prefix for the range between ip_prefixes_[first] and
482
      // ip_prefixes_[first + n - 1].
483
587
      IpType high = removeBits<IpType, address_size>(prefix, ip_prefixes_[first].ip_);
484
587
      IpType low = removeBits<IpType, address_size>(prefix, ip_prefixes_[first + n - 1].ip_);
485
587
      uint32_t index = prefix;
486

            
487
      // Find the index at which low and high diverge to get the skip.
488
984
      while (extractBits<IpType, address_size>(index, 1, low) ==
489
984
             extractBits<IpType, address_size>(index, 1, high)) {
490
397
        ++index;
491
397
      }
492
587
      compute.prefix_ = index;
493

            
494
      // For 2 elements, use a branching factor of 2(2^1).
495
587
      if (n == 2) {
496
39
        compute.branch_ = 1;
497
39
        return compute;
498
39
      }
499

            
500
      // According to the original LC-Trie paper, a large branching factor(suggested value: 16)
501
      // at the root increases performance.
502
548
      if (root_branching_factor_ > 0 && prefix == 0 && first == 0) {
503
1
        compute.branch_ = root_branching_factor_;
504
1
        return compute;
505
1
      }
506

            
507
      // Compute the number of bits required for branching by checking all patterns in the set are
508
      // covered. Ex (b=2 {00, 01, 10, 11}; b=3 {000,001,010,011,100,101,110,111}, etc)
509
547
      uint32_t branch = 1;
510
547
      uint32_t count;
511
1608
      do {
512
1608
        ++branch;
513

            
514
        // Check if the current branch factor with the fill factor can contain all of the nodes
515
        // in the current range or if the current branching factor is larger than the
516
        // IP address_size.
517
1608
        if (n < fill_factor_ * (1 << branch) ||
518
1608
            static_cast<uint32_t>(compute.prefix_ + branch) > address_size) {
519
72
          break;
520
72
        }
521

            
522
        // Start by checking the bit patterns at ip_prefixes_[first] through
523
        // ip_prefixes_[first + n-1].
524
1536
        index = first;
525
        // Pattern to search for.
526
1536
        uint32_t pattern = 0;
527
        // Number of patterns found while looping through the list.
528
1536
        count = 0;
529

            
530
        // Search for all patterns(values) within 1<<branch.
531
15468
        while (pattern < static_cast<uint32_t>(1 << branch)) {
532
13932
          bool pattern_found = false;
533
          // Keep on looping until either all nodes in the range have been visited or
534
          // an IP prefix doesn't match the pattern.
535
92980
          while (index < first + n &&
536
92980
                 static_cast<uint32_t>(extractBits<IpType, address_size>(
537
86629
                     compute.prefix_, branch, ip_prefixes_[index].ip_)) == pattern) {
538
79048
            ++index;
539
79048
            pattern_found = true;
540
79048
          }
541

            
542
13932
          if (pattern_found) {
543
6120
            ++count;
544
6120
          }
545
13932
          ++pattern;
546
13932
        }
547
        // Stop iterating once the size of the branch (with the fill factor ratio)
548
        // can no longer contain all of the prefixes within the current range of
549
        // ip_prefixes_[first] to ip_prefixes_[first+n-1].
550
1536
      } while (count >= fill_factor_ * (1 << branch));
551

            
552
      // The branching factor is decremented because the algorithm requires the largest branching
553
      // factor that covers all(most when a fill factor is specified) of the CIDR ranges in the
554
      // current sub-trie. When the loops above exits, the branch factor value is
555
      // 1. greater than the address size with the prefix.
556
      // 2. greater than the number of entries.
557
      // 3. less than the total number of patterns seen in the range.
558
      // In all of the cases above, branch - 1 is guaranteed to cover all of CIDR
559
      // ranges in the sub-trie.
560
      compute.branch_ = branch - 1;
561
547
      return compute;
562
548
    }
563

            
564
    /**
565
     * Recursively build a trie for IP prefixes from position 'first' to 'first+n-1'.
566
     * @param prefix supplies the prefix to ignore when building the sub-trie.
567
     * @param first supplies the index into ip_prefixes_ for this sub-trie.
568
     * @param n supplies the number of entries for the sub-trie.
569
     * @param position supplies the root for this sub-trie.
570
     * @param next_free_index supplies the next available index in the trie_.
571
     */
572
    void buildRecursive(uint32_t prefix, uint32_t first, uint32_t n, uint32_t position,
573
139622
                        uint32_t& next_free_index) {
574
139622
      if (position >= trie_.size()) {
575
        // There is no way to predictably determine the number of trie nodes required to build a
576
        // LC-Trie. If while building the trie the position that is being set exceeds the maximum
577
        // number of supported trie_ entries, throw an Envoy Exception.
578
136428
        if (position >= MaxLcTrieNodes) {
579
          // Adding 1 to the position to count how many nodes are trying to be set.
580
          ExceptionUtil::throwEnvoyException(
581
              fmt::format("The number of internal nodes required for the LC-Trie "
582
                          "exceeded the maximum number of "
583
                          "supported nodes. Minimum number of internal nodes required: "
584
                          "'{0}'. Maximum number of supported nodes: '{1}'.",
585
                          (position + 1), MaxLcTrieNodes));
586
        }
587
136428
        trie_.resize(position + 1);
588
136428
      }
589
      // Setting a leaf, the branch and skip are 0.
590
139622
      if (n == 1) {
591
139035
        trie_[position].address_ = first;
592
139035
        return;
593
139035
      }
594

            
595
587
      ComputePair output = computeBranchAndSkip(prefix, first, n);
596

            
597
587
      uint32_t address = next_free_index;
598
587
      trie_[position].branch_ = output.branch_;
599
      // The skip value is the number of bits between the newly calculated prefix(output.prefix_)
600
      // and the previous prefix(prefix).
601
587
      trie_[position].skip_ = output.prefix_ - prefix;
602
587
      trie_[position].address_ = address;
603

            
604
      // The next available free index to populate in the trie_ is at next_free_index +
605
      // 2^(branching factor).
606
587
      next_free_index += 1 << output.branch_;
607

            
608
587
      uint32_t new_position = first;
609

            
610
      // Build the subtrees.
611
16697
      for (uint32_t bit_pattern = 0; bit_pattern < static_cast<uint32_t>(1 << output.branch_);
612
16110
           ++bit_pattern) {
613

            
614
        // count is the number of entries in the ip_prefixes_ vector that have the same bit
615
        // pattern as the ip_prefixes_[new_position].
616
16110
        int count = 0;
617
42443
        while (new_position + count < first + n &&
618
42443
               static_cast<uint32_t>(extractBits<IpType, address_size>(
619
36217
                   output.prefix_, output.branch_, ip_prefixes_[new_position + count].ip_)) ==
620
36217
                   bit_pattern) {
621
26333
          ++count;
622
26333
        }
623

            
624
        // This logic was taken from
625
        // https://github.com/beevek/libkrb/blob/24a224d3ea840e2e7d2926e17d8849aefecc1101/krb/lc_trie.hpp#L396.
626
        // When there are no entries that match the current pattern, set a leaf at trie_[address +
627
        // bit_pattern].
628
16110
        if (count == 0) {
629
          // This case is hit when the last CIDR range(ip_prefixes_[first+n-1]) is being inserted
630
          // into the trie_. new_position is decremented by one because the count added to
631
          // new_position at line 445 are the number of entries already visited.
632
13849
          if (new_position == first + n) {
633
5639
            buildRecursive(output.prefix_ + output.branch_, new_position - 1, 1,
634
5639
                           address + bit_pattern, next_free_index);
635
8210
          } else {
636
8210
            buildRecursive(output.prefix_ + output.branch_, new_position, 1, address + bit_pattern,
637
8210
                           next_free_index);
638
8210
          }
639
15415
        } else if (count == 1 &&
640
2261
                   ip_prefixes_[new_position].length_ - output.prefix_ < output.branch_) {
641
          // All Ip address that have the prefix of `bit_pattern` will map to the only CIDR range
642
          // with the bit_pattern as a prefix.
643
1060
          uint32_t bits = output.branch_ + output.prefix_ - ip_prefixes_[new_position].length_;
644
55884
          for (uint32_t i = bit_pattern; i < bit_pattern + (1 << bits); ++i) {
645
54824
            buildRecursive(output.prefix_ + output.branch_, new_position, 1, address + i,
646
54824
                           next_free_index);
647
54824
          }
648
          // Update the bit_pattern to skip over the trie_ entries initialized above.
649
1060
          bit_pattern += (1 << bits) - 1;
650
1201
        } else {
651
          // Recursively build sub-tries for ip_prefixes_[new_position] to
652
          // ip_prefixes_[new_position+count-1].
653
1201
          buildRecursive(output.prefix_ + output.branch_, new_position, count,
654
1201
                         address + bit_pattern, next_free_index);
655
1201
        }
656
16110
        new_position += count;
657
16110
      }
658
587
    }
659

            
660
    /**
661
     * LcNode is a uint32_t. A wrapper is provided to simplify getting/setting the branch, the
662
     * skip and the address values held within the structure.
663
     *
664
     * The LcNode has three parts to it
665
     * - Branch: the first 5 bits represent the branching factor. The branching factor is used to
666
     * determine the number of descendants for the current node. The number represents a power of
667
     * 2, so there can be at most 2^31 descendant nodes.
668
     * - Skip: the next 7 bits represent the number of bits to skip when looking at an IP address.
669
     * This value can be between 0 and 127, so IPv6 is supported.
670
     * - Address: the remaining 20 bits represent an index either into the trie_ or the
671
     * ip_prefixes_. If branch_ != 0, the index is for the trie_. If branch == zero, the index is
672
     * for the ip_prefixes_.
673
     *
674
     * Note: If more than 2^19-1 CIDR ranges are to be stored in trie_, uint64_t should be used
675
     * instead.
676
     */
677
    struct LcNode {
678
      uint32_t branch_ : 5;
679
      uint32_t skip_ : 7;
680
      uint32_t address_ : 20; // If this 20-bit size changes, please change MaxLcTrieNodes too.
681
    };
682

            
683
    // The CIDR range and data needs to be maintained separately from the LC-Trie. A LC-Trie skips
684
    // chunks of data while searching for a match. This means that the node found in the LC-Trie
685
    // is not guaranteed to have the IP address in range. The last step prior to returning
686
    // associated data is to check the CIDR range pointed to by the node in the LC-Trie has
687
    // the IP address in range.
688
    std::vector<IpPrefix<IpType>> ip_prefixes_;
689

            
690
    // Main trie search structure.
691
    std::vector<LcNode> trie_;
692

            
693
    const double fill_factor_;
694
    const uint32_t root_branching_factor_;
695
  };
696

            
697
  std::unique_ptr<LcTrieInternal<Ipv4>> ipv4_trie_;
698
  std::unique_ptr<LcTrieInternal<Ipv6>> ipv6_trie_;
699
};
700

            
701
template <class T>
702
template <class IpType, uint32_t address_size>
703
LcTrie<T>::LcTrieInternal<IpType, address_size>::LcTrieInternal(std::vector<IpPrefix<IpType>>& data,
704
                                                                double fill_factor,
705
                                                                uint32_t root_branching_factor)
706
116404
    : fill_factor_(fill_factor), root_branching_factor_(root_branching_factor) {
707
116404
  build(data);
708
116404
}
709

            
710
template <class T>
711
template <class IpType, uint32_t address_size>
712
std::vector<T>
713
82663
LcTrie<T>::LcTrieInternal<IpType, address_size>::getData(const IpType& ip_address) const {
714
82663
  std::vector<T> return_vector;
715
82663
  if (trie_.empty()) {
716
6
    return return_vector;
717
6
  }
718

            
719
82657
  LcNode node = trie_[0];
720
82657
  uint32_t branch = node.branch_;
721
82657
  uint32_t position = node.skip_;
722
82657
  uint32_t address = node.address_;
723

            
724
  // branch == 0 is a leaf node.
725
83146
  while (branch != 0) {
726
    // branch is at most 2^5-1= 31 bits to extract, so we can safely cast the
727
    // output of extractBits to uint32_t without any data loss.
728
489
    node = trie_[address + static_cast<uint32_t>(
729
489
                               extractBits<IpType, address_size>(position, branch, ip_address))];
730
489
    position += branch + node.skip_;
731
489
    branch = node.branch_;
732
489
    address = node.address_;
733
489
  }
734

            
735
  // The path taken through the trie to match the ip_address may have contained skips,
736
  // so it is necessary to check whether the matched prefix really contains the
737
  // ip_address.
738
82657
  const auto& prefix = ip_prefixes_[address];
739
82657
  if (prefix.contains(ip_address)) {
740
82607
    return std::vector<T>(prefix.data_.begin(), prefix.data_.end());
741
82607
  }
742
50
  return std::vector<T>();
743
82657
}
744

            
745
} // namespace LcTrie
746
} // namespace Network
747
} // namespace Envoy