Coverage Report

Created: 2025-06-13 06:27

/src/pdns/pdns/dnsdistdist/dnsdist-ecs.cc
Line
Count
Source (jump to first uncovered line)
1
/*
2
 * This file is part of PowerDNS or dnsdist.
3
 * Copyright -- PowerDNS.COM B.V. and its contributors
4
 *
5
 * This program is free software; you can redistribute it and/or modify
6
 * it under the terms of version 2 of the GNU General Public License as
7
 * published by the Free Software Foundation.
8
 *
9
 * In addition, for the avoidance of any doubt, permission is granted to
10
 * link this program with OpenSSL and to (re)distribute the binaries
11
 * produced as the result of such linking.
12
 *
13
 * This program is distributed in the hope that it will be useful,
14
 * but WITHOUT ANY WARRANTY; without even the implied warranty of
15
 * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the
16
 * GNU General Public License for more details.
17
 *
18
 * You should have received a copy of the GNU General Public License
19
 * along with this program; if not, write to the Free Software
20
 * Foundation, Inc., 51 Franklin Street, Fifth Floor, Boston, MA 02110-1301 USA.
21
 */
22
#include "dolog.hh"
23
#include "dnsdist.hh"
24
#include "dnsdist-dnsparser.hh"
25
#include "dnsdist-ecs.hh"
26
#include "dnsparser.hh"
27
#include "dnswriter.hh"
28
#include "ednsoptions.hh"
29
#include "ednssubnet.hh"
30
31
int rewriteResponseWithoutEDNS(const PacketBuffer& initialPacket, PacketBuffer& newContent)
32
0
{
33
0
  if (initialPacket.size() < sizeof(dnsheader)) {
34
0
    return ENOENT;
35
0
  }
36
37
0
  const dnsheader_aligned dnsHeader(initialPacket.data());
38
39
0
  if (ntohs(dnsHeader->arcount) == 0) {
40
0
    return ENOENT;
41
0
  }
42
43
0
  if (ntohs(dnsHeader->qdcount) == 0) {
44
0
    return ENOENT;
45
0
  }
46
47
  // NOLINTNEXTLINE(cppcoreguidelines-pro-type-reinterpret-cast)
48
0
  PacketReader packetReader(std::string_view(reinterpret_cast<const char*>(initialPacket.data()), initialPacket.size()));
49
50
0
  size_t idx = 0;
51
0
  uint16_t qdcount = ntohs(dnsHeader->qdcount);
52
0
  uint16_t ancount = ntohs(dnsHeader->ancount);
53
0
  uint16_t nscount = ntohs(dnsHeader->nscount);
54
0
  uint16_t arcount = ntohs(dnsHeader->arcount);
55
0
  string blob;
56
0
  dnsrecordheader recordHeader{};
57
58
0
  auto rrname = packetReader.getName();
59
0
  auto rrtype = packetReader.get16BitInt();
60
0
  auto rrclass = packetReader.get16BitInt();
61
62
0
  GenericDNSPacketWriter<PacketBuffer> packetWriter(newContent, rrname, rrtype, rrclass, dnsHeader->opcode);
63
0
  packetWriter.getHeader()->id = dnsHeader->id;
64
0
  packetWriter.getHeader()->qr = dnsHeader->qr;
65
0
  packetWriter.getHeader()->aa = dnsHeader->aa;
66
0
  packetWriter.getHeader()->tc = dnsHeader->tc;
67
0
  packetWriter.getHeader()->rd = dnsHeader->rd;
68
0
  packetWriter.getHeader()->ra = dnsHeader->ra;
69
0
  packetWriter.getHeader()->ad = dnsHeader->ad;
70
0
  packetWriter.getHeader()->cd = dnsHeader->cd;
71
0
  packetWriter.getHeader()->rcode = dnsHeader->rcode;
72
73
  /* consume remaining qd if any */
74
0
  if (qdcount > 1) {
75
0
    for (idx = 1; idx < qdcount; idx++) {
76
0
      rrname = packetReader.getName();
77
0
      rrtype = packetReader.get16BitInt();
78
0
      rrclass = packetReader.get16BitInt();
79
0
      (void)rrtype;
80
0
      (void)rrclass;
81
0
    }
82
0
  }
83
84
  /* copy AN and NS */
85
0
  for (idx = 0; idx < ancount; idx++) {
86
0
    rrname = packetReader.getName();
87
0
    packetReader.getDnsrecordheader(recordHeader);
88
89
0
    packetWriter.startRecord(rrname, recordHeader.d_type, recordHeader.d_ttl, recordHeader.d_class, DNSResourceRecord::ANSWER, true);
90
0
    packetReader.xfrBlob(blob);
91
0
    packetWriter.xfrBlob(blob);
92
0
  }
93
94
0
  for (idx = 0; idx < nscount; idx++) {
95
0
    rrname = packetReader.getName();
96
0
    packetReader.getDnsrecordheader(recordHeader);
97
98
0
    packetWriter.startRecord(rrname, recordHeader.d_type, recordHeader.d_ttl, recordHeader.d_class, DNSResourceRecord::AUTHORITY, true);
99
0
    packetReader.xfrBlob(blob);
100
0
    packetWriter.xfrBlob(blob);
101
0
  }
102
  /* consume AR, looking for OPT */
103
0
  for (idx = 0; idx < arcount; idx++) {
104
0
    rrname = packetReader.getName();
105
0
    packetReader.getDnsrecordheader(recordHeader);
106
107
0
    if (recordHeader.d_type != QType::OPT) {
108
0
      packetWriter.startRecord(rrname, recordHeader.d_type, recordHeader.d_ttl, recordHeader.d_class, DNSResourceRecord::ADDITIONAL, true);
109
0
      packetReader.xfrBlob(blob);
110
0
      packetWriter.xfrBlob(blob);
111
0
    }
112
0
    else {
113
114
0
      packetReader.skip(recordHeader.d_clen);
115
0
    }
116
0
  }
117
0
  packetWriter.commit();
118
119
0
  return 0;
120
0
}
121
122
static bool addOrReplaceEDNSOption(std::vector<std::pair<uint16_t, std::string>>& options, uint16_t optionCode, bool& optionAdded, bool overrideExisting, const string& newOptionContent)
123
0
{
124
0
  for (auto it = options.begin(); it != options.end();) {
125
0
    if (it->first == optionCode) {
126
0
      optionAdded = false;
127
128
0
      if (!overrideExisting) {
129
0
        return false;
130
0
      }
131
132
0
      it = options.erase(it);
133
0
    }
134
0
    else {
135
0
      ++it;
136
0
    }
137
0
  }
138
139
0
  options.emplace_back(optionCode, std::string(&newOptionContent.at(EDNS_OPTION_CODE_SIZE + EDNS_OPTION_LENGTH_SIZE), newOptionContent.size() - (EDNS_OPTION_CODE_SIZE + EDNS_OPTION_LENGTH_SIZE)));
140
0
  return true;
141
0
}
142
143
bool slowRewriteEDNSOptionInQueryWithRecords(const PacketBuffer& initialPacket, PacketBuffer& newContent, bool& ednsAdded, uint16_t optionToReplace, bool& optionAdded, bool overrideExisting, const string& newOptionContent)
144
0
{
145
0
  if (initialPacket.size() < sizeof(dnsheader)) {
146
0
    return false;
147
0
  }
148
149
0
  const dnsheader_aligned dnsHeader(initialPacket.data());
150
151
0
  if (ntohs(dnsHeader->qdcount) == 0) {
152
0
    return false;
153
0
  }
154
155
0
  if (ntohs(dnsHeader->ancount) == 0 && ntohs(dnsHeader->nscount) == 0 && ntohs(dnsHeader->arcount) == 0) {
156
0
    throw std::runtime_error("slowRewriteEDNSOptionInQueryWithRecords should not be called for queries that have no records");
157
0
  }
158
159
0
  optionAdded = false;
160
0
  ednsAdded = true;
161
162
  // NOLINTNEXTLINE(cppcoreguidelines-pro-type-reinterpret-cast)
163
0
  PacketReader packetReader(std::string_view(reinterpret_cast<const char*>(initialPacket.data()), initialPacket.size()));
164
165
0
  size_t idx = 0;
166
0
  uint16_t qdcount = ntohs(dnsHeader->qdcount);
167
0
  uint16_t ancount = ntohs(dnsHeader->ancount);
168
0
  uint16_t nscount = ntohs(dnsHeader->nscount);
169
0
  uint16_t arcount = ntohs(dnsHeader->arcount);
170
0
  string blob;
171
0
  dnsrecordheader recordHeader{};
172
173
0
  auto rrname = packetReader.getName();
174
0
  auto rrtype = packetReader.get16BitInt();
175
0
  auto rrclass = packetReader.get16BitInt();
176
177
0
  GenericDNSPacketWriter<PacketBuffer> packetWriter(newContent, rrname, rrtype, rrclass, dnsHeader->opcode);
178
0
  packetWriter.getHeader()->id = dnsHeader->id;
179
0
  packetWriter.getHeader()->qr = dnsHeader->qr;
180
0
  packetWriter.getHeader()->aa = dnsHeader->aa;
181
0
  packetWriter.getHeader()->tc = dnsHeader->tc;
182
0
  packetWriter.getHeader()->rd = dnsHeader->rd;
183
0
  packetWriter.getHeader()->ra = dnsHeader->ra;
184
0
  packetWriter.getHeader()->ad = dnsHeader->ad;
185
0
  packetWriter.getHeader()->cd = dnsHeader->cd;
186
0
  packetWriter.getHeader()->rcode = dnsHeader->rcode;
187
188
  /* consume remaining qd if any */
189
0
  if (qdcount > 1) {
190
0
    for (idx = 1; idx < qdcount; idx++) {
191
0
      rrname = packetReader.getName();
192
0
      rrtype = packetReader.get16BitInt();
193
0
      rrclass = packetReader.get16BitInt();
194
0
      (void)rrtype;
195
0
      (void)rrclass;
196
0
    }
197
0
  }
198
199
  /* copy AN and NS */
200
0
  for (idx = 0; idx < ancount; idx++) {
201
0
    rrname = packetReader.getName();
202
0
    packetReader.getDnsrecordheader(recordHeader);
203
204
0
    packetWriter.startRecord(rrname, recordHeader.d_type, recordHeader.d_ttl, recordHeader.d_class, DNSResourceRecord::ANSWER, true);
205
0
    packetReader.xfrBlob(blob);
206
0
    packetWriter.xfrBlob(blob);
207
0
  }
208
209
0
  for (idx = 0; idx < nscount; idx++) {
210
0
    rrname = packetReader.getName();
211
0
    packetReader.getDnsrecordheader(recordHeader);
212
213
0
    packetWriter.startRecord(rrname, recordHeader.d_type, recordHeader.d_ttl, recordHeader.d_class, DNSResourceRecord::AUTHORITY, true);
214
0
    packetReader.xfrBlob(blob);
215
0
    packetWriter.xfrBlob(blob);
216
0
  }
217
218
  /* consume AR, looking for OPT */
219
0
  for (idx = 0; idx < arcount; idx++) {
220
0
    rrname = packetReader.getName();
221
0
    packetReader.getDnsrecordheader(recordHeader);
222
223
0
    if (recordHeader.d_type != QType::OPT) {
224
0
      packetWriter.startRecord(rrname, recordHeader.d_type, recordHeader.d_ttl, recordHeader.d_class, DNSResourceRecord::ADDITIONAL, true);
225
0
      packetReader.xfrBlob(blob);
226
0
      packetWriter.xfrBlob(blob);
227
0
    }
228
0
    else {
229
230
0
      ednsAdded = false;
231
0
      packetReader.xfrBlob(blob);
232
233
0
      std::vector<std::pair<uint16_t, std::string>> options;
234
0
      getEDNSOptionsFromContent(blob, options);
235
236
      /* getDnsrecordheader() has helpfully converted the TTL for us, which we do not want in that case */
237
0
      uint32_t ttl = htonl(recordHeader.d_ttl);
238
0
      EDNS0Record edns0{};
239
0
      static_assert(sizeof(edns0) == sizeof(ttl), "sizeof(EDNS0Record) must match sizeof(uint32_t) AKA RR TTL size");
240
0
      memcpy(&edns0, &ttl, sizeof(edns0));
241
242
      /* addOrReplaceEDNSOption will set it to false if there is already an existing option */
243
0
      optionAdded = true;
244
0
      addOrReplaceEDNSOption(options, optionToReplace, optionAdded, overrideExisting, newOptionContent);
245
0
      packetWriter.addOpt(recordHeader.d_class, edns0.extRCode, ntohs(edns0.extFlags), options, edns0.version);
246
0
    }
247
0
  }
248
249
0
  if (ednsAdded) {
250
0
    packetWriter.addOpt(dnsdist::configuration::s_EdnsUDPPayloadSize, 0, 0, {{optionToReplace, std::string(&newOptionContent.at(EDNS_OPTION_CODE_SIZE + EDNS_OPTION_LENGTH_SIZE), newOptionContent.size() - (EDNS_OPTION_CODE_SIZE + EDNS_OPTION_LENGTH_SIZE))}}, 0);
251
0
    optionAdded = true;
252
0
  }
253
254
0
  packetWriter.commit();
255
256
0
  return true;
257
0
}
258
259
static bool slowParseEDNSOptions(const PacketBuffer& packet, EDNSOptionViewMap& options)
260
0
{
261
0
  if (packet.size() < sizeof(dnsheader)) {
262
0
    return false;
263
0
  }
264
265
0
  const dnsheader_aligned dnsHeader(packet.data());
266
267
0
  if (ntohs(dnsHeader->qdcount) == 0) {
268
0
    return false;
269
0
  }
270
271
0
  if (ntohs(dnsHeader->arcount) == 0) {
272
0
    throw std::runtime_error("slowParseEDNSOptions() should not be called for queries that have no EDNS");
273
0
  }
274
275
0
  try {
276
0
    uint64_t numrecords = ntohs(dnsHeader->ancount) + ntohs(dnsHeader->nscount) + ntohs(dnsHeader->arcount);
277
    // NOLINTNEXTLINE(cppcoreguidelines-pro-type-reinterpret-cast,cppcoreguidelines-pro-type-const-cast)
278
0
    DNSPacketMangler dpm(const_cast<char*>(reinterpret_cast<const char*>(packet.data())), packet.size());
279
0
    uint64_t index{};
280
0
    for (index = 0; index < ntohs(dnsHeader->qdcount); ++index) {
281
0
      dpm.skipDomainName();
282
      /* type and class */
283
0
      dpm.skipBytes(4);
284
0
    }
285
286
0
    for (index = 0; index < numrecords; ++index) {
287
0
      dpm.skipDomainName();
288
289
0
      uint8_t section = index < ntohs(dnsHeader->ancount) ? 1 : (index < (ntohs(dnsHeader->ancount) + ntohs(dnsHeader->nscount)) ? 2 : 3);
290
0
      uint16_t dnstype = dpm.get16BitInt();
291
0
      dpm.get16BitInt();
292
0
      dpm.skipBytes(4); /* TTL */
293
294
0
      if (section == 3 && dnstype == QType::OPT) {
295
0
        uint32_t offset = dpm.getOffset();
296
0
        if (offset >= packet.size()) {
297
0
          return false;
298
0
        }
299
        /* if we survive this call, we can parse it safely */
300
0
        dpm.skipRData();
301
        // NOLINTNEXTLINE(cppcoreguidelines-pro-type-reinterpret-cast)
302
0
        return getEDNSOptions(reinterpret_cast<const char*>(&packet.at(offset)), packet.size() - offset, options) == 0;
303
0
      }
304
0
      dpm.skipRData();
305
0
    }
306
0
  }
307
0
  catch (...) {
308
0
    return false;
309
0
  }
310
311
0
  return true;
312
0
}
313
314
int locateEDNSOptRR(const PacketBuffer& packet, uint16_t* optStart, size_t* optLen, bool* last)
315
0
{
316
0
  if (optStart == nullptr || optLen == nullptr || last == nullptr) {
317
0
    throw std::runtime_error("Invalid values passed to locateEDNSOptRR");
318
0
  }
319
320
0
  const dnsheader_aligned dnsHeader(packet.data());
321
322
0
  if (ntohs(dnsHeader->arcount) == 0) {
323
0
    return ENOENT;
324
0
  }
325
326
  // NOLINTNEXTLINE(cppcoreguidelines-pro-type-reinterpret-cast)
327
0
  PacketReader packetReader(std::string_view(reinterpret_cast<const char*>(packet.data()), packet.size()));
328
329
0
  size_t idx = 0;
330
0
  DNSName rrname;
331
0
  uint16_t qdcount = ntohs(dnsHeader->qdcount);
332
0
  uint16_t ancount = ntohs(dnsHeader->ancount);
333
0
  uint16_t nscount = ntohs(dnsHeader->nscount);
334
0
  uint16_t arcount = ntohs(dnsHeader->arcount);
335
0
  uint16_t rrtype{};
336
0
  uint16_t rrclass{};
337
0
  dnsrecordheader recordHeader{};
338
339
  /* consume qd */
340
0
  for (idx = 0; idx < qdcount; idx++) {
341
0
    rrname = packetReader.getName();
342
0
    rrtype = packetReader.get16BitInt();
343
0
    rrclass = packetReader.get16BitInt();
344
0
    (void)rrtype;
345
0
    (void)rrclass;
346
0
  }
347
348
  /* consume AN and NS */
349
0
  for (idx = 0; idx < ancount + nscount; idx++) {
350
0
    rrname = packetReader.getName();
351
0
    packetReader.getDnsrecordheader(recordHeader);
352
0
    packetReader.skip(recordHeader.d_clen);
353
0
  }
354
355
  /* consume AR, looking for OPT */
356
0
  for (idx = 0; idx < arcount; idx++) {
357
0
    uint16_t start = packetReader.getPosition();
358
0
    rrname = packetReader.getName();
359
0
    packetReader.getDnsrecordheader(recordHeader);
360
361
0
    if (recordHeader.d_type == QType::OPT) {
362
0
      *optStart = start;
363
0
      *optLen = (packetReader.getPosition() - start) + recordHeader.d_clen;
364
365
0
      if (packet.size() < (*optStart + *optLen)) {
366
0
        throw std::range_error("Opt record overflow");
367
0
      }
368
369
0
      if (idx == ((size_t)arcount - 1)) {
370
0
        *last = true;
371
0
      }
372
0
      else {
373
0
        *last = false;
374
0
      }
375
0
      return 0;
376
0
    }
377
0
    packetReader.skip(recordHeader.d_clen);
378
0
  }
379
380
0
  return ENOENT;
381
0
}
382
383
namespace dnsdist
384
{
385
/* extract the start of the OPT RR in a QUERY packet if any */
386
int getEDNSOptionsStart(const PacketBuffer& packet, const size_t qnameWireLength, uint16_t* optRDPosition, size_t* remaining)
387
973
{
388
973
  if (optRDPosition == nullptr || remaining == nullptr) {
389
0
    throw std::runtime_error("Invalid values passed to getEDNSOptionsStart");
390
0
  }
391
392
973
  const dnsheader_aligned dnsHeader(packet.data());
393
394
973
  if (qnameWireLength >= packet.size()) {
395
0
    return ENOENT;
396
0
  }
397
398
973
  if (ntohs(dnsHeader->qdcount) != 1 || ntohs(dnsHeader->ancount) != 0 || ntohs(dnsHeader->arcount) != 1 || ntohs(dnsHeader->nscount) != 0) {
399
165
    return ENOENT;
400
165
  }
401
402
808
  size_t pos = sizeof(dnsheader) + qnameWireLength;
403
808
  pos += DNS_TYPE_SIZE + DNS_CLASS_SIZE;
404
405
808
  if (pos >= packet.size()) {
406
9
    return ENOENT;
407
9
  }
408
409
799
  if ((pos + /* root */ 1 + DNS_TYPE_SIZE + DNS_CLASS_SIZE) >= packet.size()) {
410
21
    return ENOENT;
411
21
  }
412
413
778
  if (packet[pos] != 0) {
414
    /* not the root so not an OPT record */
415
115
    return ENOENT;
416
115
  }
417
663
  pos += 1;
418
419
663
  uint16_t qtype = packet.at(pos) * 256 + packet.at(pos + 1);
420
663
  pos += DNS_TYPE_SIZE;
421
663
  pos += DNS_CLASS_SIZE;
422
423
663
  if (qtype != QType::OPT || (packet.size() - pos) < (DNS_TTL_SIZE + DNS_RDLENGTH_SIZE)) {
424
136
    return ENOENT;
425
136
  }
426
427
527
  pos += DNS_TTL_SIZE;
428
527
  *optRDPosition = pos;
429
527
  *remaining = packet.size() - pos;
430
431
527
  return 0;
432
663
}
433
}
434
435
void generateECSOption(const ComboAddress& source, string& res, uint16_t ECSPrefixLength)
436
0
{
437
0
  Netmask sourceNetmask(source, ECSPrefixLength);
438
0
  EDNSSubnetOpts ecsOpts;
439
0
  ecsOpts.setSource(sourceNetmask);
440
0
  string payload = ecsOpts.makeOptString();
441
0
  generateEDNSOption(EDNSOptionCode::ECS, payload, res);
442
0
}
443
444
bool generateOptRR(const std::string& optRData, PacketBuffer& res, size_t maximumSize, uint16_t udpPayloadSize, uint8_t ednsrcode, bool dnssecOK)
445
0
{
446
0
  const uint8_t name = 0;
447
0
  dnsrecordheader dnsHeader{};
448
0
  EDNS0Record edns0{};
449
0
  edns0.extRCode = ednsrcode;
450
0
  edns0.version = 0;
451
0
  edns0.extFlags = dnssecOK ? htons(EDNS_HEADER_FLAG_DO) : 0;
452
453
0
  if ((maximumSize - res.size()) < (sizeof(name) + sizeof(dnsHeader) + optRData.length())) {
454
0
    return false;
455
0
  }
456
457
0
  dnsHeader.d_type = htons(QType::OPT);
458
0
  dnsHeader.d_class = htons(udpPayloadSize);
459
0
  static_assert(sizeof(EDNS0Record) == sizeof(dnsHeader.d_ttl), "sizeof(EDNS0Record) must match sizeof(dnsrecordheader.d_ttl)");
460
0
  memcpy(&dnsHeader.d_ttl, &edns0, sizeof edns0);
461
0
  dnsHeader.d_clen = htons(static_cast<uint16_t>(optRData.length()));
462
463
0
  res.reserve(res.size() + sizeof(name) + sizeof(dnsHeader) + optRData.length());
464
  // NOLINTNEXTLINE(cppcoreguidelines-pro-type-reinterpret-cast,cppcoreguidelines-pro-bounds-pointer-arithmetic)
465
0
  res.insert(res.end(), reinterpret_cast<const uint8_t*>(&name), reinterpret_cast<const uint8_t*>(&name) + sizeof(name));
466
  // NOLINTNEXTLINE(cppcoreguidelines-pro-type-reinterpret-cast,cppcoreguidelines-pro-bounds-pointer-arithmetic)
467
0
  res.insert(res.end(), reinterpret_cast<const uint8_t*>(&dnsHeader), reinterpret_cast<const uint8_t*>(&dnsHeader) + sizeof(dnsHeader));
468
0
  res.insert(res.end(), optRData.begin(), optRData.end());
469
470
0
  return true;
471
0
}
472
473
static bool replaceEDNSClientSubnetOption(PacketBuffer& packet, size_t maximumSize, size_t const oldEcsOptionStartPosition, size_t const oldEcsOptionSize, size_t const optRDLenPosition, const string& newECSOption)
474
0
{
475
0
  if (oldEcsOptionStartPosition >= packet.size() || optRDLenPosition >= packet.size()) {
476
0
    throw std::runtime_error("Invalid values passed to replaceEDNSClientSubnetOption");
477
0
  }
478
479
0
  if (newECSOption.size() == oldEcsOptionSize) {
480
    /* same size as the existing option */
481
0
    memcpy(&packet.at(oldEcsOptionStartPosition), newECSOption.c_str(), oldEcsOptionSize);
482
0
  }
483
0
  else {
484
    /* different size than the existing option */
485
0
    const unsigned int newPacketLen = packet.size() + (newECSOption.length() - oldEcsOptionSize);
486
0
    const size_t beforeOptionLen = oldEcsOptionStartPosition;
487
0
    const size_t dataBehindSize = packet.size() - beforeOptionLen - oldEcsOptionSize;
488
489
    /* check that it fits in the existing buffer */
490
0
    if (newPacketLen > packet.size()) {
491
0
      if (newPacketLen > maximumSize) {
492
0
        return false;
493
0
      }
494
495
0
      packet.resize(newPacketLen);
496
0
    }
497
498
    /* fix the size of ECS Option RDLen */
499
0
    uint16_t newRDLen = (packet.at(optRDLenPosition) * 256) + packet.at(optRDLenPosition + 1);
500
0
    newRDLen += (newECSOption.size() - oldEcsOptionSize);
501
0
    packet.at(optRDLenPosition) = newRDLen / 256;
502
0
    packet.at(optRDLenPosition + 1) = newRDLen % 256;
503
504
0
    if (dataBehindSize > 0) {
505
0
      memmove(&packet.at(oldEcsOptionStartPosition), &packet.at(oldEcsOptionStartPosition + oldEcsOptionSize), dataBehindSize);
506
0
    }
507
0
    memcpy(&packet.at(oldEcsOptionStartPosition + dataBehindSize), newECSOption.c_str(), newECSOption.size());
508
0
    packet.resize(newPacketLen);
509
0
  }
510
511
0
  return true;
512
0
}
513
514
/* This function looks for an OPT RR, return true if a valid one was found (even if there was no options)
515
   and false otherwise. */
