Coverage Report

Created: 2026-03-08 06:22

next uncovered line (L), next uncovered region (R), next uncovered branch (B)
/src/pdns/pdns/dnsdistdist/dnsdist-ecs.cc
Line
Count
Source
1
/*
2
 * This file is part of PowerDNS or dnsdist.
3
 * Copyright -- PowerDNS.COM B.V. and its contributors
4
 *
5
 * This program is free software; you can redistribute it and/or modify
6
 * it under the terms of version 2 of the GNU General Public License as
7
 * published by the Free Software Foundation.
8
 *
9
 * In addition, for the avoidance of any doubt, permission is granted to
10
 * link this program with OpenSSL and to (re)distribute the binaries
11
 * produced as the result of such linking.
12
 *
13
 * This program is distributed in the hope that it will be useful,
14
 * but WITHOUT ANY WARRANTY; without even the implied warranty of
15
 * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the
16
 * GNU General Public License for more details.
17
 *
18
 * You should have received a copy of the GNU General Public License
19
 * along with this program; if not, write to the Free Software
20
 * Foundation, Inc., 51 Franklin Street, Fifth Floor, Boston, MA 02110-1301 USA.
21
 */
22
#include "dns.hh"
23
#include "dolog.hh"
24
#include "dnsdist.hh"
25
#include "dnsdist-dnsparser.hh"
26
#include "dnsdist-ecs.hh"
27
#include "dnsparser.hh"
28
#include "dnswriter.hh"
29
#include "ednsoptions.hh"
30
#include "ednssubnet.hh"
31
32
int rewriteResponseWithoutEDNS(const PacketBuffer& initialPacket, PacketBuffer& newContent)
33
0
{
34
0
  if (initialPacket.size() < sizeof(dnsheader)) {
35
0
    return ENOENT;
36
0
  }
37
38
0
  const dnsheader_aligned dnsHeader(initialPacket.data());
39
40
0
  if (ntohs(dnsHeader->arcount) == 0) {
41
0
    return ENOENT;
42
0
  }
43
44
0
  if (ntohs(dnsHeader->qdcount) == 0) {
45
0
    return ENOENT;
46
0
  }
47
48
  // NOLINTNEXTLINE(cppcoreguidelines-pro-type-reinterpret-cast)
49
0
  PacketReader packetReader(std::string_view(reinterpret_cast<const char*>(initialPacket.data()), initialPacket.size()));
50
51
0
  size_t idx = 0;
52
0
  uint16_t qdcount = ntohs(dnsHeader->qdcount);
53
0
  uint16_t ancount = ntohs(dnsHeader->ancount);
54
0
  uint16_t nscount = ntohs(dnsHeader->nscount);
55
0
  uint16_t arcount = ntohs(dnsHeader->arcount);
56
0
  string blob;
57
0
  dnsrecordheader recordHeader{};
58
59
0
  auto rrname = packetReader.getName();
60
0
  auto rrtype = packetReader.get16BitInt();
61
0
  auto rrclass = packetReader.get16BitInt();
62
63
0
  GenericDNSPacketWriter<PacketBuffer> packetWriter(newContent, rrname, rrtype, rrclass, dnsHeader->opcode);
64
0
  packetWriter.getHeader()->id = dnsHeader->id;
65
0
  packetWriter.getHeader()->qr = dnsHeader->qr;
66
0
  packetWriter.getHeader()->aa = dnsHeader->aa;
67
0
  packetWriter.getHeader()->tc = dnsHeader->tc;
68
0
  packetWriter.getHeader()->rd = dnsHeader->rd;
69
0
  packetWriter.getHeader()->ra = dnsHeader->ra;
70
0
  packetWriter.getHeader()->ad = dnsHeader->ad;
71
0
  packetWriter.getHeader()->cd = dnsHeader->cd;
72
0
  packetWriter.getHeader()->rcode = dnsHeader->rcode;
73
74
  /* consume remaining qd if any */
75
0
  if (qdcount > 1) {
76
0
    for (idx = 1; idx < qdcount; idx++) {
77
0
      rrname = packetReader.getName();
78
0
      rrtype = packetReader.get16BitInt();
79
0
      rrclass = packetReader.get16BitInt();
80
0
      (void)rrtype;
81
0
      (void)rrclass;
82
0
    }
83
0
  }
84
85
  /* copy AN and NS */
86
0
  for (idx = 0; idx < ancount; idx++) {
87
0
    rrname = packetReader.getName();
88
0
    packetReader.getDnsrecordheader(recordHeader);
89
90
0
    packetWriter.startRecord(rrname, recordHeader.d_type, recordHeader.d_ttl, recordHeader.d_class, DNSResourceRecord::ANSWER, true);
91
0
    packetReader.xfrBlob(blob);
92
0
    packetWriter.xfrBlob(blob);
93
0
  }
94
95
0
  for (idx = 0; idx < nscount; idx++) {
96
0
    rrname = packetReader.getName();
97
0
    packetReader.getDnsrecordheader(recordHeader);
98
99
0
    packetWriter.startRecord(rrname, recordHeader.d_type, recordHeader.d_ttl, recordHeader.d_class, DNSResourceRecord::AUTHORITY, true);
100
0
    packetReader.xfrBlob(blob);
101
0
    packetWriter.xfrBlob(blob);
102
0
  }
103
  /* consume AR, looking for OPT */
104
0
  for (idx = 0; idx < arcount; idx++) {
105
0
    rrname = packetReader.getName();
106
0
    packetReader.getDnsrecordheader(recordHeader);
107
108
0
    if (recordHeader.d_type != QType::OPT) {
109
0
      packetWriter.startRecord(rrname, recordHeader.d_type, recordHeader.d_ttl, recordHeader.d_class, DNSResourceRecord::ADDITIONAL, true);
110
0
      packetReader.xfrBlob(blob);
111
0
      packetWriter.xfrBlob(blob);
112
0
    }
113
0
    else {
114
115
0
      packetReader.skip(recordHeader.d_clen);
116
0
    }
117
0
  }
118
0
  packetWriter.commit();
119
120
0
  return 0;
121
0
}
122
123
static bool addOrReplaceEDNSOption(std::vector<std::pair<uint16_t, std::string>>& options, uint16_t optionCode, bool& optionAdded, bool overrideExisting, bool allowMultiple, const string& newOptionContent)
124
0
{
125
0
  if (!allowMultiple) {
126
0
    for (auto it = options.begin(); it != options.end();) {
127
0
      if (it->first == optionCode) {
128
0
        optionAdded = false;
129
130
0
        if (!overrideExisting) {
131
0
          return false;
132
0
        }
133
134
0
        it = options.erase(it);
135
0
      }
136
0
      else {
137
0
        ++it;
138
0
      }
139
0
    }
140
0
  }
141
142
0
  options.emplace_back(optionCode, std::string(&newOptionContent.at(EDNS_OPTION_CODE_SIZE + EDNS_OPTION_LENGTH_SIZE), newOptionContent.size() - (EDNS_OPTION_CODE_SIZE + EDNS_OPTION_LENGTH_SIZE)));
143
0
  return true;
144
0
}
145
146
bool slowRewriteEDNSOptionInQueryWithRecords(const PacketBuffer& initialPacket, PacketBuffer& newContent, bool& ednsAdded, uint16_t optionToReplace, bool& optionAdded, bool overrideExisting, bool allowMultiple, const string& newOptionContent)
147
0
{
148
0
  if (initialPacket.size() < sizeof(dnsheader)) {
149
0
    return false;
150
0
  }
151
152
0
  const dnsheader_aligned dnsHeader(initialPacket.data());
153
154
0
  if (ntohs(dnsHeader->qdcount) == 0) {
155
0
    return false;
156
0
  }
157
158
0
  if (ntohs(dnsHeader->ancount) == 0 && ntohs(dnsHeader->nscount) == 0 && ntohs(dnsHeader->arcount) == 0) {
159
0
    throw std::runtime_error("slowRewriteEDNSOptionInQueryWithRecords should not be called for queries that have no records");
160
0
  }
161
162
0
  optionAdded = false;
163
0
  ednsAdded = true;
164
165
  // NOLINTNEXTLINE(cppcoreguidelines-pro-type-reinterpret-cast)
166
0
  PacketReader packetReader(std::string_view(reinterpret_cast<const char*>(initialPacket.data()), initialPacket.size()));
167
168
0
  size_t idx = 0;
169
0
  uint16_t qdcount = ntohs(dnsHeader->qdcount);
170
0
  uint16_t ancount = ntohs(dnsHeader->ancount);
171
0
  uint16_t nscount = ntohs(dnsHeader->nscount);
172
0
  uint16_t arcount = ntohs(dnsHeader->arcount);
173
0
  string blob;
174
0
  dnsrecordheader recordHeader{};
175
176
0
  auto rrname = packetReader.getName();
177
0
  auto rrtype = packetReader.get16BitInt();
178
0
  auto rrclass = packetReader.get16BitInt();
179
180
0
  GenericDNSPacketWriter<PacketBuffer> packetWriter(newContent, rrname, rrtype, rrclass, dnsHeader->opcode);
181
0
  packetWriter.getHeader()->id = dnsHeader->id;
182
0
  packetWriter.getHeader()->qr = dnsHeader->qr;
183
0
  packetWriter.getHeader()->aa = dnsHeader->aa;
184
0
  packetWriter.getHeader()->tc = dnsHeader->tc;
185
0
  packetWriter.getHeader()->rd = dnsHeader->rd;
186
0
  packetWriter.getHeader()->ra = dnsHeader->ra;
187
0
  packetWriter.getHeader()->ad = dnsHeader->ad;
188
0
  packetWriter.getHeader()->cd = dnsHeader->cd;
189
0
  packetWriter.getHeader()->rcode = dnsHeader->rcode;
190
191
  /* consume remaining qd if any */
192
0
  if (qdcount > 1) {
193
0
    for (idx = 1; idx < qdcount; idx++) {
194
0
      rrname = packetReader.getName();
195
0
      rrtype = packetReader.get16BitInt();
196
0
      rrclass = packetReader.get16BitInt();
197
0
      (void)rrtype;
198
0
      (void)rrclass;
199
0
    }
200
0
  }
201
202
  /* copy AN and NS */
203
0
  for (idx = 0; idx < ancount; idx++) {
204
0
    rrname = packetReader.getName();
205
0
    packetReader.getDnsrecordheader(recordHeader);
206
207
0
    packetWriter.startRecord(rrname, recordHeader.d_type, recordHeader.d_ttl, recordHeader.d_class, DNSResourceRecord::ANSWER, true);
208
0
    packetReader.xfrBlob(blob);
209
0
    packetWriter.xfrBlob(blob);
210
0
  }
211
212
0
  for (idx = 0; idx < nscount; idx++) {
213
0
    rrname = packetReader.getName();
214
0
    packetReader.getDnsrecordheader(recordHeader);
215
216
0
    packetWriter.startRecord(rrname, recordHeader.d_type, recordHeader.d_ttl, recordHeader.d_class, DNSResourceRecord::AUTHORITY, true);
217
0
    packetReader.xfrBlob(blob);
218
0
    packetWriter.xfrBlob(blob);
219
0
  }
220
221
  /* consume AR, looking for OPT */
222
0
  for (idx = 0; idx < arcount; idx++) {
223
0
    rrname = packetReader.getName();
224
0
    packetReader.getDnsrecordheader(recordHeader);
225
226
0
    if (recordHeader.d_type != QType::OPT) {
227
0
      packetWriter.startRecord(rrname, recordHeader.d_type, recordHeader.d_ttl, recordHeader.d_class, DNSResourceRecord::ADDITIONAL, true);
228
0
      packetReader.xfrBlob(blob);
229
0
      packetWriter.xfrBlob(blob);
230
0
    }
231
0
    else {
232
233
0
      ednsAdded = false;
234
0
      packetReader.xfrBlob(blob);
235
236
0
      std::vector<std::pair<uint16_t, std::string>> options;
237
0
      getEDNSOptionsFromContent(blob, options);
238
239
      /* getDnsrecordheader() has helpfully converted the TTL for us, which we do not want in that case */
240
0
      uint32_t ttl = htonl(recordHeader.d_ttl);
241
0
      EDNS0Record edns0{};
242
0
      static_assert(sizeof(edns0) == sizeof(ttl), "sizeof(EDNS0Record) must match sizeof(uint32_t) AKA RR TTL size");
243
0
      memcpy(&edns0, &ttl, sizeof(edns0));
244
245
      /* addOrReplaceEDNSOption will set it to false if there is already an existing option */
246
0
      optionAdded = true;
247
0
      addOrReplaceEDNSOption(options, optionToReplace, optionAdded, overrideExisting, allowMultiple, newOptionContent);
248
0
      packetWriter.addOpt(recordHeader.d_class, edns0.extRCode, ntohs(edns0.extFlags), options, edns0.version);
249
0
    }
250
0
  }
251
252
0
  if (ednsAdded) {
253
0
    packetWriter.addOpt(dnsdist::configuration::s_EdnsUDPPayloadSize, 0, 0, {{optionToReplace, std::string(&newOptionContent.at(EDNS_OPTION_CODE_SIZE + EDNS_OPTION_LENGTH_SIZE), newOptionContent.size() - (EDNS_OPTION_CODE_SIZE + EDNS_OPTION_LENGTH_SIZE))}}, 0);
254
0
    optionAdded = true;
255
0
  }
256
257
0
  packetWriter.commit();
258
259
0
  return true;
260
0
}
261
262
int locateEDNSOptRR(const PacketBuffer& packet, uint16_t* optStart, size_t* optLen, bool* last)
263
0
{
264
0
  if (optStart == nullptr || optLen == nullptr || last == nullptr) {
265
0
    throw std::runtime_error("Invalid values passed to locateEDNSOptRR");
266
0
  }
267
268
0
  const dnsheader_aligned dnsHeader(packet.data());
269
270
0
  if (ntohs(dnsHeader->arcount) == 0) {
271
0
    return ENOENT;
272
0
  }
273
274
  // NOLINTNEXTLINE(cppcoreguidelines-pro-type-reinterpret-cast)
275
0
  PacketReader packetReader(std::string_view(reinterpret_cast<const char*>(packet.data()), packet.size()));
276
277
0
  size_t idx = 0;
278
0
  DNSName rrname;
279
0
  uint16_t qdcount = ntohs(dnsHeader->qdcount);
280
0
  uint16_t ancount = ntohs(dnsHeader->ancount);
281
0
  uint16_t nscount = ntohs(dnsHeader->nscount);
282
0
  uint16_t arcount = ntohs(dnsHeader->arcount);
283
0
  uint16_t rrtype{};
284
0
  uint16_t rrclass{};
285
0
  dnsrecordheader recordHeader{};
286
287
  /* consume qd */
288
0
  for (idx = 0; idx < qdcount; idx++) {
289
0
    rrname = packetReader.getName();
290
0
    rrtype = packetReader.get16BitInt();
291
0
    rrclass = packetReader.get16BitInt();
292
0
    (void)rrtype;
293
0
    (void)rrclass;
294
0
  }
295
296
  /* consume AN and NS */
297
0
  for (idx = 0; idx < ancount + nscount; idx++) {
298
0
    rrname = packetReader.getName();
299
0
    packetReader.getDnsrecordheader(recordHeader);
300
0
    packetReader.skip(recordHeader.d_clen);
301
0
  }
302
303
  /* consume AR, looking for OPT */
304
0
  for (idx = 0; idx < arcount; idx++) {
305
0
    uint16_t start = packetReader.getPosition();
306
0
    rrname = packetReader.getName();
307
0
    packetReader.getDnsrecordheader(recordHeader);
308
309
0
    if (recordHeader.d_type == QType::OPT) {
310
0
      *optStart = start;
311
0
      *optLen = (packetReader.getPosition() - start) + recordHeader.d_clen;
312
313
0
      if (packet.size() < (*optStart + *optLen)) {
314
0
        throw std::range_error("Opt record overflow");
315
0
      }
316
317
0
      if (idx == ((size_t)arcount - 1)) {
318
0
        *last = true;
319
0
      }
320
0
      else {
321
0
        *last = false;
322
0
      }
323
0
      return 0;
324
0
    }
325
0
    packetReader.skip(recordHeader.d_clen);
326
0
  }
327
328
0
  return ENOENT;
329
0
}
330
331
namespace dnsdist
332
{
333
/* extract the start of the OPT RR in a QUERY packet if any
334
 * optRDPosition points to the first byte of the RDLEN field
335
 * remaining contains the number of bytes in the packet after optRDPosition (i.e. packet.size() - optRDPosition)
336
 */
337
int getEDNSOptionsStart(const PacketBuffer& packet, const size_t qnameWireLength, uint16_t* optRDPosition, size_t* remaining)
338
949
{
339
949
  if (optRDPosition == nullptr || remaining == nullptr) {
340
0
    throw std::runtime_error("Invalid values passed to getEDNSOptionsStart");
341
0
  }
342
343
949
  const dnsheader_aligned dnsHeader(packet.data());
344
345
949
  if (qnameWireLength >= packet.size()) {
346
0
    return ENOENT;
347
0
  }
348
349
949
  if (ntohs(dnsHeader->qdcount) != 1 || ntohs(dnsHeader->ancount) != 0 || ntohs(dnsHeader->arcount) != 1 || ntohs(dnsHeader->nscount) != 0) {
350
163
    return ENOENT;
351
163
  }
352
353
786
  size_t pos = sizeof(dnsheader) + qnameWireLength;
354
786
  pos += DNS_TYPE_SIZE + DNS_CLASS_SIZE;
355
356
786
  if (pos >= packet.size()) {
357
6
    return ENOENT;
358
6
  }
359
360
780
  if ((pos + /* root */ 1 + DNS_TYPE_SIZE + DNS_CLASS_SIZE) >= packet.size()) {
361
22
    return ENOENT;
362
22
  }
363
364
758
  if (packet[pos] != 0) {
365
    /* not the root so not an OPT record */
366
118
    return ENOENT;
367
118
  }
368
640
  pos += 1;
369
370
640
  uint16_t qtype = packet.at(pos) * 256 + packet.at(pos + 1);
371
640
  pos += DNS_TYPE_SIZE;
372
640
  pos += DNS_CLASS_SIZE;
373
374
640
  if (qtype != QType::OPT || (packet.size() - pos) < (DNS_TTL_SIZE + DNS_RDLENGTH_SIZE)) {
375
115
    return ENOENT;
376
115
  }
377
378
525
  pos += DNS_TTL_SIZE;
379
525
  *optRDPosition = pos;
380
525
  *remaining = packet.size() - pos;
381
382
525
  return 0;
383
640
}
384
}
385
386
void generateECSOption(const ComboAddress& source, string& res, uint16_t ECSPrefixLength)
387
0
{
388
0
  Netmask sourceNetmask(source, ECSPrefixLength);
389
0
  EDNSSubnetOpts ecsOpts;
390
0
  ecsOpts.setSource(sourceNetmask);
391
0
  string payload = ecsOpts.makeOptString();
392
0
  generateEDNSOption(EDNSOptionCode::ECS, payload, res);
393
0
}
394
395
bool generateOptRR(const std::string& optRData, PacketBuffer& res, size_t maximumSize, uint16_t udpPayloadSize, uint8_t ednsrcode, bool dnssecOK)
396
0
{
397
0
  const uint8_t name = 0;
398
0
  dnsrecordheader dnsHeader{};
399
0
  EDNS0Record edns0{};
400
0
  edns0.extRCode = ednsrcode;
401
0
  edns0.version = 0;
402
0
  edns0.extFlags = dnssecOK ? htons(EDNS_HEADER_FLAG_DO) : 0;
403
404
0
  if ((maximumSize - res.size()) < (sizeof(name) + sizeof(dnsHeader) + optRData.length())) {
405
0
    return false;
406
0
  }
407
408
0
  dnsHeader.d_type = htons(QType::OPT);
409
0
  dnsHeader.d_class = htons(udpPayloadSize);
410
0
  static_assert(sizeof(EDNS0Record) == sizeof(dnsHeader.d_ttl), "sizeof(EDNS0Record) must match sizeof(dnsrecordheader.d_ttl)");
411
0
  memcpy(&dnsHeader.d_ttl, &edns0, sizeof edns0);
412
0
  dnsHeader.d_clen = htons(static_cast<uint16_t>(optRData.length()));
413
414
0
  res.reserve(res.size() + sizeof(name) + sizeof(dnsHeader) + optRData.length());
415
  // NOLINTNEXTLINE(cppcoreguidelines-pro-type-reinterpret-cast,cppcoreguidelines-pro-bounds-pointer-arithmetic)
416
0
  res.insert(res.end(), reinterpret_cast<const uint8_t*>(&name), reinterpret_cast<const uint8_t*>(&name) + sizeof(name));
417
  // NOLINTNEXTLINE(cppcoreguidelines-pro-type-reinterpret-cast,cppcoreguidelines-pro-bounds-pointer-arithmetic)
418
0
  res.insert(res.end(), reinterpret_cast<const uint8_t*>(&dnsHeader), reinterpret_cast<const uint8_t*>(&dnsHeader) + sizeof(dnsHeader));
419
0
  res.insert(res.end(), optRData.begin(), optRData.end());
420
421
0
  return true;
422
0
}
423
424
static bool replaceEDNSClientSubnetOption(PacketBuffer& packet, size_t maximumSize, size_t const oldEcsOptionStartPosition, size_t const oldEcsOptionSize, size_t const optRDLenPosition, const string& newECSOption)
425
0
{
426
0
  if (oldEcsOptionStartPosition >= packet.size() || optRDLenPosition >= packet.size()) {
427
0
    throw std::runtime_error("Invalid values passed to replaceEDNSClientSubnetOption");
428
0
  }
429
430
0
  if (newECSOption.size() == oldEcsOptionSize) {
431
    /* same size as the existing option */
432
0
    memcpy(&packet.at(oldEcsOptionStartPosition), newECSOption.c_str(), oldEcsOptionSize);
433
0
  }
434
0
  else {
435
    /* different size than the existing option */
436
0
    const unsigned int newPacketLen = packet.size() + (newECSOption.length() - oldEcsOptionSize);
437
0
    const size_t beforeOptionLen = oldEcsOptionStartPosition;
438
0
    const size_t dataBehindSize = packet.size() - beforeOptionLen - oldEcsOptionSize;
439
440
    /* check that it fits in the existing buffer */
441
0
    if (newPacketLen > packet.size()) {
442
0
      if (newPacketLen > maximumSize) {
443
0
        return false;
444
0
      }
445
446
0
      packet.resize(newPacketLen);
447
0
    }
448
449
    /* fix the size of ECS Option RDLen */
450
0
    uint16_t newRDLen = (packet.at(optRDLenPosition) * 256) + packet.at(optRDLenPosition + 1);
451
0
    newRDLen += (newECSOption.size() - oldEcsOptionSize);
452
0
    packet.at(optRDLenPosition) = newRDLen / 256;
453
0
    packet.at(optRDLenPosition + 1) = newRDLen % 256;
454
455
0
    if (dataBehindSize > 0) {
456
0
      memmove(&packet.at(oldEcsOptionStartPosition), &packet.at(oldEcsOptionStartPosition + oldEcsOptionSize), dataBehindSize);
457
0
    }
458
0
    memcpy(&packet.at(oldEcsOptionStartPosition + dataBehindSize), newECSOption.c_str(), newECSOption.size());
459
0
    packet.resize(newPacketLen);
460
0
  }
461
462
0
  return true;
463
0
}
464
465
/* This function looks for an OPT RR, return true if a valid one was found (even if there was no options)
466
   and false otherwise. */
