Coverage Report

Created: 2026-03-08 06:22

next uncovered line (L), next uncovered region (R), next uncovered branch (B)
/src/pdns/pdns/dnsdistdist/dnsparser.cc
Line
Count
Source
1
/*
2
 * This file is part of PowerDNS or dnsdist.
3
 * Copyright -- PowerDNS.COM B.V. and its contributors
4
 *
5
 * This program is free software; you can redistribute it and/or modify
6
 * it under the terms of version 2 of the GNU General Public License as
7
 * published by the Free Software Foundation.
8
 *
9
 * In addition, for the avoidance of any doubt, permission is granted to
10
 * link this program with OpenSSL and to (re)distribute the binaries
11
 * produced as the result of such linking.
12
 *
13
 * This program is distributed in the hope that it will be useful,
14
 * but WITHOUT ANY WARRANTY; without even the implied warranty of
15
 * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the
16
 * GNU General Public License for more details.
17
 *
18
 * You should have received a copy of the GNU General Public License
19
 * along with this program; if not, write to the Free Software
20
 * Foundation, Inc., 51 Franklin Street, Fifth Floor, Boston, MA 02110-1301 USA.
21
 */
22
#include "dnsparser.hh"
23
#include "dnswriter.hh"
24
#include <boost/algorithm/string.hpp>
25
#include <boost/format.hpp>
26
#include <cstdint>
27
#include <stdexcept>
28
#include <string>
29
30
#include "dns_random.hh"
31
#include "namespaces.hh"
32
#include "noinitvector.hh"
33
34
std::atomic<bool> DNSRecordContent::d_locked{false};
35
36
UnknownRecordContent::UnknownRecordContent(const string& zone)
37
0
{
38
  // parse the input
39
0
  vector<string> parts;
40
0
  stringtok(parts, zone);
41
  // we need exactly 3 parts, except if the length field is set to 0 then we only need 2
42
0
  if (parts.size() != 3 && !(parts.size() == 2 && boost::equals(parts.at(1), "0"))) {
43
0
    throw MOADNSException("Unknown record was stored incorrectly, need 3 fields, got " + std::to_string(parts.size()) + ": " + zone);
44
0
  }
45
46
0
  if (parts.at(0) != "\\#") {
47
0
    throw MOADNSException("Unknown record was stored incorrectly, first part should be '\\#', got '" + parts.at(0) + "'");
48
0
  }
49
50
0
  const string& relevant = (parts.size() > 2) ? parts.at(2) : "";
51
0
  auto total = pdns::checked_stoi<unsigned int>(parts.at(1));
52
0
  if (relevant.size() % 2 || (relevant.size() / 2) != total) {
53
0
    throw MOADNSException((boost::format("invalid unknown record length: size not equal to length field (%d != 2 * %d)") % relevant.size() % total).str());
54
0
  }
55
56
0
  string out;
57
0
  out.reserve(total + 1);
58
59
0
  for (unsigned int n = 0; n < total; ++n) {
60
0
    int c;
61
0
    if (sscanf(&relevant.at(2*n), "%02x", &c) != 1) {
62
0
      throw MOADNSException("unable to read data at position " + std::to_string(2 * n) + " from unknown record of size " + std::to_string(relevant.size()));
63
0
    }
64
0
    out.append(1, (char)c);
65
0
  }
66
67
0
  d_record.insert(d_record.end(), out.begin(), out.end());
68
0
}
69
70
string UnknownRecordContent::getZoneRepresentation(bool /* noDot */) const
71
0
{
72
0
  ostringstream str;
73
0
  str<<"\\# "<<(unsigned int)d_record.size()<<" ";
74
0
  char hex[4];
75
0
  for (unsigned char n : d_record) {
76
0
    snprintf(hex, sizeof(hex), "%02x", n);
77
0
    str << hex;
78
0
  }
79
0
  return str.str();
80
0
}
81
82
void UnknownRecordContent::toPacket(DNSPacketWriter& pw) const
83
0
{
84
0
  pw.xfrBlob(string(d_record.begin(),d_record.end()));
85
0
}
86
87
shared_ptr<DNSRecordContent> DNSRecordContent::deserialize(const DNSName& qname, uint16_t qtype, const string& serialized, uint16_t qclass, bool internalRepresentation)
88
0
{
89
0
  dnsheader dnsheader;
90
0
  memset(&dnsheader, 0, sizeof(dnsheader));
91
0
  dnsheader.qdcount=htons(1);
92
0
  dnsheader.ancount=htons(1);
93
94
0
  PacketBuffer packet; // build pseudo packet
95
  /* will look like: dnsheader, 5 bytes, encoded qname, dns record header, serialized data */
96
0
  const auto& encoded = qname.getStorage();
97
0
  packet.resize(sizeof(dnsheader) + 5 + encoded.size() + sizeof(struct dnsrecordheader) + serialized.size());
98
99
0
  uint16_t pos=0;
100
0
  memcpy(&packet[0], &dnsheader, sizeof(dnsheader)); pos+=sizeof(dnsheader);
101
102
0
  constexpr std::array<uint8_t, 5> tmp= {'\x0', '\x0', '\x1', '\x0', '\x1' }; // root question for ns_t_a
103
0
  memcpy(&packet[pos], tmp.data(), tmp.size()); pos += tmp.size();
104
105
0
  memcpy(&packet[pos], encoded.c_str(), encoded.size()); pos+=(uint16_t)encoded.size();
106
107
0
  struct dnsrecordheader drh;
108
0
  drh.d_type=htons(qtype);
109
0
  drh.d_class=htons(qclass);
110
0
  drh.d_ttl=0;
111
0
  drh.d_clen=htons(serialized.size());
112
113
0
  memcpy(&packet[pos], &drh, sizeof(drh)); pos+=sizeof(drh);
114
0
  if (!serialized.empty()) {
115
0
    memcpy(&packet[pos], serialized.c_str(), serialized.size());
116
0
    pos += (uint16_t) serialized.size();
117
0
    (void) pos;
118
0
  }
119
120
0
  DNSRecord dr;
121
0
  dr.d_class = qclass;
122
0
  dr.d_type = qtype;
123
0
  dr.d_name = qname;
124
0
  dr.d_clen = serialized.size();
125
  // NOLINTNEXTLINE(cppcoreguidelines-pro-type-reinterpret-cast): packet.data() is uint8_t *
126
0
  PacketReader reader(std::string_view(reinterpret_cast<const char*>(packet.data()), packet.size()), packet.size() - serialized.size() - sizeof(dnsrecordheader), internalRepresentation);
127
  /* needed to get the record boundaries right */
128
0
  reader.getDnsrecordheader(drh);
129
0
  auto content = DNSRecordContent::make(dr, reader, Opcode::Query);
130
0
  return content;
131
0
}
132
133
std::shared_ptr<DNSRecordContent> DNSRecordContent::make(const DNSRecord& dr,
134
                                                         PacketReader& pr)
135
0
{
136
0
  uint16_t searchclass = (dr.d_type == QType::OPT) ? 1 : dr.d_class; // class is invalid for OPT
137
138
0
  auto i = getTypemap().find(pair(searchclass, dr.d_type));
139
0
  if(i==getTypemap().end() || !i->second) {
140
0
    return std::make_shared<UnknownRecordContent>(dr, pr);
141
0
  }
142
143
0
  return i->second(dr, pr);
144
0
}
145
146
std::shared_ptr<DNSRecordContent> DNSRecordContent::make(uint16_t qtype, uint16_t qclass,
147
                                                         const string& content)
