Coverage Report

Created: 2023-03-26 07:17

/src/pdns/pdns/iputils.hh
Line
Count
Source (jump to first uncovered line)
1
/*
2
 * This file is part of PowerDNS or dnsdist.
3
 * Copyright -- PowerDNS.COM B.V. and its contributors
4
 *
5
 * This program is free software; you can redistribute it and/or modify
6
 * it under the terms of version 2 of the GNU General Public License as
7
 * published by the Free Software Foundation.
8
 *
9
 * In addition, for the avoidance of any doubt, permission is granted to
10
 * link this program with OpenSSL and to (re)distribute the binaries
11
 * produced as the result of such linking.
12
 *
13
 * This program is distributed in the hope that it will be useful,
14
 * but WITHOUT ANY WARRANTY; without even the implied warranty of
15
 * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the
16
 * GNU General Public License for more details.
17
 *
18
 * You should have received a copy of the GNU General Public License
19
 * along with this program; if not, write to the Free Software
20
 * Foundation, Inc., 51 Franklin Street, Fifth Floor, Boston, MA 02110-1301 USA.
21
 */
22
#pragma once
23
#include <string>
24
#include <sys/socket.h>
25
#include <netinet/in.h>
26
#include <arpa/inet.h>
27
#include <iostream>
28
#include <stdio.h>
29
#include <functional>
30
#include <bitset>
31
#include "pdnsexception.hh"
32
#include "misc.hh"
33
#include <netdb.h>
34
#include <sstream>
35
36
#include "namespaces.hh"
37
38
#ifdef __APPLE__
39
#include <libkern/OSByteOrder.h>
40
41
#define htobe16(x) OSSwapHostToBigInt16(x)
42
#define htole16(x) OSSwapHostToLittleInt16(x)
43
#define be16toh(x) OSSwapBigToHostInt16(x)
44
#define le16toh(x) OSSwapLittleToHostInt16(x)
45
46
#define htobe32(x) OSSwapHostToBigInt32(x)
47
#define htole32(x) OSSwapHostToLittleInt32(x)
48
#define be32toh(x) OSSwapBigToHostInt32(x)
49
#define le32toh(x) OSSwapLittleToHostInt32(x)
50
51
#define htobe64(x) OSSwapHostToBigInt64(x)
52
#define htole64(x) OSSwapHostToLittleInt64(x)
53
#define be64toh(x) OSSwapBigToHostInt64(x)
54
#define le64toh(x) OSSwapLittleToHostInt64(x)
55
#endif
56
57
#ifdef __sun
58
59
#define htobe16(x) BE_16(x)
60
#define htole16(x) LE_16(x)
61
#define be16toh(x) BE_IN16(&(x))
62
#define le16toh(x) LE_IN16(&(x))
63
64
#define htobe32(x) BE_32(x)
65
#define htole32(x) LE_32(x)
66
#define be32toh(x) BE_IN32(&(x))
67
#define le32toh(x) LE_IN32(&(x))
68
69
#define htobe64(x) BE_64(x)
70
#define htole64(x) LE_64(x)
71
#define be64toh(x) BE_IN64(&(x))
72
#define le64toh(x) LE_IN64(&(x))
73
74
#endif
75
76
#ifdef __FreeBSD__
77
#include <sys/endian.h>
78
#endif
79
80
#if defined(__NetBSD__) && defined(IP_PKTINFO) && !defined(IP_SENDSRCADDR)
81
// The IP_PKTINFO option in NetBSD was incompatible with Linux until a
82
// change that also introduced IP_SENDSRCADDR for FreeBSD compatibility.
83
#undef IP_PKTINFO
84
#endif
85
86
union ComboAddress {
87
  struct sockaddr_in sin4;
88
  struct sockaddr_in6 sin6;
89
90
  bool operator==(const ComboAddress& rhs) const
91
0
  {
92
0
    if(std::tie(sin4.sin_family, sin4.sin_port) != std::tie(rhs.sin4.sin_family, rhs.sin4.sin_port))
93
0
      return false;
94
0
    if(sin4.sin_family == AF_INET)
95
0
      return sin4.sin_addr.s_addr == rhs.sin4.sin_addr.s_addr;
96
0
    else
97
0
      return memcmp(&sin6.sin6_addr.s6_addr, &rhs.sin6.sin6_addr.s6_addr, sizeof(sin6.sin6_addr.s6_addr))==0;
98
0
  }
99
100
  bool operator!=(const ComboAddress& rhs) const
101
0
  {
102
0
    return(!operator==(rhs));
103
0
  }
104
105
  bool operator<(const ComboAddress& rhs) const
106
0
  {
107
0
    if(sin4.sin_family == 0) {
108
0
      return false;
109
0
    }
110
0
    if(std::tie(sin4.sin_family, sin4.sin_port) < std::tie(rhs.sin4.sin_family, rhs.sin4.sin_port))
111
0
      return true;
112
0
    if(std::tie(sin4.sin_family, sin4.sin_port) > std::tie(rhs.sin4.sin_family, rhs.sin4.sin_port))
113
0
      return false;
114
0
115
0
    if(sin4.sin_family == AF_INET)
116
0
      return sin4.sin_addr.s_addr < rhs.sin4.sin_addr.s_addr;
117
0
    else
118
0
      return memcmp(&sin6.sin6_addr.s6_addr, &rhs.sin6.sin6_addr.s6_addr, sizeof(sin6.sin6_addr.s6_addr)) < 0;
119
0
  }
120
121
  bool operator>(const ComboAddress& rhs) const
122
0
  {
123
0
    return rhs.operator<(*this);
124
0
  }
125
126
  struct addressOnlyHash
127
  {
128
    uint32_t operator()(const ComboAddress& ca) const
129
0
    {
130
0
      const unsigned char* start = nullptr;
131
0
      uint32_t len = 0;
132
0
      if (ca.sin4.sin_family == AF_INET) {
133
0
        start = reinterpret_cast<const unsigned char*>(&ca.sin4.sin_addr.s_addr);
134
0
        len = 4;
135
0
      }
136
0
      else {
137
0
        start = reinterpret_cast<const unsigned char*>(&ca.sin6.sin6_addr.s6_addr);
138
0
        len = 16;
139
0
      }
140
0
      return burtle(start, len, 0);
141
0
    }
142
  };
143
144
  struct addressOnlyLessThan
145
  {
146
    bool operator()(const ComboAddress& a, const ComboAddress& b) const
147
0
    {
148
0
      if(a.sin4.sin_family < b.sin4.sin_family)
149
0
        return true;
150
0
      if(a.sin4.sin_family > b.sin4.sin_family)
151
0
        return false;
152
0
      if(a.sin4.sin_family == AF_INET)
153
0
        return a.sin4.sin_addr.s_addr < b.sin4.sin_addr.s_addr;
154
0
      else
155
0
        return memcmp(&a.sin6.sin6_addr.s6_addr, &b.sin6.sin6_addr.s6_addr, sizeof(a.sin6.sin6_addr.s6_addr)) < 0;
156
0
    }
157
  };
158
159
  struct addressOnlyEqual
160
  {
161
    bool operator()(const ComboAddress& a, const ComboAddress& b) const
162
0
    {
163
0
      if(a.sin4.sin_family != b.sin4.sin_family)
164
0
        return false;
165
0
      if(a.sin4.sin_family == AF_INET)
166
0
        return a.sin4.sin_addr.s_addr == b.sin4.sin_addr.s_addr;
167
0
      else
168
0
        return !memcmp(&a.sin6.sin6_addr.s6_addr, &b.sin6.sin6_addr.s6_addr, sizeof(a.sin6.sin6_addr.s6_addr));
169
0
    }
170
  };
171
172
173
  socklen_t getSocklen() const
174
0
  {
175
0
    if(sin4.sin_family == AF_INET)
176
0
      return sizeof(sin4);
177
0
    else
178
0
      return sizeof(sin6);
179
0
  }
180
181
  ComboAddress()
182
139k
  {
183
139k
    sin4.sin_family=AF_INET;
184
139k
    sin4.sin_addr.s_addr=0;
185
139k
    sin4.sin_port=0;
186
139k
    sin6.sin6_scope_id = 0;
187
139k
    sin6.sin6_flowinfo = 0;
188
139k
  }
189
190
0
  ComboAddress(const struct sockaddr *sa, socklen_t salen) {
191
0
    setSockaddr(sa, salen);
192
0
  };
193
194
0
  ComboAddress(const struct sockaddr_in6 *sa) {
195
0
    setSockaddr((const struct sockaddr*)sa, sizeof(struct sockaddr_in6));
196
0
  };
197
198
0
  ComboAddress(const struct sockaddr_in *sa) {
199
0
    setSockaddr((const struct sockaddr*)sa, sizeof(struct sockaddr_in));
200
0
  };
201
202
0
  void setSockaddr(const struct sockaddr *sa, socklen_t salen) {
203
0
    if (salen > sizeof(struct sockaddr_in6)) throw PDNSException("ComboAddress can't handle other than sockaddr_in or sockaddr_in6");
204
0
    memcpy(this, sa, salen);
205
0
  }
206
207
  // 'port' sets a default value in case 'str' does not set a port
208
  explicit ComboAddress(const string& str, uint16_t port=0)
209
0
  {
210
0
    memset(&sin6, 0, sizeof(sin6));
211
0
    sin4.sin_family = AF_INET;
212
0
    sin4.sin_port = 0;
213
0
    if(makeIPv4sockaddr(str, &sin4)) {
214
0
      sin6.sin6_family = AF_INET6;
215
0
      if(makeIPv6sockaddr(str, &sin6) < 0)
216
0
        throw PDNSException("Unable to convert presentation address '"+ str +"'");
217
218
0
    }
219
0
    if(!sin4.sin_port) // 'str' overrides port!
220
0
      sin4.sin_port=htons(port);
221
0
  }
222
223
  bool isIPv6() const
224
5.47k
  {
225
5.47k
    return sin4.sin_family == AF_INET6;
226
5.47k
  }
227
  bool isIPv4() const
228
7.35k
  {
229
7.35k
    return sin4.sin_family == AF_INET;
230
7.35k
  }
231
232
  bool isMappedIPv4()  const
233
0
  {
234
0
    if(sin4.sin_family!=AF_INET6)
235
0
      return false;
236
0
237
0
    int n=0;
238
0
    const unsigned char* ptr = reinterpret_cast<const unsigned char*>(&sin6.sin6_addr.s6_addr);
239
0
    for(n=0; n < 10; ++n)
240
0
      if(ptr[n])
241
0
        return false;
242
0
243
0
    for(; n < 12; ++n)
244
0
      if(ptr[n]!=0xff)
245
0
        return false;
246
0
247
0
    return true;
248
0
  }
249
250
  ComboAddress mapToIPv4() const
251
0
  {
252
0
    if(!isMappedIPv4())
253
0
      throw PDNSException("ComboAddress can't map non-mapped IPv6 address back to IPv4");
254
0
    ComboAddress ret;
255
0
    ret.sin4.sin_family=AF_INET;
256
0
    ret.sin4.sin_port=sin4.sin_port;
257
0
258
0
    const unsigned char* ptr = reinterpret_cast<const unsigned char*>(&sin6.sin6_addr.s6_addr);
259
0
    ptr+=(sizeof(sin6.sin6_addr.s6_addr) - sizeof(ret.sin4.sin_addr.s_addr));
260
0
    memcpy(&ret.sin4.sin_addr.s_addr, ptr, sizeof(ret.sin4.sin_addr.s_addr));
261
0
    return ret;
262
0
  }
263
264
  string toString() const
265
0
  {
266
0
    char host[1024];
267
0
    int retval = 0;
268
0
    if(sin4.sin_family && !(retval = getnameinfo(reinterpret_cast<const struct sockaddr*>(this), getSocklen(), host, sizeof(host),0, 0, NI_NUMERICHOST)))
269
0
      return string(host);
270
0
    else
271
0
      return "invalid "+string(gai_strerror(retval));
272
0
  }
273
274
  //! Ignores any interface specifiers possibly available in the sockaddr data.
275
  string toStringNoInterface() const
276
0
  {
277
0
    char host[1024];
278
0
    if(sin4.sin_family == AF_INET && (nullptr != inet_ntop(sin4.sin_family, &sin4.sin_addr, host, sizeof(host))))
279
0
      return string(host);
280
0
    else if(sin4.sin_family == AF_INET6 && (nullptr != inet_ntop(sin4.sin_family, &sin6.sin6_addr, host, sizeof(host))))
281
0
      return string(host);
282
0
    else
283
0
      return "invalid "+stringerror();
284
0
  }
285
286
  [[nodiscard]] string toStringReversed() const
287
0
  {
288
0
    if (isIPv4()) {
289
0
      const auto ip = ntohl(sin4.sin_addr.s_addr);
290
0
      auto a = (ip >> 0) & 0xFF;
291
0
      auto b = (ip >> 8) & 0xFF;
292
0
      auto c = (ip >> 16) & 0xFF;
293
0
      auto d = (ip >> 24) & 0xFF;
294
0
      return std::to_string(a) + "." + std::to_string(b) + "." + std::to_string(c) + "." + std::to_string(d);
295
0
    }
296
0
    else {
297
0
      const auto* addr = &sin6.sin6_addr;
298
0
      std::stringstream res{};
299
0
      res << std::hex;
300
0
      for (int i = 15; i >= 0; i--) {
301
0
        auto byte = addr->s6_addr[i];
302
0
        res << ((byte >> 0) & 0xF) << ".";
303
0
        res << ((byte >> 4) & 0xF);
304
0
        if (i != 0) {
305
0
          res << ".";
306
0
        }
307
0
      }
308
0
      return res.str();
309
0
    }
310
0
  }
311
312
  string toStringWithPort() const
313
0
  {
314
0
    if(sin4.sin_family==AF_INET)
315
0
      return toString() + ":" + std::to_string(ntohs(sin4.sin_port));
316
0
    else
317
0
      return "["+toString() + "]:" + std::to_string(ntohs(sin4.sin_port));
318
0
  }
319
320
  string toStringWithPortExcept(int port) const
321
0
  {
322
0
    if(ntohs(sin4.sin_port) == port)
323
0
      return toString();
324
0
    if(sin4.sin_family==AF_INET)
325
0
      return toString() + ":" + std::to_string(ntohs(sin4.sin_port));
326
0
    else
327
0
      return "["+toString() + "]:" + std::to_string(ntohs(sin4.sin_port));
328
0
  }
329
330
  string toLogString() const
331
0
  {
332
0
    return toStringWithPortExcept(53);
333
0
  }
334
335
  string toByteString() const
336
0
  {
337
0
    if (isIPv4()) {
338
0
      return string(reinterpret_cast<const char*>(&sin4.sin_addr.s_addr), sizeof(sin4.sin_addr.s_addr));
339
0
    }
340
0
    return string(reinterpret_cast<const char*>(&sin6.sin6_addr.s6_addr), sizeof(sin6.sin6_addr.s6_addr));
341
0
  }
342
343
  void truncate(unsigned int bits) noexcept;
344
345
  uint16_t getPort() const
346
0
  {
347
0
    return ntohs(sin4.sin_port);
348
0
  }
349
350
  void setPort(uint16_t port)
351
0
  {
352
0
    sin4.sin_port = htons(port);
353
0
  }
354
355
  void reset()
356
0
  {
357
0
    memset(&sin4, 0, sizeof(sin4));
358
0
    memset(&sin6, 0, sizeof(sin6));
359
0
  }
360
361
  //! Get the total number of address bits (either 32 or 128 depending on IP version)
362
  uint8_t getBits() const
363
0
  {
364
0
    if (isIPv4())
365
0
      return 32;
366
0
    if (isIPv6())
367
0
      return 128;
368
0
    return 0;
369
0
  }
370
  /** Get the value of the bit at the provided bit index. When the index >= 0,
371
      the index is relative to the LSB starting at index zero. When the index < 0,
372
      the index is relative to the MSB starting at index -1 and counting down.
373
   */
374
  bool getBit(int index) const
375
0
  {
376
0
    if(isIPv4()) {
377
0
      if (index >= 32)
378
0
        return false;
379
0
      if (index < 0) {
380
0
        if (index < -32)
381
0
          return false;
382
0
        index = 32 + index;
383
0
      }
384
0
385
0
      uint32_t ls_addr = ntohl(sin4.sin_addr.s_addr);
386
0
387
0
      return ((ls_addr & (1U<<index)) != 0x00000000);
388
0
    }
389
0
    if(isIPv6()) {
390
0
      if (index >= 128)
391
0
        return false;
392
0
      if (index < 0) {
393
0
        if (index < -128)
394
0
          return false;
395
0
        index = 128 + index;
396
0
      }
397
0
398
0
      const uint8_t* ls_addr = reinterpret_cast<const uint8_t*>(sin6.sin6_addr.s6_addr);
399
0
      uint8_t byte_idx = index / 8;
400
0
      uint8_t bit_idx = index % 8;
401
0
402
0
      return ((ls_addr[15-byte_idx] & (1U << bit_idx)) != 0x00);
403
0
    }
404
0
    return false;
405
0
  }
406
407
  /*! Returns a comma-separated string of IP addresses
408
   *
409
   * \param c  An stl container with ComboAddresses
410
   * \param withPort  Also print the port (default true)
411
   * \param portExcept  Print the port, except when this is the port (default 53)
412
   */
413
  template < template < class ... > class Container, class ... Args >
414
0
  static string caContainerToString(const Container<ComboAddress, Args...>& c, const bool withPort = true, const uint16_t portExcept = 53) {
415
0
  vector<string> strs;
416
0
  for (const auto& ca : c) {
417
0
    if (withPort) {
418
0
      strs.push_back(ca.toStringWithPortExcept(portExcept));
419
0
      continue;
420
0
    }
421
0
    strs.push_back(ca.toString());
422
0
  }
423
0
  return boost::join(strs, ",");
424
0
  };
425
};
426
427
/** This exception is thrown by the Netmask class and by extension by the NetmaskGroup class */
428
class NetmaskException: public PDNSException
429
{
430
public:
431
0
  NetmaskException(const string &a) : PDNSException(a) {}
432
};
433
434
inline ComboAddress makeComboAddress(const string& str)
435
0
{
436
0
  ComboAddress address;
437
0
  address.sin4.sin_family=AF_INET;
438
0
  if(inet_pton(AF_INET, str.c_str(), &address.sin4.sin_addr) <= 0) {
439
0
    address.sin4.sin_family=AF_INET6;
440
0
    if(makeIPv6sockaddr(str, &address.sin6) < 0)
441
0
      throw NetmaskException("Unable to convert '"+str+"' to a netmask");
442
0
  }
443
0
  return address;
444
0
}
445
446
inline ComboAddress makeComboAddressFromRaw(uint8_t version, const char* raw, size_t len)
447
69.8k
{
448
69.8k
  ComboAddress address;
449
450
69.8k
  if (version == 4) {
451
51.3k
    address.sin4.sin_family = AF_INET;
452
51.3k
    if (len != sizeof(address.sin4.sin_addr)) throw NetmaskException("invalid raw address length");
453
51.3k
    memcpy(&address.sin4.sin_addr, raw, sizeof(address.sin4.sin_addr));
454
51.3k
  }
455
18.4k
  else if (version == 6) {
456
18.4k
    address.sin6.sin6_family = AF_INET6;
457
18.4k
    if (len != sizeof(address.sin6.sin6_addr)) throw NetmaskException("invalid raw address length");
458
18.4k
    memcpy(&address.sin6.sin6_addr, raw, sizeof(address.sin6.sin6_addr));
459
18.4k
  }
460
0
  else throw NetmaskException("invalid address family");
461
462
69.8k
  return address;
463
69.8k
}
464
465
inline ComboAddress makeComboAddressFromRaw(uint8_t version, const string &str)
466
69.8k
{
467
69.8k
  return makeComboAddressFromRaw(version, str.c_str(), str.size());
468
69.8k
}
469
470
/** This class represents a netmask and can be queried to see if a certain
471
    IP address is matched by this mask */
