Coverage Report

Created: 2025-06-13 06:27

/src/pdns/pdns/dnsname.hh
Line
Count
Source (jump to first uncovered line)
1
/*
2
 * This file is part of PowerDNS or dnsdist.
3
 * Copyright -- PowerDNS.COM B.V. and its contributors
4
 *
5
 * This program is free software; you can redistribute it and/or modify
6
 * it under the terms of version 2 of the GNU General Public License as
7
 * published by the Free Software Foundation.
8
 *
9
 * In addition, for the avoidance of any doubt, permission is granted to
10
 * link this program with OpenSSL and to (re)distribute the binaries
11
 * produced as the result of such linking.
12
 *
13
 * This program is distributed in the hope that it will be useful,
14
 * but WITHOUT ANY WARRANTY; without even the implied warranty of
15
 * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the
16
 * GNU General Public License for more details.
17
 *
18
 * You should have received a copy of the GNU General Public License
19
 * along with this program; if not, write to the Free Software
20
 * Foundation, Inc., 51 Franklin Street, Fifth Floor, Boston, MA 02110-1301 USA.
21
 */
22
#pragma once
23
#include <array>
24
#include <cstring>
25
#include <optional>
26
#include <string>
27
#include <utility>
28
#include <vector>
29
#include <set>
30
#include <strings.h>
31
#include <stdexcept>
32
#include <sstream>
33
#include <iterator>
34
#include <unordered_set>
35
#include <string_view>
36
37
using namespace std::string_view_literals;
38
39
#include <boost/version.hpp>
40
#include <boost/container/string.hpp>
41
42
inline bool dns_isspace(char c)
43
0
{
44
0
  return c == ' ' || c == '\t' || c == '\r' || c == '\n';
45
0
}
46
47
extern const unsigned char dns_toupper_table[256],  dns_tolower_table[256];
48
49
inline unsigned char dns_toupper(unsigned char c)
50
0
{
51
0
  return dns_toupper_table[c];
52
0
}
53
54
inline unsigned char dns_tolower(unsigned char c)
55
4.00M
{
56
4.00M
  return dns_tolower_table[c];
57
4.00M
}
58
59
#include "burtle.hh"
60
#include "views.hh"
61
62
/* Quest in life:
63
     accept escaped ascii presentations of DNS names and store them "natively"
64
     accept a DNS packet with an offset, and extract a DNS name from it
65
     build up DNSNames with prepend and append of 'raw' unescaped labels
66
67
   Be able to turn them into ASCII and "DNS name in a packet" again on request
68
69
   Provide some common operators for comparison, detection of being part of another domain
70
71
   NOTE: For now, everything MUST be . terminated, otherwise it is an error
72
*/
73
74
// DNSName: represents a case-insensitive string, allowing for non-printable
75
// characters. It is used for all kinds of name (of hosts, domains, keys,
76
// algorithm...) overall the PowerDNS codebase.
77
//
78
// The following type traits are provided:
79
// - EqualityComparable
80
// - LessThanComparable
81
// - Hash
82
#if defined(PDNS_AUTH)
83
class ZoneName;
84
#endif
85
class DNSName
86
{
87
public:
88
  static const size_t s_maxDNSNameLength = 255;
89
90
2
  DNSName() = default; //!< Constructs an *empty* DNSName, NOT the root!
91
  // Work around assertion in some boost versions that do not like self-assignment of boost::container::string
92
  DNSName& operator=(const DNSName& rhs)
93
0
  {
94
0
    if (this != &rhs) {
95
0
      d_storage = rhs.d_storage;
96
0
    }
97
0
    return *this;
98
0
  }
99
  DNSName& operator=(DNSName&& rhs) noexcept
100
2
  {
101
2
    if (this != &rhs) {
102
2
      d_storage = std::move(rhs.d_storage);
103
2
    }
104
2
    return *this;
105
2
  }
106
0
  DNSName(const DNSName& a) = default;
107
0
  DNSName(DNSName&& a) = default;
108
109
  explicit DNSName(std::string_view sw); //!< Constructs from a human formatted, escaped presentation
110
  DNSName(const char* p, size_t len, size_t offset, bool uncompress, uint16_t* qtype = nullptr, uint16_t* qclass = nullptr, unsigned int* consumed = nullptr, uint16_t minOffset = 0); //!< Construct from a DNS Packet, taking the first question if offset=12. If supplied, consumed is set to the number of bytes consumed from the packet, which will not be equal to the wire length of the resulting name in case of compression.
111
112
  bool isPartOf(const DNSName& rhs) const;   //!< Are we part of the rhs name? Note that name.isPartOf(name).
113
  inline bool operator==(const DNSName& rhs) const; //!< DNS-native comparison (case insensitive) - empty compares to empty
114
0
  bool operator!=(const DNSName& other) const { return !(*this == other); }
115
116
  std::string toString(const std::string& separator=".", const bool trailing=true) const;              //!< Our human-friendly, escaped, representation
117
  void toString(std::string& output, const std::string& separator=".", const bool trailing=true) const;
118
  std::string toLogString() const; //!< like plain toString, but returns (empty) on empty names
119
0
  std::string toStringNoDot() const { return toString(".", false); }
120
0
  std::string toStringRootDot() const { if(isRoot()) return "."; else return toString(".", false); }
121
  std::string toDNSString() const;           //!< Our representation in DNS native format
122
  std::string toDNSStringLC() const;           //!< Our representation in DNS native format, lower cased
123
  void appendRawLabel(const std::string& str); //!< Append this unescaped label
124
  void appendRawLabel(const char* start, unsigned int length); //!< Append this unescaped label
125
  void prependRawLabel(const std::string& str); //!< Prepend this unescaped label
126
  std::vector<std::string> getRawLabels() const; //!< Individual raw unescaped labels
127
  std::string getRawLabel(unsigned int pos) const; //!< Get the specified raw unescaped label
128
  DNSName getLastLabel() const; //!< Get the DNSName of the last label
129
  bool chopOff();                               //!< Turn www.powerdns.com. into powerdns.com., returns false for .
130
  DNSName makeRelative(const DNSName& zone) const;
131
  DNSName makeLowerCase() const
132
0
  {
133
0
    DNSName ret(*this);
134
0
    ret.makeUsLowerCase();
135
0
    return ret;
136
0
  }
137
  void makeUsLowerCase()
138
0
  {
139
0
    for(auto & c : d_storage) {
140
0
      c=dns_tolower(c);
141
0
    }
142
0
  }
143
  void makeUsRelative(const DNSName& zone);
144
  DNSName getCommonLabels(const DNSName& other) const; //!< Return the list of common labels from the top, for example 'c.d' for 'a.b.c.d' and 'x.y.c.d'
145
  DNSName labelReverse() const;
146
  bool isWildcard() const;
147
  bool isHostname() const;
148
  unsigned int countLabels() const;
149
  size_t wirelength() const; //!< Number of total bytes in the name
150
0
  bool empty() const { return d_storage.empty(); }
151
0
  bool isRoot() const { return d_storage.size()==1 && d_storage[0]==0; }
152
0
  void clear() { d_storage.clear(); }
153
  void trimToLabels(unsigned int);
154
  size_t hash(size_t init=0) const
155
0
  {
156
0
    return burtleCI((const unsigned char*)d_storage.c_str(), d_storage.size(), init);
157
0
  }
158
  DNSName& operator+=(const DNSName& rhs)
159
0
  {
160
0
    if(d_storage.size() + rhs.d_storage.size() > s_maxDNSNameLength + 1) // one extra byte for the second root label
161
0
      throwSafeRangeError("resulting name too long", rhs.d_storage.data(), rhs.d_storage.size());
162
0
    if(rhs.empty())
163
0
      return *this;
164
0
165
0
    if(d_storage.empty())
166
0
      d_storage+=rhs.d_storage;
167
0
    else
168
0
      d_storage.replace(d_storage.length()-1, rhs.d_storage.length(), rhs.d_storage);
169
0
170
0
    return *this;
171
0
  }
172
173
  bool operator<(const DNSName& rhs)  const // this delivers _some_ kind of ordering, but not one useful in a DNS context. Really fast though.
174
0
  {
175
0
    return std::lexicographical_compare(d_storage.rbegin(), d_storage.rend(),
176
0
         rhs.d_storage.rbegin(), rhs.d_storage.rend(),
177
0
         [](const unsigned char& a, const unsigned char& b) {
178
0
            return dns_tolower(a) < dns_tolower(b);
179
0
          }); // note that this is case insensitive, including on the label lengths
180
0
  }
181
182
  inline bool canonCompare(const DNSName& rhs) const;
183
  bool slowCanonCompare(const DNSName& rhs) const;
184
185
  typedef boost::container::string string_t;
186
187
0
  const string_t& getStorage() const {
188
0
    return d_storage;
189
0
  }
190
191
  [[nodiscard]] size_t sizeEstimate() const
192
0
  {
193
0
    return d_storage.size(); // knowingly overestimating small strings as most string
194
0
                             // implementations have internal capacity and we always include
195
0
                             // sizeof(*this)
196
0
  }
197
198
  bool has8bitBytes() const; /* returns true if at least one byte of the labels forming the name is not included in [A-Za-z0-9_*./@ \\:-] */
199
200
  class RawLabelsVisitor
201
  {
202
  public:
203
    /* Zero-copy, zero-allocation raw labels visitor.
204
       The general idea is that we walk the labels in the constructor,
205
       filling up our array of labels position and setting the initial
206
       value of d_position at the number of labels.
207
       We then can easily provide string_view into the first and last label.
208
       pop_back() moves d_position one label closer to the start, so we
209
       can also easily walk back the labels in reverse order.
210
       There is no copy because we use a reference into the DNSName storage,
211
       so it is absolutely forbidden to alter the DNSName for as long as we
212
       exist, and no allocation because we use a static array (there cannot
213
       be more than 128 labels in a DNSName).
214
    */
215
    RawLabelsVisitor(const string_t& storage);
216
    std::string_view front() const;
217
    std::string_view back() const;
218
    bool pop_back();
219
    bool empty() const;
220
  private:
221
    std::array<uint8_t, 128> d_labelPositions;
222
    const string_t& d_storage;
223
    size_t d_position{0};
224
  };
225
  RawLabelsVisitor getRawLabelsVisitor() const;
226
227
#if defined(PDNS_AUTH) // [
228
  // Sugar while ZoneName::operator DNSName are made explicit
229
  bool isPartOf(const ZoneName& rhs) const;
230
  DNSName makeRelative(const ZoneName& zone) const;
231
  void makeUsRelative(const ZoneName& zone);
232
#endif // ]
233
234
private:
235
  string_t d_storage;
236
237
  void packetParser(const char* qpos, size_t len, size_t offset, bool uncompress, uint16_t* qtype, uint16_t* qclass, unsigned int* consumed, int depth, uint16_t minOffset);
238
  size_t parsePacketUncompressed(const pdns::views::UnsignedCharView& view, size_t position, bool uncompress);
239
  static void appendEscapedLabel(std::string& appendTo, const char* orig, size_t len);
240
  static std::string unescapeLabel(const std::string& orig);
241
  static void throwSafeRangeError(const std::string& msg, const char* buf, size_t length);
242
};
243
244
size_t hash_value(DNSName const& d);
245
246
247
inline bool DNSName::canonCompare(const DNSName& rhs) const
248
0
{
249
  //      01234567890abcd
250
  // us:  1a3www4ds9a2nl
251
  // rhs: 3www6online3com
252
  // to compare, we start at the back, is nl < com? no -> done
253
  //
254
  // 0,2,6,a
255
  // 0,4,a
256
257
0
  uint8_t ourpos[64], rhspos[64];
258
0
  uint8_t ourcount=0, rhscount=0;
259
  //cout<<"Asked to compare "<<toString()<<" to "<<rhs.toString()<<endl;
260
0
  for(const unsigned char* p = (const unsigned char*)d_storage.c_str(); p < (const unsigned char*)d_storage.c_str() + d_storage.size() && *p && ourcount < sizeof(ourpos); p+=*p+1)
261
0
    ourpos[ourcount++]=(p-(const unsigned char*)d_storage.c_str());
262
0
  for(const unsigned char* p = (const unsigned char*)rhs.d_storage.c_str(); p < (const unsigned char*)rhs.d_storage.c_str() + rhs.d_storage.size() && *p && rhscount < sizeof(rhspos); p+=*p+1)
263
0
    rhspos[rhscount++]=(p-(const unsigned char*)rhs.d_storage.c_str());
264
265
0
  if(ourcount == sizeof(ourpos) || rhscount==sizeof(rhspos)) {
266
0
    return slowCanonCompare(rhs);
267
0
  }
268
269
0
  for(;;) {
270
0
    if(ourcount == 0 && rhscount != 0)
271
0
      return true;
272
0
    if(rhscount == 0)
273
0
      return false;
274
0
    ourcount--;
275
0
    rhscount--;
276
277
0
    bool res=std::lexicographical_compare(
278
0
            d_storage.c_str() + ourpos[ourcount] + 1,
279
0
            d_storage.c_str() + ourpos[ourcount] + 1 + *(d_storage.c_str() + ourpos[ourcount]),
280
0
            rhs.d_storage.c_str() + rhspos[rhscount] + 1,
281
0
            rhs.d_storage.c_str() + rhspos[rhscount] + 1 + *(rhs.d_storage.c_str() + rhspos[rhscount]),
282
0
            [](const unsigned char& a, const unsigned char& b) {
283
0
              return dns_tolower(a) < dns_tolower(b);
284
0
            });
285
286
    //    cout<<"Forward: "<<res<<endl;
287
0
    if(res)
288
0
      return true;
289
290
0
    res=std::lexicographical_compare(   rhs.d_storage.c_str() + rhspos[rhscount] + 1,
291
0
            rhs.d_storage.c_str() + rhspos[rhscount] + 1 + *(rhs.d_storage.c_str() + rhspos[rhscount]),
292
0
            d_storage.c_str() + ourpos[ourcount] + 1,
293
0
            d_storage.c_str() + ourpos[ourcount] + 1 + *(d_storage.c_str() + ourpos[ourcount]),
294
0
            [](const unsigned char& a, const unsigned char& b) {
295
0
              return dns_tolower(a) < dns_tolower(b);
296
0
            });
297
    //    cout<<"Reverse: "<<res<<endl;
298
0
    if(res)
299
0
      return false;
300
0
  }
301
0
  return false;
302
0
}
303
304
305
struct CanonDNSNameCompare
306
{
307
  bool operator()(const DNSName&a, const DNSName& b) const
308
0
  {
309
0
    return a.canonCompare(b);
310
0
  }
311
};
312
313
inline DNSName operator+(const DNSName& lhs, const DNSName& rhs)
314
0
{
315
0
  DNSName ret=lhs;
316
0
  ret += rhs;
317
0
  return ret;
318
0
}
319
320
extern const DNSName g_rootdnsname, g_wildcarddnsname;
321
322
#if defined(PDNS_AUTH) // [
323
// ZoneName: this is equivalent to DNSName, but intended to only store zone
324
// names. In addition to the name, an optional variant is allowed. The
325
// variant is never part of a DNS packet; it can only be used by backends to
326
// perform specific extra processing.
327
// Variant names are limited to [a-z0-9_-].
328
// Conversions between DNSName and ZoneName are allowed, but must be explicit;
329
// conversions to DNSName lose the variant part.
330
class ZoneName
331
{
332
public:
333
  ZoneName() = default; //!< Constructs an *empty* ZoneName, NOT the root!
334
  // Work around assertion in some boost versions that do not like self-assignment of boost::container::string
335
  ZoneName& operator=(const ZoneName& rhs)
336
0
  {
337
0
    if (this != &rhs) {
338
0
      d_name = rhs.d_name;
339
0
      d_variant = rhs.d_variant;
340
0
    }
341
0
    return *this;
342
0
  }
343
  ZoneName& operator=(ZoneName&& rhs) noexcept
344
0
  {
345
0
    if (this != &rhs) {
346
0
      d_name = std::move(rhs.d_name);
347
0
      d_variant = std::move(rhs.d_variant);
348
0
    }
349
0
    return *this;
350
0
  }
351
  ZoneName(const ZoneName& a) = default;
352
  ZoneName(ZoneName&& a) = default;
353
354
  explicit ZoneName(std::string_view name);
355
0
  explicit ZoneName(std::string_view name, std::string_view variant) : d_name(name), d_variant(variant) {}
356
0
  explicit ZoneName(const DNSName& name, std::string_view variant = ""sv) : d_name(name), d_variant(variant) {}
357
  explicit ZoneName(std::string_view name, std::string_view::size_type sep);
358
359
0
  bool isPartOf(const ZoneName& rhs) const { return d_name.isPartOf(rhs.d_name); }
360
0
  bool isPartOf(const DNSName& rhs) const { return d_name.isPartOf(rhs); }
361
0
  bool operator==(const ZoneName& rhs) const { return d_name == rhs.d_name && d_variant == rhs.d_variant; }
362
0
  bool operator!=(const ZoneName& rhs) const { return !operator==(rhs); }
363
364
  // IMPORTANT! None of the "toString" routines will output the variant, but toLogString() and toStringFull().
365
0
  std::string toString(const std::string& separator=".", const bool trailing=true) const { return d_name.toString(separator, trailing); }
366
0
  void toString(std::string& output, const std::string& separator=".", const bool trailing=true) const { d_name.toString(output, separator, trailing); }
367
  std::string toLogString() const;
368
0
  std::string toStringNoDot() const { return d_name.toStringNoDot(); }
369
0
  std::string toStringRootDot() const { return d_name.toStringRootDot(); }
370
  std::string toStringFull(const std::string& separator=".", const bool trailing=true) const;
371
372
0
  bool chopOff() { return d_name.chopOff(); }
373
  ZoneName makeLowerCase() const
374
0
  {
375
0
    ZoneName ret(*this);
376
0
    ret.d_name.makeUsLowerCase();
377
0
    return ret;
378
0
  }
379
0
  void makeUsLowerCase() { d_name.makeUsLowerCase(); }
380
0
  bool empty() const { return d_name.empty(); }
381
0
  void clear() { d_name.clear(); d_variant.clear(); }
382
0
  void trimToLabels(unsigned int trim) { d_name.trimToLabels(trim); }
383
  size_t hash(size_t init=0) const;
384
385
  bool operator<(const ZoneName& rhs)  const;
386
387
  bool canonCompare(const ZoneName& rhs) const;
388
389
  // Conversion from ZoneName to DNSName
390
0
  explicit operator const DNSName&() const { return d_name; }
391
0
  explicit operator DNSName&() { return d_name; }
392
393
0
  bool hasVariant() const { return !d_variant.empty(); }
394
0
  std::string getVariant() const { return d_variant; }
395
  void setVariant(std::string_view);
396
397
  // Search for a variant separator: mandatory (when variants are used) trailing
398
  // dot followed by another dot and the variant name, and return the length of
399
  // the zone name without its variant part, or npos if there is no variant
400
  // present.
401
  static std::string_view::size_type findVariantSeparator(std::string_view name);
402
403
private:
404
  DNSName d_name;
405
  std::string d_variant{};
406
};
407
408
size_t hash_value(ZoneName const& zone);
409
410
std::ostream & operator<<(std::ostream &ostr, const ZoneName& zone);
411
namespace std {
412
    template <>
413
    struct hash<ZoneName> {
414
0
        size_t operator () (const ZoneName& dn) const { return dn.hash(0); }
415
    };
416
}
417
418
struct CanonZoneNameCompare
419
{
420
  bool operator()(const ZoneName& a, const ZoneName& b) const
421
0
  {
422
0
    return a.canonCompare(b);
423
0
  }
424
};
425
#else // ] [
426
using ZoneName = DNSName;
427
using CanonZoneNameCompare = CanonDNSNameCompare;
428
#endif // ]
429
430
extern const ZoneName g_rootzonename;
431
432
template<typename T>
433
struct SuffixMatchTree
434
{
435
  SuffixMatchTree(std::string name = "", bool endNode_ = false) :
436
    d_name(std::move(name)), endNode(endNode_)
437
  {}
438
439
  SuffixMatchTree(const SuffixMatchTree& rhs): d_name(rhs.d_name), children(rhs.children), endNode(rhs.endNode)
440
  {
441
    if (endNode) {
442
      d_value = rhs.d_value;
443
    }
444
  }
445
  SuffixMatchTree & operator=(const SuffixMatchTree &rhs)
446
  {
447
    d_name = rhs.d_name;
448
    children = rhs.children;
449
    endNode = rhs.endNode;
450
    if (endNode) {
451
      d_value = rhs.d_value;
452
    }
453
    return *this;
454
  }
455
  bool operator<(const SuffixMatchTree& rhs) const
456
0
  {
457
0
    return strcasecmp(d_name.c_str(), rhs.d_name.c_str()) < 0;
458
0
  }
459
460
  std::string d_name;
461
  mutable std::set<SuffixMatchTree, std::less<>> children;
462
  mutable bool endNode;
463
  mutable T d_value{};
464
465
  /* this structure is used to do a lookup without allocating and
466
     copying a string, using C++14's heterogeneous lookups in ordered
467
     containers */
468
  struct LightKey
469
  {
470
    std::string_view d_name;
471
    bool operator<(const SuffixMatchTree& smt) const
472
0
    {
473
0
      auto compareUpTo = std::min(this->d_name.size(), smt.d_name.size());
474
0
      auto ret = strncasecmp(this->d_name.data(), smt.d_name.data(), compareUpTo);
475
0
      if (ret != 0) {
476
0
        return ret < 0;
477
0
      }
478
0
      if (this->d_name.size() == smt.d_name.size()) {
479
0
        return ret < 0;
480
0
      }
481
0
      return this->d_name.size() < smt.d_name.size();
482
0
    }
483
  };
484
485
  bool operator<(const LightKey& lk) const
486
0
  {
487
0
    auto compareUpTo = std::min(this->d_name.size(), lk.d_name.size());
488
0
    auto ret = strncasecmp(this->d_name.data(), lk.d_name.data(), compareUpTo);
489
0
    if (ret != 0) {
490
0
      return ret < 0;
491
0
    }
492
0
    if (this->d_name.size() == lk.d_name.size()) {
493
0
      return ret < 0;
494
0
    }
495
0
    return this->d_name.size() < lk.d_name.size();
496
0
  }
497
498
  template<typename V>
499
  void visit(const V& v) const {
500
    for(const auto& c : children) {
501
      c.visit(v);
502
    }
503
504
    if (endNode) {
505
      v(*this);
506
    }
507
  }
508
509
  void add(const DNSName& name, T&& t)
510
0
  {
511
0
    auto labels = name.getRawLabels();
512
0
    add(labels, std::move(t));
513
0
  }
514
515
  void add(std::vector<std::string>& labels, T&& value) const
516
0
  {
517
0
    if (labels.empty()) { // this allows insertion of the root
518
0
      endNode = true;
519
0
      d_value = std::move(value);
520
0
    }
521
0
    else if(labels.size()==1) {
522
0
      auto res = children.emplace(*labels.begin(), true);
523
0
      if (!res.second) {
524
0
        // we might already have had the node as an
525
0
        // intermediary one, but it's now an end node
526
0
        if (!res.first->endNode) {
527
0
          res.first->endNode = true;
528
0
        }
529
0
      }
530
0
      res.first->d_value = std::move(value);
531
0
    }
532
0
    else {
533
0
      auto res = children.emplace(*labels.rbegin(), false);
534
0
      labels.pop_back();
535
0
      res.first->add(labels, std::move(value));
536
0
    }
537
0
  }
538
539
  void remove(const DNSName &name, bool subtree=false) const
540
0
  {
541
0
    auto labels = name.getRawLabels();
542
0
    remove(labels, subtree);
543
0
  }
544
545
  /* Removes the node at `labels`, also make sure that no empty
546
   * children will be left behind in memory
547
   */
548
  void remove(std::vector<std::string>& labels, bool subtree = false) const
549
0
  {
550
0
    if (labels.empty()) { // this allows removal of the root
551
0
      endNode = false;
552
0
      if (subtree) {
553
0
        children.clear();
554
0
      }
555
0
      return;
556
0
    }
557
0
558
0
    SuffixMatchTree smt(*labels.rbegin());
559
0
    auto child = children.find(smt);
560
0
    if (child == children.end()) {
561
0
      // No subnode found, we're done
562
0
      return;
563
0
    }
564
0
565
0
    // We have found a child
566
0
    labels.pop_back();
567
0
    if (labels.empty()) {
568
0
      // The child is no longer an endnode
569
0
      child->endNode = false;
570
0
571
0
      if (subtree) {
572
0
        child->children.clear();
573
0
      }
574
0
575
0
      // If the child has no further children, just remove it from the set.
576
0
      if (child->children.empty()) {
577
0
        children.erase(child);
578
0
      }
579
0
      return;
580
0
    }
581
0
582
0
    // We are not at the end, let the child figure out what to do
583
0
    child->remove(labels);
584
0
  }
585
586
  T* lookup(const DNSName& name) const
587
0
  {
588
0
    auto bestNode = getBestNode(name);
589
0
    if (bestNode) {
590
0
      return &bestNode->d_value;
591
0
    }
592
0
    return nullptr;
593
0
  }
594
595
  std::optional<DNSName> getBestMatch(const DNSName& name) const
596
0
  {
597
0
    if (children.empty()) { // speed up empty set
598
0
      return endNode ? std::optional<DNSName>(g_rootdnsname) : std::nullopt;
599
0
    }
600
0
601
0
    auto visitor = name.getRawLabelsVisitor();
602
0
    return getBestMatch(visitor);
603
0
  }
604
605
  // Returns all end-nodes, fully qualified (not as separate labels)
606
  std::vector<DNSName> getNodes() const {
607
    std::vector<DNSName> ret;
608
    if (endNode) {
609
      ret.push_back(DNSName(d_name));
610
    }
611
    for (const auto& child : children) {
612
      auto nodes = child.getNodes();
613
      ret.reserve(ret.size() + nodes.size());
614
      for (const auto &node: nodes) {
615
        ret.push_back(node + DNSName(d_name));
616
      }
617
    }
618
    return ret;
619
  }
620
621
private:
622
  const SuffixMatchTree* getBestNode(const DNSName& name)  const
623
0
  {
624
0
    if (children.empty()) { // speed up empty set
625
0
      if (endNode) {
626
0
        return this;
627
0
      }
628
0
      return nullptr;
629
0
    }
630
0
631
0
    auto visitor = name.getRawLabelsVisitor();
632
0
    return getBestNode(visitor);
633
0
  }
634
635
  const SuffixMatchTree* getBestNode(DNSName::RawLabelsVisitor& visitor) const
636
0
  {
637
0
    if (visitor.empty()) { // optimization
638
0
      if (endNode) {
639
0
        return this;
640
0
      }
641
0
      return nullptr;
642
0
    }
643
0
644
0
    const LightKey lk{visitor.back()};
645
0
    auto child = children.find(lk);
646
0
    if (child == children.end()) {
647
0
      if (endNode) {
648
0
        return this;
649
0
      }
650
0
      return nullptr;
651
0
    }
652
0
    visitor.pop_back();
653
0
    auto result = child->getBestNode(visitor);
654
0
    if (result) {
655
0
      return result;
656
0
    }
657
0
    return endNode ? this : nullptr;
658
0
  }
659
660
  std::optional<DNSName> getBestMatch(DNSName::RawLabelsVisitor& visitor) const
661
0
  {
662
0
    if (visitor.empty()) { // optimization
663
0
      if (endNode) {
664
0
        return std::optional<DNSName>(d_name);
665
0
      }
666
0
      return std::nullopt;
667
0
    }
668
0
669
0
    const LightKey lk{visitor.back()};
670
0
    auto child = children.find(lk);
671
0
    if (child == children.end()) {
672
0
      if (endNode) {
673
0
        return std::optional<DNSName>(d_name);
674
0
      }
675
0
      return std::nullopt;
676
0
    }
677
0
    visitor.pop_back();
678
0
    auto result = child->getBestMatch(visitor);
679
0
    if (result) {
680
0
      if (!d_name.empty()) {
681
0
        result->appendRawLabel(d_name);
682
0
      }
683
0
      return result;
684
0
    }
685
0
    return endNode ? std::optional<DNSName>(d_name) : std::nullopt;
686
0
  }
687
};
688
689
/* Quest in life: serve as a rapid block list. If you add a DNSName to a root SuffixMatchNode,
690
   anything part of that domain will return 'true' in check */
