Coverage Report

Created: 2025-06-13 06:28

/src/pdns/pdns/dnsparser.cc
Line
Count
Source (jump to first uncovered line)
1
/*
2
 * This file is part of PowerDNS or dnsdist.
3
 * Copyright -- PowerDNS.COM B.V. and its contributors
4
 *
5
 * This program is free software; you can redistribute it and/or modify
6
 * it under the terms of version 2 of the GNU General Public License as
7
 * published by the Free Software Foundation.
8
 *
9
 * In addition, for the avoidance of any doubt, permission is granted to
10
 * link this program with OpenSSL and to (re)distribute the binaries
11
 * produced as the result of such linking.
12
 *
13
 * This program is distributed in the hope that it will be useful,
14
 * but WITHOUT ANY WARRANTY; without even the implied warranty of
15
 * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the
16
 * GNU General Public License for more details.
17
 *
18
 * You should have received a copy of the GNU General Public License
19
 * along with this program; if not, write to the Free Software
20
 * Foundation, Inc., 51 Franklin Street, Fifth Floor, Boston, MA 02110-1301 USA.
21
 */
22
#include "dnsparser.hh"
23
#include "dnswriter.hh"
24
#include <boost/algorithm/string.hpp>
25
#include <boost/format.hpp>
26
27
#include "namespaces.hh"
28
#include "noinitvector.hh"
29
30
std::atomic<bool> DNSRecordContent::d_locked{false};
31
32
UnknownRecordContent::UnknownRecordContent(const string& zone)
33
0
{
34
  // parse the input
35
0
  vector<string> parts;
36
0
  stringtok(parts, zone);
37
  // we need exactly 3 parts, except if the length field is set to 0 then we only need 2
38
0
  if (parts.size() != 3 && !(parts.size() == 2 && boost::equals(parts.at(1), "0"))) {
39
0
    throw MOADNSException("Unknown record was stored incorrectly, need 3 fields, got " + std::to_string(parts.size()) + ": " + zone);
40
0
  }
41
42
0
  if (parts.at(0) != "\\#") {
43
0
    throw MOADNSException("Unknown record was stored incorrectly, first part should be '\\#', got '" + parts.at(0) + "'");
44
0
  }
45
46
0
  const string& relevant = (parts.size() > 2) ? parts.at(2) : "";
47
0
  auto total = pdns::checked_stoi<unsigned int>(parts.at(1));
48
0
  if (relevant.size() % 2 || (relevant.size() / 2) != total) {
49
0
    throw MOADNSException((boost::format("invalid unknown record length: size not equal to length field (%d != 2 * %d)") % relevant.size() % total).str());
50
0
  }
51
52
0
  string out;
53
0
  out.reserve(total + 1);
54
55
0
  for (unsigned int n = 0; n < total; ++n) {
56
0
    int c;
57
0
    if (sscanf(&relevant.at(2*n), "%02x", &c) != 1) {
58
0
      throw MOADNSException("unable to read data at position " + std::to_string(2 * n) + " from unknown record of size " + std::to_string(relevant.size()));
59
0
    }
60
0
    out.append(1, (char)c);
61
0
  }
62
63
0
  d_record.insert(d_record.end(), out.begin(), out.end());
64
0
}
65
66
string UnknownRecordContent::getZoneRepresentation(bool /* noDot */) const
67
0
{
68
0
  ostringstream str;
69
0
  str<<"\\# "<<(unsigned int)d_record.size()<<" ";
70
0
  char hex[4];
71
0
  for (unsigned char n : d_record) {
72
0
    snprintf(hex, sizeof(hex), "%02x", n);
73
0
    str << hex;
74
0
  }
75
0
  return str.str();
76
0
}
77
78
void UnknownRecordContent::toPacket(DNSPacketWriter& pw) const
79
0
{
80
0
  pw.xfrBlob(string(d_record.begin(),d_record.end()));
81
0
}
82
83
shared_ptr<DNSRecordContent> DNSRecordContent::deserialize(const DNSName& qname, uint16_t qtype, const string& serialized, uint16_t qclass, bool internalRepresentation)
84
0
{
85
0
  dnsheader dnsheader;
86
0
  memset(&dnsheader, 0, sizeof(dnsheader));
87
0
  dnsheader.qdcount=htons(1);
88
0
  dnsheader.ancount=htons(1);
89
90
0
  PacketBuffer packet; // build pseudo packet
91
  /* will look like: dnsheader, 5 bytes, encoded qname, dns record header, serialized data */
92
0
  const auto& encoded = qname.getStorage();
93
0
  packet.resize(sizeof(dnsheader) + 5 + encoded.size() + sizeof(struct dnsrecordheader) + serialized.size());
94
95
0
  uint16_t pos=0;
96
0
  memcpy(&packet[0], &dnsheader, sizeof(dnsheader)); pos+=sizeof(dnsheader);
97
98
0
  constexpr std::array<uint8_t, 5> tmp= {'\x0', '\x0', '\x1', '\x0', '\x1' }; // root question for ns_t_a
99
0
  memcpy(&packet[pos], tmp.data(), tmp.size()); pos += tmp.size();
100
101
0
  memcpy(&packet[pos], encoded.c_str(), encoded.size()); pos+=(uint16_t)encoded.size();
102
103
0
  struct dnsrecordheader drh;
104
0
  drh.d_type=htons(qtype);
105
0
  drh.d_class=htons(qclass);
106
0
  drh.d_ttl=0;
107
0
  drh.d_clen=htons(serialized.size());
108
109
0
  memcpy(&packet[pos], &drh, sizeof(drh)); pos+=sizeof(drh);
110
0
  if (!serialized.empty()) {
111
0
    memcpy(&packet[pos], serialized.c_str(), serialized.size());
112
0
    pos += (uint16_t) serialized.size();
113
0
    (void) pos;
114
0
  }
115
116
0
  DNSRecord dr;
117
0
  dr.d_class = qclass;
118
0
  dr.d_type = qtype;
119
0
  dr.d_name = qname;
120
0
  dr.d_clen = serialized.size();
121
  // NOLINTNEXTLINE(cppcoreguidelines-pro-type-reinterpret-cast): packet.data() is uint8_t *
122
0
  PacketReader reader(std::string_view(reinterpret_cast<const char*>(packet.data()), packet.size()), packet.size() - serialized.size() - sizeof(dnsrecordheader), internalRepresentation);
123
  /* needed to get the record boundaries right */
124
0
  reader.getDnsrecordheader(drh);
125
0
  auto content = DNSRecordContent::make(dr, reader, Opcode::Query);
126
0
  return content;
127
0
}
128
129
std::shared_ptr<DNSRecordContent> DNSRecordContent::make(const DNSRecord& dr,
130
                                                         PacketReader& pr)