472
class Netmask
473
{
474
public:
475
  Netmask()
476
0
  {
477
0
    d_network.sin4.sin_family = 0; // disable this doing anything useful
478
0
    d_network.sin4.sin_port = 0; // this guarantees d_network compares identical
479
0
    d_mask = 0;
480
0
    d_bits = 0;
481
0
  }
482
483
  Netmask(const ComboAddress& network, uint8_t bits=0xff): d_network(network)
484
0
  {
485
0
    d_network.sin4.sin_port = 0;
486
0
    setBits(network.isIPv4() ? std::min(bits, static_cast<uint8_t>(32)) : std::min(bits, static_cast<uint8_t>(128)));
487
0
  }
488
489
  Netmask(const sockaddr_in* network, uint8_t bits = 0xff): d_network(network)
490
0
  {
491
0
    d_network.sin4.sin_port = 0;
492
0
    setBits(std::min(bits, static_cast<uint8_t>(32)));
493
0
  }
494
  Netmask(const sockaddr_in6* network, uint8_t bits = 0xff): d_network(network)
495
0
  {
496
0
    d_network.sin4.sin_port = 0;
497
0
    setBits(std::min(bits, static_cast<uint8_t>(128)));
498
0
  }
499
  void setBits(uint8_t value)
500
0
  {
501
0
    d_bits = value;
502
503
0
    if (d_bits < 32) {
504
0
      d_mask = ~(0xFFFFFFFF >> d_bits);
505
0
    }
506
0
    else {
507
      // note that d_mask is unused for IPv6
508
0
      d_mask = 0xFFFFFFFF;
509
0
    }
510
511
0
    if (isIPv4()) {
512
0
      d_network.sin4.sin_addr.s_addr = htonl(ntohl(d_network.sin4.sin_addr.s_addr) & d_mask);
513
0
    }
514
0
    else if (isIPv6()) {
515
0
      uint8_t bytes = d_bits/8;
516
0
      uint8_t *us = (uint8_t*) &d_network.sin6.sin6_addr.s6_addr;
517
0
      uint8_t bits = d_bits % 8;
518
0
      uint8_t mask = (uint8_t) ~(0xFF>>bits);
519
520
0
      if (bytes < sizeof(d_network.sin6.sin6_addr.s6_addr)) {
521
0
        us[bytes] &= mask;
522
0
      }
523
524
0
      for(size_t idx = bytes + 1; idx < sizeof(d_network.sin6.sin6_addr.s6_addr); ++idx) {
525
0
        us[idx] = 0;
526
0
      }
527
0
    }
528
0
  }
529
530
  //! Constructor supplies the mask, which cannot be changed
531
  Netmask(const string &mask)
532
0
  {
533
0
    pair<string,string> split = splitField(mask,'/');
534
0
    d_network = makeComboAddress(split.first);
535
536
0
    if (!split.second.empty()) {
537
0
      setBits(pdns::checked_stoi<uint8_t>(split.second));
538
0
    }
539
0
    else if (d_network.sin4.sin_family == AF_INET) {
540
0
      setBits(32);
541
0
    }
542
0
    else {
543
0
      setBits(128);
544
0
    }
545
0
  }
546
547
  bool match(const ComboAddress& ip) const
548
0
  {
549
0
    return match(&ip);
550
0
  }
551
552
  //! If this IP address in socket address matches
553
  bool match(const ComboAddress *ip) const
554
0
  {
555
0
    if(d_network.sin4.sin_family != ip->sin4.sin_family) {
556
0
      return false;
557
0
    }
558
0
    if(d_network.sin4.sin_family == AF_INET) {
559
0
      return match4(htonl((unsigned int)ip->sin4.sin_addr.s_addr));
560
0
    }
561
0
    if(d_network.sin6.sin6_family == AF_INET6) {
562
0
      uint8_t bytes=d_bits/8, n;
563
0
      const uint8_t *us=(const uint8_t*) &d_network.sin6.sin6_addr.s6_addr;
564
0
      const uint8_t *them=(const uint8_t*) &ip->sin6.sin6_addr.s6_addr;
565
0
566
0
      for(n=0; n < bytes; ++n) {
567
0
        if(us[n]!=them[n]) {
568
0
          return false;
569
0
        }
570
0
      }
571
0
      // still here, now match remaining bits
572
0
      uint8_t bits= d_bits % 8;
573
0
      uint8_t mask= (uint8_t) ~(0xFF>>bits);
574
0
575
0
      return((us[n]) == (them[n] & mask));
576
0
    }
577
0
    return false;
578
0
  }
579
580
  //! If this ASCII IP address matches
581
  bool match(const string &ip) const
582
0
  {
583
0
    ComboAddress address=makeComboAddress(ip);
584
0
    return match(&address);
585
0
  }
586
587
  //! If this IP address in native format matches
588
  bool match4(uint32_t ip) const
589
0
  {
590
0
    return (ip & d_mask) == (ntohl(d_network.sin4.sin_addr.s_addr));
591
0
  }
592
593
  string toString() const
594
0
  {
595
0
    return d_network.toStringNoInterface()+"/"+std::to_string((unsigned int)d_bits);
596
0
  }
597
598
  string toStringNoMask() const
599
0
  {
600
0
    return d_network.toStringNoInterface();
601
0
  }
602
603
  const ComboAddress& getNetwork() const
604
0
  {
605
0
    return d_network;
606
0
  }
607
608
  const ComboAddress& getMaskedNetwork() const
609
0
  {
610
0
    return getNetwork();
611
0
  }
612
613
  uint8_t getBits() const
614
0
  {
615
0
    return d_bits;
616
0
  }
617
618
  bool isIPv6() const
619
0
  {
620
0
    return d_network.sin6.sin6_family == AF_INET6;
621
0
  }
622
623
  bool isIPv4() const
624
0
  {
625
0
    return d_network.sin4.sin_family == AF_INET;
626
0
  }
627
628
  bool operator<(const Netmask& rhs) const
629
0
  {
630
0
    if (empty() && !rhs.empty())
631
0
      return false;
632
0
633
0
    if (!empty() && rhs.empty())
634
0
      return true;
635
0
636
0
    if (d_bits > rhs.d_bits)
637
0
      return true;
638
0
    if (d_bits < rhs.d_bits)
639
0
      return false;
640
0
641
0
    return d_network < rhs.d_network;
642
0
  }
643
644
  bool operator>(const Netmask& rhs) const
645
0
  {
646
0
    return rhs.operator<(*this);
647
0
  }
648
649
  bool operator==(const Netmask& rhs) const
650
0
  {
651
0
    return std::tie(d_network, d_bits) == std::tie(rhs.d_network, rhs.d_bits);
652
0
  }
653
654
  bool empty() const
655
0
  {
656
0
    return d_network.sin4.sin_family==0;
657
0
  }
658
659
  //! Get normalized version of the netmask. This means that all address bits below the network bits are zero.
660
0
  Netmask getNormalized() const {
661
0
    return Netmask(getMaskedNetwork(), d_bits);
662
0
  }
663
  //! Get Netmask for super network of this one (i.e. with fewer network bits)
664
0
  Netmask getSuper(uint8_t bits) const {
665
0
    return Netmask(d_network, std::min(d_bits, bits));
666
0
  }
667
668
  //! Get the total number of address bits for this netmask (either 32 or 128 depending on IP version)
669
  uint8_t getFullBits() const
670
0
  {
671
0
    return d_network.getBits();
672
0
  }
673
674
  /** Get the value of the bit at the provided bit index. When the index >= 0,
675
      the index is relative to the LSB starting at index zero. When the index < 0,
676
      the index is relative to the MSB starting at index -1 and counting down.
677
      When the index points outside the network bits, it always yields zero.
678
   */
679
  bool getBit(int bit) const
680
0
  {
681
0
    if (bit < -d_bits)
682
0
      return false;
683
0
    if (bit >= 0) {
684
0
      if(isIPv4()) {
685
0
        if (bit >= 32 || bit < (32 - d_bits))
686
0
          return false;
687
0
      }
688
0
      if(isIPv6()) {
689
0
        if (bit >= 128 || bit < (128 - d_bits))
690
0
          return false;
691
0
      }
692
0
    }
693
0
    return d_network.getBit(bit);
694
0
  }
695
696
  struct Hash {
697
    size_t operator()(const Netmask& nm) const
698
0
    {
699
0
      return burtle(&nm.d_bits, 1, ComboAddress::addressOnlyHash()(nm.d_network));
700
0
    }
701
  };
702
703
private:
704
  ComboAddress d_network;
705
  uint32_t d_mask;
706
  uint8_t d_bits;
707
};
708
709
namespace std {
710
  template<>
711
  struct hash<Netmask> {
712
0
    auto operator()(const Netmask& nm) const {
713
0
      return Netmask::Hash{}(nm);
714
0
    }
715
  };
716
}
717
718
/** Binary tree map implementation with <Netmask,T> pair.
719
 *
720
 * This is an binary tree implementation for storing attributes for IPv4 and IPv6 prefixes.
721
 * The most simple use case is simple NetmaskTree<bool> used by NetmaskGroup, which only
722
 * wants to know if given IP address is matched in the prefixes stored.
723
 *
724
 * This element is useful for anything that needs to *STORE* prefixes, and *MATCH* IP addresses
725
 * to a *LIST* of *PREFIXES*. Not the other way round.
726
 *
727
 * You can store IPv4 and IPv6 addresses to same tree, separate payload storage is kept per AFI.
728
 * Network prefixes (Netmasks) are always recorded in normalized fashion, meaning that only
729
 * the network bits are set. This is what is returned in the insert() and lookup() return
730
 * values.
731
 *
732
 * Use swap if you need to move the tree to another NetmaskTree instance, it is WAY faster
733
 * than using copy ctor or assignment operator, since it moves the nodes and tree root to
734
 * new home instead of actually recreating the tree.
735
 *
736
 * Please see NetmaskGroup for example of simple use case. Other usecases can be found
737
 * from GeoIPBackend and Sortlist, and from dnsdist.
738
 */