691
struct SuffixMatchNode
692
{
693
  public:
694
    SuffixMatchNode() = default;
695
    SuffixMatchTree<bool> d_tree;
696
697
    void add(const DNSName& dnsname)
698
0
    {
699
0
      d_tree.add(dnsname, true);
700
0
      d_nodes.insert(dnsname);
701
0
    }
702
703
    void add(const std::string& name)
704
0
    {
705
0
      add(DNSName(name));
706
0
    }
707
708
    void add(std::vector<std::string> labels)
709
0
    {
710
0
      d_tree.add(labels, true);
711
0
      DNSName tmp;
712
0
      while (!labels.empty()) {
713
0
        tmp.appendRawLabel(labels.back());
714
0
        labels.pop_back(); // This is safe because we have a copy of labels
715
0
      }
716
0
      d_nodes.insert(tmp);
717
0
    }
718
719
    void remove(const DNSName& name)
720
0
    {
721
0
      d_tree.remove(name);
722
0
      d_nodes.erase(name);
723
0
    }
724
725
    void remove(std::vector<std::string> labels)
726
0
    {
727
0
      d_tree.remove(labels);
728
0
      DNSName tmp;
729
0
      while (!labels.empty()) {
730
0
        tmp.appendRawLabel(labels.back());
731
0
        labels.pop_back(); // This is safe because we have a copy of labels
732
0
      }
733
0
      d_nodes.erase(tmp);
734
0
    }
735
736
    bool check(const DNSName& dnsname) const
737
0
    {
738
0
      return d_tree.lookup(dnsname) != nullptr;
739
0
    }
740
741
    std::optional<DNSName> getBestMatch(const DNSName& name) const
742
0
    {
743
0
      return d_tree.getBestMatch(name);
744
0
    }
745
746
    std::string toString() const
747
0
    {
748
0
      std::string ret;
749
0
      bool first = true;
750
0
      for (const auto& n : d_nodes) {
751
0
        if (!first) {
752
0
          ret += ", ";
753
0
        }
754
0
        first = false;
755
0
        ret += n.toString();
756
0
      }
757
0
      return ret;
758
0
    }
759
760
  private:
761
    mutable std::set<DNSName> d_nodes; // Only used for string generation
762
};
763
764
std::ostream & operator<<(std::ostream &os, const DNSName& d);
765
namespace std {
766
    template <>
767
    struct hash<DNSName> {
768
0
        size_t operator () (const DNSName& dn) const { return dn.hash(0); }
769
    };
770
}
771
772
DNSName::string_t segmentDNSNameRaw(const char* input, size_t inputlen); // from ragel
773
774
bool DNSName::operator==(const DNSName& rhs) const
775
0
{
776
0
  if (rhs.empty() != empty() || rhs.d_storage.size() != d_storage.size()) {
777
0
    return false;
778
0
  }
779
780
0
  const auto* us = d_storage.cbegin();
781
0
  const auto* p = rhs.d_storage.cbegin();
782
0
  for (; us != d_storage.cend() && p != rhs.d_storage.cend(); ++us, ++p) {
783
0
    if (dns_tolower(*p) != dns_tolower(*us)) {
784
0
      return false;
785
0
    }
786
0
  }
787
0
  return true;
788
0
}
789
790
struct DNSNameSet: public std::unordered_set<DNSName> {
791
0
    std::string toString() const {
792
0
        std::ostringstream oss;
793
0
        std::copy(begin(), end(), std::ostream_iterator<DNSName>(oss, "\n"));
794
0
        return oss.str();
795
0
    }
796
};