516
bool parseEDNSOptions(const DNSQuestion& dnsQuestion)
517
0
{
518
0
  const auto dnsHeader = dnsQuestion.getHeader();
519
0
  if (dnsQuestion.ednsOptions != nullptr) {
520
0
    return true;
521
0
  }
522
523
  // dnsQuestion.ednsOptions is mutable
524
0
  dnsQuestion.ednsOptions = std::make_unique<EDNSOptionViewMap>();
525
526
0
  if (ntohs(dnsHeader->arcount) == 0) {
527
    /* nothing in additional so no EDNS */
528
0
    return false;
529
0
  }
530
531
0
  if (ntohs(dnsHeader->ancount) != 0 || ntohs(dnsHeader->nscount) != 0 || ntohs(dnsHeader->arcount) > 1) {
532
0
    return slowParseEDNSOptions(dnsQuestion.getData(), *dnsQuestion.ednsOptions);
533
0
  }
534
535
0
  size_t remaining = 0;
536
0
  uint16_t optRDPosition{};
537
0
  int res = dnsdist::getEDNSOptionsStart(dnsQuestion.getData(), dnsQuestion.ids.qname.wirelength(), &optRDPosition, &remaining);
538
539
0
  if (res == 0) {
540
    // NOLINTNEXTLINE(cppcoreguidelines-pro-type-reinterpret-cast)
541
0
    res = getEDNSOptions(reinterpret_cast<const char*>(&dnsQuestion.getData().at(optRDPosition)), remaining, *dnsQuestion.ednsOptions);
542
0
    return (res == 0);
543
0
  }
544
545
0
  return false;
546
0
}
547
548
static bool addECSToExistingOPT(PacketBuffer& packet, size_t maximumSize, const string& newECSOption, size_t optRDLenPosition, bool& ecsAdded)
549
0
{
550
  /* we need to add one EDNS0 ECS option, fixing the size of EDNS0 RDLENGTH */
551
  /* getEDNSOptionsStart has already checked that there is exactly one AR,
552
     no NS and no AN */
553
0
  uint16_t oldRDLen = (packet.at(optRDLenPosition) * 256) + packet.at(optRDLenPosition + 1);
554
0
  if (packet.size() != (optRDLenPosition + sizeof(uint16_t) + oldRDLen)) {
555
    /* we are supposed to be the last record, do we have some trailing data to remove? */
556
    // NOLINTNEXTLINE(cppcoreguidelines-pro-type-reinterpret-cast)
557
0
    uint32_t realPacketLen = getDNSPacketLength(reinterpret_cast<const char*>(packet.data()), packet.size());
558
0
    packet.resize(realPacketLen);
559
0
  }
560
561
0
  if ((maximumSize - packet.size()) < newECSOption.size()) {
562
0
    return false;
563
0
  }
564
565
0
  uint16_t newRDLen = oldRDLen + newECSOption.size();
566
0
  packet.at(optRDLenPosition) = newRDLen / 256;
567
0
  packet.at(optRDLenPosition + 1) = newRDLen % 256;
568
569
0
  packet.insert(packet.end(), newECSOption.begin(), newECSOption.end());
570
0
  ecsAdded = true;
571
572
0
  return true;
573
0
}
574
575
static bool addEDNSWithECS(PacketBuffer& packet, size_t maximumSize, const string& newECSOption, bool& ednsAdded, bool& ecsAdded)
576
0
{
577
0
  if (!generateOptRR(newECSOption, packet, maximumSize, dnsdist::configuration::s_EdnsUDPPayloadSize, 0, false)) {
578
0
    return false;
579
0
  }
580
581
0
  dnsdist::PacketMangling::editDNSHeaderFromPacket(packet, [](dnsheader& header) {
582
0
    uint16_t arcount = ntohs(header.arcount);
583
0
    arcount++;
584
0
    header.arcount = htons(arcount);
585
0
    return true;
586
0
  });
587
0
  ednsAdded = true;
588
0
  ecsAdded = true;
589
590
0
  return true;
591
0
}
592
593
bool handleEDNSClientSubnet(PacketBuffer& packet, const size_t maximumSize, const size_t qnameWireLength, bool& ednsAdded, bool& ecsAdded, bool overrideExisting, const string& newECSOption)
594
0
{
595
0
  if (qnameWireLength > packet.size()) {
596
0
    throw std::runtime_error("Invalid value passed to handleEDNSClientSubnet");
597
0
  }
598
599
0
  const dnsheader_aligned dnsHeader(packet.data());
600
601
0
  if (ntohs(dnsHeader->ancount) != 0 || ntohs(dnsHeader->nscount) != 0 || (ntohs(dnsHeader->arcount) != 0 && ntohs(dnsHeader->arcount) != 1)) {
602
0
    PacketBuffer newContent;
603
0
    newContent.reserve(packet.size());
604
605
0
    if (!slowRewriteEDNSOptionInQueryWithRecords(packet, newContent, ednsAdded, EDNSOptionCode::ECS, ecsAdded, overrideExisting, newECSOption)) {
606
0
      return false;
607
0
    }
608
609
0
    if (newContent.size() > maximumSize) {
610
0
      ednsAdded = false;
611
0
      ecsAdded = false;
612
0
      return false;
613
0
    }
614
615
0
    packet = std::move(newContent);
616
0
    return true;
617
0
  }
618
619
0
  uint16_t optRDPosition = 0;
620
0
  size_t remaining = 0;
621
622
0
  int res = dnsdist::getEDNSOptionsStart(packet, qnameWireLength, &optRDPosition, &remaining);
623
624
0
  if (res != 0) {
625
    /* no EDNS but there might be another record in additional (TSIG?) */
626
    /* Careful, this code assumes that ANCOUNT == 0 && NSCOUNT == 0 */
627
0
    size_t minimumPacketSize = sizeof(dnsheader) + qnameWireLength + sizeof(uint16_t) + sizeof(uint16_t);
628
0
    if (packet.size() > minimumPacketSize) {
629
0
      if (ntohs(dnsHeader->arcount) == 0) {
630
        /* well now.. */
631
0
        packet.resize(minimumPacketSize);
632
0
      }
633
0
      else {
634
        // NOLINTNEXTLINE(cppcoreguidelines-pro-type-reinterpret-cast)
635
0
        uint32_t realPacketLen = getDNSPacketLength(reinterpret_cast<const char*>(packet.data()), packet.size());
636
0
        packet.resize(realPacketLen);
637
0
      }
638
0
    }
639
640
0
    return addEDNSWithECS(packet, maximumSize, newECSOption, ednsAdded, ecsAdded);
641
0
  }
642
643
0
  size_t ecsOptionStartPosition = 0;
644
0
  size_t ecsOptionSize = 0;
645
  // NOLINTNEXTLINE(cppcoreguidelines-pro-type-reinterpret-cast)
646
0
  res = getEDNSOption(reinterpret_cast<const char*>(&packet.at(optRDPosition)), remaining, EDNSOptionCode::ECS, &ecsOptionStartPosition, &ecsOptionSize);
647
648
0
  if (res == 0) {
649
    /* there is already an ECS value */
650
0
    if (!overrideExisting) {
651
0
      return true;
652
0
    }
653
654
0
    return replaceEDNSClientSubnetOption(packet, maximumSize, optRDPosition + ecsOptionStartPosition, ecsOptionSize, optRDPosition, newECSOption);
655
0
  }
656
657
  /* we have an EDNS OPT RR but no existing ECS option */
658
0
  return addECSToExistingOPT(packet, maximumSize, newECSOption, optRDPosition, ecsAdded);
659
0
}
660
661
bool handleEDNSClientSubnet(DNSQuestion& dnsQuestion, bool& ednsAdded, bool& ecsAdded)
662
0
{
663
0
  string newECSOption;
664
0
  generateECSOption(dnsQuestion.ecs ? dnsQuestion.ecs->getNetwork() : dnsQuestion.ids.origRemote, newECSOption, dnsQuestion.ecs ? dnsQuestion.ecs->getBits() : dnsQuestion.ecsPrefixLength);
665
666
0
  return handleEDNSClientSubnet(dnsQuestion.getMutableData(), dnsQuestion.getMaximumSize(), dnsQuestion.ids.qname.wirelength(), ednsAdded, ecsAdded, dnsQuestion.ecsOverride, newECSOption);
667
0
}
668
669
static int removeEDNSOptionFromOptions(unsigned char* optionsStart, const uint16_t optionsLen, const uint16_t optionCodeToRemove, uint16_t* newOptionsLen)
670
0
{
671
0
  const pdns::views::UnsignedCharView view(optionsStart, optionsLen);
672
0
  size_t pos = 0;
673
0
  while ((pos + 4) <= view.size()) {
674
0
    size_t optionBeginPos = pos;
675
0
    const uint16_t optionCode = 0x100 * view.at(pos) + view.at(pos + 1);
676
0
    pos += sizeof(optionCode);
677
0
    const uint16_t optionLen = 0x100 * view.at(pos) + view.at(pos + 1);
678
0
    pos += sizeof(optionLen);
679
0
    if ((pos + optionLen) > view.size()) {
680
0
      return EINVAL;
681
0
    }
682
0
    if (optionCode == optionCodeToRemove) {
683
0
      if (pos + optionLen < view.size()) {
684
        /* move remaining options over the removed one,
685
           if any */
686
        // NOLINTNEXTLINE(cppcoreguidelines-pro-bounds-pointer-arithmetic)
687
0
        memmove(optionsStart + optionBeginPos, optionsStart + pos + optionLen, optionsLen - (pos + optionLen));
688
0
      }
689
0
      *newOptionsLen = optionsLen - (sizeof(optionCode) + sizeof(optionLen) + optionLen);
690
0
      return 0;
691
0
    }
692
0
    pos += optionLen;
693
0
  }
694
0
  return ENOENT;
695
0
}
696
697
int removeEDNSOptionFromOPT(char* optStart, size_t* optLen, const uint16_t optionCodeToRemove)
698
0
{
699
0
  if (*optLen < optRecordMinimumSize) {
700
0
    return EINVAL;
701
0
  }
702
0
  const pdns::views::UnsignedCharView view(optStart, *optLen);
703
  /* skip the root label, qtype, qclass and TTL */
704
0
  size_t position = 9;
705
0
  uint16_t rdLen = (0x100 * view.at(position) + view.at(position + 1));
706
0
  position += sizeof(rdLen);
707
0
  if (position + rdLen != view.size()) {
708
0
    return EINVAL;
709
0
  }
710
0
  uint16_t newRdLen = 0;
711
  // NOLINTNEXTLINE(cppcoreguidelines-pro-type-reinterpret-cast,cppcoreguidelines-pro-bounds-pointer-arithmetic)
712
0
  int res = removeEDNSOptionFromOptions(reinterpret_cast<unsigned char*>(optStart + position), rdLen, optionCodeToRemove, &newRdLen);
713
0
  if (res != 0) {
714
0
    return res;
715
0
  }
716
0
  *optLen -= (rdLen - newRdLen);
717
  // NOLINTNEXTLINE(cppcoreguidelines-pro-type-reinterpret-cast,cppcoreguidelines-pro-bounds-pointer-arithmetic)
718
0
  auto* rdLenPtr = reinterpret_cast<unsigned char*>(optStart + 9);
719
  // NOLINTNEXTLINE(cppcoreguidelines-pro-bounds-pointer-arithmetic)
720
0
  rdLenPtr[0] = newRdLen / 0x100;
721
  // NOLINTNEXTLINE(cppcoreguidelines-pro-bounds-pointer-arithmetic)
722
0
  rdLenPtr[1] = newRdLen % 0x100;
723
0
  return 0;
724
0
}
725
726
bool isEDNSOptionInOpt(const PacketBuffer& packet, const size_t optStart, const size_t optLen, const uint16_t optionCodeToFind, size_t* optContentStart, uint16_t* optContentLen)
727
0
{
728
0
  if (optLen < optRecordMinimumSize) {
729
0
    return false;
730
0
  }
731
0
  size_t position = optStart + 9;
732
0
  uint16_t rdLen = (0x100 * static_cast<unsigned char>(packet.at(position)) + static_cast<unsigned char>(packet.at(position + 1)));
733
0
  position += sizeof(rdLen);
734
0
  if (rdLen > (optLen - optRecordMinimumSize)) {
735
0
    return false;
736
0
  }
737
738
0
  size_t rdEnd = position + rdLen;
739
0
  while ((position + 4) <= rdEnd) {
740
0
    const uint16_t optionCode = 0x100 * static_cast<unsigned char>(packet.at(position)) + static_cast<unsigned char>(packet.at(position + 1));
741
0
    position += sizeof(optionCode);
742
0
    const uint16_t optionLen = 0x100 * static_cast<unsigned char>(packet.at(position)) + static_cast<unsigned char>(packet.at(position + 1));
743
0
    position += sizeof(optionLen);
744
745
0
    if ((position + optionLen) > rdEnd) {
746
0
      return false;
747
0
    }
748
749
0
    if (optionCode == optionCodeToFind) {
750
0
      if (optContentStart != nullptr) {
751
0
        *optContentStart = position;
752
0
      }
753
754
0
      if (optContentLen != nullptr) {
755
0
        *optContentLen = optionLen;
756
0
      }
757
758
0
      return true;
759
0
    }
760
0
    position += optionLen;
761
0
  }
762
0
  return false;
763
0
}
764
765
int rewriteResponseWithoutEDNSOption(const PacketBuffer& initialPacket, const uint16_t optionCodeToSkip, PacketBuffer& newContent)
766
0
{
767
0
  if (initialPacket.size() < sizeof(dnsheader)) {
768
0
    return ENOENT;
769
0
  }
770
771
0
  const dnsheader_aligned dnsHeader(initialPacket.data());
772
773
0
  if (ntohs(dnsHeader->arcount) == 0) {
774
0
    return ENOENT;
775
0
  }
776
777
0
  if (ntohs(dnsHeader->qdcount) == 0) {
778
0
    return ENOENT;
779
0
  }
780
781
  // NOLINTNEXTLINE(cppcoreguidelines-pro-type-reinterpret-cast)
782
0
  PacketReader packetReader(std::string_view(reinterpret_cast<const char*>(initialPacket.data()), initialPacket.size()));
783
784
0
  size_t idx = 0;
785
0
  DNSName rrname;
786
0
  uint16_t qdcount = ntohs(dnsHeader->qdcount);
787
0
  uint16_t ancount = ntohs(dnsHeader->ancount);
788
0
  uint16_t nscount = ntohs(dnsHeader->nscount);
789
0
  uint16_t arcount = ntohs(dnsHeader->arcount);
790
0
  uint16_t rrtype = 0;
791
0
  uint16_t rrclass = 0;
792
0
  string blob;
793
0
  dnsrecordheader recordHeader{};
794
795
0
  rrname = packetReader.getName();
796
0
  rrtype = packetReader.get16BitInt();
797
0
  rrclass = packetReader.get16BitInt();
798
799
0
  GenericDNSPacketWriter<PacketBuffer> packetWriter(newContent, rrname, rrtype, rrclass, dnsHeader->opcode);
800
0
  packetWriter.getHeader()->id = dnsHeader->id;
801
0
  packetWriter.getHeader()->qr = dnsHeader->qr;
802
0
  packetWriter.getHeader()->aa = dnsHeader->aa;
803
0
  packetWriter.getHeader()->tc = dnsHeader->tc;
804
0
  packetWriter.getHeader()->rd = dnsHeader->rd;
805
0
  packetWriter.getHeader()->ra = dnsHeader->ra;
806
0
  packetWriter.getHeader()->ad = dnsHeader->ad;
807
0
  packetWriter.getHeader()->cd = dnsHeader->cd;
808
0
  packetWriter.getHeader()->rcode = dnsHeader->rcode;
809
810
  /* consume remaining qd if any */
811
0
  if (qdcount > 1) {
812
0
    for (idx = 1; idx < qdcount; idx++) {
813
0
      rrname = packetReader.getName();
814
0
      rrtype = packetReader.get16BitInt();
815
0
      rrclass = packetReader.get16BitInt();
816
0
      (void)rrtype;
817
0
      (void)rrclass;
818
0
    }
819
0
  }
820
821
  /* copy AN and NS */
822
0
  for (idx = 0; idx < ancount; idx++) {
823
0
    rrname = packetReader.getName();
824
0
    packetReader.getDnsrecordheader(recordHeader);
825
826
0
    packetWriter.startRecord(rrname, recordHeader.d_type, recordHeader.d_ttl, recordHeader.d_class, DNSResourceRecord::ANSWER, true);
827
0
    packetReader.xfrBlob(blob);
828
0
    packetWriter.xfrBlob(blob);
829
0
  }
830
831
0
  for (idx = 0; idx < nscount; idx++) {
832
0
    rrname = packetReader.getName();
833
0
    packetReader.getDnsrecordheader(recordHeader);
834
835
0
    packetWriter.startRecord(rrname, recordHeader.d_type, recordHeader.d_ttl, recordHeader.d_class, DNSResourceRecord::AUTHORITY, true);
836
0
    packetReader.xfrBlob(blob);
837
0
    packetWriter.xfrBlob(blob);
838
0
  }
839
840
  /* consume AR, looking for OPT */
841
0
  for (idx = 0; idx < arcount; idx++) {
842
0
    rrname = packetReader.getName();
843
0
    packetReader.getDnsrecordheader(recordHeader);
844
845
0
    if (recordHeader.d_type != QType::OPT) {
846
0
      packetWriter.startRecord(rrname, recordHeader.d_type, recordHeader.d_ttl, recordHeader.d_class, DNSResourceRecord::ADDITIONAL, true);
847
0
      packetReader.xfrBlob(blob);
848
0
      packetWriter.xfrBlob(blob);
849
0
    }
850
0
    else {
851
0
      packetWriter.startRecord(rrname, recordHeader.d_type, recordHeader.d_ttl, recordHeader.d_class, DNSResourceRecord::ADDITIONAL, false);
852
0
      packetReader.xfrBlob(blob);
853
0
      uint16_t rdLen = blob.length();
854
      // NOLINTNEXTLINE(cppcoreguidelines-pro-type-reinterpret-cast)
855
0
      removeEDNSOptionFromOptions(reinterpret_cast<unsigned char*>(blob.data()), rdLen, optionCodeToSkip, &rdLen);
856
      /* xfrBlob(string, size) completely ignores size.. */
857
0
      if (rdLen > 0) {
858
0
        blob.resize((size_t)rdLen);
859
0
        packetWriter.xfrBlob(blob);
860
0
      }
861
0
      else {
862
0
        packetWriter.commit();
863
0
      }
864
0
    }
865
0
  }
866
0
  packetWriter.commit();
867
868
0
  return 0;
869
0
}
870
871
bool addEDNS(PacketBuffer& packet, size_t maximumSize, bool dnssecOK, uint16_t payloadSize, uint8_t ednsrcode)
872
0
{
873
0
  if (!generateOptRR(std::string(), packet, maximumSize, payloadSize, ednsrcode, dnssecOK)) {
874
0
    return false;
875
0
  }
876
877
0
  dnsdist::PacketMangling::editDNSHeaderFromPacket(packet, [](dnsheader& header) {
878
0
    header.arcount = htons(ntohs(header.arcount) + 1);
879
0
    return true;
880
0
  });
881
882
0
  return true;
883
0
}
884
885
/*
886
  This function keeps the existing header and DNSSECOK bit (if any) but wipes anything else,
887
  generating a NXD or NODATA answer with a SOA record in the additional section (or optionally the authority section for a full cacheable NXDOMAIN/NODATA).
888
*/
889
bool setNegativeAndAdditionalSOA(DNSQuestion& dnsQuestion, bool nxd, const DNSName& zone, uint32_t ttl, const DNSName& mname, const DNSName& rname, uint32_t serial, uint32_t refresh, uint32_t retry, uint32_t expire, uint32_t minimum, bool soaInAuthoritySection)
890
0
{
891
0
  auto& packet = dnsQuestion.getMutableData();
892
0
  auto dnsHeader = dnsQuestion.getHeader();
893
0
  if (ntohs(dnsHeader->qdcount) != 1) {
894
0
    return false;
895
0
  }
896
897
0
  size_t queryPartSize = sizeof(dnsheader) + dnsQuestion.ids.qname.wirelength() + DNS_TYPE_SIZE + DNS_CLASS_SIZE;
898
0
  if (packet.size() < queryPartSize) {
899
    /* something is already wrong, don't build on flawed foundations */
900
0
    return false;
901
0
  }
902
903
0
  uint16_t qtype = htons(QType::SOA);
904
0
  uint16_t qclass = htons(QClass::IN);
905
0
  uint16_t rdLength = mname.wirelength() + rname.wirelength() + sizeof(serial) + sizeof(refresh) + sizeof(retry) + sizeof(expire) + sizeof(minimum);
906
0
  size_t soaSize = zone.wirelength() + sizeof(qtype) + sizeof(qclass) + sizeof(ttl) + sizeof(rdLength) + rdLength;
907
0
  bool hadEDNS = false;
908
0
  bool dnssecOK = false;
909
910
0
  if (dnsdist::configuration::getCurrentRuntimeConfiguration().d_addEDNSToSelfGeneratedResponses) {
911
0
    uint16_t payloadSize = 0;
912
0
    uint16_t zValue = 0;
913
    // NOLINTNEXTLINE(cppcoreguidelines-pro-type-reinterpret-cast)
914
0
    hadEDNS = getEDNSUDPPayloadSizeAndZ(reinterpret_cast<const char*>(packet.data()), packet.size(), &payloadSize, &zValue);
915
0
    if (hadEDNS) {
916
0
      dnssecOK = (zValue & EDNS_HEADER_FLAG_DO) != 0;
917
0
    }
918
0
  }
919
920
  /* chop off everything after the question */
921
0
  packet.resize(queryPartSize);
922
0
  dnsdist::PacketMangling::editDNSHeaderFromPacket(packet, [nxd](dnsheader& header) {
923
0
    if (nxd) {
924
0
      header.rcode = RCode::NXDomain;
925
0
    }
926
0
    else {
927
0
      header.rcode = RCode::NoError;
928
0
    }
929
0
    header.qr = true;
930
0
    header.ancount = 0;
931
0
    header.nscount = 0;
932
0
    header.arcount = 0;
933
0
    return true;
934
0
  });
935
936
0
  rdLength = htons(rdLength);
937
0
  ttl = htonl(ttl);
938
0
  serial = htonl(serial);
939
0
  refresh = htonl(refresh);
940
0
  retry = htonl(retry);
941
0
  expire = htonl(expire);
942
0
  minimum = htonl(minimum);
943
944
0
  std::string soa;
945
0
  soa.reserve(soaSize);
946
0
  soa.append(zone.toDNSString());
947
  // NOLINTNEXTLINE(cppcoreguidelines-pro-type-reinterpret-cast)
948
0
  soa.append(reinterpret_cast<const char*>(&qtype), sizeof(qtype));
949
  // NOLINTNEXTLINE(cppcoreguidelines-pro-type-reinterpret-cast)
950
0
  soa.append(reinterpret_cast<const char*>(&qclass), sizeof(qclass));
951
  // NOLINTNEXTLINE(cppcoreguidelines-pro-type-reinterpret-cast)
952
0
  soa.append(reinterpret_cast<const char*>(&ttl), sizeof(ttl));
953
  // NOLINTNEXTLINE(cppcoreguidelines-pro-type-reinterpret-cast)
954
0
  soa.append(reinterpret_cast<const char*>(&rdLength), sizeof(rdLength));
955
0
  soa.append(mname.toDNSString());
956
0
  soa.append(rname.toDNSString());
957
  // NOLINTNEXTLINE(cppcoreguidelines-pro-type-reinterpret-cast)
958
0
  soa.append(reinterpret_cast<const char*>(&serial), sizeof(serial));
959
  // NOLINTNEXTLINE(cppcoreguidelines-pro-type-reinterpret-cast)
960
0
  soa.append(reinterpret_cast<const char*>(&refresh), sizeof(refresh));
961
  // NOLINTNEXTLINE(cppcoreguidelines-pro-type-reinterpret-cast)
962
0
  soa.append(reinterpret_cast<const char*>(&retry), sizeof(retry));
963
  // NOLINTNEXTLINE(cppcoreguidelines-pro-type-reinterpret-cast)
964
0
  soa.append(reinterpret_cast<const char*>(&expire), sizeof(expire));
965
  // NOLINTNEXTLINE(cppcoreguidelines-pro-type-reinterpret-cast)
966
0
  soa.append(reinterpret_cast<const char*>(&minimum), sizeof(minimum));
967
968
0
  if (soa.size() != soaSize) {
969
0
    throw std::runtime_error("Unexpected SOA response size: " + std::to_string(soa.size()) + " vs " + std::to_string(soaSize));
970
0
  }
971
972
0
  packet.insert(packet.end(), soa.begin(), soa.end());
973
974
  /* We are populating a response with only the query in place, order of sections is QD,AN,NS,AR
975
     NS (authority) is before AR (additional) so we can just decide which section the SOA record is in here
976
     and have EDNS added to AR afterwards */
977
0
  dnsdist::PacketMangling::editDNSHeaderFromPacket(packet, [soaInAuthoritySection](dnsheader& header) {
978
0
    if (soaInAuthoritySection) {
979
0
      header.nscount = htons(1);
980
0
    }
981
0
    else {
982
0
      header.arcount = htons(1);
983
0
    }
984
0
    return true;
985
0
  });
986
987
0
  if (hadEDNS) {
988
    /* now we need to add a new OPT record */
989
0
    return addEDNS(packet, dnsQuestion.getMaximumSize(), dnssecOK, dnsdist::configuration::getCurrentRuntimeConfiguration().d_payloadSizeSelfGenAnswers, dnsQuestion.ednsRCode);
990
0
  }
991
992
0
  return true;
993
0
}
994
995
bool addEDNSToQueryTurnedResponse(DNSQuestion& dnsQuestion)
996
0
{
997
0
  uint16_t optRDPosition{};
998
  /* remaining is at least the size of the rdlen + the options if any + the following records if any */
999
0
  size_t remaining = 0;
1000
1001
0
  auto& packet = dnsQuestion.getMutableData();
1002
0
  int res = dnsdist::getEDNSOptionsStart(packet, dnsQuestion.ids.qname.wirelength(), &optRDPosition, &remaining);
1003
1004
0
  if (res != 0) {
1005
    /* if the initial query did not have EDNS0, we are done */
1006
0
    return true;
1007
0
  }
1008
1009
0
  const size_t existingOptLen = /* root */ 1 + DNS_TYPE_SIZE + DNS_CLASS_SIZE + EDNS_EXTENDED_RCODE_SIZE + EDNS_VERSION_SIZE + /* Z */ 2 + remaining;
1010
0
  if (existingOptLen >= packet.size()) {
1011
    /* something is wrong, bail out */
1012
0
    return false;
1013
0
  }
1014
1015
0
  const size_t optPosition = (optRDPosition - (/* root */ 1 + DNS_TYPE_SIZE + DNS_CLASS_SIZE + EDNS_EXTENDED_RCODE_SIZE + EDNS_VERSION_SIZE + /* Z */ 2));
1016
1017
0
  size_t zPosition = optPosition + /* root */ 1 + DNS_TYPE_SIZE + DNS_CLASS_SIZE + EDNS_EXTENDED_RCODE_SIZE + EDNS_VERSION_SIZE;
1018
0
  uint16_t zValue = 0x100 * packet.at(zPosition) + packet.at(zPosition + 1);
1019
0
  bool dnssecOK = (zValue & EDNS_HEADER_FLAG_DO) != 0;
1020
1021
  /* remove the existing OPT record, and everything else that follows (any SIG or TSIG would be useless anyway) */
1022
0
  packet.resize(packet.size() - existingOptLen);
1023
0
  dnsdist::PacketMangling::editDNSHeaderFromPacket(packet, [](dnsheader& header) {
1024
0
    header.arcount = 0;
1025
0
    return true;
1026
0
  });
1027
1028
0
  if (dnsdist::configuration::getCurrentRuntimeConfiguration().d_addEDNSToSelfGeneratedResponses) {
1029
    /* now we need to add a new OPT record */
1030
0
    return addEDNS(packet, dnsQuestion.getMaximumSize(), dnssecOK, dnsdist::configuration::getCurrentRuntimeConfiguration().d_payloadSizeSelfGenAnswers, dnsQuestion.ednsRCode);
1031
0
  }
1032
1033
  /* otherwise we are just fine */
1034
0
  return true;
1035
0
}
1036
1037
namespace dnsdist
1038
{
1039
static std::optional<size_t> getEDNSRecordPosition(const DNSQuestion& dnsQuestion)
1040
0
{
1041
0
  try {
1042
0
    const auto& packet = dnsQuestion.getData();
1043
0
    if (packet.size() <= sizeof(dnsheader)) {
1044
0
      return std::nullopt;
1045
0
    }
1046
1047
0
    uint16_t optRDPosition = 0;
1048
0
    size_t remaining = 0;
1049
0
    auto res = getEDNSOptionsStart(packet, dnsQuestion.ids.qname.wirelength(), &optRDPosition, &remaining);
1050
0
    if (res != 0) {
1051
0
      return std::nullopt;
1052
0
    }
1053
1054
0
    if (optRDPosition < DNS_TTL_SIZE) {
1055
0
      return std::nullopt;
1056
0
    }
1057
1058
0
    return optRDPosition - DNS_TTL_SIZE;
1059
0
  }
1060
0
  catch (...) {
1061
0
    return std::nullopt;
1062
0
  }
1063
0
}
1064
1065
// goal in life - if you send us a reasonably normal packet, we'll get Z for you, otherwise 0
1066
int getEDNSZ(const DNSQuestion& dnsQuestion)
1067
0
{
1068
0
  try {
1069
0
    auto position = getEDNSRecordPosition(dnsQuestion);
1070
1071
0
    if (!position) {
1072
0
      return 0;
1073
0
    }
1074
1075
0
    const auto& packet = dnsQuestion.getData();
1076
0
    if ((*position + EDNS_EXTENDED_RCODE_SIZE + EDNS_VERSION_SIZE + 1) >= packet.size()) {
1077
0
      return 0;
1078
0
    }
1079
1080
0
    return 0x100 * packet.at(*position + EDNS_EXTENDED_RCODE_SIZE + EDNS_VERSION_SIZE) + packet.at(*position + EDNS_EXTENDED_RCODE_SIZE + EDNS_VERSION_SIZE + 1);
1081
0
  }
1082
0
  catch (...) {
1083
0
    return 0;
1084
0
  }
1085
0
}
1086
1087
std::optional<uint8_t> getEDNSVersion(const DNSQuestion& dnsQuestion)
1088
0
{
1089
0
  try {
1090
0
    auto position = getEDNSRecordPosition(dnsQuestion);
1091
1092
0
    if (!position) {
1093
0
      return std::nullopt;
1094
0
    }
1095
1096
0
    const auto& packet = dnsQuestion.getData();
1097
0
    if ((*position + EDNS_EXTENDED_RCODE_SIZE + EDNS_VERSION_SIZE) >= packet.size()) {
1098
0
      return std::nullopt;
1099
0
    }
1100
1101
0
    return packet.at(*position + EDNS_EXTENDED_RCODE_SIZE);
1102
0
  }
1103
0
  catch (...) {
1104
0
    return std::nullopt;
1105
0
  }
1106
0
}
1107
1108
std::optional<uint8_t> getEDNSExtendedRCode(const DNSQuestion& dnsQuestion)
1109
0
{
1110
0
  try {
1111
0
    auto position = getEDNSRecordPosition(dnsQuestion);
1112
1113
0
    if (!position) {
1114
0
      return std::nullopt;
1115
0
    }
1116
1117
0
    const auto& packet = dnsQuestion.getData();
1118
0
    if ((*position + EDNS_EXTENDED_RCODE_SIZE) >= packet.size()) {
1119
0
      return std::nullopt;
1120
0
    }
1121
1122
0
    return packet.at(*position);
1123
0
  }
1124
0
  catch (...) {
1125
0
    return std::nullopt;
1126
0
  }
1127
0
}
1128
1129
}
1130
1131
bool queryHasEDNS(const DNSQuestion& dnsQuestion)
1132
0
{
1133
0
  uint16_t optRDPosition = 0;
1134
0
  size_t ecsRemaining = 0;
1135
1136
0
  int res = dnsdist::getEDNSOptionsStart(dnsQuestion.getData(), dnsQuestion.ids.qname.wirelength(), &optRDPosition, &ecsRemaining);
1137
0
  return res == 0;
1138
0
}
1139
1140
bool getEDNS0Record(const PacketBuffer& packet, EDNS0Record& edns0)
1141
0
{
1142
0
  uint16_t optStart = 0;
1143
0
  size_t optLen = 0;
1144
0
  bool last = false;
1145
0
  int res = locateEDNSOptRR(packet, &optStart, &optLen, &last);
1146
0
  if (res != 0) {
1147
    // no EDNS OPT RR
1148
0
    return false;
1149
0
  }
1150
1151
0
  if (optLen < optRecordMinimumSize) {
1152
0
    return false;
1153
0
  }
1154
1155
0
  if (optStart < packet.size() && packet.at(optStart) != 0) {
1156
    // OPT RR Name != '.'
1157
0
    return false;
1158
0
  }
1159
1160
0
  static_assert(sizeof(EDNS0Record) == sizeof(uint32_t), "sizeof(EDNS0Record) must match sizeof(uint32_t) AKA RR TTL size");
1161
  // copy out 4-byte "ttl" (really the EDNS0 record), after root label (1) + type (2) + class (2).
1162
0
  memcpy(&edns0, &packet.at(optStart + 5), sizeof edns0);
1163
0
  return true;
1164
0
}
1165
1166
bool setEDNSOption(DNSQuestion& dnsQuestion, uint16_t ednsCode, const std::string& ednsData, bool isQuery)
1167
0
{
1168
0
  std::string optRData;
1169
0
  generateEDNSOption(ednsCode, ednsData, optRData);
1170
1171
0
  if (dnsQuestion.getHeader()->arcount != 0) {
1172
0
    bool ednsAdded = false;
1173
0
    bool optionAdded = false;
1174
0
    PacketBuffer newContent;
1175
0
    newContent.reserve(dnsQuestion.getData().size());
1176
1177
0
    if (!slowRewriteEDNSOptionInQueryWithRecords(dnsQuestion.getData(), newContent, ednsAdded, ednsCode, optionAdded, true, optRData)) {
1178
0
      return false;
1179
0
    }
1180
1181
0
    if (newContent.size() > dnsQuestion.getMaximumSize()) {
1182
0
      return false;
1183
0
    }
1184
1185
0
    dnsQuestion.getMutableData() = std::move(newContent);
1186
0
    if (isQuery && !dnsQuestion.ids.ednsAdded && ednsAdded) {
1187
0
      dnsQuestion.ids.ednsAdded = true;
1188
0
    }
1189
1190
0
    return true;
1191
0
  }
1192
1193
0
  auto& data = dnsQuestion.getMutableData();
1194
0
  if (generateOptRR(optRData, data, dnsQuestion.getMaximumSize(), dnsdist::configuration::s_EdnsUDPPayloadSize, 0, false)) {
1195
0
    dnsdist::PacketMangling::editDNSHeaderFromPacket(dnsQuestion.getMutableData(), [](dnsheader& header) {
1196
0
      header.arcount = htons(1);
1197
0
      return true;
1198
0
    });
1199
1200
0
    if (isQuery) {
1201
      // make sure that any EDNS sent by the backend is removed before forwarding the response to the client
1202
0
      dnsQuestion.ids.ednsAdded = true;
1203
0
    }
1204
0
  }
1205
1206
0
  return true;
1207
0
}
1208
1209
namespace dnsdist
1210
{
1211
bool setInternalQueryRCode(InternalQueryState& state, PacketBuffer& buffer, uint8_t rcode, bool clearAnswers)
1212
0
{
1213
0
  const auto qnameLength = state.qname.wirelength();
1214
0
  if (buffer.size() < sizeof(dnsheader) + qnameLength + sizeof(uint16_t) + sizeof(uint16_t)) {
1215
0
    return false;
1216
0
  }
1217
1218
0
  EDNS0Record edns0{};
1219
0
  bool hadEDNS = false;
1220
0
  if (clearAnswers) {
1221
0
    hadEDNS = getEDNS0Record(buffer, edns0);
1222
0
  }
1223
1224
0
  dnsdist::PacketMangling::editDNSHeaderFromPacket(buffer, [rcode, clearAnswers](dnsheader& header) {
1225
0
    header.rcode = rcode;
1226
0
    header.ad = false;
1227
0
    header.aa = false;
1228
0
    header.ra = header.rd;
1229
0
    header.qr = true;
1230
1231
0
    if (clearAnswers) {
1232
0
      header.ancount = 0;
1233
0
      header.nscount = 0;
1234
0
      header.arcount = 0;
1235
0
    }
1236
0
    return true;
1237
0
  });
1238
1239
0
  if (clearAnswers) {
1240
0
    buffer.resize(sizeof(dnsheader) + qnameLength + sizeof(uint16_t) + sizeof(uint16_t));
1241
0
    if (hadEDNS) {
1242
0
      DNSQuestion dnsQuestion(state, buffer);
1243
0
      if (!addEDNS(buffer, dnsQuestion.getMaximumSize(), (edns0.extFlags & htons(EDNS_HEADER_FLAG_DO)) != 0, dnsdist::configuration::getCurrentRuntimeConfiguration().d_payloadSizeSelfGenAnswers, 0)) {
1244
0
        return false;
1245
0
      }
1246
0
    }
1247
0
  }
1248
1249
0
  return true;
1250
0
}
1251
}