739
template <typename T, class K = Netmask>
740
class NetmaskTree {
741
public:
742
  class Iterator;
743
744
  typedef K key_type;
745
  typedef T value_type;
746
  typedef std::pair<const key_type,value_type> node_type;
747
  typedef size_t size_type;
748
  typedef class Iterator iterator;
749
750
private:
751
  /** Single node in tree, internal use only.
752
    */
753
  class TreeNode : boost::noncopyable {
754
  public:
755
    explicit TreeNode() noexcept :
756
      parent(nullptr), node(), assigned(false), d_bits(0) {
757
    }
758
    explicit TreeNode(const key_type& key) :
759
      parent(nullptr), node({key.getNormalized(), value_type()}),
760
      assigned(false), d_bits(key.getFullBits()) {
761
    }
762
763
    //<! Makes a left leaf node with specified key.
764
0
    TreeNode* make_left(const key_type& key) {
765
0
      d_bits = node.first.getBits();
766
0
      left = make_unique<TreeNode>(key);
767
0
      left->parent = this;
768
0
      return left.get();
769
0
    }
770
771
    //<! Makes a right leaf node with specified key.
772
0
    TreeNode* make_right(const key_type& key) {
773
0
      d_bits = node.first.getBits();
774
0
      right = make_unique<TreeNode>(key);
775
0
      right->parent = this;
776
0
      return right.get();
777
0
    }
778
779
    //<! Splits branch at indicated bit position by inserting key
780
0
    TreeNode* split(const key_type& key, int bits) {
781
0
      if (parent == nullptr) {
782
0
        // not to be called on the root node
783
0
        throw std::logic_error(
784
0
          "NetmaskTree::TreeNode::split(): must not be called on root node");
785
0
      }
786
0
787
0
      // determine reference from parent
788
0
      unique_ptr<TreeNode>& parent_ref =
789
0
        (parent->left.get() == this ? parent->left : parent->right);
790
0
      if (parent_ref.get() != this) {
791
0
        throw std::logic_error(
792
0
          "NetmaskTree::TreeNode::split(): parent node reference is invalid");
793
0
      }
794
0
795
0
      // create new tree node for the new key
796
0
      TreeNode* new_node = new TreeNode(key);
797
0
      new_node->d_bits = bits;
798
0
799
0
      // attach the new node under our former parent
800
0
      unique_ptr<TreeNode> new_child(new_node);
801
0
      std::swap(parent_ref, new_child); // hereafter new_child points to "this"
802
0
      new_node->parent = parent;
803
0
804
0
      // attach "this" node below the new node
805
0
      // (left or right depending on bit)
806
0
      new_child->parent = new_node;
807
0
      if (new_child->node.first.getBit(-1-bits)) {
808
0
        std::swap(new_node->right, new_child);
809
0
      } else {
810
0
        std::swap(new_node->left, new_child);
811
0
      }
812
0
813
0
      return new_node;
814
0
    }
815
816
    //<! Forks branch for new key at indicated bit position
817
0
    TreeNode* fork(const key_type& key, int bits) {
818
0
      if (parent == nullptr) {
819
0
        // not to be called on the root node
820
0
        throw std::logic_error(
821
0
          "NetmaskTree::TreeNode::fork(): must not be called on root node");
822
0
      }
823
0
824
0
      // determine reference from parent
825
0
      unique_ptr<TreeNode>& parent_ref =
826
0
        (parent->left.get() == this ? parent->left : parent->right);
827
0
      if (parent_ref.get() != this) {
828
0
        throw std::logic_error(
829
0
          "NetmaskTree::TreeNode::fork(): parent node reference is invalid");
830
0
      }
831
0
832
0
      // create new tree node for the branch point
833
0
      TreeNode* branch_node = new TreeNode(node.first.getSuper(bits));
834
0
      branch_node->d_bits = bits;
835
0
836
0
      // the current node will now be a child of the new branch node
837
0
      // (hereafter new_child1 points to "this")
838
0
      unique_ptr<TreeNode> new_child1 = std::move(parent_ref);
839
0
      // attach the branch node under our former parent
840
0
      parent_ref = std::unique_ptr<TreeNode>(branch_node);
841
0
      branch_node->parent = parent;
842
0
843
0
      // create second new leaf node for the new key
844
0
      unique_ptr<TreeNode> new_child2 = make_unique<TreeNode>(key);
845
0
      TreeNode* new_node = new_child2.get();
846
0
847
0
      // attach the new child nodes below the branch node
848
0
      // (left or right depending on bit)
849
0
      new_child1->parent = branch_node;
850
0
      new_child2->parent = branch_node;
851
0
      if (new_child1->node.first.getBit(-1-bits)) {
852
0
        branch_node->right = std::move(new_child1);
853
0
        branch_node->left = std::move(new_child2);
854
0
      } else {
855
0
        branch_node->right = std::move(new_child2);
856
0
        branch_node->left = std::move(new_child1);
857
0
      }
858
0
      // now we have attached the new unique pointers to the tree:
859
0
      // - branch_node is below its parent
860
0
      // - new_child1 (ourselves) is below branch_node
861
0
      // - new_child2, the new leaf node, is below branch_node as well
862
0
863
0
      return new_node;
864
0
    }
865
866
    //<! Traverse left branch depth-first
867
    TreeNode *traverse_l()
868
0
    {
869
0
      TreeNode *tnode = this;
870
0
871
0
      while (tnode->left)
872
0
        tnode = tnode->left.get();
873
0
      return tnode;
874
0
    }
875
876
    //<! Traverse tree depth-first and in-order (L-N-R)
877
    TreeNode *traverse_lnr()
878
0
    {
879
0
      TreeNode *tnode = this;
880
0
881
0
      // precondition: descended left as deep as possible
882
0
      if (tnode->right) {
883
0
        // descend right
884
0
        tnode = tnode->right.get();
885
0
        // descend left as deep as possible and return next node
886
0
        return tnode->traverse_l();
887
0
      }
888
0
889
0
      // ascend to parent
890
0
      while (tnode->parent != nullptr) {
891
0
        TreeNode *prev_child = tnode;
892
0
        tnode = tnode->parent;
893
0
894
0
        // return this node, but only when we come from the left child branch
895
0
        if (tnode->left && tnode->left.get() == prev_child)
896
0
          return tnode;
897
0
      }
898
0
      return nullptr;
899
0
    }
900
901
    //<! Traverse only assigned nodes
902
    TreeNode *traverse_lnr_assigned()
903
0
    {
904
0
      TreeNode *tnode = traverse_lnr();
905
0
906
0
      while (tnode != nullptr && !tnode->assigned)
907
0
        tnode = tnode->traverse_lnr();
908
0
      return tnode;
909
0
    }
910
911
    unique_ptr<TreeNode> left;
912
    unique_ptr<TreeNode> right;
913
    TreeNode* parent;
914
915
    node_type node;
916
    bool assigned; //<! Whether this node is assigned-to by the application
917
918
    int d_bits; //<! How many bits have been used so far
919
  };
920
921
  void cleanup_tree(TreeNode* node)
922
0
  {
923
0
    // only cleanup this node if it has no children and node not assigned
924
0
    if (!(node->left || node->right || node->assigned)) {
925
0
      // get parent node ptr
926
0
      TreeNode* pparent = node->parent;
927
0
      // delete this node
928
0
      if (pparent) {
929
0
        if (pparent->left.get() == node)
930
0
          pparent->left.reset();
931
0
        else
932
0
          pparent->right.reset();
933
0
        // now recurse up to the parent
934
0
        cleanup_tree(pparent);
935
0
      }
936
0
    }
937
0
  }
938
939
  void copyTree(const NetmaskTree& rhs)
940
  {
941
    try {
942
      TreeNode *node = rhs.d_root.get();
943
      if (node != nullptr)
944
        node = node->traverse_l();
945
      while (node != nullptr) {
946
        if (node->assigned)
947
          insert(node->node.first).second = node->node.second;
948
        node = node->traverse_lnr();
949
      }
950
    }
951
    catch (const NetmaskException&) {
952
      abort();
953
    }
954
    catch (const std::logic_error&) {
955
      abort();
956
    }
957
  }
958
959
public:
960
  class Iterator {
961
  public:
962
    typedef node_type value_type;
963
    typedef node_type& reference;
964
    typedef node_type* pointer;
965
    typedef std::forward_iterator_tag iterator_category;
966
    typedef size_type difference_type;
967
968
  private:
969
    friend class NetmaskTree;
970
971
    const NetmaskTree* d_tree;
972
    TreeNode* d_node;
973
974
    Iterator(const NetmaskTree* tree, TreeNode* node): d_tree(tree), d_node(node) {
975
    }
976
977
  public:
978
    Iterator(): d_tree(nullptr), d_node(nullptr) {}
979
980
    Iterator& operator++() // prefix
981
0
    {
982
0
      if (d_node == nullptr) {
983
0
        throw std::logic_error(
984
0
          "NetmaskTree::Iterator::operator++: iterator is invalid");
985
0
      }
986
0
      d_node = d_node->traverse_lnr_assigned();
987
0
      return *this;
988
0
    }
989
    Iterator operator++(int) // postfix
990
    {
991
      Iterator tmp(*this);
992
      operator++();
993
      return tmp;
994
    }
995
996
    reference operator*()
997
0
    {
998
0
      if (d_node == nullptr) {
999
0
        throw std::logic_error(
1000
0
          "NetmaskTree::Iterator::operator*: iterator is invalid");
1001
0
      }
1002
0
      return d_node->node;
1003
0
    }
1004
1005
    pointer operator->()
1006
0
    {
1007
0
      if (d_node == nullptr) {
1008
0
        throw std::logic_error(
1009
0
          "NetmaskTree::Iterator::operator->: iterator is invalid");
1010
0
      }
1011
0
      return &d_node->node;
1012
0
    }
1013
1014
    bool operator==(const Iterator& rhs)
1015
0
    {
1016
0
      return (d_tree == rhs.d_tree && d_node == rhs.d_node);
1017
0
    }
1018
    bool operator!=(const Iterator& rhs)
1019
0
    {
1020
0
      return !(*this == rhs);
1021
0
    }
1022
  };
1023
1024
public:
1025
  NetmaskTree() noexcept: d_root(new TreeNode()), d_left(nullptr), d_size(0) {
1026
  }
1027
1028
  NetmaskTree(const NetmaskTree& rhs): d_root(new TreeNode()), d_left(nullptr), d_size(0) {
1029
    copyTree(rhs);
1030
  }
1031
1032
  NetmaskTree& operator=(const NetmaskTree& rhs) {
1033
    clear();
1034
    copyTree(rhs);
1035
    return *this;
1036
  }
1037
1038
0
  const iterator begin() const {
1039
0
    return Iterator(this, d_left);
1040
0
  }
1041
0
  const iterator end() const {
1042
0
    return Iterator(this, nullptr);
1043
0
  }
1044
  iterator begin() {
1045
    return Iterator(this, d_left);
1046
  }
1047
  iterator end() {
1048
    return Iterator(this, nullptr);
1049
  }
1050
1051
  node_type& insert(const string &mask) {
1052
    return insert(key_type(mask));
1053
  }
1054
1055
  //<! Creates new value-pair in tree and returns it.
1056
0
  node_type& insert(const key_type& key) {
1057
0
    TreeNode* node;
1058
0
    bool is_left = true;
1059
0
1060
0
    // we turn left on IPv4 and right on IPv6
1061
0
    if (key.isIPv4()) {
1062
0
      node = d_root->left.get();
1063
0
      if (node == nullptr) {
1064
0
        node = new TreeNode(key);
1065
0
        node->assigned = true;
1066
0
        node->parent = d_root.get();
1067
0
1068
0
        d_root->left = unique_ptr<TreeNode>(node);
1069
0
        d_size++;
1070
0
        d_left = node;
1071
0
        return node->node;
1072
0
      }
1073
0
    } else if (key.isIPv6()) {
1074
0
      node = d_root->right.get();
1075
0
      if (node == nullptr) {
1076
0
        node = new TreeNode(key);
1077
0
        node->assigned = true;
1078
0
        node->parent = d_root.get();
1079
0
1080
0
        d_root->right = unique_ptr<TreeNode>(node);
1081
0
        d_size++;
1082
0
        if (!d_root->left)
1083
0
          d_left = node;
1084
0
        return node->node;
1085
0
      }
1086
0
      if (d_root->left)
1087
0
        is_left = false;
1088
0
    } else
1089
0
      throw NetmaskException("invalid address family");
1090
0
1091
0
    // we turn left on 0 and right on 1
1092
0
    int bits = 0;
1093
0
    for(; bits < key.getBits(); bits++) {
1094
0
      bool vall = key.getBit(-1-bits);
1095
0
1096
0
      if (bits >= node->d_bits) {
1097
0
        // the end of the current node is reached; continue with the next
1098
0
        if (vall) {
1099
0
          if (node->left || node->assigned)
1100
0
            is_left = false;
1101
0
          if (!node->right) {
1102
0
            // the right branch doesn't exist yet; attach our key here
1103
0
            node = node->make_right(key);
1104
0
            break;
1105
0
          }
1106
0
          node = node->right.get();
1107
0
        } else {
1108
0
          if (!node->left) {
1109
0
            // the left branch doesn't exist yet; attach our key here
1110
0
            node = node->make_left(key);
1111
0
            break;
1112
0
          }
1113
0
          node = node->left.get();
1114
0
        }
1115
0
        continue;
1116
0
      }
1117
0
      if (bits >= node->node.first.getBits()) {
1118
0
        // the matching branch ends here, yet the key netmask has more bits; add a
1119
0
        // child node below the existing branch leaf.
1120
0
        if (vall) {
1121
0
          if (node->assigned)
1122
0
            is_left = false;
1123
0
          node = node->make_right(key);
1124
0
        } else {
1125
0
          node = node->make_left(key);
1126
0
        }
1127
0
        break;
1128
0
      }
1129
0
      bool valr = node->node.first.getBit(-1-bits);
1130
0
      if (vall != valr) {
1131
0
        if (vall)
1132
0
          is_left = false;
1133
0
        // the branch matches just upto this point, yet continues in a different
1134
0
        // direction; fork the branch.
1135
0
        node = node->fork(key, bits);
1136
0
        break;
1137
0
      }
1138
0
    }
1139
0
1140
0
    if (node->node.first.getBits() > key.getBits()) {
1141
0
      // key is a super-network of the matching node; split the branch and
1142
0
      // insert a node for the key above the matching node.
1143
0
      node = node->split(key, key.getBits());
1144
0
    }
1145
0
1146
0
    if (node->left)
1147
0
      is_left = false;
1148
0
1149
0
    node_type& value = node->node;
1150
0
1151
0
    if (!node->assigned) {
1152
0
      // only increment size if not assigned before
1153
0
      d_size++;
1154
0
      // update the pointer to the left-most tree node
1155
0
      if (is_left)
1156
0
        d_left = node;
1157
0
      node->assigned = true;
1158
0
    } else {
1159
0
      // tree node exists for this value
1160
0
      if (is_left && d_left != node) {
1161
0
        throw std::logic_error(
1162
0
          "NetmaskTree::insert(): lost track of left-most node in tree");
1163
0
      }
1164
0
    }
1165
0
1166
0
    return value;
1167
0
  }
1168
1169
  //<! Creates or updates value
1170
  void insert_or_assign(const key_type& mask, const value_type& value) {
1171
    insert(mask).second = value;
1172
  }
1173
1174
  void insert_or_assign(const string& mask, const value_type& value) {
1175
    insert(key_type(mask)).second = value;
1176
  }
1177
1178
  //<! check if given key is present in TreeMap
1179
  bool has_key(const key_type& key) const {
1180
    const node_type *ptr = lookup(key);
1181
    return ptr && ptr->first == key;
1182
  }
1183
1184
  //<! Returns "best match" for key_type, which might not be value
1185
  const node_type* lookup(const key_type& value) const {
1186
    uint8_t max_bits = value.getBits();
1187
    return lookupImpl(value, max_bits);
1188
  }
1189
1190
  //<! Perform best match lookup for value, using at most max_bits
1191
0
  const node_type* lookup(const ComboAddress& value, int max_bits = 128) const {
1192
0
    uint8_t addr_bits = value.getBits();
1193
0
    if (max_bits < 0 || max_bits > addr_bits) {
1194
0
      max_bits = addr_bits;
1195
0
    }
1196
0
1197
0
    return lookupImpl(key_type(value, max_bits), max_bits);
1198
0
  }
1199
1200
  //<! Removes key from TreeMap.
1201
0
  void erase(const key_type& key) {
1202
0
    TreeNode *node = nullptr;
1203
0
1204
0
    if (key.isIPv4())
1205
0
      node = d_root->left.get();
1206
0
    else if (key.isIPv6())
1207
0
      node = d_root->right.get();
1208
0
    else
1209
0
      throw NetmaskException("invalid address family");
1210
0
    // no tree, no value
1211
0
    if (node == nullptr) return;
1212
0
1213
0
    int bits = 0;
1214
0
    for(; node && bits < key.getBits(); bits++) {
1215
0
      bool vall = key.getBit(-1-bits);
1216
0
      if (bits >= node->d_bits) {
1217
0
        // the end of the current node is reached; continue with the next
1218
0
        if (vall) {
1219
0
          node = node->right.get();
1220
0
        } else {
1221
0
          node = node->left.get();
1222
0
        }
1223
0
        continue;
1224
0
      }
1225
0
      if (bits >= node->node.first.getBits()) {
1226
0
        // the matching branch ends here
1227
0
        if (key.getBits() != node->node.first.getBits())
1228
0
          node = nullptr;
1229
0
        break;
1230
0
      }
1231
0
      bool valr = node->node.first.getBit(-1-bits);
1232
0
      if (vall != valr) {
1233
0
        // the branch matches just upto this point, yet continues in a different
1234
0
        // direction
1235
0
        node = nullptr;
1236
0
        break;
1237
0
      }
1238
0
    }
1239
0
    if (node) {
1240
0
      if (d_size == 0) {
1241
0
        throw std::logic_error(
1242
0
          "NetmaskTree::erase(): size of tree is zero before erase");
1243
0
      }
1244
0
      d_size--;
1245
0
      node->assigned = false;
1246
0
      node->node.second = value_type();
1247
0
1248
0
      if (node == d_left)
1249
0
        d_left = d_left->traverse_lnr_assigned();
1250
0
1251
0
      cleanup_tree(node);
1252
0
    }
1253
0
  }
1254
1255
  void erase(const string& key) {
1256
    erase(key_type(key));
1257
  }
1258
1259
  //<! checks whether the container is empty.
1260
0
  bool empty() const {
1261
0
    return (d_size == 0);
1262
0
  }
1263
1264
  //<! returns the number of elements
1265
0
  size_type size() const {
1266
0
    return d_size;
1267
0
  }
1268
1269
  //<! See if given ComboAddress matches any prefix
1270
  bool match(const ComboAddress& value) const {
1271
    return (lookup(value) != nullptr);
1272
  }
1273
1274
  bool match(const std::string& value) const {
1275
    return match(ComboAddress(value));
1276
  }
1277
1278
  //<! Clean out the tree
1279
0
  void clear() {
1280
0
    d_root.reset(new TreeNode());
1281
0
    d_left = nullptr;
1282
0
    d_size = 0;
1283
0
  }
1284
1285
  //<! swaps the contents with another NetmaskTree
1286
  void swap(NetmaskTree& rhs) {
1287
    std::swap(d_root, rhs.d_root);
1288
    std::swap(d_left, rhs.d_left);
1289
    std::swap(d_size, rhs.d_size);
1290
  }
1291
1292
private:
1293
1294
0
  const node_type* lookupImpl(const key_type& value, uint8_t max_bits) const {
1295
0
    TreeNode *node = nullptr;
1296
0
1297
0
    if (value.isIPv4())
1298
0
      node = d_root->left.get();
1299
0
    else if (value.isIPv6())
1300
0
      node = d_root->right.get();
1301
0
    else
1302
0
      throw NetmaskException("invalid address family");
1303
0
    if (node == nullptr) return nullptr;
1304
0
1305
0
    node_type *ret = nullptr;
1306
0
1307
0
    int bits = 0;
1308
0
    for(; bits < max_bits; bits++) {
1309
0
      bool vall = value.getBit(-1-bits);
1310
0
      if (bits >= node->d_bits) {
1311
0
        // the end of the current node is reached; continue with the next
1312
0
        // (we keep track of last assigned node)
1313
0
        if (node->assigned && bits == node->node.first.getBits())
1314
0
          ret = &node->node;
1315
0
        if (vall) {
1316
0
          if (!node->right)
1317
0
            break;
1318
0
          node = node->right.get();
1319
0
        } else {
1320
0
          if (!node->left)
1321
0
            break;
1322
0
          node = node->left.get();
1323
0
        }
1324
0
        continue;
1325
0
      }
1326
0
      if (bits >= node->node.first.getBits()) {
1327
0
        // the matching branch ends here
1328
0
        break;
1329
0
      }
1330
0
      bool valr = node->node.first.getBit(-1-bits);
1331
0
      if (vall != valr) {
1332
0
        // the branch matches just upto this point, yet continues in a different
1333
0
        // direction
1334
0
        break;
1335
0
      }
1336
0
    }
1337
0
    // needed if we did not find one in loop
1338
0
    if (node->assigned && bits == node->node.first.getBits())
1339
0
      ret = &node->node;
1340
0
1341
0
    // this can be nullptr.
1342
0
    return ret;
1343
0
  }
1344
1345
  unique_ptr<TreeNode> d_root; //<! Root of our tree
1346
  TreeNode *d_left;
1347
  size_type d_size;
1348
};
1349
1350
/** This class represents a group of supplemental Netmask classes. An IP address matches
1351
    if it is matched by one or more of the Netmask objects within.
1352
*/
1353
class NetmaskGroup
1354
{
1355
public:
1356
0
  NetmaskGroup() noexcept {
1357
0
  }
1358
1359
  //! If this IP address is matched by any of the classes within
1360
1361
  bool match(const ComboAddress *ip) const
1362
0
  {
1363
0
    const auto &ret = tree.lookup(*ip);
1364
0
    if(ret) return ret->second;
1365
0
    return false;
1366
0
  }
1367
1368
  bool match(const ComboAddress& ip) const
1369
0
  {
1370
0
    return match(&ip);
1371
0
  }
1372
1373
  bool lookup(const ComboAddress* ip, Netmask* nmp) const
1374
0
  {
1375
0
    const auto &ret = tree.lookup(*ip);
1376
0
    if (ret) {
1377
0
      if (nmp != nullptr)
1378
0
        *nmp = ret->first;
1379
0
1380
0
      return ret->second;
1381
0
    }
1382
0
    return false;
1383
0
  }
1384
1385
  bool lookup(const ComboAddress& ip, Netmask* nmp) const
1386
0
  {
1387
0
    return lookup(&ip, nmp);
1388
0
  }
1389
1390
  //! Add this string to the list of possible matches
1391
  void addMask(const string &ip, bool positive=true)
1392
0
  {
1393
0
    if(!ip.empty() && ip[0] == '!') {
1394
0
      addMask(Netmask(ip.substr(1)), false);
1395
0
    } else {
1396
0
      addMask(Netmask(ip), positive);
1397
0
    }
1398
0
  }
1399
1400
  //! Add this Netmask to the list of possible matches
1401
  void addMask(const Netmask& nm, bool positive=true)
1402
0
  {
1403
0
    tree.insert(nm).second=positive;
1404
0
  }
1405
1406
  void addMasks(const NetmaskGroup& group, boost::optional<bool> positive)
1407
0
  {
1408
0
    for (const auto& entry : group.tree) {
1409
0
      addMask(entry.first, positive ? *positive : entry.second);
1410
0
    }
1411
0
  }
1412
1413
  //! Delete this Netmask from the list of possible matches
1414
  void deleteMask(const Netmask& nm)
1415
0
  {
1416
0
    tree.erase(nm);
1417
0
  }
1418
1419
  void deleteMask(const std::string& ip)
1420
0
  {
1421
0
    if (!ip.empty())
1422
0
      deleteMask(Netmask(ip));
1423
0
  }
1424
1425
  void clear()
1426
0
  {
1427
0
    tree.clear();
1428
0
  }
1429
1430
  bool empty() const
1431
0
  {
1432
0
    return tree.empty();
1433
0
  }
1434
1435
  size_t size() const
1436
0
  {
1437
0
    return tree.size();
1438
0
  }
1439
1440
  string toString() const
1441
0
  {
1442
0
    ostringstream str;
1443
0
    for(auto iter = tree.begin(); iter != tree.end(); ++iter) {
1444
0
      if(iter != tree.begin())
1445
0
        str <<", ";
1446
0
      if(!(iter->second))
1447
0
        str<<"!";
1448
0
      str<<iter->first.toString();
1449
0
    }
1450
0
    return str.str();
1451
0
  }
1452
1453
  void toStringVector(vector<string>* vec) const
1454
0
  {
1455
0
    for(auto iter = tree.begin(); iter != tree.end(); ++iter) {
1456
0
      vec->push_back((iter->second ? "" : "!") + iter->first.toString());
1457
0
    }
1458
0
  }
1459
1460
  void toMasks(const string &ips)
1461
0
  {
1462
0
    vector<string> parts;
1463
0
    stringtok(parts, ips, ", \t");
1464
0
1465
0
    for (vector<string>::const_iterator iter = parts.begin(); iter != parts.end(); ++iter)
1466
0
      addMask(*iter);
1467
0
  }
1468
1469
private:
1470
  NetmaskTree<bool> tree;
1471
};
1472
1473
struct SComboAddress
1474
{
1475
0
  SComboAddress(const ComboAddress& orig) : ca(orig) {}
1476
  ComboAddress ca;
1477
  bool operator<(const SComboAddress& rhs) const
1478
0
  {
1479
0
    return ComboAddress::addressOnlyLessThan()(ca, rhs.ca);
1480
0
  }
1481
  operator const ComboAddress&()
1482
0
  {
1483
0
    return ca;
1484
0
  }
1485
};
1486
1487
class NetworkError : public runtime_error
1488
{
1489
public:
1490
  NetworkError(const string& why="Network Error") : runtime_error(why.c_str())
1491
0
  {}
1492
  NetworkError(const char *why="Network Error") : runtime_error(why)
1493
0
  {}
1494
};
1495
1496
class AddressAndPortRange
1497
{
1498
public:
1499
  AddressAndPortRange(): d_addrMask(0), d_portMask(0)
1500
0
  {
1501
0
    d_addr.sin4.sin_family = 0; // disable this doing anything useful
1502
0
    d_addr.sin4.sin_port = 0; // this guarantees d_network compares identical
1503
0
  }
1504
1505
  AddressAndPortRange(ComboAddress ca, uint8_t addrMask, uint8_t portMask = 0): d_addr(std::move(ca)), d_addrMask(addrMask), d_portMask(portMask)
1506
0
  {
1507
0
    if (!d_addr.isIPv4()) {
1508
0
      d_portMask = 0;
1509
0
    }
1510
0
1511
0
    uint16_t port = d_addr.getPort();
1512
0
    if (d_portMask < 16) {
1513
0
      uint16_t mask = ~(0xFFFF >> d_portMask);
1514
0
      port = port & mask;
1515
0
    }
1516
0
1517
0
    if (d_addrMask < d_addr.getBits()) {
1518
0
      if (d_portMask > 0) {
1519
0
        throw std::runtime_error("Trying to create a AddressAndPortRange with a reduced address mask (" + std::to_string(d_addrMask) + ") and a port range (" + std::to_string(d_portMask) + ")");
1520
0
      }
1521
0
      d_addr = Netmask(d_addr, d_addrMask).getMaskedNetwork();
1522
0
    }
1523
0
    d_addr.setPort(port);
1524
0
  }
1525
1526
  uint8_t getFullBits() const
1527
0
  {
1528
0
    return d_addr.getBits() + 16;
1529
0
  }
1530
1531
  uint8_t getBits() const
1532
0
  {
1533
0
    if (d_addrMask < d_addr.getBits()) {
1534
0
      return d_addrMask;
1535
0
    }
1536
0
1537
0
    return d_addr.getBits() + d_portMask;
1538
0
  }
1539
1540
  /** Get the value of the bit at the provided bit index. When the index >= 0,
1541
      the index is relative to the LSB starting at index zero. When the index < 0,
1542
      the index is relative to the MSB starting at index -1 and counting down.
1543
  */
1544
  bool getBit(int index) const
1545
0
  {
1546
0
    if (index >= getFullBits()) {
1547
0
      return false;
1548
0
    }
1549
0
    if (index < 0) {
1550
0
      index = getFullBits() + index;
1551
0
    }
1552
0
1553
0
    if (index < 16) {
1554
0
      /* we are into the port bits */
1555
0
      uint16_t port = d_addr.getPort();
1556
0
      return ((port & (1U<<index)) != 0x0000);
1557
0
    }
1558
0
1559
0
    index -= 16;
1560
0
1561
0
    return d_addr.getBit(index);
1562
0
  }
1563
1564
  bool isIPv4() const
1565
0
  {
1566
0
    return d_addr.isIPv4();
1567
0
  }
1568
1569
  bool isIPv6() const
1570
0
  {
1571
0
    return d_addr.isIPv6();
1572
0
  }
1573
1574
  AddressAndPortRange getNormalized() const
1575
0
  {
1576
0
    return AddressAndPortRange(d_addr, d_addrMask, d_portMask);
1577
0
  }
1578
1579
  AddressAndPortRange getSuper(uint8_t bits) const
1580
0
  {
1581
0
    if (bits <= d_addrMask) {
1582
0
      return AddressAndPortRange(d_addr, bits, 0);
1583
0
    }
1584
0
    if (bits <= d_addrMask + d_portMask) {
1585
0
      return AddressAndPortRange(d_addr, d_addrMask, d_portMask - (bits - d_addrMask));
1586
0
    }
1587
0
1588
0
    return AddressAndPortRange(d_addr, d_addrMask, d_portMask);
1589
0
  }
1590
1591
  const ComboAddress& getNetwork() const
1592
0
  {
1593
0
    return d_addr;
1594
0
  }
1595
1596
  string toString() const
1597
0
  {
1598
0
    if (d_addrMask < d_addr.getBits() || d_portMask == 0) {
1599
0
      return d_addr.toStringNoInterface() + "/" + std::to_string(d_addrMask);
1600
0
    }
1601
0
    return d_addr.toStringNoInterface() + ":" + std::to_string(d_addr.getPort()) + "/" + std::to_string(d_portMask);
1602
0
  }
1603
1604
  bool empty() const
1605
0
  {
1606
0
    return d_addr.sin4.sin_family == 0;
1607
0
  }
1608
1609
  bool operator==(const AddressAndPortRange& rhs) const
1610
0
  {
1611
0
    return std::tie(d_addr, d_addrMask, d_portMask) == std::tie(rhs.d_addr, rhs.d_addrMask, rhs.d_portMask);
1612
0
  }
1613
1614
  bool operator<(const AddressAndPortRange& rhs) const
1615
0
  {
1616
0
    if (empty() && !rhs.empty()) {
1617
0
      return false;
1618
0
    }
1619
0
1620
0
    if (!empty() && rhs.empty()) {
1621
0
      return true;
1622
0
    }
1623
0
1624
0
    if (d_addrMask > rhs.d_addrMask) {
1625
0
      return true;
1626
0
    }
1627
0
1628
0
    if (d_addrMask < rhs.d_addrMask) {
1629
0
      return false;
1630
0
    }
1631
0
1632
0
    if (d_addr < rhs.d_addr) {
1633
0
      return true;
1634
0
    }
1635
0
1636
0
    if (d_addr > rhs.d_addr) {
1637
0
      return false;
1638
0
    }
1639
0
1640
0
    if (d_portMask > rhs.d_portMask) {
1641
0
      return true;
1642
0
    }
1643
0
1644
0
    if (d_portMask < rhs.d_portMask) {
1645
0
      return false;
1646
0
    }
1647
0
1648
0
    return d_addr.getPort() < rhs.d_addr.getPort();
1649
0
  }
1650
1651
  bool operator>(const AddressAndPortRange& rhs) const
1652
0
  {
1653
0
    return rhs.operator<(*this);
1654
0
  }
1655
1656
  struct hash
1657
  {
1658
    uint32_t operator()(const AddressAndPortRange& apr) const
1659
0
    {
1660
0
      ComboAddress::addressOnlyHash hashOp;
1661
0
      uint16_t port = apr.d_addr.getPort();
1662
0
      /* it's fine to hash the whole address and port because the non-relevant parts have
1663
0
         been masked to 0 */
1664
0
      return burtle(reinterpret_cast<const unsigned char*>(&port), sizeof(port), hashOp(apr.d_addr));
1665
0
    }
1666
  };
1667
1668
private:
1669
  ComboAddress d_addr;
1670
  uint8_t d_addrMask;
1671
  /* only used for v4 addresses */
1672
  uint8_t d_portMask;
1673
};
1674
1675
int SSocket(int family, int type, int flags);
1676
int SConnect(int sockfd, const ComboAddress& remote);
1677
/* tries to connect to remote for a maximum of timeout seconds.
1678
   sockfd should be set to non-blocking beforehand.
1679
   returns 0 on success (the socket is writable), throw a
1680
   runtime_error otherwise */