131
0
{
132
0
  uint16_t searchclass = (dr.d_type == QType::OPT) ? 1 : dr.d_class; // class is invalid for OPT
133
134
0
  auto i = getTypemap().find(pair(searchclass, dr.d_type));
135
0
  if(i==getTypemap().end() || !i->second) {
136
0
    return std::make_shared<UnknownRecordContent>(dr, pr);
137
0
  }
138
139
0
  return i->second(dr, pr);
140
0
}
141
142
std::shared_ptr<DNSRecordContent> DNSRecordContent::make(uint16_t qtype, uint16_t qclass,
143
                                                         const string& content)
144
0
{
145
0
  auto i = getZmakermap().find(pair(qclass, qtype));
146
0
  if(i==getZmakermap().end()) {
147
0
    return std::make_shared<UnknownRecordContent>(content);
148
0
  }
149
150
0
  return i->second(content);
151
0
}
152
153
std::shared_ptr<DNSRecordContent> DNSRecordContent::make(const DNSRecord& dr, PacketReader& pr, uint16_t oc)
154
0
{
155
  // For opcode UPDATE and where the DNSRecord is an answer record, we don't care about content, because this is
156
  // not used within the prerequisite section of RFC2136, so - we can simply use unknownrecordcontent.
157
  // For section 3.2.3, we do need content so we need to get it properly. But only for the correct QClasses.
158
0
  if (oc == Opcode::Update && dr.d_place == DNSResourceRecord::ANSWER && dr.d_class != 1)
159
0
    return std::make_shared<UnknownRecordContent>(dr, pr);
160
161
0
  uint16_t searchclass = (dr.d_type == QType::OPT) ? 1 : dr.d_class; // class is invalid for OPT
162
163
0
  auto i = getTypemap().find(pair(searchclass, dr.d_type));
164
0
  if(i==getTypemap().end() || !i->second) {
165
0
    return std::make_shared<UnknownRecordContent>(dr, pr);
166
0
  }
167
168
0
  return i->second(dr, pr);
169
0
}
170
171
0
string DNSRecordContent::upgradeContent(const DNSName& qname, const QType& qtype, const string& content) {
172
  // seamless upgrade for previously unsupported but now implemented types.
173
0
  UnknownRecordContent unknown_content(content);
174
0
  shared_ptr<DNSRecordContent> rc = DNSRecordContent::deserialize(qname, qtype.getCode(), unknown_content.serialize(qname));
175
0
  return rc->getZoneRepresentation();
176
0
}
177
178
DNSRecordContent::typemap_t& DNSRecordContent::getTypemap()
179
107
{
180
107
  static DNSRecordContent::typemap_t typemap;
181
107
  return typemap;
182
107
}
183
184
DNSRecordContent::n2typemap_t& DNSRecordContent::getN2Typemap()
185
363k
{
186
363k
  static DNSRecordContent::n2typemap_t n2typemap;
187
363k
  return n2typemap;
188
363k
}
189
190
DNSRecordContent::t2namemap_t& DNSRecordContent::getT2Namemap()
191
110
{
192
110
  static DNSRecordContent::t2namemap_t t2namemap;
193
110
  return t2namemap;
194
110
}
195
196
DNSRecordContent::zmakermap_t& DNSRecordContent::getZmakermap()
197
107
{
198
107
  static DNSRecordContent::zmakermap_t zmakermap;
199
107
  return zmakermap;
200
107
}
201
202
bool DNSRecordContent::isRegisteredType(uint16_t rtype, uint16_t rclass)
203
0
{
204
0
  return getTypemap().count(pair(rclass, rtype)) != 0;
205
0
}
206
207
0
DNSRecord::DNSRecord(const DNSResourceRecord& rr): d_name(rr.qname)
208
0
{
209
0
  d_type = rr.qtype.getCode();
210
0
  d_ttl = rr.ttl;
211
0
  d_class = rr.qclass;
212
0
  d_place = DNSResourceRecord::ANSWER;
213
0
  d_clen = 0;
214
0
  d_content = DNSRecordContent::make(d_type, rr.qclass, rr.content);
215
0
}
216
217
// If you call this and you are not parsing a packet coming from a socket, you are doing it wrong.
218
DNSResourceRecord DNSResourceRecord::fromWire(const DNSRecord& wire)
219
0
{
220
0
  DNSResourceRecord resourceRecord;
221
0
  resourceRecord.qname = wire.d_name;
222
0
  resourceRecord.qtype = QType(wire.d_type);
223
0
  resourceRecord.ttl = wire.d_ttl;
224
0
  resourceRecord.content = wire.getContent()->getZoneRepresentation(true);
225
0
  resourceRecord.auth = false;
226
0
  resourceRecord.qclass = wire.d_class;
227
0
  return resourceRecord;
228
0
}
229
230
void MOADNSParser::init(bool query, const std::string_view& packet)
231
0
{
232
0
  if (packet.size() < sizeof(dnsheader))
233
0
    throw MOADNSException("Packet shorter than minimal header");
234
235
0
  memcpy(&d_header, packet.data(), sizeof(dnsheader));
236
237
0
  if(d_header.opcode != Opcode::Query && d_header.opcode != Opcode::Notify && d_header.opcode != Opcode::Update)
238
0
    throw MOADNSException("Can't parse non-query packet with opcode="+ std::to_string(d_header.opcode));
239
240
0
  d_header.qdcount=ntohs(d_header.qdcount);
241
0
  d_header.ancount=ntohs(d_header.ancount);
242
0
  d_header.nscount=ntohs(d_header.nscount);
243
0
  d_header.arcount=ntohs(d_header.arcount);
244
245
0
  if (query && (d_header.qdcount > 1))
246
0
    throw MOADNSException("Query with QD > 1 ("+std::to_string(d_header.qdcount)+")");
247
248
0
  unsigned int n=0;
249
250
0
  PacketReader pr(packet);
251
0
  bool validPacket=false;
252
0
  try {
253
0
    d_qtype = d_qclass = 0; // sometimes replies come in with no question, don't present garbage then
254
255
0
    for(n=0;n < d_header.qdcount; ++n) {
256
0
      d_qname=pr.getName();
257
0
      d_qtype=pr.get16BitInt();
258
0
      d_qclass=pr.get16BitInt();
259
0
    }
260
261
0
    struct dnsrecordheader ah;
262
0
    vector<unsigned char> record;
263
0
    bool seenTSIG = false;
264
0
    validPacket=true;
265
0
    d_answers.reserve((unsigned int)(d_header.ancount + d_header.nscount + d_header.arcount));
266
0
    for(n=0;n < (unsigned int)(d_header.ancount + d_header.nscount + d_header.arcount); ++n) {
267
0
      DNSRecord dr;
268
269
0
      if(n < d_header.ancount)
270
0
        dr.d_place=DNSResourceRecord::ANSWER;
271
0
      else if(n < d_header.ancount + d_header.nscount)
272
0
        dr.d_place=DNSResourceRecord::AUTHORITY;
273
0
      else
274
0
        dr.d_place=DNSResourceRecord::ADDITIONAL;
275
276
0
      unsigned int recordStartPos=pr.getPosition();
277
278
0
      DNSName name=pr.getName();
279
280
0
      pr.getDnsrecordheader(ah);
281
0
      dr.d_ttl=ah.d_ttl;
282
0
      dr.d_type=ah.d_type;
283
0
      dr.d_class=ah.d_class;
284
285
0
      dr.d_name = std::move(name);
286
0
      dr.d_clen = ah.d_clen;
287
288
0
      if (query &&
289
0
          !(d_qtype == QType::IXFR && dr.d_place == DNSResourceRecord::AUTHORITY && dr.d_type == QType::SOA) && // IXFR queries have a SOA in their AUTHORITY section
290
0
          (dr.d_place == DNSResourceRecord::ANSWER || dr.d_place == DNSResourceRecord::AUTHORITY || (dr.d_type != QType::OPT && dr.d_type != QType::TSIG && dr.d_type != QType::SIG && dr.d_type != QType::TKEY) || ((dr.d_type == QType::TSIG || dr.d_type == QType::SIG || dr.d_type == QType::TKEY) && dr.d_class != QClass::ANY))) {
291
//        cerr<<"discarding RR, query is "<<query<<", place is "<<dr.d_place<<", type is "<<dr.d_type<<", class is "<<dr.d_class<<endl;
292
0
        dr.setContent(std::make_shared<UnknownRecordContent>(dr, pr));
293
0
      }
294
0
      else {
295
//        cerr<<"parsing RR, query is "<<query<<", place is "<<dr.d_place<<", type is "<<dr.d_type<<", class is "<<dr.d_class<<endl;
296
0
        dr.setContent(DNSRecordContent::make(dr, pr, d_header.opcode));
297
0
      }
298
299
0
      if (dr.d_place == DNSResourceRecord::ADDITIONAL && seenTSIG) {
300
0
        throw MOADNSException("Packet ("+d_qname.toString()+"|#"+std::to_string(d_qtype)+") has an unexpected record ("+std::to_string(dr.d_type)+") after a TSIG one.");
301
0
      }
302
303
0
      if(dr.d_type == QType::TSIG && dr.d_class == QClass::ANY) {
304
0
        if(seenTSIG || dr.d_place != DNSResourceRecord::ADDITIONAL) {
305
0
          throw MOADNSException("Packet ("+d_qname.toLogString()+"|#"+std::to_string(d_qtype)+") has a TSIG record in an invalid position.");
306
0
        }
307
0
        seenTSIG = true;
308
0
        d_tsigPos = recordStartPos;
309
0
      }
310
311
0
      d_answers.emplace_back(std::move(dr));
312
0
    }
313
314
#if 0
315
    if(pr.getPosition()!=packet.size()) {
316
      throw MOADNSException("Packet ("+d_qname+"|#"+std::to_string(d_qtype)+") has trailing garbage ("+ std::to_string(pr.getPosition()) + " < " +
317
                            std::to_string(packet.size()) + ")");
318
    }
319
#endif
320
0
  }
321
0
  catch(const std::out_of_range &re) {
322
0
    if(validPacket && d_header.tc) { // don't sweat it over truncated packets, but do adjust an, ns and arcount
323
0
      if(n < d_header.ancount) {
324
0
        d_header.ancount=n; d_header.nscount = d_header.arcount = 0;
325
0
      }
326
0
      else if(n < d_header.ancount + d_header.nscount) {
327
0
        d_header.nscount = n - d_header.ancount; d_header.arcount=0;
328
0
      }
329
0
      else {
330
0
        d_header.arcount = n - d_header.ancount - d_header.nscount;
331
0
      }
332
0
    }
333
0
    else {
334
0
      throw MOADNSException("Error parsing packet of "+std::to_string(packet.size())+" bytes (rd="+
335
0
                            std::to_string(d_header.rd)+
336
0
                            "), out of bounds: "+string(re.what()));
337
0
    }
338
0
  }
339
0
}
340
341
bool MOADNSParser::hasEDNS() const
342
0
{
343
0
  if (d_header.arcount == 0 || d_answers.empty()) {
344
0
    return false;
345
0
  }
346
347
0
  for (const auto& record : d_answers) {
348
0
    if (record.d_place == DNSResourceRecord::ADDITIONAL && record.d_type == QType::OPT) {
349
0
      return true;
350
0
    }
351
0
  }
352
353
0
  return false;
354
0
}
355
356
void PacketReader::getDnsrecordheader(struct dnsrecordheader &ah)
357
0
{
358
0
  unsigned char *p = reinterpret_cast<unsigned char*>(&ah);
359
360
0
  for(unsigned int n = 0; n < sizeof(dnsrecordheader); ++n) {
361
0
    p[n] = d_content.at(d_pos++);
362
0
  }
363
364
0
  ah.d_type = ntohs(ah.d_type);
365
0
  ah.d_class = ntohs(ah.d_class);
366
0
  ah.d_clen = ntohs(ah.d_clen);
367
0
  ah.d_ttl = ntohl(ah.d_ttl);
368
369
0
  d_startrecordpos = d_pos; // needed for getBlob later on
370
0
  d_recordlen = ah.d_clen;
371
0
}
372
373
374
void PacketReader::copyRecord(vector<unsigned char>& dest, uint16_t len)
375
0
{
376
0
  if (len == 0) {
377
0
    return;
378
0
  }
379
0
  if ((d_pos + len) > d_content.size()) {
380
0
    throw std::out_of_range("Attempt to copy outside of packet");
381
0
  }
382
383
0
  dest.resize(len);
384
385
0
  for (uint16_t n = 0; n < len; ++n) {
386
0
    dest.at(n) = d_content.at(d_pos++);
387
0
  }
388
0
}
389
390
void PacketReader::copyRecord(unsigned char* dest, uint16_t len)
391
0
{
392
0
  if (d_pos + len > d_content.size()) {
393
0
    throw std::out_of_range("Attempt to copy outside of packet");
394
0
  }
395
396
0
  memcpy(dest, &d_content.at(d_pos), len);
397
0
  d_pos += len;
398
0
}
399
400
void PacketReader::xfrNodeOrLocatorID(NodeOrLocatorID& ret)
401
0
{
402
0
  if (d_pos + sizeof(ret) > d_content.size()) {
403
0
    throw std::out_of_range("Attempt to read 64 bit value outside of packet");
404
0
  }
405
0
  memcpy(&ret.content, &d_content.at(d_pos), sizeof(ret.content));
406
0
  d_pos += sizeof(ret);
407
0
}
408
409
void PacketReader::xfr48BitInt(uint64_t& ret)
410
0
{
411
0
  ret=0;
412
0
  ret+=static_cast<uint8_t>(d_content.at(d_pos++));
413
0
  ret<<=8;
414
0
  ret+=static_cast<uint8_t>(d_content.at(d_pos++));
415
0
  ret<<=8;
416
0
  ret+=static_cast<uint8_t>(d_content.at(d_pos++));
417
0
  ret<<=8;
418
0
  ret+=static_cast<uint8_t>(d_content.at(d_pos++));
419
0
  ret<<=8;
420
0
  ret+=static_cast<uint8_t>(d_content.at(d_pos++));
421
0
  ret<<=8;
422
0
  ret+=static_cast<uint8_t>(d_content.at(d_pos++));
423
0
}
424
425
uint32_t PacketReader::get32BitInt()
426
0
{
427
0
  uint32_t ret=0;
428
0
  ret+=static_cast<uint8_t>(d_content.at(d_pos++));
429
0
  ret<<=8;
430
0
  ret+=static_cast<uint8_t>(d_content.at(d_pos++));
431
0
  ret<<=8;
432
0
  ret+=static_cast<uint8_t>(d_content.at(d_pos++));
433
0
  ret<<=8;
434
0
  ret+=static_cast<uint8_t>(d_content.at(d_pos++));
435
436
0
  return ret;
437
0
}
438
439
440
uint16_t PacketReader::get16BitInt()
441
0
{
442
0
  uint16_t ret=0;
443
0
  ret+=static_cast<uint8_t>(d_content.at(d_pos++));
444
0
  ret<<=8;
445
0
  ret+=static_cast<uint8_t>(d_content.at(d_pos++));
446
447
0
  return ret;
448
0
}
449
450
uint8_t PacketReader::get8BitInt()
451
0
{
452
0
  return d_content.at(d_pos++);
453
0
}
454
455
DNSName PacketReader::getName()
456
0
{
457
0
  unsigned int consumed;
458
0
  try {
459
0
    DNSName dn((const char*) d_content.data(), d_content.size(), d_pos, true /* uncompress */, nullptr /* qtype */, nullptr /* qclass */, &consumed, sizeof(dnsheader));
460
461
0
    d_pos+=consumed;
462
0
    return dn;
463
0
  }
464
0
  catch(const std::range_error& re) {
465
0
    throw std::out_of_range(string("dnsname issue: ")+re.what());
466
0
  }
467
0
  catch(...) {
468
0
    throw std::out_of_range("dnsname issue");
469
0
  }
470
0
  throw PDNSException("PacketReader::getName(): name is empty");
471
0
}
472
473
static string txtEscape(const string &name)
474
0
{
475
0
  string ret;
476
0
  char ebuf[5];
477
478
0
  for(char i : name) {
479
0
    if((unsigned char) i >= 127 || (unsigned char) i < 32) {
480
0
      snprintf(ebuf, sizeof(ebuf), "\\%03u", (unsigned char)i);
481
0
      ret += ebuf;
482
0
    }
483
0
    else if(i=='"' || i=='\\'){
484
0
      ret += '\\';
485
0
      ret += i;
486
0
    }
487
0
    else
488
0
      ret += i;
489
0
  }
490
0
  return ret;
491
0
}
492
493
// exceptions thrown here do not result in logging in the main pdns auth server - just so you know!
494
string PacketReader::getText(bool multi, bool lenField)
495
0
{
496
0
  string ret;
497
0
  ret.reserve(40);
498
0
  while(d_pos < d_startrecordpos + d_recordlen ) {
499
0
    if(!ret.empty()) {
500
0
      ret.append(1,' ');
501
0
    }
502
0
    uint16_t labellen;
503
0
    if(lenField)
504
0
      labellen=static_cast<uint8_t>(d_content.at(d_pos++));
505
0
    else
506
0
      labellen=d_recordlen - (d_pos - d_startrecordpos);
507
508
0
    ret.append(1,'"');
509
0
    if(labellen) { // no need to do anything for an empty string
510
0
      string val(&d_content.at(d_pos), &d_content.at(d_pos+labellen-1)+1);
511
0
      ret.append(txtEscape(val)); // the end is one beyond the packet
512
0
    }
513
0
    ret.append(1,'"');
514
0
    d_pos+=labellen;
515
0
    if(!multi)
516
0
      break;
517
0
  }
518
519
0
  if (ret.empty() && !lenField) {
520
    // all lenField == false cases (CAA and URI at the time of this writing) want that emptiness to be explicit
521
0
    return "\"\"";
522
0
  }
523
0
  return ret;
524
0
}
525
526
string PacketReader::getUnquotedText(bool lenField)
527
0
{
528
0
  uint16_t stop_at;
529
0
  if(lenField)
530
0
    stop_at = static_cast<uint8_t>(d_content.at(d_pos)) + d_pos + 1;
531
0
  else
532
0
    stop_at = d_recordlen;
533
534
  /* think unsigned overflow */
535
0
  if (stop_at < d_pos) {
536
0
    throw std::out_of_range("getUnquotedText out of record range");
537
0
  }
538
539
0
  if(stop_at == d_pos)
540
0
    return "";
541
542
0
  d_pos++;
543
0
  string ret(d_content.substr(d_pos, stop_at-d_pos));
544
0
  d_pos = stop_at;
545
0
  return ret;
546
0
}
547
548
void PacketReader::xfrBlob(string& blob)
549
0
{
550
0
  try {
551
0
    if(d_recordlen && !(d_pos == (d_startrecordpos + d_recordlen))) {
552
0
      if (d_pos > (d_startrecordpos + d_recordlen)) {
553
0
        throw std::out_of_range("xfrBlob out of record range");
554
0
      }
555
0
      blob.assign(&d_content.at(d_pos), &d_content.at(d_startrecordpos + d_recordlen - 1 ) + 1);
556
0
    }
557
0
    else {
558
0
      blob.clear();
559
0
    }
560
561
0
    d_pos = d_startrecordpos + d_recordlen;
562
0
  }
563
0
  catch(...)
564
0
  {
565
0
    throw std::out_of_range("xfrBlob out of range");
566
0
  }
567
0
}
568
569
0
void PacketReader::xfrBlobNoSpaces(string& blob, int length) {
570
0
  xfrBlob(blob, length);
571
0
}
572
573
void PacketReader::xfrBlob(string& blob, int length)
574
0
{
575
0
  if(length) {
576
0
    if (length < 0) {
577
0
      throw std::out_of_range("xfrBlob out of range (negative length)");
578
0
    }
579
580
0
    blob.assign(&d_content.at(d_pos), &d_content.at(d_pos + length - 1 ) + 1 );
581
582
0
    d_pos += length;
583
0
  }
584
0
  else {
585
0
    blob.clear();
586
0
  }
587
0
}
588
589
0
void PacketReader::xfrSvcParamKeyVals(set<SvcParam> &kvs) {
590
0
  while (d_pos < (d_startrecordpos + d_recordlen)) {
591
0
    if (d_pos + 2 > (d_startrecordpos + d_recordlen)) {
592
0
      throw std::out_of_range("incomplete key");
593
0
    }
594
0
    uint16_t keyInt;
595
0
    xfr16BitInt(keyInt);
596
0
    auto key = static_cast<SvcParam::SvcParamKey>(keyInt);
597
0
    uint16_t len;
598
0
    xfr16BitInt(len);
599
600
0
    if (d_pos + len > (d_startrecordpos + d_recordlen)) {
601
0
      throw std::out_of_range("record is shorter than SVCB lengthfield implies");
602
0
    }
603
604
0
    switch (key)
605
0
    {
606
0
    case SvcParam::mandatory: {
607
0
      if (len % 2 != 0) {
608
0
        throw std::out_of_range("mandatory SvcParam has invalid length");
609
0
      }
610
0
      if (len == 0) {
611
0
        throw std::out_of_range("empty 'mandatory' values");
612
0
      }
613
0
      std::set<SvcParam::SvcParamKey> paramKeys;
614
0
      size_t stop = d_pos + len;
615
0
      while (d_pos < stop) {
616
0
        uint16_t keyval;
617
0
        xfr16BitInt(keyval);
618
0
        paramKeys.insert(static_cast<SvcParam::SvcParamKey>(keyval));
619
0
      }
620
0
      kvs.insert(SvcParam(key, std::move(paramKeys)));
621
0
      break;
622
0
    }
623
0
    case SvcParam::alpn: {
624
0
      size_t stop = d_pos + len;
625
0
      std::vector<string> alpns;
626
0
      while (d_pos < stop) {
627
0
        string alpn;
628
0
        uint8_t alpnLen = 0;
629
0
        xfr8BitInt(alpnLen);
630
0
        if (alpnLen == 0) {
631
0
          throw std::out_of_range("alpn length of 0");
632
0
        }
633
0
        xfrBlob(alpn, alpnLen);
634
0
        alpns.push_back(std::move(alpn));
635
0
      }
636
0
      kvs.insert(SvcParam(key, std::move(alpns)));
637
0
      break;
638
0
    }
639
0
    case SvcParam::no_default_alpn: {
640
0
      if (len != 0) {
641
0
        throw std::out_of_range("invalid length for no-default-alpn");
642
0
      }
643
0
      kvs.insert(SvcParam(key));
644
0
      break;
645
0
    }
646
0
    case SvcParam::port: {
647
0
      if (len != 2) {
648
0
        throw std::out_of_range("invalid length for port");
649
0
      }
650
0
      uint16_t port;
651
0
      xfr16BitInt(port);
652
0
      kvs.insert(SvcParam(key, port));
653
0
      break;
654
0
    }
655
0
    case SvcParam::ipv4hint: /* fall-through */
656
0
    case SvcParam::ipv6hint: {
657
0
      size_t addrLen = (key == SvcParam::ipv4hint ? 4 : 16);
658
0
      if (len % addrLen != 0) {
659
0
        throw std::out_of_range("invalid length for " + SvcParam::keyToString(key));
660
0
      }
661
0
      vector<ComboAddress> addresses;
662
0
      auto stop = d_pos + len;
663
0
      while (d_pos < stop)
664
0
      {
665
0
        ComboAddress addr;
666
0
        xfrCAWithoutPort(key, addr);
667
0
        addresses.push_back(addr);
668
0
      }
669
      // If there were no addresses, and the input comes from internal
670
      // representation, we can reasonably assume this is the serialization
671
      // of "auto".
672
0
      bool doAuto{d_internal && len == 0};
673
0
      auto param = SvcParam(key, std::move(addresses));
674
0
      param.setAutoHint(doAuto);
675
0
      kvs.insert(std::move(param));
676
0
      break;
677
0
    }
678
0
    case SvcParam::ech: {
679
0
      std::string blob;
680
0
      blob.reserve(len);
681
0
      xfrBlobNoSpaces(blob, len);
682
0
      kvs.insert(SvcParam(key, blob));
683
0
      break;
684
0
    }
685
0
    default: {
686
0
      std::string blob;
687
0
      blob.reserve(len);
688
0
      xfrBlob(blob, len);
689
0
      kvs.insert(SvcParam(key, blob));
690
0
      break;
691
0
    }
692
0
    }
693
0
  }
694
0
}
695
696
697
void PacketReader::xfrHexBlob(string& blob, bool /* keepReading */)
698
0
{
699
0
  xfrBlob(blob);
700
0
}
701
702
//FIXME400 remove this method completely
703
string simpleCompress(const string& elabel, const string& root)
704
0
{
705
0
  string label=elabel;
706
  // FIXME400: this relies on the semi-canonical escaped output from getName
707
0
  if(strchr(label.c_str(), '\\')) {
708
0
    boost::replace_all(label, "\\.", ".");
709
0
    boost::replace_all(label, "\\032", " ");
710
0
    boost::replace_all(label, "\\\\", "\\");
711
0
  }
712
0
  typedef vector<pair<unsigned int, unsigned int> > parts_t;
713
0
  parts_t parts;
714
0
  vstringtok(parts, label, ".");
715
0
  string ret;
716
0
  ret.reserve(label.size()+4);
717
0
  for(const auto & part : parts) {
718
0
    if(!root.empty() && !strncasecmp(root.c_str(), label.c_str() + part.first, 1 + label.length() - part.first)) { // also match trailing 0, hence '1 +'
719
0
      const unsigned char rootptr[2]={0xc0,0x11};
720
0
      ret.append((const char *) rootptr, 2);
721
0
      return ret;
722
0
    }
723
0
    ret.append(1, (char)(part.second - part.first));
724
0
    ret.append(label.c_str() + part.first, part.second - part.first);
725
0
  }
726
0
  ret.append(1, (char)0);
727
0
  return ret;
728
0
}
729
730
// method of operation: silently fail if it doesn't work - we're only trying to be nice, don't fall over on it
731
void editDNSPacketTTL(char* packet, size_t length, const std::function<uint32_t(uint8_t, uint16_t, uint16_t, uint32_t)>& visitor)
732
0
{
733
0
  if(length < sizeof(dnsheader))
734
0
    return;
735
0
  try
736
0
  {
737
0
    dnsheader dh;
738
0
    memcpy((void*)&dh, (const dnsheader*)packet, sizeof(dh));
739
0
    uint64_t numrecords = ntohs(dh.ancount) + ntohs(dh.nscount) + ntohs(dh.arcount);
740
0
    DNSPacketMangler dpm(packet, length);
741
742
0
    uint64_t n;
743
0
    for(n=0; n < ntohs(dh.qdcount) ; ++n) {
744
0
      dpm.skipDomainName();
745
      /* type and class */
746
0
      dpm.skipBytes(4);
747
0
    }
748
749
0
    for(n=0; n < numrecords; ++n) {
750
0
      dpm.skipDomainName();
751
752
0
      uint8_t section = n < ntohs(dh.ancount) ? 1 : (n < (ntohs(dh.ancount) + ntohs(dh.nscount)) ? 2 : 3);
753
0
      uint16_t dnstype = dpm.get16BitInt();
754
0
      uint16_t dnsclass = dpm.get16BitInt();
755
756
0
      if(dnstype == QType::OPT) // not getting near that one with a stick
757
0
        break;
758
759
0
      uint32_t dnsttl = dpm.get32BitInt();
760
0
      uint32_t newttl = visitor(section, dnsclass, dnstype, dnsttl);
761
0
      if (newttl) {
762
0
        dpm.rewindBytes(sizeof(newttl));
763
0
        dpm.setAndSkip32BitInt(newttl);
764
0
      }
765
0
      dpm.skipRData();
766
0
    }
767
0
  }
768
0
  catch(...)
769
0
  {
770
0
    return;
771
0
  }
772
0
}
773
774
static bool checkIfPacketContainsRecords(const PacketBuffer& packet, const std::unordered_set<QType>& qtypes)
775
0
{
776
0
  auto length = packet.size();
777
0
  if (length < sizeof(dnsheader)) {
778
0
    return false;
779
0
  }
780
781
0
  try {
782
0
    const dnsheader_aligned dh(packet.data());
783
0
    DNSPacketMangler dpm(const_cast<char*>(reinterpret_cast<const char*>(packet.data())), length);
784
785
0
    const uint16_t qdcount = ntohs(dh->qdcount);
786
0
    for (size_t n = 0; n < qdcount; ++n) {
787
0
      dpm.skipDomainName();
788
      /* type and class */
789
0
      dpm.skipBytes(4);
790
0
    }
791
0
    const size_t recordsCount = static_cast<size_t>(ntohs(dh->ancount)) + ntohs(dh->nscount) + ntohs(dh->arcount);
792
0
    for (size_t n = 0; n < recordsCount; ++n) {
793
0
      dpm.skipDomainName();
794
0
      uint16_t dnstype = dpm.get16BitInt();
795
0
      uint16_t dnsclass = dpm.get16BitInt();
796
0
      if (dnsclass == QClass::IN && qtypes.count(dnstype) > 0) {
797
0
        return true;
798
0
      }
799
      /* ttl */
800
0
      dpm.skipBytes(4);
801
0
      dpm.skipRData();
802
0
    }
803
0
  }
804
0
  catch (...) {
805
0
  }
806
807
0
  return false;
808
0
}
809
810
static int rewritePacketWithoutRecordTypes(const PacketBuffer& initialPacket, PacketBuffer& newContent, const std::unordered_set<QType>& qtypes)
811
0
{
812
0
  static const std::unordered_set<QType>& safeTypes{QType::A, QType::AAAA, QType::DHCID, QType::TXT, QType::OPT, QType::HINFO, QType::DNSKEY, QType::CDNSKEY, QType::DS, QType::CDS, QType::DLV, QType::SSHFP, QType::KEY, QType::CERT, QType::TLSA, QType::SMIMEA, QType::OPENPGPKEY, QType::SVCB, QType::HTTPS, QType::NSEC3, QType::CSYNC, QType::NSEC3PARAM, QType::LOC, QType::NID, QType::L32, QType::L64, QType::EUI48, QType::EUI64, QType::URI, QType::CAA};
813
814
0
  if (initialPacket.size() < sizeof(dnsheader)) {
815
0
    return EINVAL;
816
0
  }
817
0
  try {
818
0
    const dnsheader_aligned dh(initialPacket.data());
819
820
0
    if (ntohs(dh->qdcount) == 0)
821
0
      return ENOENT;
822
0
    auto packetView = std::string_view(reinterpret_cast<const char*>(initialPacket.data()), initialPacket.size());
823
824
0
    PacketReader pr(packetView);
825
826
0
    size_t idx = 0;
827
0
    DNSName rrname;
828
0
    uint16_t qdcount = ntohs(dh->qdcount);
829
0
    uint16_t ancount = ntohs(dh->ancount);
830
0
    uint16_t nscount = ntohs(dh->nscount);
831
0
    uint16_t arcount = ntohs(dh->arcount);
832
0
    uint16_t rrtype;
833
0
    uint16_t rrclass;
834
0
    string blob;
835
0
    struct dnsrecordheader ah;
836
837
0
    rrname = pr.getName();
838
0
    rrtype = pr.get16BitInt();
839
0
    rrclass = pr.get16BitInt();
840
841
0
    GenericDNSPacketWriter<PacketBuffer> pw(newContent, rrname, rrtype, rrclass, dh->opcode);
842
0
    pw.getHeader()->id=dh->id;
843
0
    pw.getHeader()->qr=dh->qr;
844
0
    pw.getHeader()->aa=dh->aa;
845
0
    pw.getHeader()->tc=dh->tc;
846
0
    pw.getHeader()->rd=dh->rd;
847
0
    pw.getHeader()->ra=dh->ra;
848
0
    pw.getHeader()->ad=dh->ad;
849
0
    pw.getHeader()->cd=dh->cd;
850
0
    pw.getHeader()->rcode=dh->rcode;
851
852
    /* consume remaining qd if any */
853
0
    if (qdcount > 1) {
854
0
      for(idx = 1; idx < qdcount; idx++) {
855
0
        rrname = pr.getName();
856
0
        rrtype = pr.get16BitInt();
857
0
        rrclass = pr.get16BitInt();
858
0
        (void) rrtype;
859
0
        (void) rrclass;
860
0
      }
861
0
    }
862
863
    /* copy AN */
864
0
    for (idx = 0; idx < ancount; idx++) {
865
0
      rrname = pr.getName();
866
0
      pr.getDnsrecordheader(ah);
867
0
      pr.xfrBlob(blob);
868
869
0
      if (qtypes.find(ah.d_type) == qtypes.end()) {
870
        // if this is not a safe type
871
0
        if (safeTypes.find(ah.d_type) == safeTypes.end()) {
872
          // "unsafe" types might countain compressed data, so cancel rewrite
873
0
          newContent.clear();
874
0
          return EIO;
875
0
        }
876
0
        pw.startRecord(rrname, ah.d_type, ah.d_ttl, ah.d_class, DNSResourceRecord::ANSWER, true);
877
0
        pw.xfrBlob(blob);
878
0
      }
879
0
    }
880
881
    /* copy NS */
882
0
    for (idx = 0; idx < nscount; idx++) {
883
0
      rrname = pr.getName();
884
0
      pr.getDnsrecordheader(ah);
885
0
      pr.xfrBlob(blob);
886
887
0
      if (qtypes.find(ah.d_type) == qtypes.end()) {
888
0
        if (safeTypes.find(ah.d_type) == safeTypes.end()) {
889
          // "unsafe" types might countain compressed data, so cancel rewrite
890
0
          newContent.clear();
891
0
          return EIO;
892
0
        }
893
0
        pw.startRecord(rrname, ah.d_type, ah.d_ttl, ah.d_class, DNSResourceRecord::AUTHORITY, true);
894
0
        pw.xfrBlob(blob);
895
0
      }
896
0
    }
897
    /* copy AR */
898
0
    for (idx = 0; idx < arcount; idx++) {
899
0
      rrname = pr.getName();
900
0
      pr.getDnsrecordheader(ah);
901
0
      pr.xfrBlob(blob);
902
903
0
      if (qtypes.find(ah.d_type) == qtypes.end()) {
904
0
        if (safeTypes.find(ah.d_type) == safeTypes.end()) {
905
          // "unsafe" types might countain compressed data, so cancel rewrite
906
0
          newContent.clear();
907
0
          return EIO;
908
0
        }
909
0
        pw.startRecord(rrname, ah.d_type, ah.d_ttl, ah.d_class, DNSResourceRecord::ADDITIONAL, true);
910
0
        pw.xfrBlob(blob);
911
0
      }
912
0
    }
913
0
    pw.commit();
914
915
0
  }
916
0
  catch (...)
917
0
  {
918
0
    newContent.clear();
919
0
    return EIO;
920
0
  }
921
0
  return 0;
922
0
}
923
924
void clearDNSPacketRecordTypes(vector<uint8_t>& packet, const std::unordered_set<QType>& qtypes)
925
0
{
926
0
  return clearDNSPacketRecordTypes(reinterpret_cast<PacketBuffer&>(packet), qtypes);
927
0
}
928
929
void clearDNSPacketRecordTypes(PacketBuffer& packet, const std::unordered_set<QType>& qtypes)
930
0
{
931
0
  if (!checkIfPacketContainsRecords(packet, qtypes)) {
932
0
    return;
933
0
  }
934
935
0
  PacketBuffer newContent;
936
937
0
  auto result = rewritePacketWithoutRecordTypes(packet, newContent, qtypes);
938
0
  if (!result) {
939
0
    packet = std::move(newContent);
940
0
  }
941
0
}
942
943
// method of operation: silently fail if it doesn't work - we're only trying to be nice, don't fall over on it
944
void ageDNSPacket(char* packet, size_t length, uint32_t seconds, const dnsheader_aligned& aligned_dh)
945
0
{
946
0
  if (length < sizeof(dnsheader)) {
947
0
    return;
948
0
  }
949
0
  try {
950
0
    const dnsheader* dhp = aligned_dh.get();
951
0
    const uint64_t dqcount = ntohs(dhp->qdcount);
952
0
    const uint64_t numrecords = ntohs(dhp->ancount) + ntohs(dhp->nscount) + ntohs(dhp->arcount);
953
0
    DNSPacketMangler dpm(packet, length);
954
955
0
    for (uint64_t rec = 0; rec < dqcount; ++rec) {
956
0
      dpm.skipDomainName();
957
      /* type and class */
958
0
      dpm.skipBytes(4);
959
0
    }
960
961
0
    for(uint64_t rec = 0; rec < numrecords; ++rec) {
962
0
      dpm.skipDomainName();
963
964
0
      uint16_t dnstype = dpm.get16BitInt();
965
      /* class */
966
0
      dpm.skipBytes(2);
967
968
0
      if (dnstype != QType::OPT) { // not aging that one with a stick
969
0
        dpm.decreaseAndSkip32BitInt(seconds);
970
0
      } else {
971
0
        dpm.skipBytes(4);
972
0
      }
973
0
      dpm.skipRData();
974
0
    }
975
0
  }
976
0
  catch(...) {
977
0
  }
978
0
}
979
980
void ageDNSPacket(std::string& packet, uint32_t seconds, const dnsheader_aligned& aligned_dh)
981
0
{
982
0
  ageDNSPacket(packet.data(), packet.length(), seconds, aligned_dh);
983
0
}
984
985
uint32_t getDNSPacketMinTTL(const char* packet, size_t length, bool* seenAuthSOA)
986
0
{
987
0
  uint32_t result = std::numeric_limits<uint32_t>::max();
988
0
  if(length < sizeof(dnsheader)) {
989
0
    return result;
990
0
  }
991
0
  try
992
0
  {
993
0
    const dnsheader_aligned dh(packet);
994
0
    DNSPacketMangler dpm(const_cast<char*>(packet), length);
995
996
0
    const uint16_t qdcount = ntohs(dh->qdcount);
997
0
    for(size_t n = 0; n < qdcount; ++n) {
998
0
      dpm.skipDomainName();
999
      /* type and class */
1000
0
      dpm.skipBytes(4);
1001
0
    }
1002
0
    const size_t numrecords = ntohs(dh->ancount) + ntohs(dh->nscount) + ntohs(dh->arcount);
1003
0
    for(size_t n = 0; n < numrecords; ++n) {
1004
0
      dpm.skipDomainName();
1005
0
      const uint16_t dnstype = dpm.get16BitInt();
1006
      /* class */
1007
0
      const uint16_t dnsclass = dpm.get16BitInt();
1008
1009
0
      if(dnstype == QType::OPT) {
1010
0
        break;
1011
0
      }
1012
1013
      /* report it if we see a SOA record in the AUTHORITY section */
1014
0
      if(dnstype == QType::SOA && dnsclass == QClass::IN && seenAuthSOA != nullptr && n >= ntohs(dh->ancount) && n < (ntohs(dh->ancount) + ntohs(dh->nscount))) {
1015
0
        *seenAuthSOA = true;
1016
0
      }
1017
1018
0
      const uint32_t ttl = dpm.get32BitInt();
1019
0
      result = std::min(result, ttl);
1020
1021
0
      dpm.skipRData();
1022
0
    }
1023
0
  }
1024
0
  catch(...)
1025
0
  {
1026
0
  }
1027
0
  return result;
1028
0
}
1029
1030
uint32_t getDNSPacketLength(const char* packet, size_t length)
1031
0
{
1032
0
  uint32_t result = length;
1033
0
  if(length < sizeof(dnsheader)) {
1034
0
    return result;
1035
0
  }
1036
0
  try
1037
0
  {
1038
0
    const dnsheader_aligned dh(packet);
1039
0
    DNSPacketMangler dpm(const_cast<char*>(packet), length);
1040
1041
0
    const uint16_t qdcount = ntohs(dh->qdcount);
1042
0
    for(size_t n = 0; n < qdcount; ++n) {
1043
0
      dpm.skipDomainName();
1044
      /* type and class */
1045
0
      dpm.skipBytes(4);
1046
0
    }
1047
0
    const size_t numrecords = ntohs(dh->ancount) + ntohs(dh->nscount) + ntohs(dh->arcount);
1048
0
    for(size_t n = 0; n < numrecords; ++n) {
1049
0
      dpm.skipDomainName();
1050
      /* type (2), class (2) and ttl (4) */
1051
0
      dpm.skipBytes(8);
1052
0
      dpm.skipRData();
1053
0
    }
1054
0
    result = dpm.getOffset();
1055
0
  }
1056
0
  catch(...)
1057
0
  {
1058
0
  }
1059
0
  return result;
1060
0
}
1061
1062
uint16_t getRecordsOfTypeCount(const char* packet, size_t length, uint8_t section, uint16_t type)
1063
0
{
1064
0
  uint16_t result = 0;
1065
0
  if(length < sizeof(dnsheader)) {
1066
0
    return result;
1067
0
  }
1068
0
  try
1069
0
  {
1070
0
    const dnsheader_aligned dh(packet);
1071
0
    DNSPacketMangler dpm(const_cast<char*>(packet), length);
1072
1073
0
    const uint16_t qdcount = ntohs(dh->qdcount);
1074
0
    for(size_t n = 0; n < qdcount; ++n) {
1075
0
      dpm.skipDomainName();
1076
0
      if (section == 0) {
1077
0
        uint16_t dnstype = dpm.get16BitInt();
1078
0
        if (dnstype == type) {
1079
0
          result++;
1080
0
        }
1081
        /* class */
1082
0
        dpm.skipBytes(2);
1083
0
      } else {
1084
        /* type and class */
1085
0
        dpm.skipBytes(4);
1086
0
      }
1087
0
    }
1088
0
    const uint16_t ancount = ntohs(dh->ancount);
1089
0
    for(size_t n = 0; n < ancount; ++n) {
1090
0
      dpm.skipDomainName();
1091
0
      if (section == 1) {
1092
0
        uint16_t dnstype = dpm.get16BitInt();
1093
0
        if (dnstype == type) {
1094
0
          result++;
1095
0
        }
1096
        /* class */
1097
0
        dpm.skipBytes(2);
1098
0
      } else {
1099
        /* type and class */
1100
0
        dpm.skipBytes(4);
1101
0
      }
1102
      /* ttl */
1103
0
      dpm.skipBytes(4);
1104
0
      dpm.skipRData();
1105
0
    }
1106
0
    const uint16_t nscount = ntohs(dh->nscount);
1107
0
    for(size_t n = 0; n < nscount; ++n) {
1108
0
      dpm.skipDomainName();
1109
0
      if (section == 2) {
1110
0
        uint16_t dnstype = dpm.get16BitInt();
1111
0
        if (dnstype == type) {
1112
0
          result++;
1113
0
        }
1114
        /* class */
1115
0
        dpm.skipBytes(2);
1116
0
      } else {
1117
        /* type and class */
1118
0
        dpm.skipBytes(4);
1119
0
      }
1120
      /* ttl */
1121
0
      dpm.skipBytes(4);
1122
0
      dpm.skipRData();
1123
0
    }
1124
0
    const uint16_t arcount = ntohs(dh->arcount);
1125
0
    for(size_t n = 0; n < arcount; ++n) {
1126
0
      dpm.skipDomainName();
1127
0
      if (section == 3) {
1128
0
        uint16_t dnstype = dpm.get16BitInt();
1129
0
        if (dnstype == type) {
1130
0
          result++;
1131
0
        }
1132
        /* class */
1133
0
        dpm.skipBytes(2);
1134
0
      } else {
1135
        /* type and class */
1136
0
        dpm.skipBytes(4);
1137
0
      }
1138
      /* ttl */
1139
0
      dpm.skipBytes(4);
1140
0
      dpm.skipRData();
1141
0
    }
1142
0
  }
1143
0
  catch(...)
1144
0
  {
1145
0
  }
1146
0
  return result;
1147
0
}
1148
1149
bool getEDNSUDPPayloadSizeAndZ(const char* packet, size_t length, uint16_t* payloadSize, uint16_t* z)
1150
0
{
1151
0
  if (length < sizeof(dnsheader)) {
1152
0
    return false;
1153
0
  }
1154
1155
0
  *payloadSize = 0;
1156
0
  *z = 0;
1157
1158
0
  try
1159
0
  {
1160
0
    const dnsheader_aligned dh(packet);
1161
0
    DNSPacketMangler dpm(const_cast<char*>(packet), length);
1162
1163
0
    const uint16_t qdcount = ntohs(dh->qdcount);
1164
0
    for(size_t n = 0; n < qdcount; ++n) {
1165
0
      dpm.skipDomainName();
1166
      /* type and class */
1167
0
      dpm.skipBytes(4);
1168
0
    }
1169
0
    const size_t numrecords = ntohs(dh->ancount) + ntohs(dh->nscount) + ntohs(dh->arcount);
1170
0
    for(size_t n = 0; n < numrecords; ++n) {
1171
0
      dpm.skipDomainName();
1172
0
      const uint16_t dnstype = dpm.get16BitInt();
1173
0
      const uint16_t dnsclass = dpm.get16BitInt();
1174
1175
0
      if(dnstype == QType::OPT) {
1176
        /* skip extended rcode and version */
1177
0
        dpm.skipBytes(2);
1178
0
        *z = dpm.get16BitInt();
1179
0
        *payloadSize = dnsclass;
1180
0
        return true;
1181
0
      }
1182
1183
      /* TTL */
1184
0
      dpm.skipBytes(4);
1185
0
      dpm.skipRData();
1186
0
    }
1187
0
  }
1188
0
  catch(...)
1189
0
  {
1190
0
  }
1191
1192
0
  return false;
1193
0
}
1194
1195
bool visitDNSPacket(const std::string_view& packet, const std::function<bool(uint8_t, uint16_t, uint16_t, uint32_t, uint16_t, const char*)>& visitor)
1196
0
{
1197
0
  if (packet.size() < sizeof(dnsheader)) {
1198
0
    return false;
1199
0
  }
1200
1201
0
  try
1202
0
  {
1203
0
    const dnsheader_aligned dh(packet.data());
1204
0
    uint64_t numrecords = ntohs(dh->ancount) + ntohs(dh->nscount) + ntohs(dh->arcount);
1205
0
    PacketReader reader(packet);
1206
1207
0
    uint64_t n;
1208
0
    for (n = 0; n < ntohs(dh->qdcount) ; ++n) {
1209
0
      (void) reader.getName();
1210
      /* type and class */
1211
0
      reader.skip(4);
1212
0
    }
1213
1214
0
    for (n = 0; n < numrecords; ++n) {
1215
0
      (void) reader.getName();
1216
1217
0
      uint8_t section = n < ntohs(dh->ancount) ? 1 : (n < (ntohs(dh->ancount) + ntohs(dh->nscount)) ? 2 : 3);
1218
0
      uint16_t dnstype = reader.get16BitInt();
1219
0
      uint16_t dnsclass = reader.get16BitInt();
1220
1221
0
      if (dnstype == QType::OPT) {
1222
        // not getting near that one with a stick
1223
0
        break;
1224
0
      }
1225
1226
0
      uint32_t dnsttl = reader.get32BitInt();
1227
0
      uint16_t contentLength = reader.get16BitInt();
1228
0
      uint16_t pos = reader.getPosition();
1229
1230
0
      bool done = visitor(section, dnsclass, dnstype, dnsttl, contentLength, &packet.at(pos));
1231
0
      if (done) {
1232
0
        return true;
1233
0
      }
1234
1235
0
      reader.skip(contentLength);
1236
0
    }
1237
0
  }
1238
0
  catch (...) {
1239
0
    return false;
1240
0
  }
1241
1242
0
  return true;
1243
0
}