467
bool parseEDNSOptions(const DNSQuestion& dnsQuestion)
468
0
{
469
0
  const auto dnsHeader = dnsQuestion.getHeader();
470
0
  if (dnsQuestion.ednsOptions != nullptr) {
471
0
    return true;
472
0
  }
473
474
  // dnsQuestion.ednsOptions is mutable
475
0
  dnsQuestion.ednsOptions = std::make_unique<EDNSOptionViewMap>();
476
477
0
  if (ntohs(dnsHeader->arcount) == 0) {
478
    /* nothing in additional so no EDNS */
479
0
    return false;
480
0
  }
481
482
0
  if (ntohs(dnsHeader->ancount) != 0 || ntohs(dnsHeader->nscount) != 0 || ntohs(dnsHeader->arcount) > 1) {
483
0
    return slowParseEDNSOptions(dnsQuestion.getData(), *dnsQuestion.ednsOptions);
484
0
  }
485
486
0
  size_t remaining = 0;
487
0
  uint16_t optRDPosition{};
488
0
  int res = dnsdist::getEDNSOptionsStart(dnsQuestion.getData(), dnsQuestion.ids.qname.wirelength(), &optRDPosition, &remaining);
489
490
0
  if (res == 0) {
491
    // NOLINTNEXTLINE(cppcoreguidelines-pro-type-reinterpret-cast)
492
0
    res = getEDNSOptions(reinterpret_cast<const char*>(&dnsQuestion.getData().at(optRDPosition)), remaining, *dnsQuestion.ednsOptions);
493
0
    return (res == 0);
494
0
  }
495
496
0
  return false;
497
0
}
498
499
static bool addECSToExistingOPT(PacketBuffer& packet, size_t maximumSize, const string& newECSOption, size_t optRDLenPosition, bool& ecsAdded)
500
0
{
501
  /* we need to add one EDNS0 ECS option, fixing the size of EDNS0 RDLENGTH */
502
  /* getEDNSOptionsStart has already checked that there is exactly one AR,
503
     no NS and no AN */
504
0
  uint16_t oldRDLen = (packet.at(optRDLenPosition) * 256) + packet.at(optRDLenPosition + 1);
505
0
  if (packet.size() != (optRDLenPosition + sizeof(uint16_t) + oldRDLen)) {
506
    /* we are supposed to be the last record, do we have some trailing data to remove? */
507
    // NOLINTNEXTLINE(cppcoreguidelines-pro-type-reinterpret-cast)
508
0
    uint32_t realPacketLen = getDNSPacketLength(reinterpret_cast<const char*>(packet.data()), packet.size());
509
0
    packet.resize(realPacketLen);
510
0
  }
511
512
0
  if ((maximumSize - packet.size()) < newECSOption.size()) {
513
0
    return false;
514
0
  }
515
516
0
  uint16_t newRDLen = oldRDLen + newECSOption.size();
517
0
  packet.at(optRDLenPosition) = newRDLen / 256;
518
0
  packet.at(optRDLenPosition + 1) = newRDLen % 256;
519
520
0
  packet.insert(packet.end(), newECSOption.begin(), newECSOption.end());
521
0
  ecsAdded = true;
522
523
0
  return true;
524
0
}
525
526
static bool addEDNSWithECS(PacketBuffer& packet, size_t maximumSize, const string& newECSOption, bool& ednsAdded, bool& ecsAdded)
527
0
{
528
0
  if (!generateOptRR(newECSOption, packet, maximumSize, dnsdist::configuration::s_EdnsUDPPayloadSize, 0, false)) {
529
0
    return false;
530
0
  }
531
532
0
  dnsdist::PacketMangling::editDNSHeaderFromPacket(packet, [](dnsheader& header) {
533
0
    uint16_t arcount = ntohs(header.arcount);
534
0
    arcount++;
535
0
    header.arcount = htons(arcount);
536
0
    return true;
537
0
  });
538
0
  ednsAdded = true;
539
0
  ecsAdded = true;
540
541
0
  return true;
542
0
}
543
544
bool handleEDNSClientSubnet(PacketBuffer& packet, const size_t maximumSize, const size_t qnameWireLength, bool& ednsAdded, bool& ecsAdded, bool overrideExisting, const string& newECSOption)
545
0
{
546
0
  if (qnameWireLength > packet.size()) {
547
0
    throw std::runtime_error("Invalid value passed to handleEDNSClientSubnet");
548
0
  }
549
550
0
  const dnsheader_aligned dnsHeader(packet.data());
551
552
0
  if (ntohs(dnsHeader->ancount) != 0 || ntohs(dnsHeader->nscount) != 0 || (ntohs(dnsHeader->arcount) != 0 && ntohs(dnsHeader->arcount) != 1)) {
553
0
    PacketBuffer newContent;
554
0
    newContent.reserve(packet.size());
555
556
0
    if (!slowRewriteEDNSOptionInQueryWithRecords(packet, newContent, ednsAdded, EDNSOptionCode::ECS, ecsAdded, overrideExisting, false, newECSOption)) {
557
0
      return false;
558
0
    }
559
560
0
    if (newContent.size() > maximumSize) {
561
0
      ednsAdded = false;
562
0
      ecsAdded = false;
563
0
      return false;
564
0
    }
565
566
0
    packet = std::move(newContent);
567
0
    return true;
568
0
  }
569
570
0
  uint16_t optRDPosition = 0;
571
0
  size_t remaining = 0;
572
573
0
  int res = dnsdist::getEDNSOptionsStart(packet, qnameWireLength, &optRDPosition, &remaining);
574
575
0
  if (res != 0) {
576
    /* no EDNS but there might be another record in additional (TSIG?) */
577
    /* Careful, this code assumes that ANCOUNT == 0 && NSCOUNT == 0 */
578
0
    size_t minimumPacketSize = sizeof(dnsheader) + qnameWireLength + sizeof(uint16_t) + sizeof(uint16_t);
579
0
    if (packet.size() > minimumPacketSize) {
580
0
      if (ntohs(dnsHeader->arcount) == 0) {
581
        /* well now.. */
582
0
        packet.resize(minimumPacketSize);
583
0
      }
584
0
      else {
585
        // NOLINTNEXTLINE(cppcoreguidelines-pro-type-reinterpret-cast)
586
0
        uint32_t realPacketLen = getDNSPacketLength(reinterpret_cast<const char*>(packet.data()), packet.size());
587
0
        packet.resize(realPacketLen);
588
0
      }
589
0
    }
590
591
0
    return addEDNSWithECS(packet, maximumSize, newECSOption, ednsAdded, ecsAdded);
592
0
  }
593
594
0
  size_t ecsOptionStartPosition = 0;
595
0
  size_t ecsOptionSize = 0;
596
  // NOLINTNEXTLINE(cppcoreguidelines-pro-type-reinterpret-cast)
597
0
  res = getEDNSOption(reinterpret_cast<const char*>(&packet.at(optRDPosition)), remaining, EDNSOptionCode::ECS, &ecsOptionStartPosition, &ecsOptionSize);
598
599
0
  if (res == 0) {
600
    /* there is already an ECS value */
601
0
    if (!overrideExisting) {
602
0
      return true;
603
0
    }
604
605
0
    return replaceEDNSClientSubnetOption(packet, maximumSize, optRDPosition + ecsOptionStartPosition, ecsOptionSize, optRDPosition, newECSOption);
606
0
  }
607
608
  /* we have an EDNS OPT RR but no existing ECS option */
609
0
  return addECSToExistingOPT(packet, maximumSize, newECSOption, optRDPosition, ecsAdded);
610
0
}
611
612
bool handleEDNSClientSubnet(DNSQuestion& dnsQuestion, bool& ednsAdded, bool& ecsAdded)
613
0
{
614
0
  string newECSOption;
615
0
  generateECSOption(dnsQuestion.ecs ? dnsQuestion.ecs->getNetwork() : dnsQuestion.ids.origRemote, newECSOption, dnsQuestion.ecs ? dnsQuestion.ecs->getBits() : dnsQuestion.ecsPrefixLength);
616
617
0
  return handleEDNSClientSubnet(dnsQuestion.getMutableData(), dnsQuestion.getMaximumSize(), dnsQuestion.ids.qname.wirelength(), ednsAdded, ecsAdded, dnsQuestion.ecsOverride, newECSOption);
618
0
}
619
620
static int removeEDNSOptionFromOptions(unsigned char* optionsStart, const uint16_t optionsLen, const uint16_t optionCodeToRemove, uint16_t* newOptionsLen)
621
0
{
622
0
  const pdns::views::UnsignedCharView view(optionsStart, optionsLen);
623
0
  size_t pos = 0;
624
0
  while ((pos + 4) <= view.size()) {
625
0
    size_t optionBeginPos = pos;
626
0
    const uint16_t optionCode = 0x100 * view.at(pos) + view.at(pos + 1);
627
0
    pos += sizeof(optionCode);
628
0
    const uint16_t optionLen = 0x100 * view.at(pos) + view.at(pos + 1);
629
0
    pos += sizeof(optionLen);
630
0
    if ((pos + optionLen) > view.size()) {
631
0
      return EINVAL;
632
0
    }
633
0
    if (optionCode == optionCodeToRemove) {
634
0
      if (pos + optionLen < view.size()) {
635
        /* move remaining options over the removed one,
636
           if any */
637
        // NOLINTNEXTLINE(cppcoreguidelines-pro-bounds-pointer-arithmetic)
638
0
        memmove(optionsStart + optionBeginPos, optionsStart + pos + optionLen, optionsLen - (pos + optionLen));
639
0
      }
640
0
      *newOptionsLen = optionsLen - (sizeof(optionCode) + sizeof(optionLen) + optionLen);
641
0
      return 0;
642
0
    }
643
0
    pos += optionLen;
644
0
  }
645
0
  return ENOENT;
646
0
}
647
648
int removeEDNSOptionFromOPT(char* optStart, size_t* optLen, const uint16_t optionCodeToRemove)
649
0
{
650
0
  if (*optLen < optRecordMinimumSize) {
651
0
    return EINVAL;
652
0
  }
653
0
  const pdns::views::UnsignedCharView view(optStart, *optLen);
654
  /* skip the root label, qtype, qclass and TTL */
655
0
  size_t position = 9;
656
0
  uint16_t rdLen = (0x100 * view.at(position) + view.at(position + 1));
657
0
  position += sizeof(rdLen);
658
0
  if (position + rdLen != view.size()) {
659
0
    return EINVAL;
660
0
  }
661
0
  uint16_t newRdLen = 0;
662
  // NOLINTNEXTLINE(cppcoreguidelines-pro-type-reinterpret-cast,cppcoreguidelines-pro-bounds-pointer-arithmetic)
663
0
  int res = removeEDNSOptionFromOptions(reinterpret_cast<unsigned char*>(optStart + position), rdLen, optionCodeToRemove, &newRdLen);
664
0
  if (res != 0) {
665
0
    return res;
666
0
  }
667
0
  *optLen -= (rdLen - newRdLen);
668
  // NOLINTNEXTLINE(cppcoreguidelines-pro-type-reinterpret-cast,cppcoreguidelines-pro-bounds-pointer-arithmetic)
669
0
  auto* rdLenPtr = reinterpret_cast<unsigned char*>(optStart + 9);
670
  // NOLINTNEXTLINE(cppcoreguidelines-pro-bounds-pointer-arithmetic)
671
0
  rdLenPtr[0] = newRdLen / 0x100;
672
  // NOLINTNEXTLINE(cppcoreguidelines-pro-bounds-pointer-arithmetic)
673
0
  rdLenPtr[1] = newRdLen % 0x100;
674
0
  return 0;
675
0
}
676
677
bool isEDNSOptionInOpt(const PacketBuffer& packet, const size_t optStart, const size_t optLen, const uint16_t optionCodeToFind, size_t* optContentStart, uint16_t* optContentLen)
678
0
{
679
0
  if (optLen < optRecordMinimumSize) {
680
0
    return false;
681
0
  }
682
0
  size_t position = optStart + 9;
683
0
  uint16_t rdLen = (0x100 * static_cast<unsigned char>(packet.at(position)) + static_cast<unsigned char>(packet.at(position + 1)));
684
0
  position += sizeof(rdLen);
685
0
  if (rdLen > (optLen - optRecordMinimumSize)) {
686
0
    return false;
687
0
  }
688
689
0
  size_t rdEnd = position + rdLen;
690
0
  while ((position + 4) <= rdEnd) {
691
0
    const uint16_t optionCode = 0x100 * static_cast<unsigned char>(packet.at(position)) + static_cast<unsigned char>(packet.at(position + 1));
692
0
    position += sizeof(optionCode);
693
0
    const uint16_t optionLen = 0x100 * static_cast<unsigned char>(packet.at(position)) + static_cast<unsigned char>(packet.at(position + 1));
694
0
    position += sizeof(optionLen);
695
696
0
    if ((position + optionLen) > rdEnd) {
697
0
      return false;
698
0
    }
699
700
0
    if (optionCode == optionCodeToFind) {
701
0
      if (optContentStart != nullptr) {
702
0
        *optContentStart = position;
703
0
      }
704
705
0
      if (optContentLen != nullptr) {
706
0
        *optContentLen = optionLen;
707
0
      }
708
709
0
      return true;
710
0
    }
711
0
    position += optionLen;
712
0
  }
713
0
  return false;
714
0
}
715
716
int rewriteResponseWithoutEDNSOption(const PacketBuffer& initialPacket, const uint16_t optionCodeToSkip, PacketBuffer& newContent)
717
0
{
718
0
  if (initialPacket.size() < sizeof(dnsheader)) {
719
0
    return ENOENT;
720
0
  }
721
722
0
  const dnsheader_aligned dnsHeader(initialPacket.data());
723
724
0
  if (ntohs(dnsHeader->arcount) == 0) {
725
0
    return ENOENT;
726
0
  }
727
728
0
  if (ntohs(dnsHeader->qdcount) == 0) {
729
0
    return ENOENT;
730
0
  }
731
732
  // NOLINTNEXTLINE(cppcoreguidelines-pro-type-reinterpret-cast)
733
0
  PacketReader packetReader(std::string_view(reinterpret_cast<const char*>(initialPacket.data()), initialPacket.size()));
734
735
0
  size_t idx = 0;
736
0
  DNSName rrname;
737
0
  uint16_t qdcount = ntohs(dnsHeader->qdcount);
738
0
  uint16_t ancount = ntohs(dnsHeader->ancount);
739
0
  uint16_t nscount = ntohs(dnsHeader->nscount);
740
0
  uint16_t arcount = ntohs(dnsHeader->arcount);
741
0
  uint16_t rrtype = 0;
742
0
  uint16_t rrclass = 0;
743
0
  string blob;
744
0
  dnsrecordheader recordHeader{};
745
746
0
  rrname = packetReader.getName();
747
0
  rrtype = packetReader.get16BitInt();
748
0
  rrclass = packetReader.get16BitInt();
749
750
0
  GenericDNSPacketWriter<PacketBuffer> packetWriter(newContent, rrname, rrtype, rrclass, dnsHeader->opcode);
751
0
  packetWriter.getHeader()->id = dnsHeader->id;
752
0
  packetWriter.getHeader()->qr = dnsHeader->qr;
753
0
  packetWriter.getHeader()->aa = dnsHeader->aa;
754
0
  packetWriter.getHeader()->tc = dnsHeader->tc;
755
0
  packetWriter.getHeader()->rd = dnsHeader->rd;
756
0
  packetWriter.getHeader()->ra = dnsHeader->ra;
757
0
  packetWriter.getHeader()->ad = dnsHeader->ad;
758
0
  packetWriter.getHeader()->cd = dnsHeader->cd;
759
0
  packetWriter.getHeader()->rcode = dnsHeader->rcode;
760
761
  /* consume remaining qd if any */
762
0
  if (qdcount > 1) {
763
0
    for (idx = 1; idx < qdcount; idx++) {
764
0
      rrname = packetReader.getName();
765
0
      rrtype = packetReader.get16BitInt();
766
0
      rrclass = packetReader.get16BitInt();
767
0
      (void)rrtype;
768
0
      (void)rrclass;
769
0
    }
770
0
  }
771
772
  /* copy AN and NS */
773
0
  for (idx = 0; idx < ancount; idx++) {
774
0
    rrname = packetReader.getName();
775
0
    packetReader.getDnsrecordheader(recordHeader);
776
777
0
    packetWriter.startRecord(rrname, recordHeader.d_type, recordHeader.d_ttl, recordHeader.d_class, DNSResourceRecord::ANSWER, true);
778
0
    packetReader.xfrBlob(blob);
779
0
    packetWriter.xfrBlob(blob);
780
0
  }
781
782
0
  for (idx = 0; idx < nscount; idx++) {
783
0
    rrname = packetReader.getName();
784
0
    packetReader.getDnsrecordheader(recordHeader);
785
786
0
    packetWriter.startRecord(rrname, recordHeader.d_type, recordHeader.d_ttl, recordHeader.d_class, DNSResourceRecord::AUTHORITY, true);
787
0
    packetReader.xfrBlob(blob);
788
0
    packetWriter.xfrBlob(blob);
789
0
  }
790
791
  /* consume AR, looking for OPT */
792
0
  for (idx = 0; idx < arcount; idx++) {
793
0
    rrname = packetReader.getName();
794
0
    packetReader.getDnsrecordheader(recordHeader);
795
796
0
    if (recordHeader.d_type != QType::OPT) {
797
0
      packetWriter.startRecord(rrname, recordHeader.d_type, recordHeader.d_ttl, recordHeader.d_class, DNSResourceRecord::ADDITIONAL, true);
798
0
      packetReader.xfrBlob(blob);
799
0
      packetWriter.xfrBlob(blob);
800
0
    }
801
0
    else {
802
0
      packetWriter.startRecord(rrname, recordHeader.d_type, recordHeader.d_ttl, recordHeader.d_class, DNSResourceRecord::ADDITIONAL, false);
803
0
      packetReader.xfrBlob(blob);
804
0
      uint16_t rdLen = blob.length();
805
      // NOLINTNEXTLINE(cppcoreguidelines-pro-type-reinterpret-cast)
806
0
      removeEDNSOptionFromOptions(reinterpret_cast<unsigned char*>(blob.data()), rdLen, optionCodeToSkip, &rdLen);
807
      /* xfrBlob(string, size) completely ignores size.. */
808
0
      if (rdLen > 0) {
809
0
        blob.resize((size_t)rdLen);
810
0
        packetWriter.xfrBlob(blob);
811
0
      }
812
0
      else {
813
0
        packetWriter.commit();
814
0
      }
815
0
    }
816
0
  }
817
0
  packetWriter.commit();
818
819
0
  return 0;
820
0
}
821
822
bool addEDNS(PacketBuffer& packet, size_t maximumSize, bool dnssecOK, uint16_t payloadSize, uint8_t ednsrcode)
823
0
{
824
0
  if (!generateOptRR(std::string(), packet, maximumSize, payloadSize, ednsrcode, dnssecOK)) {
825
0
    return false;
826
0
  }
827
828
0
  dnsdist::PacketMangling::editDNSHeaderFromPacket(packet, [](dnsheader& header) {
829
0
    header.arcount = htons(ntohs(header.arcount) + 1);
830
0
    return true;
831
0
  });
832
833
0
  return true;
834
0
}
835
836
/*
837
  This function keeps the existing header and DNSSECOK bit (if any) but wipes anything else,
838
  generating a NXD or NODATA answer with a SOA record in the additional section (or optionally the authority section for a full cacheable NXDOMAIN/NODATA).
839
*/
840
bool setNegativeAndAdditionalSOA(DNSQuestion& dnsQuestion, bool nxd, const DNSName& zone, uint32_t ttl, const DNSName& mname, const DNSName& rname, uint32_t serial, uint32_t refresh, uint32_t retry, uint32_t expire, uint32_t minimum, bool soaInAuthoritySection)
841
0
{
842
0
  auto& packet = dnsQuestion.getMutableData();
843
0
  auto dnsHeader = dnsQuestion.getHeader();
844
0
  if (ntohs(dnsHeader->qdcount) != 1) {
845
0
    return false;
846
0
  }
847
848
0
  size_t queryPartSize = sizeof(dnsheader) + dnsQuestion.ids.qname.wirelength() + DNS_TYPE_SIZE + DNS_CLASS_SIZE;
849
0
  if (packet.size() < queryPartSize) {
850
    /* something is already wrong, don't build on flawed foundations */
851
0
    return false;
852
0
  }
853
854
0
  uint16_t qtype = htons(QType::SOA);
855
0
  uint16_t qclass = htons(QClass::IN);
856
0
  uint16_t rdLength = mname.wirelength() + rname.wirelength() + sizeof(serial) + sizeof(refresh) + sizeof(retry) + sizeof(expire) + sizeof(minimum);
857
0
  size_t soaSize = zone.wirelength() + sizeof(qtype) + sizeof(qclass) + sizeof(ttl) + sizeof(rdLength) + rdLength;
858
0
  bool hadEDNS = false;
859
0
  bool dnssecOK = false;
860
861
0
  if (dnsdist::configuration::getCurrentRuntimeConfiguration().d_addEDNSToSelfGeneratedResponses) {
862
0
    uint16_t payloadSize = 0;
863
0
    uint16_t zValue = 0;
864
    // NOLINTNEXTLINE(cppcoreguidelines-pro-type-reinterpret-cast)
865
0
    hadEDNS = getEDNSUDPPayloadSizeAndZ(reinterpret_cast<const char*>(packet.data()), packet.size(), &payloadSize, &zValue);
866
0
    if (hadEDNS) {
867
0
      dnssecOK = (zValue & EDNS_HEADER_FLAG_DO) != 0;
868
0
    }
869
0
  }
870
871
  /* chop off everything after the question */
872
0
  packet.resize(queryPartSize);
873
0
  dnsdist::PacketMangling::editDNSHeaderFromPacket(packet, [nxd](dnsheader& header) {
874
0
    if (nxd) {
875
0
      header.rcode = RCode::NXDomain;
876
0
    }
877
0
    else {
878
0
      header.rcode = RCode::NoError;
879
0
    }
880
0
    header.qr = true;
881
0
    header.ancount = 0;
882
0
    header.nscount = 0;
883
0
    header.arcount = 0;
884
0
    return true;
885
0
  });
886
887
0
  rdLength = htons(rdLength);
888
0
  ttl = htonl(ttl);
889
0
  serial = htonl(serial);
890
0
  refresh = htonl(refresh);
891
0
  retry = htonl(retry);
892
0
  expire = htonl(expire);
893
0
  minimum = htonl(minimum);
894
895
0
  std::string soa;
896
0
  soa.reserve(soaSize);
897
0
  soa.append(zone.toDNSString());
898
  // NOLINTNEXTLINE(cppcoreguidelines-pro-type-reinterpret-cast)
899
0
  soa.append(reinterpret_cast<const char*>(&qtype), sizeof(qtype));
900
  // NOLINTNEXTLINE(cppcoreguidelines-pro-type-reinterpret-cast)
901
0
  soa.append(reinterpret_cast<const char*>(&qclass), sizeof(qclass));
902
  // NOLINTNEXTLINE(cppcoreguidelines-pro-type-reinterpret-cast)
903
0
  soa.append(reinterpret_cast<const char*>(&ttl), sizeof(ttl));
904
  // NOLINTNEXTLINE(cppcoreguidelines-pro-type-reinterpret-cast)
905
0
  soa.append(reinterpret_cast<const char*>(&rdLength), sizeof(rdLength));
906
0
  soa.append(mname.toDNSString());
907
0
  soa.append(rname.toDNSString());
908
  // NOLINTNEXTLINE(cppcoreguidelines-pro-type-reinterpret-cast)
909
0
  soa.append(reinterpret_cast<const char*>(&serial), sizeof(serial));
910
  // NOLINTNEXTLINE(cppcoreguidelines-pro-type-reinterpret-cast)
911
0
  soa.append(reinterpret_cast<const char*>(&refresh), sizeof(refresh));
912
  // NOLINTNEXTLINE(cppcoreguidelines-pro-type-reinterpret-cast)
913
0
  soa.append(reinterpret_cast<const char*>(&retry), sizeof(retry));
914
  // NOLINTNEXTLINE(cppcoreguidelines-pro-type-reinterpret-cast)
915
0
  soa.append(reinterpret_cast<const char*>(&expire), sizeof(expire));
916
  // NOLINTNEXTLINE(cppcoreguidelines-pro-type-reinterpret-cast)
917
0
  soa.append(reinterpret_cast<const char*>(&minimum), sizeof(minimum));
918
919
0
  if (soa.size() != soaSize) {
920
0
    throw std::runtime_error("Unexpected SOA response size: " + std::to_string(soa.size()) + " vs " + std::to_string(soaSize));
921
0
  }
922
923
0
  packet.insert(packet.end(), soa.begin(), soa.end());
924
925
  /* We are populating a response with only the query in place, order of sections is QD,AN,NS,AR
926
     NS (authority) is before AR (additional) so we can just decide which section the SOA record is in here
927
     and have EDNS added to AR afterwards */
928
0
  dnsdist::PacketMangling::editDNSHeaderFromPacket(packet, [soaInAuthoritySection](dnsheader& header) {
929
0
    if (soaInAuthoritySection) {
930
0
      header.nscount = htons(1);
931
0
    }
932
0
    else {
933
0
      header.arcount = htons(1);
934
0
    }
935
0
    return true;
936
0
  });
937
938
0
  if (hadEDNS) {
939
    /* now we need to add a new OPT record */
940
0
    return addEDNS(packet, dnsQuestion.getMaximumSize(), dnssecOK, dnsdist::configuration::getCurrentRuntimeConfiguration().d_payloadSizeSelfGenAnswers, dnsQuestion.ednsRCode);
941
0
  }
942
943
0
  return true;
944
0
}
945
946
bool addEDNSToQueryTurnedResponse(DNSQuestion& dnsQuestion)
947
0
{
948
0
  uint16_t optRDPosition{};
949
  /* remaining is at least the size of the rdlen + the options if any + the following records if any */
950
0
  size_t remaining = 0;
951
952
0
  auto& packet = dnsQuestion.getMutableData();
953
0
  int res = dnsdist::getEDNSOptionsStart(packet, dnsQuestion.ids.qname.wirelength(), &optRDPosition, &remaining);
954
955
0
  if (res != 0) {
956
    /* if the initial query did not have EDNS0, we are done */
957
0
    return true;
958
0
  }
959
960
0
  const size_t existingOptLen = /* root */ 1 + DNS_TYPE_SIZE + DNS_CLASS_SIZE + EDNS_EXTENDED_RCODE_SIZE + EDNS_VERSION_SIZE + /* Z */ 2 + remaining;
961
0
  if (existingOptLen >= packet.size()) {
962
    /* something is wrong, bail out */
963
0
    return false;
964
0
  }
965
966
0
  const size_t optPosition = (optRDPosition - (/* root */ 1 + DNS_TYPE_SIZE + DNS_CLASS_SIZE + EDNS_EXTENDED_RCODE_SIZE + EDNS_VERSION_SIZE + /* Z */ 2));
967
968
0
  size_t zPosition = optPosition + /* root */ 1 + DNS_TYPE_SIZE + DNS_CLASS_SIZE + EDNS_EXTENDED_RCODE_SIZE + EDNS_VERSION_SIZE;
969
0
  uint16_t zValue = 0x100 * packet.at(zPosition) + packet.at(zPosition + 1);
970
0
  bool dnssecOK = (zValue & EDNS_HEADER_FLAG_DO) != 0;
971
972
  /* remove the existing OPT record, and everything else that follows (any SIG or TSIG would be useless anyway) */
973
0
  packet.resize(packet.size() - existingOptLen);
974
0
  dnsdist::PacketMangling::editDNSHeaderFromPacket(packet, [](dnsheader& header) {
975
0
    header.arcount = 0;
976
0
    return true;
977
0
  });
978
979
0
  if (dnsdist::configuration::getCurrentRuntimeConfiguration().d_addEDNSToSelfGeneratedResponses) {
980
    /* now we need to add a new OPT record */
981
0
    return addEDNS(packet, dnsQuestion.getMaximumSize(), dnssecOK, dnsdist::configuration::getCurrentRuntimeConfiguration().d_payloadSizeSelfGenAnswers, dnsQuestion.ednsRCode);
982
0
  }
983
984
  /* otherwise we are just fine */
985
0
  return true;
986
0
}
987
988
namespace dnsdist
989
{
990
static std::optional<size_t> getEDNSRecordPosition(const DNSQuestion& dnsQuestion)
991
0
{
992
0
  try {
993
0
    const auto& packet = dnsQuestion.getData();
994
0
    if (packet.size() <= sizeof(dnsheader)) {
995
0
      return std::nullopt;
996
0
    }
997
998
0
    uint16_t optRDPosition = 0;
999
0
    size_t remaining = 0;
1000
0
    auto res = getEDNSOptionsStart(packet, dnsQuestion.ids.qname.wirelength(), &optRDPosition, &remaining);
1001
0
    if (res != 0) {
1002
0
      return std::nullopt;
1003
0
    }
1004
1005
0
    if (optRDPosition < DNS_TTL_SIZE) {
1006
0
      return std::nullopt;
1007
0
    }
1008
1009
0
    return optRDPosition - DNS_TTL_SIZE;
1010
0
  }
1011
0
  catch (...) {
1012
0
    return std::nullopt;
1013
0
  }
1014
0
}
1015
1016
// goal in life - if you send us a reasonably normal packet, we'll get Z for you, otherwise 0
1017
int getEDNSZ(const DNSQuestion& dnsQuestion)
1018
0
{
1019
0
  try {
1020
0
    auto position = getEDNSRecordPosition(dnsQuestion);
1021
1022
0
    if (!position) {
1023
0
      return 0;
1024
0
    }
1025
1026
0
    const auto& packet = dnsQuestion.getData();
1027
0
    if ((*position + EDNS_EXTENDED_RCODE_SIZE + EDNS_VERSION_SIZE + 1) >= packet.size()) {
1028
0
      return 0;
1029
0
    }
1030
1031
0
    return 0x100 * packet.at(*position + EDNS_EXTENDED_RCODE_SIZE + EDNS_VERSION_SIZE) + packet.at(*position + EDNS_EXTENDED_RCODE_SIZE + EDNS_VERSION_SIZE + 1);
1032
0
  }
1033
0
  catch (...) {
1034
0
    return 0;
1035
0
  }
1036
0
}
1037
1038
std::optional<uint8_t> getEDNSVersion(const DNSQuestion& dnsQuestion)
1039
0
{
1040
0
  try {
1041
0
    auto position = getEDNSRecordPosition(dnsQuestion);
1042
1043
0
    if (!position) {
1044
0
      return std::nullopt;
1045
0
    }
1046
1047
0
    const auto& packet = dnsQuestion.getData();
1048
0
    if ((*position + EDNS_EXTENDED_RCODE_SIZE + EDNS_VERSION_SIZE) >= packet.size()) {
1049
0
      return std::nullopt;
1050
0
    }
1051
1052
0
    return packet.at(*position + EDNS_EXTENDED_RCODE_SIZE);
1053
0
  }
1054
0
  catch (...) {
1055
0
    return std::nullopt;
1056
0
  }
1057
0
}
1058
1059
std::optional<uint8_t> getEDNSExtendedRCode(const DNSQuestion& dnsQuestion)
1060
0
{
1061
0
  try {
1062
0
    auto position = getEDNSRecordPosition(dnsQuestion);
1063
1064
0
    if (!position) {
1065
0
      return std::nullopt;
1066
0
    }
1067
1068
0
    const auto& packet = dnsQuestion.getData();
1069
0
    if ((*position + EDNS_EXTENDED_RCODE_SIZE) >= packet.size()) {
1070
0
      return std::nullopt;
1071
0
    }
1072
1073
0
    return packet.at(*position);
1074
0
  }
1075
0
  catch (...) {
1076
0
    return std::nullopt;
1077
0
  }
1078
0
}
1079
1080
}
1081
1082
bool queryHasEDNS(const DNSQuestion& dnsQuestion)
1083
0
{
1084
0
  uint16_t optRDPosition = 0;
1085
0
  size_t ecsRemaining = 0;
1086
1087
0
  int res = dnsdist::getEDNSOptionsStart(dnsQuestion.getData(), dnsQuestion.ids.qname.wirelength(), &optRDPosition, &ecsRemaining);
1088
0
  return res == 0;
1089
0
}
1090
1091
bool getEDNS0Record(const PacketBuffer& packet, EDNS0Record& edns0)
1092
0
{
1093
0
  uint16_t optStart = 0;
1094
0
  size_t optLen = 0;
1095
0
  bool last = false;
1096
0
  int res = locateEDNSOptRR(packet, &optStart, &optLen, &last);
1097
0
  if (res != 0) {
1098
    // no EDNS OPT RR
1099
0
    return false;
1100
0
  }
1101
1102
0
  if (optLen < optRecordMinimumSize) {
1103
0
    return false;
1104
0
  }
1105
1106
0
  if (optStart < packet.size() && packet.at(optStart) != 0) {
1107
    // OPT RR Name != '.'
1108
0
    return false;
1109
0
  }
1110
1111
0
  static_assert(sizeof(EDNS0Record) == sizeof(uint32_t), "sizeof(EDNS0Record) must match sizeof(uint32_t) AKA RR TTL size");
1112
  // copy out 4-byte "ttl" (really the EDNS0 record), after root label (1) + type (2) + class (2).
1113
0
  memcpy(&edns0, &packet.at(optStart + 5), sizeof edns0);
1114
0
  return true;
1115
0
}
1116
1117
bool setEDNSOption(PacketBuffer& buf, uint16_t ednsCode, const std::string& ednsData, size_t maximumSize, bool& ednsAdded, bool& optionAdded)
1118
0
{
1119
0
  if (buf.size() < sizeof(dnsheader)) {
1120
0
    return false;
1121
0
  }
1122
0
  std::string optRData;
1123
0
  generateEDNSOption(ednsCode, ednsData, optRData);
1124
1125
0
  const dnsheader_aligned dnsHeader(buf.data());
1126
0
  if (dnsHeader->arcount != 0) {
1127
0
    ednsAdded = false;
1128
0
    optionAdded = false;
1129
0
    PacketBuffer newContent;
1130
0
    newContent.reserve(buf.size());
1131
1132
0
    if (!slowRewriteEDNSOptionInQueryWithRecords(buf, newContent, ednsAdded, ednsCode, optionAdded, true, false, optRData)) {
1133
0
      return false;
1134
0
    }
1135
1136
0
    if (newContent.size() > maximumSize) {
1137
0
      return false;
1138
0
    }
1139
1140
0
    buf = std::move(newContent);
1141
0
    return true;
1142
0
  }
1143
1144
0
  if (generateOptRR(optRData, buf, maximumSize, dnsdist::configuration::s_EdnsUDPPayloadSize, 0, false)) {
1145
0
    dnsdist::PacketMangling::editDNSHeaderFromPacket(buf, [](dnsheader& header) {
1146
0
      header.arcount = htons(1);
1147
0
      return true;
1148
0
    });
1149
0
  }
1150
1151
0
  return true;
1152
0
}
1153
1154
bool setEDNSOption(DNSQuestion& dnsQuestion, uint16_t ednsCode, const std::string& ednsData, bool isQuery)
1155
0
{
1156
0
  bool ednsAdded = false;
1157
0
  bool optionAdded = false;
1158
0
  auto ret = setEDNSOption(dnsQuestion.getMutableData(), ednsCode, ednsData, dnsQuestion.getMaximumSize(), ednsAdded, optionAdded);
1159
0
  if (!ret) {
1160
0
    return ret;
1161
0
  }
1162
1163
0
  if (isQuery && !dnsQuestion.ids.ednsAdded && ednsAdded) {
1164
0
    dnsQuestion.ids.ednsAdded = true;
1165
0
  }
1166
1167
0
  return true;
1168
0
}
1169
1170
namespace dnsdist
1171
{
1172
bool setInternalQueryRCode(InternalQueryState& state, PacketBuffer& buffer, uint8_t rcode, bool clearAnswers)
1173
0
{
1174
0
  const auto qnameLength = state.qname.wirelength();
1175
0
  if (buffer.size() < sizeof(dnsheader) + qnameLength + sizeof(uint16_t) + sizeof(uint16_t)) {
1176
0
    return false;
1177
0
  }
1178
1179
0
  EDNS0Record edns0{};
1180
0
  bool hadEDNS = false;
1181
0
  if (clearAnswers) {
1182
0
    hadEDNS = getEDNS0Record(buffer, edns0);
1183
0
  }
1184
1185
0
  dnsdist::PacketMangling::editDNSHeaderFromPacket(buffer, [rcode, clearAnswers](dnsheader& header) {
1186
0
    header.rcode = rcode;
1187
0
    header.ad = false;
1188
0
    header.aa = false;
1189
0
    header.ra = header.rd;
1190
0
    header.qr = true;
1191
1192
0
    if (clearAnswers) {
1193
0
      header.ancount = 0;
1194
0
      header.nscount = 0;
1195
0
      header.arcount = 0;
1196
0
    }
1197
0
    return true;
1198
0
  });
1199
1200
0
  if (clearAnswers) {
1201
0
    buffer.resize(sizeof(dnsheader) + qnameLength + sizeof(uint16_t) + sizeof(uint16_t));
1202
0
    if (hadEDNS) {
1203
0
      DNSQuestion dnsQuestion(state, buffer);
1204
0
      if (!addEDNS(buffer, dnsQuestion.getMaximumSize(), (edns0.extFlags & htons(EDNS_HEADER_FLAG_DO)) != 0, dnsdist::configuration::getCurrentRuntimeConfiguration().d_payloadSizeSelfGenAnswers, 0)) {
1205
0
        return false;
1206
0
      }
1207
0
    }
1208
0
  }
1209
1210
0
  return true;
1211
0
}
1212
}