1681
int SConnectWithTimeout(int sockfd, const ComboAddress& remote, const struct timeval& timeout);
1682
int SBind(int sockfd, const ComboAddress& local);
1683
int SAccept(int sockfd, ComboAddress& remote);
1684
int SListen(int sockfd, int limit);
1685
int SSetsockopt(int sockfd, int level, int opname, int value);
1686
void setSocketIgnorePMTU(int sockfd, int family);
1687
bool setReusePort(int sockfd);
1688
1689
#if defined(IP_PKTINFO)
1690
  #define GEN_IP_PKTINFO IP_PKTINFO
1691
#elif defined(IP_RECVDSTADDR)
1692
  #define GEN_IP_PKTINFO IP_RECVDSTADDR
1693
#endif
1694
1695
bool IsAnyAddress(const ComboAddress& addr);
1696
bool HarvestDestinationAddress(const struct msghdr* msgh, ComboAddress* destination);
1697
bool HarvestTimestamp(struct msghdr* msgh, struct timeval* tv);
1698
void fillMSGHdr(struct msghdr* msgh, struct iovec* iov, cmsgbuf_aligned* cbuf, size_t cbufsize, char* data, size_t datalen, ComboAddress* addr);
1699
int sendOnNBSocket(int fd, const struct msghdr *msgh);
1700
size_t sendMsgWithOptions(int fd, const char* buffer, size_t len, const ComboAddress* dest, const ComboAddress* local, unsigned int localItf, int flags);
1701
1702
/* requires a non-blocking, connected TCP socket */
1703
bool isTCPSocketUsable(int sock);
1704
1705
extern template class NetmaskTree<bool>;
1706
ComboAddress parseIPAndPort(const std::string& input, uint16_t port);
1707
1708
std::set<std::string> getListOfNetworkInterfaces();
1709
std::vector<ComboAddress> getListOfAddressesOfNetworkInterface(const std::string& itf);
1710
std::vector<Netmask> getListOfRangesOfNetworkInterface(const std::string& itf);
1711
1712
/* These functions throw if the value was already set to a higher value,
1713
   or on error */
1714
void setSocketBuffer(int fd, int optname, uint32_t size);
1715
void setSocketReceiveBuffer(int fd, uint32_t size);
1716
void setSocketSendBuffer(int fd, uint32_t size);