148
0
{
149
0
  auto i = getZmakermap().find(pair(qclass, qtype));
150
0
  if(i==getZmakermap().end()) {
151
0
    return std::make_shared<UnknownRecordContent>(content);
152
0
  }
153
154
0
  return i->second(content);
155
0
}
156
157
std::shared_ptr<DNSRecordContent> DNSRecordContent::make(const DNSRecord& dr, PacketReader& pr, uint16_t oc)
158
133k
{
159
  // For opcode UPDATE and where the DNSRecord is an answer record, we don't care about content, because this is
160
  // not used within the prerequisite section of RFC2136, so - we can simply use unknownrecordcontent.
161
  // For section 3.2.3, we do need content so we need to get it properly. But only for the correct QClasses.
162
133k
  if (oc == Opcode::Update && dr.d_place == DNSResourceRecord::ANSWER && dr.d_class != 1)
163
2.42k
    return std::make_shared<UnknownRecordContent>(dr, pr);
164
165
130k
  uint16_t searchclass = (dr.d_type == QType::OPT) ? 1 : dr.d_class; // class is invalid for OPT
166
167
130k
  auto i = getTypemap().find(pair(searchclass, dr.d_type));
168
130k
  if(i==getTypemap().end() || !i->second) {
169
32.5k
    return std::make_shared<UnknownRecordContent>(dr, pr);
170
32.5k
  }
171
172
98.1k
  return i->second(dr, pr);
173
130k
}
174
175
0
string DNSRecordContent::upgradeContent(const DNSName& qname, const QType& qtype, const string& content) {
176
  // seamless upgrade for previously unsupported but now implemented types.
177
0
  UnknownRecordContent unknown_content(content);
178
0
  shared_ptr<DNSRecordContent> rc = DNSRecordContent::deserialize(qname, qtype.getCode(), unknown_content.serialize(qname));
179
0
  return rc->getZoneRepresentation();
180
0
}
181
182
DNSRecordContent::typemap_t& DNSRecordContent::getTypemap()
183
261k
{
184
261k
  static DNSRecordContent::typemap_t typemap;
185
261k
  return typemap;
186
261k
}
187
188
DNSRecordContent::n2typemap_t& DNSRecordContent::getN2Typemap()
189
353k
{
190
353k
  static DNSRecordContent::n2typemap_t n2typemap;
191
353k
  return n2typemap;
192
353k
}
193
194
DNSRecordContent::t2namemap_t& DNSRecordContent::getT2Namemap()
195
232
{
196
232
  static DNSRecordContent::t2namemap_t t2namemap;
197
232
  return t2namemap;
198
232
}
199
200
DNSRecordContent::zmakermap_t& DNSRecordContent::getZmakermap()
201
226
{
202
226
  static DNSRecordContent::zmakermap_t zmakermap;
203
226
  return zmakermap;
204
226
}
205
206
bool DNSRecordContent::isRegisteredType(uint16_t rtype, uint16_t rclass)
207
0
{
208
0
  return getTypemap().count(pair(rclass, rtype)) != 0;
209
0
}
210
211
0
DNSRecord::DNSRecord(const DNSResourceRecord& rr): d_name(rr.qname)
212
0
{
213
0
  d_type = rr.qtype.getCode();
214
0
  d_ttl = rr.ttl;
215
0
  d_class = rr.qclass;
216
0
  d_place = DNSResourceRecord::ANSWER;
217
0
  d_clen = 0;
218
0
  d_content = DNSRecordContent::make(d_type, rr.qclass, rr.content);
219
0
}
220
221
// If you call this and you are not parsing a packet coming from a socket, you are doing it wrong.
222
DNSResourceRecord DNSResourceRecord::fromWire(const DNSRecord& wire)
223
0
{
224
0
  DNSResourceRecord resourceRecord;
225
0
  resourceRecord.qname = wire.d_name;
226
0
  resourceRecord.qtype = QType(wire.d_type);
227
0
  resourceRecord.ttl = wire.d_ttl;
228
0
  resourceRecord.content = wire.getContent()->getZoneRepresentation(true);
229
0
  resourceRecord.auth = false;
230
0
  resourceRecord.qclass = wire.d_class;
231
0
  return resourceRecord;
232
0
}
233
234
void MOADNSParser::init(bool query, const std::string_view& packet)
235
10.7k
{
236
10.7k
  if (packet.size() < sizeof(dnsheader))
237
16
    throw MOADNSException("Packet shorter than minimal header");
238
239
10.7k
  memcpy(&d_header, packet.data(), sizeof(dnsheader));
240
241
10.7k
  if(d_header.opcode != Opcode::Query && d_header.opcode != Opcode::Notify && d_header.opcode != Opcode::Update)
242
12
    throw MOADNSException("Can't parse non-query packet with opcode="+ std::to_string(d_header.opcode));
243
244
10.7k
  d_header.qdcount=ntohs(d_header.qdcount);
245
10.7k
  d_header.ancount=ntohs(d_header.ancount);
246
10.7k
  d_header.nscount=ntohs(d_header.nscount);
247
10.7k
  d_header.arcount=ntohs(d_header.arcount);
248
249
10.7k
  if (query && (d_header.qdcount > 1))
250
317
    throw MOADNSException("Query with QD > 1 ("+std::to_string(d_header.qdcount)+")");
251
252
10.4k
  unsigned int n=0;
253
254
10.4k
  PacketReader pr(packet);
255
10.4k
  bool validPacket=false;
256
10.4k
  try {
257
10.4k
    d_qtype = d_qclass = 0; // sometimes replies come in with no question, don't present garbage then
258
259
43.8k
    for(n=0;n < d_header.qdcount; ++n) {
260
33.4k
      d_qname=pr.getName();
261
33.4k
      d_qtype=pr.get16BitInt();
262
33.4k
      d_qclass=pr.get16BitInt();
263
33.4k
    }
264
265
10.4k
    struct dnsrecordheader ah;
266
10.4k
    vector<unsigned char> record;
267
10.4k
    bool seenTSIG = false;
268
10.4k
    validPacket=true;
269
10.4k
    d_answers.reserve((unsigned int)(d_header.ancount + d_header.nscount + d_header.arcount));
270
186k
    for(n=0;n < (unsigned int)(d_header.ancount + d_header.nscount + d_header.arcount); ++n) {
271
176k
      DNSRecord dr;
272
273
176k
      if(n < d_header.ancount)
274
92.9k
        dr.d_place=DNSResourceRecord::ANSWER;
275
83.4k
      else if(n < d_header.ancount + d_header.nscount)
276
42.8k
        dr.d_place=DNSResourceRecord::AUTHORITY;
277
40.6k
      else
278
40.6k
        dr.d_place=DNSResourceRecord::ADDITIONAL;
279
280
176k
      unsigned int recordStartPos=pr.getPosition();
281
282
176k
      DNSName name=pr.getName();
283
284
176k
      pr.getDnsrecordheader(ah);
285
176k
      dr.d_ttl=ah.d_ttl;
286
176k
      dr.d_type=ah.d_type;
287
176k
      dr.d_class=ah.d_class;
288
289
176k
      dr.d_name = std::move(name);
290
176k
      dr.d_clen = ah.d_clen;
291
292
176k
      if (query &&
293
42.6k
          !(d_qtype == QType::IXFR && dr.d_place == DNSResourceRecord::AUTHORITY && dr.d_type == QType::SOA) && // IXFR queries have a SOA in their AUTHORITY section
294
41.4k
          (dr.d_place == DNSResourceRecord::ANSWER || dr.d_place == DNSResourceRecord::AUTHORITY || (dr.d_type != QType::OPT && dr.d_type != QType::TSIG && dr.d_type != QType::SIG && dr.d_type != QType::TKEY) || ((dr.d_type == QType::TSIG || dr.d_type == QType::SIG || dr.d_type == QType::TKEY) && dr.d_class != QClass::ANY))) {
295
//        cerr<<"discarding RR, query is "<<query<<", place is "<<dr.d_place<<", type is "<<dr.d_type<<", class is "<<dr.d_class<<endl;
296
39.5k
        dr.setContent(std::make_shared<UnknownRecordContent>(dr, pr));
297
39.5k
      }
298
136k
      else {
299
//        cerr<<"parsing RR, query is "<<query<<", place is "<<dr.d_place<<", type is "<<dr.d_type<<", class is "<<dr.d_class<<endl;
300
136k
        dr.setContent(DNSRecordContent::make(dr, pr, d_header.opcode));
301
136k
      }
302
303
176k
      if (dr.d_place == DNSResourceRecord::ADDITIONAL && seenTSIG) {
304
161
        throw MOADNSException("Packet ("+d_qname.toString()+"|#"+std::to_string(d_qtype)+") has an unexpected record ("+std::to_string(dr.d_type)+") after a TSIG one.");
305
161
      }
306
307
176k
      if(dr.d_type == QType::TSIG && dr.d_class == QClass::ANY) {
308
444
        if(seenTSIG || dr.d_place != DNSResourceRecord::ADDITIONAL) {
309
208
          throw MOADNSException("Packet ("+d_qname.toLogString()+"|#"+std::to_string(d_qtype)+") has a TSIG record in an invalid position.");
310
208
        }
311
236
        seenTSIG = true;
312
236
        d_tsigPos = recordStartPos;
313
236
      }
314
315
176k
      d_answers.emplace_back(std::move(dr));
316
176k
    }
317
318
#if 0
319
    if(pr.getPosition()!=packet.size()) {
320
      throw MOADNSException("Packet ("+d_qname+"|#"+std::to_string(d_qtype)+") has trailing garbage ("+ std::to_string(pr.getPosition()) + " < " +
321
                            std::to_string(packet.size()) + ")");
322
    }
323
#endif
324
10.4k
  }
325
10.4k
  catch(const std::out_of_range &re) {
326
9.99k
    if(validPacket && d_header.tc) { // don't sweat it over truncated packets, but do adjust an, ns and arcount
327
2.54k
      if(n < d_header.ancount) {
328
1.53k
        d_header.ancount=n; d_header.nscount = d_header.arcount = 0;
329
1.53k
      }
330
1.00k
      else if(n < d_header.ancount + d_header.nscount) {
331
557
        d_header.nscount = n - d_header.ancount; d_header.arcount=0;
332
557
      }
333
450
      else {
334
450
        d_header.arcount = n - d_header.ancount - d_header.nscount;
335
450
      }
336
2.54k
    }
337
7.45k
    else {
338
7.45k
      throw MOADNSException("Error parsing packet of "+std::to_string(packet.size())+" bytes (rd="+
339
7.45k
                            std::to_string(d_header.rd)+
340
7.45k
                            "), out of bounds: "+string(re.what()));
341
7.45k
    }
342
9.99k
  }
343
10.4k
}
344
345
bool MOADNSParser::hasEDNS() const
346
0
{
347
0
  if (d_header.arcount == 0 || d_answers.empty()) {
348
0
    return false;
349
0
  }
350
351
0
  for (const auto& record : d_answers) {
352
0
    if (record.d_place == DNSResourceRecord::ADDITIONAL && record.d_type == QType::OPT) {
353
0
      return true;
354
0
    }
355
0
  }
356
357
0
  return false;
358
0
}
359
360
void PacketReader::getDnsrecordheader(struct dnsrecordheader &ah)
361
173k
{
362
173k
  unsigned char *p = reinterpret_cast<unsigned char*>(&ah);
363
364
1.90M
  for(unsigned int n = 0; n < sizeof(dnsrecordheader); ++n) {
365
1.72M
    p[n] = d_content.at(d_pos++);
366
1.72M
  }
367
368
173k
  ah.d_type = ntohs(ah.d_type);
369
173k
  ah.d_class = ntohs(ah.d_class);
370
173k
  ah.d_clen = ntohs(ah.d_clen);
371
173k
  ah.d_ttl = ntohl(ah.d_ttl);
372
373
173k
  d_startrecordpos = d_pos; // needed for getBlob later on
374
173k
  d_recordlen = ah.d_clen;
375
173k
}
376
377
378
void PacketReader::copyRecord(vector<unsigned char>& dest, uint16_t len)
379
74.6k
{
380
74.6k
  if (len == 0) {
381
46.8k
    return;
382
46.8k
  }
383
27.7k
  if ((d_pos + len) > d_content.size()) {
384
3.07k
    throw std::out_of_range("Attempt to copy outside of packet");
385
3.07k
  }
386
387
24.6k
  dest.resize(len);
388
389
4.00M
  for (uint16_t n = 0; n < len; ++n) {
390
3.97M
    dest.at(n) = d_content.at(d_pos++);
391
3.97M
  }
392
24.6k
}
393
394
void PacketReader::copyRecord(unsigned char* dest, uint16_t len)
395
1.60k
{
396
1.60k
  if (d_pos + len > d_content.size()) {
397
12
    throw std::out_of_range("Attempt to copy outside of packet");
398
12
  }
399
400
1.59k
  memcpy(dest, &d_content.at(d_pos), len);
401
1.59k
  d_pos += len;
402
1.59k
}
403
404
void PacketReader::xfrNodeOrLocatorID(NodeOrLocatorID& ret)
405
3.07k
{
406
3.07k
  if (d_pos + sizeof(ret) > d_content.size()) {
407
24
    throw std::out_of_range("Attempt to read 64 bit value outside of packet");
408
24
  }
409
3.04k
  memcpy(&ret.content, &d_content.at(d_pos), sizeof(ret.content));
410
3.04k
  d_pos += sizeof(ret);
411
3.04k
}
412
413
void PacketReader::xfr48BitInt(uint64_t& ret)
414
494
{
415
494
  ret=0;
416
494
  ret+=static_cast<uint8_t>(d_content.at(d_pos++));
417
494
  ret<<=8;
418
494
  ret+=static_cast<uint8_t>(d_content.at(d_pos++));
419
494
  ret<<=8;
420
494
  ret+=static_cast<uint8_t>(d_content.at(d_pos++));
421
494
  ret<<=8;
422
494
  ret+=static_cast<uint8_t>(d_content.at(d_pos++));
423
494
  ret<<=8;
424
494
  ret+=static_cast<uint8_t>(d_content.at(d_pos++));
425
494
  ret<<=8;
426
494
  ret+=static_cast<uint8_t>(d_content.at(d_pos++));
427
494
}
428
429
uint32_t PacketReader::get32BitInt()
430
30.7k
{
431
30.7k
  uint32_t ret=0;
432
30.7k
  ret+=static_cast<uint8_t>(d_content.at(d_pos++));
433
30.7k
  ret<<=8;
434
30.7k
  ret+=static_cast<uint8_t>(d_content.at(d_pos++));
435
30.7k
  ret<<=8;
436
30.7k
  ret+=static_cast<uint8_t>(d_content.at(d_pos++));
437
30.7k
  ret<<=8;
438
30.7k
  ret+=static_cast<uint8_t>(d_content.at(d_pos++));
439
440
30.7k
  return ret;
441
30.7k
}
442
443
444
uint16_t PacketReader::get16BitInt()
445
419k
{
446
419k
  uint16_t ret=0;
447
419k
  ret+=static_cast<uint8_t>(d_content.at(d_pos++));
448
419k
  ret<<=8;
449
419k
  ret+=static_cast<uint8_t>(d_content.at(d_pos++));
450
451
419k
  return ret;
452
419k
}
453
454
uint8_t PacketReader::get8BitInt()
455
200k
{
456
200k
  return d_content.at(d_pos++);
457
200k
}
458
459
DNSName PacketReader::getName()
460
259k
{
461
259k
  unsigned int consumed;
462
259k
  try {
463
259k
    DNSName dn((const char*) d_content.data(), d_content.size(), d_pos, true /* uncompress */, nullptr /* qtype */, nullptr /* qclass */, &consumed, sizeof(dnsheader));
464
465
259k
    d_pos+=consumed;
466
259k
    return dn;
467
259k
  }
468
259k
  catch(const std::range_error& re) {
469
3.35k
    throw std::out_of_range(string("dnsname issue: ")+re.what());
470
3.35k
  }
471
259k
  catch(...) {
472
0
    throw std::out_of_range("dnsname issue");
473
0
  }
474
0
  throw PDNSException("PacketReader::getName(): name is empty");
475
259k
}
476
477
// FIXME see #6010 and #3503 if you want a proper solution
478
string txtEscape(const string &name)
479
15.8k
{
480
15.8k
  string ret;
481
15.8k
  std::array<char, 5> ebuf{};
482
483
702k
  for (char letter : name) {
484
702k
    const unsigned uch = static_cast<unsigned char>(letter);
485
702k
    if (uch >= 127 || uch < 32) {
486
626k
      snprintf(ebuf.data(), ebuf.size(), "\\%03u", uch);
487
626k
      ret += ebuf.data();
488
626k
    }
489
75.9k
    else if (letter == '"' || letter == '\\'){
490
2.02k
      ret += '\\';
491
2.02k
      ret += letter;
492
2.02k
    }
493
73.9k
    else {
494
73.9k
      ret += letter;
495
73.9k
    }
496
702k
  }
497
15.8k
  return ret;
498
15.8k
}
499
500
// exceptions thrown here do not result in logging in the main pdns auth server - just so you know!
501
string PacketReader::getText(bool multi, bool lenField)
502
14.6k
{
503
14.6k
  string ret;
504
14.6k
  ret.reserve(40);
505
49.2k
  while(d_pos < d_startrecordpos + d_recordlen ) {
506
37.7k
    if(!ret.empty()) {
507
31.4k
      ret.append(1,' ');
508
31.4k
    }
509
37.7k
    uint16_t labellen;
510
37.7k
    if(lenField)
511
36.5k
      labellen=static_cast<uint8_t>(d_content.at(d_pos++));
512
1.11k
    else
513
1.11k
      labellen=d_recordlen - (d_pos - d_startrecordpos);
514
515
37.7k
    ret.append(1,'"');
516
37.7k
    if(labellen) { // no need to do anything for an empty string
517
16.0k
      string val(&d_content.at(d_pos), &d_content.at(d_pos+labellen-1)+1);
518
16.0k
      ret.append(txtEscape(val)); // the end is one beyond the packet
519
16.0k
    }
520
37.7k
    ret.append(1,'"');
521
37.7k
    d_pos+=labellen;
522
37.7k
    if(!multi)
523
3.10k
      break;
524
37.7k
  }
525
526
14.6k
  if (ret.empty() && !lenField) {
527
    // all lenField == false cases (CAA and URI at the time of this writing) want that emptiness to be explicit
528
4.16k
    return "\"\"";
529
4.16k
  }
530
10.4k
  return ret;
531
14.6k
}
532
533
string PacketReader::getUnquotedText(bool lenField)
534
3.15k
{
535
3.15k
  uint16_t stop_at;
536
3.15k
  if(lenField)
537
3.15k
    stop_at = static_cast<uint8_t>(d_content.at(d_pos)) + d_pos + 1;
538
0
  else
539
0
    stop_at = d_recordlen;
540
541
  /* think unsigned overflow */
542
3.15k
  if (stop_at < d_pos) {
543
1
    throw std::out_of_range("getUnquotedText out of record range");
544
1
  }
545
546
3.15k
  if(stop_at == d_pos)
547
0
    return "";
548
549
3.15k
  d_pos++;
550
3.15k
  string ret(d_content.substr(d_pos, stop_at-d_pos));
551
3.15k
  d_pos = stop_at;
552
3.15k
  return ret;
553
3.15k
}
554
555
void PacketReader::xfrBlob(string& blob)
556
38.2k
{
557
38.2k
  try {
558
38.2k
    if(d_recordlen && !(d_pos == (d_startrecordpos + d_recordlen))) {
559
9.58k
      if (d_pos > (d_startrecordpos + d_recordlen)) {
560
7
        throw std::out_of_range("xfrBlob out of record range");
561
7
      }
562
9.57k
      blob.assign(&d_content.at(d_pos), &d_content.at(d_startrecordpos + d_recordlen - 1 ) + 1);
563
9.57k
    }
564
28.6k
    else {
565
28.6k
      blob.clear();
566
28.6k
    }
567
568
38.2k
    d_pos = d_startrecordpos + d_recordlen;
569
38.2k
  }
570
38.2k
  catch(...)
571
38.2k
  {
572
177
    throw std::out_of_range("xfrBlob out of range");
573
177
  }
574
38.2k
}
575
576
2.71k
void PacketReader::xfrBlobNoSpaces(string& blob, int length) {
577
2.71k
  xfrBlob(blob, length);
578
2.71k
}
579
580
void PacketReader::xfrBlob(string& blob, int length)
581
147k
{
582
147k
  if(length) {
583
144k
    if (length < 0) {
584
0
      throw std::out_of_range("xfrBlob out of range (negative length)");
585
0
    }
586
587
144k
    blob.assign(&d_content.at(d_pos), &d_content.at(d_pos + length - 1 ) + 1 );
588
589
144k
    d_pos += length;
590
144k
  }
591
3.48k
  else {
592
3.48k
    blob.clear();
593
3.48k
  }
594
147k
}
595
596
7.91k
void PacketReader::xfrSvcParamKeyVals(set<SvcParam> &kvs) {
597
7.91k
  int32_t lastKey{-1}; // Keep track of the last key, as ordering should be strict
598
599
15.9k
  while (d_pos < (d_startrecordpos + d_recordlen)) {
600
8.87k
    if (d_pos + 2 > (d_startrecordpos + d_recordlen)) {
601
5
      throw std::out_of_range("incomplete key");
602
5
    }
603
8.86k
    uint16_t keyInt;
604
8.86k
    xfr16BitInt(keyInt);
605
606
8.86k
    if (keyInt <= lastKey) {
607
52
      throw std::out_of_range("Found SVCParamKey " + std::to_string(keyInt) + " after SVCParamKey " + std::to_string(lastKey));
608
52
    }
609
8.81k
    lastKey = keyInt;
610
611
8.81k
    auto key = static_cast<SvcParam::SvcParamKey>(keyInt);
612
613
8.81k
    uint16_t len;
614
8.81k
    xfr16BitInt(len);
615
616
8.81k
    if (d_pos + len > (d_startrecordpos + d_recordlen)) {
617
31
      throw std::out_of_range("record is shorter than SVCB lengthfield implies");
618
31
    }
619
620
8.78k
    switch (key)
621
8.78k
    {
622
462
    case SvcParam::mandatory: {
623
462
      if (len % 2 != 0) {
624
1
        throw std::out_of_range("mandatory SvcParam has invalid length");
625
1
      }
626
461
      if (len == 0) {
627
2
        throw std::out_of_range("empty 'mandatory' values");
628
2
      }
629
459
      std::set<SvcParam::SvcParamKey> paramKeys;
630
459
      size_t stop = d_pos + len;
631
180k
      while (d_pos < stop) {
632
179k
        uint16_t keyval;
633
179k
        xfr16BitInt(keyval);
634
179k
        paramKeys.insert(static_cast<SvcParam::SvcParamKey>(keyval));
635
179k
      }
636
459
      kvs.insert(SvcParam(key, std::move(paramKeys)));
637
459
      break;
638
461
    }
639
1.23k
    case SvcParam::alpn: {
640
1.23k
      size_t stop = d_pos + len;
641
1.23k
      std::vector<string> alpns;
642
111k
      while (d_pos < stop) {
643
110k
        string alpn;
644
110k
        uint8_t alpnLen = 0;
645
110k
        xfr8BitInt(alpnLen);
646
110k
        if (alpnLen == 0) {
647
9
          throw std::out_of_range("alpn length of 0");
648
9
        }
649
110k
        if (d_pos + alpnLen > stop) {
650
15
          throw std::out_of_range("alpn length is larger than rest of alpn SVC Param");
651
15
        }
652
110k
        xfrBlob(alpn, alpnLen);
653
110k
        alpns.push_back(std::move(alpn));
654
110k
      }
655
1.21k
      kvs.insert(SvcParam(key, std::move(alpns)));
656
1.21k
      break;
657
1.23k
    }
658
199
    case SvcParam::ohttp:
659
415
    case SvcParam::no_default_alpn: {
660
415
      if (len != 0) {
661
2
        throw std::out_of_range("invalid length for " + SvcParam::keyToString(key));
662
2
      }
663
413
      kvs.insert(SvcParam(key));
664
413
      break;
665
415
    }
666
1.36k
    case SvcParam::port: {
667
1.36k
      if (len != 2) {
668
17
        throw std::out_of_range("invalid length for port");
669
17
      }
670
1.34k
      uint16_t port;
671
1.34k
      xfr16BitInt(port);
672
1.34k
      kvs.insert(SvcParam(key, port));
673
1.34k
      break;
674
1.36k
    }
675
467
    case SvcParam::ipv4hint:
676
1.17k
    case SvcParam::ipv6hint: {
677
1.17k
      size_t addrLen = (key == SvcParam::ipv4hint ? 4 : 16);
678
1.17k
      if (len % addrLen != 0) {
679
5
        throw std::out_of_range("invalid length for " + SvcParam::keyToString(key));
680
5
      }
681
1.17k
      vector<ComboAddress> addresses;
682
1.17k
      auto stop = d_pos + len;
683
26.1k
      while (d_pos < stop)
684
24.9k
      {
685
24.9k
        ComboAddress addr;
686
24.9k
        xfrCAWithoutPort(key, addr);
687
24.9k
        addresses.push_back(addr);
688
24.9k
      }
689
      // If there were no addresses, and the input comes from internal
690
      // representation, we can reasonably assume this is the serialization
691
      // of "auto".
692
1.17k
      bool doAuto{d_internal && len == 0};
693
1.17k
      auto param = SvcParam(key, std::move(addresses));
694
1.17k
      param.setAutoHint(doAuto);
695
1.17k
      kvs.insert(std::move(param));
696
1.17k
      break;
697
1.17k
    }
698
592
    case SvcParam::ech: {
699
592
      std::string blob;
700
592
      blob.reserve(len);
701
592
      xfrBlobNoSpaces(blob, len);
702
592
      kvs.insert(SvcParam(key, blob));
703
592
      break;
704
1.17k
    }
705
743
    case SvcParam::tls_supported_groups: {
706
743
      if (len % 2 != 0) {
707
2
        throw std::out_of_range("invalid length for " + SvcParam::keyToString(key));
708
2
      }
709
741
      vector<uint16_t> groups;
710
741
      groups.reserve(len / 2);
711
741
      auto stop = d_pos + len;
712
88.2k
      while (d_pos < stop)
713
87.5k
      {
714
87.5k
        uint16_t group = 0;
715
87.5k
        xfr16BitInt(group);
716
87.5k
        groups.push_back(group);
717
87.5k
      }
718
741
      auto param = SvcParam(key, std::move(groups));
719
741
      kvs.insert(std::move(param));
720
741
      break;
721
743
    }
722
2.54k
    default: {
723
2.54k
      std::string blob;
724
2.54k
      blob.reserve(len);
725
2.54k
      xfrBlob(blob, len);
726
2.54k
      kvs.insert(SvcParam(key, blob));
727
2.54k
      break;
728
743
    }
729
8.78k
    }
730
8.78k
  }
731
7.91k
}
732
733
734
void PacketReader::xfrHexBlob(string& blob, bool /* keepReading */)
735
12.9k
{
736
12.9k
  xfrBlob(blob);
737
12.9k
}
738
739
//FIXME400 remove this method completely
740
string simpleCompress(const string& elabel, const string& root)
741
0
{
742
0
  string label=elabel;
743
  // FIXME400: this relies on the semi-canonical escaped output from getName
744
0
  if(strchr(label.c_str(), '\\')) {
745
0
    boost::replace_all(label, "\\.", ".");
746
0
    boost::replace_all(label, "\\032", " ");
747
0
    boost::replace_all(label, "\\\\", "\\");
748
0
  }
749
0
  typedef vector<pair<unsigned int, unsigned int> > parts_t;
750
0
  parts_t parts;
751
0
  vstringtok(parts, label, ".");
752
0
  string ret;
753
0
  ret.reserve(label.size()+4);
754
0
  for(const auto & part : parts) {
755
0
    if(!root.empty() && !strncasecmp(root.c_str(), label.c_str() + part.first, 1 + label.length() - part.first)) { // also match trailing 0, hence '1 +'
756
0
      const unsigned char rootptr[2]={0xc0,0x11};
757
0
      ret.append((const char *) rootptr, 2);
758
0
      return ret;
759
0
    }
760
0
    ret.append(1, (char)(part.second - part.first));
761
0
    ret.append(label.c_str() + part.first, part.second - part.first);
762
0
  }
763
0
  ret.append(1, (char)0);
764
0
  return ret;
765
0
}
766
767
// method of operation: silently fail if it doesn't work - we're only trying to be nice, don't fall over on it
768
void editDNSPacketTTL(char* packet, size_t length, const std::function<uint32_t(uint8_t, uint16_t, uint16_t, uint32_t)>& visitor)
769
0
{
770
0
  if(length < sizeof(dnsheader))
771
0
    return;
772
0
  try
773
0
  {
774
0
    dnsheader dh;
775
0
    memcpy((void*)&dh, (const dnsheader*)packet, sizeof(dh));
776
0
    uint64_t numrecords = ntohs(dh.ancount) + ntohs(dh.nscount) + ntohs(dh.arcount);
777
0
    DNSPacketMangler dpm(packet, length);
778
779
0
    uint64_t n;
780
0
    for(n=0; n < ntohs(dh.qdcount) ; ++n) {
781
0
      dpm.skipDomainName();
782
      /* type and class */
783
0
      dpm.skipBytes(4);
784
0
    }
785
786
0
    for(n=0; n < numrecords; ++n) {
787
0
      dpm.skipDomainName();
788
789
0
      uint8_t section = n < ntohs(dh.ancount) ? 1 : (n < (ntohs(dh.ancount) + ntohs(dh.nscount)) ? 2 : 3);
790
0
      uint16_t dnstype = dpm.get16BitInt();
791
0
      uint16_t dnsclass = dpm.get16BitInt();
792
793
0
      if(dnstype == QType::OPT) // not getting near that one with a stick
794
0
        break;
795
796
0
      uint32_t dnsttl = dpm.get32BitInt();
797
0
      uint32_t newttl = visitor(section, dnsclass, dnstype, dnsttl);
798
0
      if (newttl) {
799
0
        dpm.rewindBytes(sizeof(newttl));
800
0
        dpm.setAndSkip32BitInt(newttl);
801
0
      }
802
0
      dpm.skipRData();
803
0
    }
804
0
  }
805
0
  catch(...)
806
0
  {
807
0
    return;
808
0
  }
809
0
}
810
811
static bool checkIfPacketContainsRecords(const PacketBuffer& packet, const std::unordered_set<QType>& qtypes)
812
0
{
813
0
  auto length = packet.size();
814
0
  if (length < sizeof(dnsheader)) {
815
0
    return false;
816
0
  }
817
818
0
  try {
819
0
    const dnsheader_aligned dh(packet.data());
820
0
    DNSPacketMangler dpm(const_cast<char*>(reinterpret_cast<const char*>(packet.data())), length);
821
822
0
    const uint16_t qdcount = ntohs(dh->qdcount);
823
0
    for (size_t n = 0; n < qdcount; ++n) {
824
0
      dpm.skipDomainName();
825
      /* type and class */
826
0
      dpm.skipBytes(4);
827
0
    }
828
0
    const size_t recordsCount = static_cast<size_t>(ntohs(dh->ancount)) + ntohs(dh->nscount) + ntohs(dh->arcount);
829
0
    for (size_t n = 0; n < recordsCount; ++n) {
830
0
      dpm.skipDomainName();
831
0
      uint16_t dnstype = dpm.get16BitInt();
832
0
      uint16_t dnsclass = dpm.get16BitInt();
833
0
      if (dnsclass == QClass::IN && qtypes.count(dnstype) > 0) {
834
0
        return true;
835
0
      }
836
      /* ttl */
837
0
      dpm.skipBytes(4);
838
0
      dpm.skipRData();
839
0
    }
840
0
  }
841
0
  catch (...) {
842
0
  }
843
844
0
  return false;
845
0
}
846
847
static int rewritePacketWithoutRecordTypes(const PacketBuffer& initialPacket, PacketBuffer& newContent, const std::unordered_set<QType>& qtypes)
848
0
{
849
0
  static const std::unordered_set<QType>& safeTypes{QType::A, QType::AAAA, QType::DHCID, QType::TXT, QType::OPT, QType::HINFO, QType::DNSKEY, QType::CDNSKEY, QType::DS, QType::CDS, QType::DLV, QType::SSHFP, QType::KEY, QType::CERT, QType::TLSA, QType::SMIMEA, QType::OPENPGPKEY, QType::SVCB, QType::HTTPS, QType::NSEC3, QType::CSYNC, QType::NSEC3PARAM, QType::LOC, QType::NID, QType::L32, QType::L64, QType::EUI48, QType::EUI64, QType::URI, QType::CAA};
850
851
0
  if (initialPacket.size() < sizeof(dnsheader)) {
852
0
    return EINVAL;
853
0
  }
854
0
  try {
855
0
    const dnsheader_aligned dh(initialPacket.data());
856
857
0
    if (ntohs(dh->qdcount) == 0)
858
0
      return ENOENT;
859
0
    auto packetView = std::string_view(reinterpret_cast<const char*>(initialPacket.data()), initialPacket.size());
860
861
0
    PacketReader pr(packetView);
862
863
0
    size_t idx = 0;
864
0
    DNSName rrname;
865
0
    uint16_t qdcount = ntohs(dh->qdcount);
866
0
    uint16_t ancount = ntohs(dh->ancount);
867
0
    uint16_t nscount = ntohs(dh->nscount);
868
0
    uint16_t arcount = ntohs(dh->arcount);
869
0
    uint16_t rrtype;
870
0
    uint16_t rrclass;
871
0
    string blob;
872
0
    struct dnsrecordheader ah;
873
874
0
    rrname = pr.getName();
875
0
    rrtype = pr.get16BitInt();
876
0
    rrclass = pr.get16BitInt();
877
878
0
    GenericDNSPacketWriter<PacketBuffer> pw(newContent, rrname, rrtype, rrclass, dh->opcode);
879
0
    pw.getHeader()->id=dh->id;
880
0
    pw.getHeader()->qr=dh->qr;
881
0
    pw.getHeader()->aa=dh->aa;
882
0
    pw.getHeader()->tc=dh->tc;
883
0
    pw.getHeader()->rd=dh->rd;
884
0
    pw.getHeader()->ra=dh->ra;
885
0
    pw.getHeader()->ad=dh->ad;
886
0
    pw.getHeader()->cd=dh->cd;
887
0
    pw.getHeader()->rcode=dh->rcode;
888
889
    /* consume remaining qd if any */
890
0
    if (qdcount > 1) {
891
0
      for(idx = 1; idx < qdcount; idx++) {
892
0
        rrname = pr.getName();
893
0
        rrtype = pr.get16BitInt();
894
0
        rrclass = pr.get16BitInt();
895
0
        (void) rrtype;
896
0
        (void) rrclass;
897
0
      }
898
0
    }
899
900
    /* copy AN */
901
0
    for (idx = 0; idx < ancount; idx++) {
902
0
      rrname = pr.getName();
903
0
      pr.getDnsrecordheader(ah);
904
0
      pr.xfrBlob(blob);
905
906
0
      if (qtypes.find(ah.d_type) == qtypes.end()) {
907
        // if this is not a safe type
908
0
        if (safeTypes.find(ah.d_type) == safeTypes.end()) {
909
          // "unsafe" types might contain compressed data, so cancel rewrite
910
0
          newContent.clear();
911
0
          return EIO;
912
0
        }
913
0
        pw.startRecord(rrname, ah.d_type, ah.d_ttl, ah.d_class, DNSResourceRecord::ANSWER, true);
914
0
        pw.xfrBlob(blob);
915
0
      }
916
0
    }
917
918
    /* copy NS */
919
0
    for (idx = 0; idx < nscount; idx++) {
920
0
      rrname = pr.getName();
921
0
      pr.getDnsrecordheader(ah);
922
0
      pr.xfrBlob(blob);
923
924
0
      if (qtypes.find(ah.d_type) == qtypes.end()) {
925
0
        if (safeTypes.find(ah.d_type) == safeTypes.end()) {
926
          // "unsafe" types might contain compressed data, so cancel rewrite
927
0
          newContent.clear();
928
0
          return EIO;
929
0
        }
930
0
        pw.startRecord(rrname, ah.d_type, ah.d_ttl, ah.d_class, DNSResourceRecord::AUTHORITY, true);
931
0
        pw.xfrBlob(blob);
932
0
      }
933
0
    }
934
    /* copy AR */
935
0
    for (idx = 0; idx < arcount; idx++) {
936
0
      rrname = pr.getName();
937
0
      pr.getDnsrecordheader(ah);
938
0
      pr.xfrBlob(blob);
939
940
0
      if (qtypes.find(ah.d_type) == qtypes.end()) {
941
0
        if (safeTypes.find(ah.d_type) == safeTypes.end()) {
942
          // "unsafe" types might contain compressed data, so cancel rewrite
943
0
          newContent.clear();
944
0
          return EIO;
945
0
        }
946
0
        pw.startRecord(rrname, ah.d_type, ah.d_ttl, ah.d_class, DNSResourceRecord::ADDITIONAL, true);
947
0
        pw.xfrBlob(blob);
948
0
      }
949
0
    }
950
0
    pw.commit();
951
952
0
  }
953
0
  catch (...)
954
0
  {
955
0
    newContent.clear();
956
0
    return EIO;
957
0
  }
958
0
  return 0;
959
0
}
960
961
void clearDNSPacketRecordTypes(vector<uint8_t>& packet, const std::unordered_set<QType>& qtypes)
962
0
{
963
0
  return clearDNSPacketRecordTypes(reinterpret_cast<PacketBuffer&>(packet), qtypes);
964
0
}
965
966
void clearDNSPacketRecordTypes(PacketBuffer& packet, const std::unordered_set<QType>& qtypes)
967
0
{
968
0
  if (!checkIfPacketContainsRecords(packet, qtypes)) {
969
0
    return;
970
0
  }
971
972
0
  PacketBuffer newContent;
973
974
0
  auto result = rewritePacketWithoutRecordTypes(packet, newContent, qtypes);
975
0
  if (!result) {
976
0
    packet = std::move(newContent);
977
0
  }
978
0
}
979
980
// method of operation: silently fail if it doesn't work - we're only trying to be nice, don't fall over on it
981
void ageDNSPacket(char* packet, size_t length, uint32_t seconds, const dnsheader_aligned& aligned_dh)
982
0
{
983
0
  if (length < sizeof(dnsheader)) {
984
0
    return;
985
0
  }
986
0
  try {
987
0
    const dnsheader* dhp = aligned_dh.get();
988
0
    const uint64_t dqcount = ntohs(dhp->qdcount);
989
0
    const uint64_t numrecords = ntohs(dhp->ancount) + ntohs(dhp->nscount) + ntohs(dhp->arcount);
990
0
    DNSPacketMangler dpm(packet, length);
991
992
0
    for (uint64_t rec = 0; rec < dqcount; ++rec) {
993
0
      dpm.skipDomainName();
994
      /* type and class */
995
0
      dpm.skipBytes(4);
996
0
    }
997
998
0
    for(uint64_t rec = 0; rec < numrecords; ++rec) {
999
0
      dpm.skipDomainName();
1000
1001
0
      uint16_t dnstype = dpm.get16BitInt();
1002
      /* class */
1003
0
      dpm.skipBytes(2);
1004
1005
0
      if (dnstype != QType::OPT) { // not aging that one with a stick
1006
0
        dpm.decreaseAndSkip32BitInt(seconds);
1007
0
      } else {
1008
0
        dpm.skipBytes(4);
1009
0
      }
1010
0
      dpm.skipRData();
1011
0
    }
1012
0
  }
1013
0
  catch(...) {
1014
0
  }
1015
0
}
1016
1017
void ageDNSPacket(std::string& packet, uint32_t seconds, const dnsheader_aligned& aligned_dh)
1018
0
{
1019
0
  ageDNSPacket(packet.data(), packet.length(), seconds, aligned_dh);
1020
0
}
1021
1022
void shuffleDNSPacket(char* packet, size_t length, const dnsheader_aligned& aligned_dh)
1023
0
{
1024
0
  if (length < sizeof(dnsheader)) {
1025
0
    return;
1026
0
  }
1027
0
  try {
1028
0
    const dnsheader* dhp = aligned_dh.get();
1029
0
    const uint16_t ancount = ntohs(dhp->ancount);
1030
0
    if (ancount == 1) {
1031
      // quick exit, nothing to shuffle
1032
0
      return;
1033
0
    }
1034
1035
0
    DNSPacketMangler dpm(packet, length);
1036
1037
0
    const uint16_t qdcount = ntohs(dhp->qdcount);
1038
1039
0
    for(size_t iter = 0; iter < qdcount; ++iter) {
1040
0
      dpm.skipDomainName();
1041
      /* type and class */
1042
0
      dpm.skipBytes(4);
1043
0
    }
1044
1045
    // for now shuffle only first rrset, only As and AAAAs
1046
0
    uint16_t rrset_type = 0;
1047
0
    DNSName rrset_dnsname{};
1048
0
    std::vector<std::pair<uint32_t, uint32_t>> rrdata_indexes;
1049
0
    rrdata_indexes.reserve(ancount);
1050
1051
0
    for(size_t iter = 0; iter < ancount; ++iter) {
1052
0
      auto domain_start = dpm.getOffset();
1053
0
      dpm.skipDomainName();
1054
0
      const uint16_t dnstype = dpm.get16BitInt();
1055
0
      if (dnstype == QType::A || dnstype == QType::AAAA) {
1056
0
        if (rrdata_indexes.empty()) {
1057
0
          rrset_type = dnstype;
1058
0
          rrset_dnsname = DNSName(packet, length, domain_start, true);
1059
0
        } else {
1060
0
          if (dnstype != rrset_type) {
1061
0
            break;
1062
0
          }
1063
0
          if (DNSName(packet, length, domain_start, true) != rrset_dnsname) {
1064
0
            break;
1065
0
          }
1066
0
        }
1067
        /* class */
1068
0
        dpm.skipBytes(2);
1069
1070
        /* ttl */
1071
0
        dpm.skipBytes(4);
1072
0
        rrdata_indexes.push_back(dpm.skipRDataAndReturnOffsets());
1073
0
      } else {
1074
0
        if (!rrdata_indexes.empty()) {
1075
0
          break;
1076
0
        }
1077
        /* class */
1078
0
        dpm.skipBytes(2);
1079
1080
        /* ttl */
1081
0
        dpm.skipBytes(4);
1082
0
        dpm.skipRData();
1083
0
      }
1084
0
    }
1085
1086
0
    if (rrdata_indexes.size() >= 2) {
1087
0
      using uid = std::uniform_int_distribution<std::vector<std::pair<uint32_t, uint32_t>>::size_type>;
1088
0
      uid dist;
1089
1090
0
      pdns::dns_random_engine randomEngine;
1091
0
      for (auto swapped = rrdata_indexes.size() - 1; swapped > 0; --swapped) {
1092
0
        auto swapped_with = dist(randomEngine, uid::param_type(0, swapped));
1093
0
        if (swapped != swapped_with) {
1094
0
          dpm.swapInPlace(rrdata_indexes.at(swapped), rrdata_indexes.at(swapped_with));
1095
0
        }
1096
0
      }
1097
0
    }
1098
0
  }
1099
0
  catch(...) {
1100
0
  }
1101
0
}
1102
1103
uint32_t getDNSPacketMinTTL(const char* packet, size_t length, bool* seenAuthSOA)
1104
0
{
1105
0
  uint32_t result = std::numeric_limits<uint32_t>::max();
1106
0
  if(length < sizeof(dnsheader)) {
1107
0
    return result;
1108
0
  }
1109
0
  try
1110
0
  {
1111
0
    const dnsheader_aligned dh(packet);
1112
0
    DNSPacketMangler dpm(const_cast<char*>(packet), length);
1113
1114
0
    const uint16_t qdcount = ntohs(dh->qdcount);
1115
0
    for(size_t n = 0; n < qdcount; ++n) {
1116
0
      dpm.skipDomainName();
1117
      /* type and class */
1118
0
      dpm.skipBytes(4);
1119
0
    }
1120
0
    const size_t numrecords = ntohs(dh->ancount) + ntohs(dh->nscount) + ntohs(dh->arcount);
1121
0
    for(size_t n = 0; n < numrecords; ++n) {
1122
0
      dpm.skipDomainName();
1123
0
      const uint16_t dnstype = dpm.get16BitInt();
1124
      /* class */
1125
0
      const uint16_t dnsclass = dpm.get16BitInt();
1126
1127
0
      if(dnstype == QType::OPT) {
1128
0
        break;
1129
0
      }
1130
1131
      /* report it if we see a SOA record in the AUTHORITY section */
1132
0
      if(dnstype == QType::SOA && dnsclass == QClass::IN && seenAuthSOA != nullptr && n >= ntohs(dh->ancount) && n < (ntohs(dh->ancount) + ntohs(dh->nscount))) {
1133
0
        *seenAuthSOA = true;
1134
0
      }
1135
1136
0
      const uint32_t ttl = dpm.get32BitInt();
1137
0
      result = std::min(result, ttl);
1138
1139
0
      dpm.skipRData();
1140
0
    }
1141
0
  }
1142
0
  catch(...)
1143
0
  {
1144
0
  }
1145
0
  return result;
1146
0
}
1147
1148
uint32_t getDNSPacketLength(const char* packet, size_t length)
1149
0
{
1150
0
  uint32_t result = length;
1151
0
  if(length < sizeof(dnsheader)) {
1152
0
    return result;
1153
0
  }
1154
0
  try
1155
0
  {
1156
0
    const dnsheader_aligned dh(packet);
1157
0
    DNSPacketMangler dpm(const_cast<char*>(packet), length);
1158
1159
0
    const uint16_t qdcount = ntohs(dh->qdcount);
1160
0
    for(size_t n = 0; n < qdcount; ++n) {
1161
0
      dpm.skipDomainName();
1162
      /* type and class */
1163
0
      dpm.skipBytes(4);
1164
0
    }
1165
0
    const size_t numrecords = ntohs(dh->ancount) + ntohs(dh->nscount) + ntohs(dh->arcount);
1166
0
    for(size_t n = 0; n < numrecords; ++n) {
1167
0
      dpm.skipDomainName();
1168
      /* type (2), class (2) and ttl (4) */
1169
0
      dpm.skipBytes(8);
1170
0
      dpm.skipRData();
1171
0
    }
1172
0
    result = dpm.getOffset();
1173
0
  }
1174
0
  catch(...)
1175
0
  {
1176
0
  }
1177
0
  return result;
1178
0
}
1179
1180
uint16_t getRecordsOfTypeCount(const char* packet, size_t length, uint8_t section, uint16_t type)
1181
0
{
1182
0
  uint16_t result = 0;
1183
0
  if(length < sizeof(dnsheader)) {
1184
0
    return result;
1185
0
  }
1186
0
  try
1187
0
  {
1188
0
    const dnsheader_aligned dh(packet);
1189
0
    DNSPacketMangler dpm(const_cast<char*>(packet), length);
1190
1191
0
    const uint16_t qdcount = ntohs(dh->qdcount);
1192
0
    for(size_t n = 0; n < qdcount; ++n) {
1193
0
      dpm.skipDomainName();
1194
0
      if (section == 0) {
1195
0
        uint16_t dnstype = dpm.get16BitInt();
1196
0
        if (dnstype == type) {
1197
0
          result++;
1198
0
        }
1199
        /* class */
1200
0
        dpm.skipBytes(2);
1201
0
      } else {
1202
        /* type and class */
1203
0
        dpm.skipBytes(4);
1204
0
      }
1205
0
    }
1206
0
    const uint16_t ancount = ntohs(dh->ancount);
1207
0
    for(size_t n = 0; n < ancount; ++n) {
1208
0
      dpm.skipDomainName();
1209
0
      if (section == 1) {
1210
0
        uint16_t dnstype = dpm.get16BitInt();
1211
0
        if (dnstype == type) {
1212
0
          result++;
1213
0
        }
1214
        /* class */
1215
0
        dpm.skipBytes(2);
1216
0
      } else {
1217
        /* type and class */
1218
0
        dpm.skipBytes(4);
1219
0
      }
1220
      /* ttl */
1221
0
      dpm.skipBytes(4);
1222
0
      dpm.skipRData();
1223
0
    }
1224
0
    const uint16_t nscount = ntohs(dh->nscount);
1225
0
    for(size_t n = 0; n < nscount; ++n) {
1226
0
      dpm.skipDomainName();
1227
0
      if (section == 2) {
1228
0
        uint16_t dnstype = dpm.get16BitInt();
1229
0
        if (dnstype == type) {
1230
0
          result++;
1231
0
        }
1232
        /* class */
1233
0
        dpm.skipBytes(2);
1234
0
      } else {
1235
        /* type and class */
1236
0
        dpm.skipBytes(4);
1237
0
      }
1238
      /* ttl */
1239
0
      dpm.skipBytes(4);
1240
0
      dpm.skipRData();
1241
0
    }
1242
0
    const uint16_t arcount = ntohs(dh->arcount);
1243
0
    for(size_t n = 0; n < arcount; ++n) {
1244
0
      dpm.skipDomainName();
1245
0
      if (section == 3) {
1246
0
        uint16_t dnstype = dpm.get16BitInt();
1247
0
        if (dnstype == type) {
1248
0
          result++;
1249
0
        }
1250
        /* class */
1251
0
        dpm.skipBytes(2);
1252
0
      } else {
1253
        /* type and class */
1254
0
        dpm.skipBytes(4);
1255
0
      }
1256
      /* ttl */
1257
0
      dpm.skipBytes(4);
1258
0
      dpm.skipRData();
1259
0
    }
1260
0
  }
1261
0
  catch(...)
1262
0
  {
1263
0
  }
1264
0
  return result;
1265
0
}
1266
1267
bool getEDNSUDPPayloadSizeAndZ(const char* packet, size_t length, uint16_t* payloadSize, uint16_t* z)
1268
0
{
1269
0
  if (length < sizeof(dnsheader)) {
1270
0
    return false;
1271
0
  }
1272
1273
0
  *payloadSize = 0;
1274
0
  *z = 0;
1275
1276
0
  try
1277
0
  {
1278
0
    const dnsheader_aligned dh(packet);
1279
0
    if (dh->arcount == 0) {
1280
      // The OPT pseudo-RR, if present, has to be in the additional section (https://datatracker.ietf.org/doc/html/rfc6891#section-6.1.1)
1281
0
      return false;
1282
0
    }
1283
1284
0
    DNSPacketMangler dpm(const_cast<char*>(packet), length);
1285
1286
0
    const uint16_t qdcount = ntohs(dh->qdcount);
1287
0
    for(size_t n = 0; n < qdcount; ++n) {
1288
0
      dpm.skipDomainName();
1289
      /* type and class */
1290
0
      dpm.skipBytes(4);
1291
0
    }
1292
0
    const size_t numrecords = ntohs(dh->ancount) + ntohs(dh->nscount) + ntohs(dh->arcount);
1293
0
    for(size_t n = 0; n < numrecords; ++n) {
1294
0
      dpm.skipDomainName();
1295
0
      const auto dnstype = dpm.get16BitInt();
1296
1297
0
      if (dnstype == QType::OPT) {
1298
0
        const auto dnsclass = dpm.get16BitInt();
1299
        /* skip extended rcode and version */
1300
0
        dpm.skipBytes(2);
1301
0
        *z = dpm.get16BitInt();
1302
0
        *payloadSize = dnsclass;
1303
0
        return true;
1304
0
      }
1305
      /* skip class */
1306
0
      dpm.skipBytes(2);
1307
      /* TTL */
1308
0
      dpm.skipBytes(4);
1309
0
      dpm.skipRData();
1310
0
    }
1311
0
  }
1312
0
  catch(...)
1313
0
  {
1314
0
  }
1315
1316
0
  return false;
1317
0
}
1318
1319
bool visitDNSPacket(const std::string_view& packet, const std::function<bool(uint8_t, uint16_t, uint16_t, uint32_t, uint16_t, const char*)>& visitor)
1320
0
{
1321
0
  if (packet.size() < sizeof(dnsheader)) {
1322
0
    return false;
1323
0
  }
1324
1325
0
  try
1326
0
  {
1327
0
    const dnsheader_aligned dh(packet.data());
1328
0
    uint64_t numrecords = ntohs(dh->ancount) + ntohs(dh->nscount) + ntohs(dh->arcount);
1329
0
    PacketReader reader(packet);
1330
1331
0
    uint64_t n;
1332
0
    for (n = 0; n < ntohs(dh->qdcount) ; ++n) {
1333
0
      (void) reader.getName();
1334
      /* type and class */
1335
0
      reader.skip(4);
1336
0
    }
1337
1338
0
    for (n = 0; n < numrecords; ++n) {
1339
0
      (void) reader.getName();
1340
1341
0
      uint8_t section = n < ntohs(dh->ancount) ? 1 : (n < (ntohs(dh->ancount) + ntohs(dh->nscount)) ? 2 : 3);
1342
0
      uint16_t dnstype = reader.get16BitInt();
1343
0
      uint16_t dnsclass = reader.get16BitInt();
1344
1345
0
      if (dnstype == QType::OPT) {
1346
        // not getting near that one with a stick
1347
0
        break;
1348
0
      }
1349
1350
0
      uint32_t dnsttl = reader.get32BitInt();
1351
0
      uint16_t contentLength = reader.get16BitInt();
1352
0
      uint16_t pos = reader.getPosition();
1353
1354
0
      bool done = visitor(section, dnsclass, dnstype, dnsttl, contentLength, &packet.at(pos));
1355
0
      if (done) {
1356
0
        return true;
1357
0
      }
1358
1359
0
      reader.skip(contentLength);
1360
0
    }
1361
0
  }
1362
0
  catch (...) {
1363
0
    return false;
1364
0
  }
1365
1366
0
  return true;
1367
0
}