Coverage Report

Created: 2023-09-25 07:00

/src/pdns/pdns/dnsdistdist/dnsdist-ecs.cc
Line
Count
Source (jump to first uncovered line)
1
/*
2
 * This file is part of PowerDNS or dnsdist.
3
 * Copyright -- PowerDNS.COM B.V. and its contributors
4
 *
5
 * This program is free software; you can redistribute it and/or modify
6
 * it under the terms of version 2 of the GNU General Public License as
7
 * published by the Free Software Foundation.
8
 *
9
 * In addition, for the avoidance of any doubt, permission is granted to
10
 * link this program with OpenSSL and to (re)distribute the binaries
11
 * produced as the result of such linking.
12
 *
13
 * This program is distributed in the hope that it will be useful,
14
 * but WITHOUT ANY WARRANTY; without even the implied warranty of
15
 * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the
16
 * GNU General Public License for more details.
17
 *
18
 * You should have received a copy of the GNU General Public License
19
 * along with this program; if not, write to the Free Software
20
 * Foundation, Inc., 51 Franklin Street, Fifth Floor, Boston, MA 02110-1301 USA.
21
 */
22
#include "dolog.hh"
23
#include "dnsdist.hh"
24
#include "dnsdist-ecs.hh"
25
#include "dnsparser.hh"
26
#include "dnswriter.hh"
27
#include "ednsoptions.hh"
28
#include "ednssubnet.hh"
29
30
/* when we add EDNS to a query, we don't want to advertise
31
   a large buffer size */
32
size_t g_EdnsUDPPayloadSize = 512;
33
static const uint16_t defaultPayloadSizeSelfGenAnswers = 1232;
34
static_assert(defaultPayloadSizeSelfGenAnswers < s_udpIncomingBufferSize, "The UDP responder's payload size should be smaller or equal to our incoming buffer size");
35
uint16_t g_PayloadSizeSelfGenAnswers{defaultPayloadSizeSelfGenAnswers};
36
37
/* draft-ietf-dnsop-edns-client-subnet-04 "11.1.  Privacy" */
38
uint16_t g_ECSSourcePrefixV4 = 24;
39
uint16_t g_ECSSourcePrefixV6 = 56;
40
41
bool g_ECSOverride{false};
42
bool g_addEDNSToSelfGeneratedResponses{true};
43
44
int rewriteResponseWithoutEDNS(const PacketBuffer& initialPacket, PacketBuffer& newContent)
45
0
{
46
0
  assert(initialPacket.size() >= sizeof(dnsheader));
47
0
  const struct dnsheader* dh = reinterpret_cast<const struct dnsheader*>(initialPacket.data());
48
49
0
  if (ntohs(dh->arcount) == 0)
50
0
    return ENOENT;
51
52
0
  if (ntohs(dh->qdcount) == 0)
53
0
    return ENOENT;
54
55
0
  PacketReader pr(std::string_view(reinterpret_cast<const char*>(initialPacket.data()), initialPacket.size()));
56
57
0
  size_t idx = 0;
58
0
  DNSName rrname;
59
0
  uint16_t qdcount = ntohs(dh->qdcount);
60
0
  uint16_t ancount = ntohs(dh->ancount);
61
0
  uint16_t nscount = ntohs(dh->nscount);
62
0
  uint16_t arcount = ntohs(dh->arcount);
63
0
  uint16_t rrtype;
64
0
  uint16_t rrclass;
65
0
  string blob;
66
0
  struct dnsrecordheader ah;
67
68
0
  rrname = pr.getName();
69
0
  rrtype = pr.get16BitInt();
70
0
  rrclass = pr.get16BitInt();
71
72
0
  GenericDNSPacketWriter<PacketBuffer> pw(newContent, rrname, rrtype, rrclass, dh->opcode);
73
0
  pw.getHeader()->id=dh->id;
74
0
  pw.getHeader()->qr=dh->qr;
75
0
  pw.getHeader()->aa=dh->aa;
76
0
  pw.getHeader()->tc=dh->tc;
77
0
  pw.getHeader()->rd=dh->rd;
78
0
  pw.getHeader()->ra=dh->ra;
79
0
  pw.getHeader()->ad=dh->ad;
80
0
  pw.getHeader()->cd=dh->cd;
81
0
  pw.getHeader()->rcode=dh->rcode;
82
83
  /* consume remaining qd if any */
84
0
  if (qdcount > 1) {
85
0
    for(idx = 1; idx < qdcount; idx++) {
86
0
      rrname = pr.getName();
87
0
      rrtype = pr.get16BitInt();
88
0
      rrclass = pr.get16BitInt();
89
0
      (void) rrtype;
90
0
      (void) rrclass;
91
0
    }
92
0
  }
93
94
  /* copy AN and NS */
95
0
  for (idx = 0; idx < ancount; idx++) {
96
0
    rrname = pr.getName();
97
0
    pr.getDnsrecordheader(ah);
98
99
0
    pw.startRecord(rrname, ah.d_type, ah.d_ttl, ah.d_class, DNSResourceRecord::ANSWER, true);
100
0
    pr.xfrBlob(blob);
101
0
    pw.xfrBlob(blob);
102
0
  }
103
104
0
  for (idx = 0; idx < nscount; idx++) {
105
0
    rrname = pr.getName();
106
0
    pr.getDnsrecordheader(ah);
107
108
0
    pw.startRecord(rrname, ah.d_type, ah.d_ttl, ah.d_class, DNSResourceRecord::AUTHORITY, true);
109
0
    pr.xfrBlob(blob);
110
0
    pw.xfrBlob(blob);
111
0
  }
112
  /* consume AR, looking for OPT */
113
0
  for (idx = 0; idx < arcount; idx++) {
114
0
    rrname = pr.getName();
115
0
    pr.getDnsrecordheader(ah);
116
117
0
    if (ah.d_type != QType::OPT) {
118
0
      pw.startRecord(rrname, ah.d_type, ah.d_ttl, ah.d_class, DNSResourceRecord::ADDITIONAL, true);
119
0
      pr.xfrBlob(blob);
120
0
      pw.xfrBlob(blob);
121
0
    } else {
122
123
0
      pr.skip(ah.d_clen);
124
0
    }
125
0
  }
126
0
  pw.commit();
127
128
0
  return 0;
129
0
}
130
131
static bool addOrReplaceEDNSOption(std::vector<std::pair<uint16_t, std::string>>& options, uint16_t optionCode, bool& optionAdded, bool overrideExisting, const string& newOptionContent)
132
0
{
133
0
  for (auto it = options.begin(); it != options.end(); ) {
134
0
    if (it->first == optionCode) {
135
0
      optionAdded = false;
136
137
0
      if (!overrideExisting) {
138
0
        return false;
139
0
      }
140
141
0
      it = options.erase(it);
142
0
    }
143
0
    else {
144
0
      ++it;
145
0
    }
146
0
  }
147
148
0
  options.emplace_back(optionCode, std::string(&newOptionContent.at(EDNS_OPTION_CODE_SIZE + EDNS_OPTION_LENGTH_SIZE), newOptionContent.size() - (EDNS_OPTION_CODE_SIZE + EDNS_OPTION_LENGTH_SIZE)));
149
0
  return true;
150
0
}
151
152
bool slowRewriteEDNSOptionInQueryWithRecords(const PacketBuffer& initialPacket, PacketBuffer& newContent, bool& ednsAdded, uint16_t optionToReplace, bool& optionAdded, bool overrideExisting, const string& newOptionContent)
153
0
{
154
0
  assert(initialPacket.size() >= sizeof(dnsheader));
155
0
  const struct dnsheader* dh = reinterpret_cast<const struct dnsheader*>(initialPacket.data());
156
157
0
  if (ntohs(dh->qdcount) == 0) {
158
0
    return false;
159
0
  }
160
161
0
  if (ntohs(dh->ancount) == 0 && ntohs(dh->nscount) == 0 && ntohs(dh->arcount) == 0) {
162
0
    throw std::runtime_error(std::string(__PRETTY_FUNCTION__) + " should not be called for queries that have no records");
163
0
  }
164
165
0
  optionAdded = false;
166
0
  ednsAdded = true;
167
168
0
  PacketReader pr(std::string_view(reinterpret_cast<const char*>(initialPacket.data()), initialPacket.size()));
169
170
0
  size_t idx = 0;
171
0
  DNSName rrname;
172
0
  uint16_t qdcount = ntohs(dh->qdcount);
173
0
  uint16_t ancount = ntohs(dh->ancount);
174
0
  uint16_t nscount = ntohs(dh->nscount);
175
0
  uint16_t arcount = ntohs(dh->arcount);
176
0
  uint16_t rrtype;
177
0
  uint16_t rrclass;
178
0
  string blob;
179
0
  struct dnsrecordheader ah;
180
181
0
  rrname = pr.getName();
182
0
  rrtype = pr.get16BitInt();
183
0
  rrclass = pr.get16BitInt();
184
185
0
  GenericDNSPacketWriter<PacketBuffer> pw(newContent, rrname, rrtype, rrclass, dh->opcode);
186
0
  pw.getHeader()->id=dh->id;
187
0
  pw.getHeader()->qr=dh->qr;
188
0
  pw.getHeader()->aa=dh->aa;
189
0
  pw.getHeader()->tc=dh->tc;
190
0
  pw.getHeader()->rd=dh->rd;
191
0
  pw.getHeader()->ra=dh->ra;
192
0
  pw.getHeader()->ad=dh->ad;
193
0
  pw.getHeader()->cd=dh->cd;
194
0
  pw.getHeader()->rcode=dh->rcode;
195
196
  /* consume remaining qd if any */
197
0
  if (qdcount > 1) {
198
0
    for(idx = 1; idx < qdcount; idx++) {
199
0
      rrname = pr.getName();
200
0
      rrtype = pr.get16BitInt();
201
0
      rrclass = pr.get16BitInt();
202
0
      (void) rrtype;
203
0
      (void) rrclass;
204
0
    }
205
0
  }
206
207
  /* copy AN and NS */
208
0
  for (idx = 0; idx < ancount; idx++) {
209
0
    rrname = pr.getName();
210
0
    pr.getDnsrecordheader(ah);
211
212
0
    pw.startRecord(rrname, ah.d_type, ah.d_ttl, ah.d_class, DNSResourceRecord::ANSWER, true);
213
0
    pr.xfrBlob(blob);
214
0
    pw.xfrBlob(blob);
215
0
  }
216
217
0
  for (idx = 0; idx < nscount; idx++) {
218
0
    rrname = pr.getName();
219
0
    pr.getDnsrecordheader(ah);
220
221
0
    pw.startRecord(rrname, ah.d_type, ah.d_ttl, ah.d_class, DNSResourceRecord::AUTHORITY, true);
222
0
    pr.xfrBlob(blob);
223
0
    pw.xfrBlob(blob);
224
0
  }
225
226
  /* consume AR, looking for OPT */
227
0
  for (idx = 0; idx < arcount; idx++) {
228
0
    rrname = pr.getName();
229
0
    pr.getDnsrecordheader(ah);
230
231
0
    if (ah.d_type != QType::OPT) {
232
0
      pw.startRecord(rrname, ah.d_type, ah.d_ttl, ah.d_class, DNSResourceRecord::ADDITIONAL, true);
233
0
      pr.xfrBlob(blob);
234
0
      pw.xfrBlob(blob);
235
0
    } else {
236
237
0
      ednsAdded = false;
238
0
      pr.xfrBlob(blob);
239
240
0
      std::vector<std::pair<uint16_t, std::string>> options;
241
0
      getEDNSOptionsFromContent(blob, options);
242
243
      /* getDnsrecordheader() has helpfully converted the TTL for us, which we do not want in that case */
244
0
      uint32_t ttl = htonl(ah.d_ttl);
245
0
      EDNS0Record edns0;
246
0
      static_assert(sizeof(edns0) == sizeof(ttl), "sizeof(EDNS0Record) must match sizeof(uint32_t) AKA RR TTL size");
247
0
      memcpy(&edns0, &ttl, sizeof(edns0));
248
249
      /* addOrReplaceEDNSOption will set it to false if there is already an existing option */
250
0
      optionAdded = true;
251
0
      addOrReplaceEDNSOption(options, optionToReplace, optionAdded, overrideExisting, newOptionContent);
252
0
      pw.addOpt(ah.d_class, edns0.extRCode, edns0.extFlags, options, edns0.version);
253
0
    }
254
0
  }
255
256
0
  if (ednsAdded) {
257
0
    pw.addOpt(g_EdnsUDPPayloadSize, 0, 0, {{optionToReplace, std::string(&newOptionContent.at(EDNS_OPTION_CODE_SIZE + EDNS_OPTION_LENGTH_SIZE), newOptionContent.size() - (EDNS_OPTION_CODE_SIZE + EDNS_OPTION_LENGTH_SIZE))}}, 0);
258
0
    optionAdded = true;
259
0
  }
260
261
0
  pw.commit();
262
263
0
  return true;
264
0
}
265
266
static bool slowParseEDNSOptions(const PacketBuffer& packet, EDNSOptionViewMap& options)
267
0
{
268
0
  if (packet.size() < sizeof(dnsheader)) {
269
0
    return false;
270
0
  }
271
272
0
  const struct dnsheader* dh = reinterpret_cast<const struct dnsheader*>(packet.data());
273
274
0
  if (ntohs(dh->qdcount) == 0) {
275
0
    return false;
276
0
  }
277
278
0
  if (ntohs(dh->arcount) == 0) {
279
0
    throw std::runtime_error("slowParseEDNSOptions() should not be called for queries that have no EDNS");
280
0
  }
281
282
0
  try {
283
0
    uint64_t numrecords = ntohs(dh->ancount) + ntohs(dh->nscount) + ntohs(dh->arcount);
284
0
    DNSPacketMangler dpm(const_cast<char*>(reinterpret_cast<const char*>(&packet.at(0))), packet.size());
285
0
    uint64_t n;
286
0
    for(n=0; n < ntohs(dh->qdcount) ; ++n) {
287
0
      dpm.skipDomainName();
288
      /* type and class */
289
0
      dpm.skipBytes(4);
290
0
    }
291
292
0
    for(n=0; n < numrecords; ++n) {
293
0
      dpm.skipDomainName();
294
295
0
      uint8_t section = n < ntohs(dh->ancount) ? 1 : (n < (ntohs(dh->ancount) + ntohs(dh->nscount)) ? 2 : 3);
296
0
      uint16_t dnstype = dpm.get16BitInt();
297
0
      dpm.get16BitInt();
298
0
      dpm.skipBytes(4); /* TTL */
299
300
0
      if(section == 3 && dnstype == QType::OPT) {
301
0
        uint32_t offset = dpm.getOffset();
302
0
        if (offset >= packet.size()) {
303
0
          return false;
304
0
        }
305
        /* if we survive this call, we can parse it safely */
306
0
        dpm.skipRData();
307
0
        return getEDNSOptions(reinterpret_cast<const char*>(&packet.at(offset)), packet.size() - offset, options) == 0;
308
0
      }
309
0
      else {
310
0
        dpm.skipRData();
311
0
      }
312
0
    }
313
0
  }
314
0
  catch(...)
315
0
  {
316
0
    return false;
317
0
  }
318
319
0
  return true;
320
0
}
321
322
int locateEDNSOptRR(const PacketBuffer& packet, uint16_t * optStart, size_t * optLen, bool * last)
323
0
{
324
0
  assert(optStart != NULL);
325
0
  assert(optLen != NULL);
326
0
  assert(last != NULL);
327
0
  const struct dnsheader* dh = reinterpret_cast<const struct dnsheader*>(packet.data());
328
329
0
  if (ntohs(dh->arcount) == 0)
330
0
    return ENOENT;
331
332
0
  PacketReader pr(std::string_view(reinterpret_cast<const char*>(packet.data()), packet.size()));
333
334
0
  size_t idx = 0;
335
0
  DNSName rrname;
336
0
  uint16_t qdcount = ntohs(dh->qdcount);
337
0
  uint16_t ancount = ntohs(dh->ancount);
338
0
  uint16_t nscount = ntohs(dh->nscount);
339
0
  uint16_t arcount = ntohs(dh->arcount);
340
0
  uint16_t rrtype;
341
0
  uint16_t rrclass;
342
0
  struct dnsrecordheader ah;
343
344
  /* consume qd */
345
0
  for(idx = 0; idx < qdcount; idx++) {
346
0
    rrname = pr.getName();
347
0
    rrtype = pr.get16BitInt();
348
0
    rrclass = pr.get16BitInt();
349
0
    (void) rrtype;
350
0
    (void) rrclass;
351
0
  }
352
353
  /* consume AN and NS */
354
0
  for (idx = 0; idx < ancount + nscount; idx++) {
355
0
    rrname = pr.getName();
356
0
    pr.getDnsrecordheader(ah);
357
0
    pr.skip(ah.d_clen);
358
0
  }
359
360
  /* consume AR, looking for OPT */
361
0
  for (idx = 0; idx < arcount; idx++) {
362
0
    uint16_t start = pr.getPosition();
363
0
    rrname = pr.getName();
364
0
    pr.getDnsrecordheader(ah);
365
366
0
    if (ah.d_type == QType::OPT) {
367
0
      *optStart = start;
368
0
      *optLen = (pr.getPosition() - start) + ah.d_clen;
369
370
0
      if (packet.size() < (*optStart + *optLen)) {
371
0
        throw std::range_error("Opt record overflow");
372
0
      }
373
374
0
      if (idx == ((size_t) arcount - 1)) {
375
0
        *last = true;
376
0
      }
377
0
      else {
378
0
        *last = false;
379
0
      }
380
0
      return 0;
381
0
    }
382
0
    pr.skip(ah.d_clen);
383
0
  }
384
385
0
  return ENOENT;
386
0
}
387
388
/* extract the start of the OPT RR in a QUERY packet if any */
389
int getEDNSOptionsStart(const PacketBuffer& packet, const size_t offset, uint16_t* optRDPosition, size_t* remaining)
390
961
{
391
961
  assert(optRDPosition != nullptr);
392
0
  assert(remaining != nullptr);
393
0
  const struct dnsheader* dh = reinterpret_cast<const struct dnsheader*>(packet.data());
394
395
961
  if (offset >= packet.size()) {
396
0
    return ENOENT;
397
0
  }
398
399
961
  if (ntohs(dh->qdcount) != 1 || ntohs(dh->ancount) != 0 || ntohs(dh->arcount) != 1 || ntohs(dh->nscount) != 0)
400
149
    return ENOENT;
401
402
812
  size_t pos = sizeof(dnsheader) + offset;
403
812
  pos += DNS_TYPE_SIZE + DNS_CLASS_SIZE;
404
405
812
  if (pos >= packet.size())
406
9
    return ENOENT;
407
408
803
  if ((pos + /* root */ 1 + DNS_TYPE_SIZE + DNS_CLASS_SIZE) >= packet.size()) {
409
18
    return ENOENT;
410
18
  }
411
412
785
  if (packet[pos] != 0) {
413
    /* not the root so not an OPT record */
414
103
    return ENOENT;
415
103
  }
416
682
  pos += 1;
417
418
682
  uint16_t qtype = packet.at(pos)*256 + packet.at(pos+1);
419
682
  pos += DNS_TYPE_SIZE;
420
682
  pos += DNS_CLASS_SIZE;
421
422
682
  if (qtype != QType::OPT || (packet.size() - pos) < (DNS_TTL_SIZE + DNS_RDLENGTH_SIZE)) {
423
143
    return ENOENT;
424
143
  }
425
426
539
  pos += DNS_TTL_SIZE;
427
539
  *optRDPosition = pos;
428
539
  *remaining = packet.size() - pos;
429
430
539
  return 0;
431
682
}
432
433
void generateECSOption(const ComboAddress& source, string& res, uint16_t ECSPrefixLength)
434
0
{
435
0
  Netmask sourceNetmask(source, ECSPrefixLength);
436
0
  EDNSSubnetOpts ecsOpts;
437
0
  ecsOpts.source = sourceNetmask;
438
0
  string payload = makeEDNSSubnetOptsString(ecsOpts);
439
0
  generateEDNSOption(EDNSOptionCode::ECS, payload, res);
440
0
}
441
442
bool generateOptRR(const std::string& optRData, PacketBuffer& res, size_t maximumSize, uint16_t udpPayloadSize, uint8_t ednsrcode, bool dnssecOK)
443
0
{
444
0
  const uint8_t name = 0;
445
0
  dnsrecordheader dh;
446
0
  EDNS0Record edns0;
447
0
  edns0.extRCode = ednsrcode;
448
0
  edns0.version = 0;
449
0
  edns0.extFlags = dnssecOK ? htons(EDNS_HEADER_FLAG_DO) : 0;
450
451
0
  if ((maximumSize - res.size()) < (sizeof(name) + sizeof(dh) + optRData.length())) {
452
0
    return false;
453
0
  }
454
455
0
  dh.d_type = htons(QType::OPT);
456
0
  dh.d_class = htons(udpPayloadSize);
457
0
  static_assert(sizeof(EDNS0Record) == sizeof(dh.d_ttl), "sizeof(EDNS0Record) must match sizeof(dnsrecordheader.d_ttl)");
458
0
  memcpy(&dh.d_ttl, &edns0, sizeof edns0);
459
0
  dh.d_clen = htons(static_cast<uint16_t>(optRData.length()));
460
461
0
  res.reserve(res.size() + sizeof(name) + sizeof(dh) + optRData.length());
462
0
  res.insert(res.end(), reinterpret_cast<const uint8_t*>(&name), reinterpret_cast<const uint8_t*>(&name) + sizeof(name));
463
0
  res.insert(res.end(), reinterpret_cast<const uint8_t*>(&dh), reinterpret_cast<const uint8_t*>(&dh) + sizeof(dh));
464
0
  res.insert(res.end(), reinterpret_cast<const uint8_t*>(optRData.data()), reinterpret_cast<const uint8_t*>(optRData.data()) + optRData.length());
465
466
0
  return true;
467
0
}
468
469
static bool replaceEDNSClientSubnetOption(PacketBuffer& packet, size_t maximumSize, size_t const oldEcsOptionStartPosition, size_t const oldEcsOptionSize, size_t const optRDLenPosition, const string& newECSOption)
470
0
{
471
0
  assert(oldEcsOptionStartPosition < packet.size());
472
0
  assert(optRDLenPosition < packet.size());
473
474
0
  if (newECSOption.size() == oldEcsOptionSize) {
475
    /* same size as the existing option */
476
0
    memcpy(&packet.at(oldEcsOptionStartPosition), newECSOption.c_str(), oldEcsOptionSize);
477
0
  }
478
0
  else {
479
    /* different size than the existing option */
480
0
    const unsigned int newPacketLen = packet.size() + (newECSOption.length() - oldEcsOptionSize);
481
0
    const size_t beforeOptionLen = oldEcsOptionStartPosition;
482
0
    const size_t dataBehindSize = packet.size() - beforeOptionLen - oldEcsOptionSize;
483
484
    /* check that it fits in the existing buffer */
485
0
    if (newPacketLen > packet.size()) {
486
0
      if (newPacketLen > maximumSize) {
487
0
        return false;
488
0
      }
489
490
0
      packet.resize(newPacketLen);
491
0
    }
492
493
    /* fix the size of ECS Option RDLen */
494
0
    uint16_t newRDLen = (packet.at(optRDLenPosition) * 256) + packet.at(optRDLenPosition + 1);
495
0
    newRDLen += (newECSOption.size() - oldEcsOptionSize);
496
0
    packet.at(optRDLenPosition) = newRDLen / 256;
497
0
    packet.at(optRDLenPosition + 1) = newRDLen % 256;
498
499
0
    if (dataBehindSize > 0) {
500
0
      memmove(&packet.at(oldEcsOptionStartPosition), &packet.at(oldEcsOptionStartPosition + oldEcsOptionSize), dataBehindSize);
501
0
    }
502
0
    memcpy(&packet.at(oldEcsOptionStartPosition + dataBehindSize), newECSOption.c_str(), newECSOption.size());
503
0
    packet.resize(newPacketLen);
504
0
  }
505
506
0
  return true;
507
0
}
508
509
/* This function looks for an OPT RR, return true if a valid one was found (even if there was no options)
510
   and false otherwise. */
511
bool parseEDNSOptions(const DNSQuestion& dq)
512
0
{
513
0
  const auto dh = dq.getHeader();
514
0
  if (dq.ednsOptions != nullptr) {
515
0
    return true;
516
0
  }
517
518
  // dq.ednsOptions is mutable
519
0
  dq.ednsOptions = std::make_unique<EDNSOptionViewMap>();
520
521
0
  if (ntohs(dh->arcount) == 0) {
522
    /* nothing in additional so no EDNS */
523
0
    return false;
524
0
  }
525
526
0
  if (ntohs(dh->ancount) != 0 || ntohs(dh->nscount) != 0 || ntohs(dh->arcount) > 1) {
527
0
    return slowParseEDNSOptions(dq.getData(), *dq.ednsOptions);
528
0
  }
529
530
0
  size_t remaining = 0;
531
0
  uint16_t optRDPosition;
532
0
  int res = getEDNSOptionsStart(dq.getData(), dq.ids.qname.wirelength(), &optRDPosition, &remaining);
533
534
0
  if (res == 0) {
535
0
    res = getEDNSOptions(reinterpret_cast<const char*>(&dq.getData().at(optRDPosition)), remaining, *dq.ednsOptions);
536
0
    return (res == 0);
537
0
  }
538
539
0
  return false;
540
0
}
541
542
static bool addECSToExistingOPT(PacketBuffer& packet, size_t maximumSize, const string& newECSOption, size_t optRDLenPosition, bool& ecsAdded)
543
0
{
544
  /* we need to add one EDNS0 ECS option, fixing the size of EDNS0 RDLENGTH */
545
  /* getEDNSOptionsStart has already checked that there is exactly one AR,
546
     no NS and no AN */
547
0
  uint16_t oldRDLen = (packet.at(optRDLenPosition) * 256) + packet.at(optRDLenPosition + 1);
548
0
  if (packet.size() != (optRDLenPosition + sizeof(uint16_t) + oldRDLen)) {
549
    /* we are supposed to be the last record, do we have some trailing data to remove? */
550
0
    uint32_t realPacketLen = getDNSPacketLength(reinterpret_cast<const char*>(packet.data()), packet.size());
551
0
    packet.resize(realPacketLen);
552
0
  }
553
554
0
  if ((maximumSize - packet.size()) < newECSOption.size()) {
555
0
    return false;
556
0
  }
557
558
0
  uint16_t newRDLen = oldRDLen + newECSOption.size();
559
0
  packet.at(optRDLenPosition) = newRDLen / 256;
560
0
  packet.at(optRDLenPosition + 1) = newRDLen % 256;
561
562
0
  packet.insert(packet.end(), newECSOption.begin(), newECSOption.end());
563
0
  ecsAdded = true;
564
565
0
  return true;
566
0
}
567
568
static bool addEDNSWithECS(PacketBuffer& packet, size_t maximumSize, const string& newECSOption, bool& ednsAdded, bool& ecsAdded)
569
0
{
570
0
  if (!generateOptRR(newECSOption, packet, maximumSize, g_EdnsUDPPayloadSize, 0, false)) {
571
0
    return false;
572
0
  }
573
574
0
  struct dnsheader* dh = reinterpret_cast<struct dnsheader*>(packet.data());
575
0
  uint16_t arcount = ntohs(dh->arcount);
576
0
  arcount++;
577
0
  dh->arcount = htons(arcount);
578
0
  ednsAdded = true;
579
0
  ecsAdded = true;
580
581
0
  return true;
582
0
}
583
584
bool handleEDNSClientSubnet(PacketBuffer& packet, const size_t maximumSize, const size_t qnameWireLength, bool& ednsAdded, bool& ecsAdded, bool overrideExisting, const string& newECSOption)
585
0
{
586
0
  assert(qnameWireLength <= packet.size());
587
588
0
  const struct dnsheader* dh = reinterpret_cast<const struct dnsheader*>(packet.data());
589
590
0
  if (ntohs(dh->ancount) != 0 || ntohs(dh->nscount) != 0 || (ntohs(dh->arcount) != 0 && ntohs(dh->arcount) != 1)) {
591
0
    PacketBuffer newContent;
592
0
    newContent.reserve(packet.size());
593
594
0
    if (!slowRewriteEDNSOptionInQueryWithRecords(packet, newContent, ednsAdded, EDNSOptionCode::ECS, ecsAdded, overrideExisting, newECSOption)) {
595
0
      return false;
596
0
    }
597
598
0
    if (newContent.size() > maximumSize) {
599
0
      ednsAdded = false;
600
0
      ecsAdded = false;
601
0
      return false;
602
0
    }
603
604
0
    packet = std::move(newContent);
605
0
    return true;
606
0
  }
607
608
0
  uint16_t optRDPosition = 0;
609
0
  size_t remaining = 0;
610
611
0
  int res = getEDNSOptionsStart(packet, qnameWireLength, &optRDPosition, &remaining);
612
613
0
  if (res != 0) {
614
    /* no EDNS but there might be another record in additional (TSIG?) */
615
    /* Careful, this code assumes that ANCOUNT == 0 && NSCOUNT == 0 */
616
0
    size_t minimumPacketSize = sizeof(dnsheader) + qnameWireLength + sizeof(uint16_t) + sizeof(uint16_t);
617
0
    if (packet.size() > minimumPacketSize) {
618
0
      if (ntohs(dh->arcount) == 0) {
619
        /* well now.. */
620
0
        packet.resize(minimumPacketSize);
621
0
      }
622
0
      else {
623
0
        uint32_t realPacketLen = getDNSPacketLength(reinterpret_cast<const char*>(packet.data()), packet.size());
624
0
        packet.resize(realPacketLen);
625
0
      }
626
0
    }
627
628
0
    return addEDNSWithECS(packet, maximumSize, newECSOption, ednsAdded, ecsAdded);
629
0
  }
630
631
0
  size_t ecsOptionStartPosition = 0;
632
0
  size_t ecsOptionSize = 0;
633
634
0
  res = getEDNSOption(reinterpret_cast<const char*>(&packet.at(optRDPosition)), remaining, EDNSOptionCode::ECS, &ecsOptionStartPosition, &ecsOptionSize);
635
636
0
  if (res == 0) {
637
    /* there is already an ECS value */
638
0
    if (!overrideExisting) {
639
0
      return true;
640
0
    }
641
642
0
    return replaceEDNSClientSubnetOption(packet, maximumSize, optRDPosition + ecsOptionStartPosition, ecsOptionSize, optRDPosition, newECSOption);
643
0
  } else {
644
    /* we have an EDNS OPT RR but no existing ECS option */
645
0
    return addECSToExistingOPT(packet, maximumSize, newECSOption, optRDPosition, ecsAdded);
646
0
  }
647
648
0
  return true;
649
0
}
650
651
bool handleEDNSClientSubnet(DNSQuestion& dq, bool& ednsAdded, bool& ecsAdded)
652
0
{
653
0
  string newECSOption;
654
0
  generateECSOption(dq.ecs ? dq.ecs->getNetwork() : dq.ids.origRemote, newECSOption, dq.ecs ? dq.ecs->getBits() : dq.ecsPrefixLength);
655
656
0
  return handleEDNSClientSubnet(dq.getMutableData(), dq.getMaximumSize(), dq.ids.qname.wirelength(), ednsAdded, ecsAdded, dq.ecsOverride, newECSOption);
657
0
}
658
659
static int removeEDNSOptionFromOptions(unsigned char* optionsStart, const uint16_t optionsLen, const uint16_t optionCodeToRemove, uint16_t* newOptionsLen)
660
0
{
661
0
  unsigned char* p = optionsStart;
662
0
  size_t pos = 0;
663
0
  while ((pos + 4) <= optionsLen) {
664
0
    unsigned char* optionBegin = p;
665
0
    const uint16_t optionCode = 0x100*p[0] + p[1];
666
0
    p += sizeof(optionCode);
667
0
    pos += sizeof(optionCode);
668
0
    const uint16_t optionLen = 0x100*p[0] + p[1];
669
0
    p += sizeof(optionLen);
670
0
    pos += sizeof(optionLen);
671
0
    if ((pos + optionLen) > optionsLen) {
672
0
      return EINVAL;
673
0
    }
674
0
    if (optionCode == optionCodeToRemove) {
675
0
      if (pos + optionLen < optionsLen) {
676
        /* move remaining options over the removed one,
677
           if any */
678
0
        memmove(optionBegin, p + optionLen, optionsLen - (pos + optionLen));
679
0
      }
680
0
      *newOptionsLen = optionsLen - (sizeof(optionCode) + sizeof(optionLen) + optionLen);
681
0
      return 0;
682
0
    }
683
0
    p += optionLen;
684
0
    pos += optionLen;
685
0
  }
686
0
  return ENOENT;
687
0
}
688
689
int removeEDNSOptionFromOPT(char* optStart, size_t* optLen, const uint16_t optionCodeToRemove)
690
0
{
691
0
  if (*optLen < optRecordMinimumSize) {
692
0
    return EINVAL;
693
0
  }
694
0
  const unsigned char* end = (const unsigned char*) optStart + *optLen;
695
0
  unsigned char* p = (unsigned char*) optStart + 9;
696
0
  unsigned char* rdLenPtr = p;
697
0
  uint16_t rdLen = (0x100*p[0] + p[1]);
698
0
  p += sizeof(rdLen);
699
0
  if (p + rdLen != end) {
700
0
    return EINVAL;
701
0
  }
702
0
  uint16_t newRdLen = 0;
703
0
  int res = removeEDNSOptionFromOptions(p, rdLen, optionCodeToRemove, &newRdLen);
704
0
  if (res != 0) {
705
0
    return res;
706
0
  }
707
0
  *optLen -= (rdLen - newRdLen);
708
0
  rdLenPtr[0] = newRdLen / 0x100;
709
0
  rdLenPtr[1] = newRdLen % 0x100;
710
0
  return 0;
711
0
}
712
713
bool isEDNSOptionInOpt(const PacketBuffer& packet, const size_t optStart, const size_t optLen, const uint16_t optionCodeToFind, size_t* optContentStart, uint16_t* optContentLen)
714
0
{
715
0
  if (optLen < optRecordMinimumSize) {
716
0
    return false;
717
0
  }
718
0
  size_t p = optStart + 9;
719
0
  uint16_t rdLen = (0x100*static_cast<unsigned char>(packet.at(p)) + static_cast<unsigned char>(packet.at(p+1)));
720
0
  p += sizeof(rdLen);
721
0
  if (rdLen > (optLen - optRecordMinimumSize)) {
722
0
    return false;
723
0
  }
724
725
0
  size_t rdEnd = p + rdLen;
726
0
  while ((p + 4) <= rdEnd) {
727
0
    const uint16_t optionCode = 0x100*static_cast<unsigned char>(packet.at(p)) + static_cast<unsigned char>(packet.at(p+1));
728
0
    p += sizeof(optionCode);
729
0
    const uint16_t optionLen = 0x100*static_cast<unsigned char>(packet.at(p)) + static_cast<unsigned char>(packet.at(p+1));
730
0
    p += sizeof(optionLen);
731
732
0
    if ((p + optionLen) > rdEnd) {
733
0
      return false;
734
0
    }
735
736
0
    if (optionCode == optionCodeToFind) {
737
0
      if (optContentStart != nullptr) {
738
0
        *optContentStart = p;
739
0
      }
740
741
0
      if (optContentLen != nullptr) {
742
0
        *optContentLen = optionLen;
743
0
      }
744
745
0
      return true;
746
0
    }
747
0
    p += optionLen;
748
0
  }
749
0
  return false;
750
0
}
751
752
int rewriteResponseWithoutEDNSOption(const PacketBuffer& initialPacket, const uint16_t optionCodeToSkip, PacketBuffer& newContent)
753
0
{
754
0
  assert(initialPacket.size() >= sizeof(dnsheader));
755
0
  const struct dnsheader* dh = reinterpret_cast<const struct dnsheader*>(initialPacket.data());
756
757
0
  if (ntohs(dh->arcount) == 0)
758
0
    return ENOENT;
759
760
0
  if (ntohs(dh->qdcount) == 0)
761
0
    return ENOENT;
762
763
0
  PacketReader pr(std::string_view(reinterpret_cast<const char*>(initialPacket.data()), initialPacket.size()));
764
765
0
  size_t idx = 0;
766
0
  DNSName rrname;
767
0
  uint16_t qdcount = ntohs(dh->qdcount);
768
0
  uint16_t ancount = ntohs(dh->ancount);
769
0
  uint16_t nscount = ntohs(dh->nscount);
770
0
  uint16_t arcount = ntohs(dh->arcount);
771
0
  uint16_t rrtype;
772
0
  uint16_t rrclass;
773
0
  string blob;
774
0
  struct dnsrecordheader ah;
775
776
0
  rrname = pr.getName();
777
0
  rrtype = pr.get16BitInt();
778
0
  rrclass = pr.get16BitInt();
779
780
0
  GenericDNSPacketWriter<PacketBuffer> pw(newContent, rrname, rrtype, rrclass, dh->opcode);
781
0
  pw.getHeader()->id=dh->id;
782
0
  pw.getHeader()->qr=dh->qr;
783
0
  pw.getHeader()->aa=dh->aa;
784
0
  pw.getHeader()->tc=dh->tc;
785
0
  pw.getHeader()->rd=dh->rd;
786
0
  pw.getHeader()->ra=dh->ra;
787
0
  pw.getHeader()->ad=dh->ad;
788
0
  pw.getHeader()->cd=dh->cd;
789
0
  pw.getHeader()->rcode=dh->rcode;
790
791
  /* consume remaining qd if any */
792
0
  if (qdcount > 1) {
793
0
    for(idx = 1; idx < qdcount; idx++) {
794
0
      rrname = pr.getName();
795
0
      rrtype = pr.get16BitInt();
796
0
      rrclass = pr.get16BitInt();
797
0
      (void) rrtype;
798
0
      (void) rrclass;
799
0
    }
800
0
  }
801
802
  /* copy AN and NS */
803
0
  for (idx = 0; idx < ancount; idx++) {
804
0
    rrname = pr.getName();
805
0
    pr.getDnsrecordheader(ah);
806
807
0
    pw.startRecord(rrname, ah.d_type, ah.d_ttl, ah.d_class, DNSResourceRecord::ANSWER, true);
808
0
    pr.xfrBlob(blob);
809
0
    pw.xfrBlob(blob);
810
0
  }
811
812
0
  for (idx = 0; idx < nscount; idx++) {
813
0
    rrname = pr.getName();
814
0
    pr.getDnsrecordheader(ah);
815
816
0
    pw.startRecord(rrname, ah.d_type, ah.d_ttl, ah.d_class, DNSResourceRecord::AUTHORITY, true);
817
0
    pr.xfrBlob(blob);
818
0
    pw.xfrBlob(blob);
819
0
  }
820
821
  /* consume AR, looking for OPT */
822
0
  for (idx = 0; idx < arcount; idx++) {
823
0
    rrname = pr.getName();
824
0
    pr.getDnsrecordheader(ah);
825
826
0
    if (ah.d_type != QType::OPT) {
827
0
      pw.startRecord(rrname, ah.d_type, ah.d_ttl, ah.d_class, DNSResourceRecord::ADDITIONAL, true);
828
0
      pr.xfrBlob(blob);
829
0
      pw.xfrBlob(blob);
830
0
    } else {
831
0
      pw.startRecord(rrname, ah.d_type, ah.d_ttl, ah.d_class, DNSResourceRecord::ADDITIONAL, false);
832
0
      pr.xfrBlob(blob);
833
0
      uint16_t rdLen = blob.length();
834
0
      removeEDNSOptionFromOptions((unsigned char*)blob.c_str(), rdLen, optionCodeToSkip, &rdLen);
835
      /* xfrBlob(string, size) completely ignores size.. */
836
0
      if (rdLen > 0) {
837
0
        blob.resize((size_t)rdLen);
838
0
        pw.xfrBlob(blob);
839
0
      } else {
840
0
        pw.commit();
841
0
      }
842
0
    }
843
0
  }
844
0
  pw.commit();
845
846
0
  return 0;
847
0
}
848
849
bool addEDNS(PacketBuffer& packet, size_t maximumSize, bool dnssecOK, uint16_t payloadSize, uint8_t ednsrcode)
850
0
{
851
0
  if (!generateOptRR(std::string(), packet, maximumSize, payloadSize, ednsrcode, dnssecOK)) {
852
0
    return false;
853
0
  }
854
855
0
  auto dh = reinterpret_cast<dnsheader*>(packet.data());
856
0
  dh->arcount = htons(ntohs(dh->arcount) + 1);
857
858
0
  return true;
859
0
}
860
861
/*
862
  This function keeps the existing header and DNSSECOK bit (if any) but wipes anything else,
863
  generating a NXD or NODATA answer with a SOA record in the additional section (or optionally the authority section for a full cacheable NXDOMAIN/NODATA).
864
*/
865
bool setNegativeAndAdditionalSOA(DNSQuestion& dq, bool nxd, const DNSName& zone, uint32_t ttl, const DNSName& mname, const DNSName& rname, uint32_t serial, uint32_t refresh, uint32_t retry, uint32_t expire, uint32_t minimum, bool soaInAuthoritySection)
866
0
{
867
0
  auto& packet = dq.getMutableData();
868
0
  auto dh = dq.getHeader();
869
0
  if (ntohs(dh->qdcount) != 1) {
870
0
    return false;
871
0
  }
872
873
0
  size_t queryPartSize = sizeof(dnsheader) + dq.ids.qname.wirelength() + DNS_TYPE_SIZE + DNS_CLASS_SIZE;
874
0
  if (packet.size() < queryPartSize) {
875
    /* something is already wrong, don't build on flawed foundations */
876
0
    return false;
877
0
  }
878
879
0
  uint16_t qtype = htons(QType::SOA);
880
0
  uint16_t qclass = htons(QClass::IN);
881
0
  uint16_t rdLength = mname.wirelength() + rname.wirelength() + sizeof(serial) + sizeof(refresh) + sizeof(retry) + sizeof(expire) + sizeof(minimum);
882
0
  size_t soaSize = zone.wirelength() + sizeof(qtype) + sizeof(qclass) + sizeof(ttl) + sizeof(rdLength) + rdLength;
883
0
  bool hadEDNS = false;
884
0
  bool dnssecOK = false;
885
886
0
  if (g_addEDNSToSelfGeneratedResponses) {
887
0
    uint16_t payloadSize = 0;
888
0
    uint16_t z = 0;
889
0
    hadEDNS = getEDNSUDPPayloadSizeAndZ(reinterpret_cast<const char*>(packet.data()), packet.size(), &payloadSize, &z);
890
0
    if (hadEDNS) {
891
0
      dnssecOK = z & EDNS_HEADER_FLAG_DO;
892
0
    }
893
0
  }
894
895
  /* chop off everything after the question */
896
0
  packet.resize(queryPartSize);
897
0
  dh = dq.getHeader();
898
0
  if (nxd) {
899
0
    dh->rcode = RCode::NXDomain;
900
0
  }
901
0
  else {
902
0
    dh->rcode = RCode::NoError;
903
0
  }
904
0
  dh->qr = true;
905
0
  dh->ancount = 0;
906
0
  dh->nscount = 0;
907
0
  dh->arcount = 0;
908
909
0
  rdLength = htons(rdLength);
910
0
  ttl = htonl(ttl);
911
0
  serial = htonl(serial);
912
0
  refresh = htonl(refresh);
913
0
  retry = htonl(retry);
914
0
  expire = htonl(expire);
915
0
  minimum = htonl(minimum);
916
917
0
  std::string soa;
918
0
  soa.reserve(soaSize);
919
0
  soa.append(zone.toDNSString());
920
0
  soa.append(reinterpret_cast<const char*>(&qtype), sizeof(qtype));
921
0
  soa.append(reinterpret_cast<const char*>(&qclass), sizeof(qclass));
922
0
  soa.append(reinterpret_cast<const char*>(&ttl), sizeof(ttl));
923
0
  soa.append(reinterpret_cast<const char*>(&rdLength), sizeof(rdLength));
924
0
  soa.append(mname.toDNSString());
925
0
  soa.append(rname.toDNSString());
926
0
  soa.append(reinterpret_cast<const char*>(&serial), sizeof(serial));
927
0
  soa.append(reinterpret_cast<const char*>(&refresh), sizeof(refresh));
928
0
  soa.append(reinterpret_cast<const char*>(&retry), sizeof(retry));
929
0
  soa.append(reinterpret_cast<const char*>(&expire), sizeof(expire));
930
0
  soa.append(reinterpret_cast<const char*>(&minimum), sizeof(minimum));
931
932
0
  if (soa.size() != soaSize) {
933
0
    throw std::runtime_error("Unexpected SOA response size: " + std::to_string(soa.size()) + " vs " + std::to_string(soaSize));
934
0
  }
935
936
0
  packet.insert(packet.end(), soa.begin(), soa.end());
937
0
  dh = dq.getHeader();
938
939
  /* We are populating a response with only the query in place, order of sections is QD,AN,NS,AR
940
     NS (authority) is before AR (additional) so we can just decide which section the SOA record is in here
941
     and have EDNS added to AR afterwards */
942
0
  if (soaInAuthoritySection) {
943
0
    dh->nscount = htons(1);
944
0
  } else {
945
0
    dh->arcount = htons(1);
946
0
  }
947
948
0
  if (hadEDNS) {
949
    /* now we need to add a new OPT record */
950
0
    return addEDNS(packet, dq.getMaximumSize(), dnssecOK, g_PayloadSizeSelfGenAnswers, dq.ednsRCode);
951
0
  }
952
953
0
  return true;
954
0
}
955
956
bool addEDNSToQueryTurnedResponse(DNSQuestion& dq)
957
0
{
958
0
  uint16_t optRDPosition;
959
  /* remaining is at least the size of the rdlen + the options if any + the following records if any */
960
0
  size_t remaining = 0;
961
962
0
  auto& packet = dq.getMutableData();
963
0
  int res = getEDNSOptionsStart(packet, dq.ids.qname.wirelength(), &optRDPosition, &remaining);
964
965
0
  if (res != 0) {
966
    /* if the initial query did not have EDNS0, we are done */
967
0
    return true;
968
0
  }
969
970
0
  const size_t existingOptLen = /* root */ 1 + DNS_TYPE_SIZE + DNS_CLASS_SIZE + EDNS_EXTENDED_RCODE_SIZE + EDNS_VERSION_SIZE + /* Z */ 2 + remaining;
971
0
  if (existingOptLen >= packet.size()) {
972
    /* something is wrong, bail out */
973
0
    return false;
974
0
  }
975
976
0
  uint8_t* optRDLen = &packet.at(optRDPosition);
977
0
  uint8_t* optPtr = (optRDLen - (/* root */ 1 + DNS_TYPE_SIZE + DNS_CLASS_SIZE + EDNS_EXTENDED_RCODE_SIZE + EDNS_VERSION_SIZE + /* Z */ 2));
978
979
0
  const uint8_t* zPtr = optPtr + /* root */ 1 + DNS_TYPE_SIZE + DNS_CLASS_SIZE + EDNS_EXTENDED_RCODE_SIZE + EDNS_VERSION_SIZE;
980
0
  uint16_t z = 0x100 * (*zPtr) + *(zPtr + 1);
981
0
  bool dnssecOK = z & EDNS_HEADER_FLAG_DO;
982
983
  /* remove the existing OPT record, and everything else that follows (any SIG or TSIG would be useless anyway) */
984
0
  packet.resize(packet.size() - existingOptLen);
985
0
  dq.getHeader()->arcount = 0;
986
987
0
  if (g_addEDNSToSelfGeneratedResponses) {
988
    /* now we need to add a new OPT record */
989
0
    return addEDNS(packet, dq.getMaximumSize(), dnssecOK, g_PayloadSizeSelfGenAnswers, dq.ednsRCode);
990
0
  }
991
992
  /* otherwise we are just fine */
993
0
  return true;
994
0
}
995
996
// goal in life - if you send us a reasonably normal packet, we'll get Z for you, otherwise 0
997
int getEDNSZ(const DNSQuestion& dq)
998
0
{
999
0
  try
1000
0
  {
1001
0
    const auto& dh = dq.getHeader();
1002
0
    if (ntohs(dh->qdcount) != 1 || dh->ancount != 0 || ntohs(dh->arcount) != 1 || dh->nscount != 0) {
1003
0
      return 0;
1004
0
    }
1005
1006
0
    if (dq.getData().size() <= sizeof(dnsheader)) {
1007
0
      return 0;
1008
0
    }
1009
1010
0
    size_t pos = sizeof(dnsheader) + dq.ids.qname.wirelength() + DNS_TYPE_SIZE + DNS_CLASS_SIZE;
1011
1012
0
    if (dq.getData().size() <= (pos + /* root */ 1 + DNS_TYPE_SIZE + DNS_CLASS_SIZE)) {
1013
0
      return 0;
1014
0
    }
1015
1016
0
    auto& packet = dq.getData();
1017
1018
0
    if (packet.at(pos) != 0) {
1019
      /* not root, so not a valid OPT record */
1020
0
      return 0;
1021
0
    }
1022
1023
0
    pos++;
1024
1025
0
    uint16_t qtype = packet.at(pos)*256 + packet.at(pos+1);
1026
0
    pos += DNS_TYPE_SIZE;
1027
0
    pos += DNS_CLASS_SIZE;
1028
1029
0
    if (qtype != QType::OPT || (pos + EDNS_EXTENDED_RCODE_SIZE + EDNS_VERSION_SIZE + 1) >= packet.size()) {
1030
0
      return 0;
1031
0
    }
1032
1033
0
    const uint8_t* z = &packet.at(pos + EDNS_EXTENDED_RCODE_SIZE + EDNS_VERSION_SIZE);
1034
0
    return 0x100 * (*z) + *(z+1);
1035
0
  }
1036
0
  catch(...)
1037
0
  {
1038
0
    return 0;
1039
0
  }
1040
0
}
1041
1042
bool queryHasEDNS(const DNSQuestion& dq)
1043
0
{
1044
0
  uint16_t optRDPosition;
1045
0
  size_t ecsRemaining = 0;
1046
1047
0
  int res = getEDNSOptionsStart(dq.getData(), dq.ids.qname.wirelength(), &optRDPosition, &ecsRemaining);
1048
0
  if (res == 0) {
1049
0
    return true;
1050
0
  }
1051
1052
0
  return false;
1053
0
}
1054
1055
bool getEDNS0Record(const PacketBuffer& packet, EDNS0Record& edns0)
1056
0
{
1057
0
  uint16_t optStart;
1058
0
  size_t optLen = 0;
1059
0
  bool last = false;
1060
0
  int res = locateEDNSOptRR(packet, &optStart, &optLen, &last);
1061
0
  if (res != 0) {
1062
    // no EDNS OPT RR
1063
0
    return false;
1064
0
  }
1065
1066
0
  if (optLen < optRecordMinimumSize) {
1067
0
    return false;
1068
0
  }
1069
1070
0
  if (optStart < packet.size() && packet.at(optStart) != 0) {
1071
    // OPT RR Name != '.'
1072
0
    return false;
1073
0
  }
1074
1075
0
  static_assert(sizeof(EDNS0Record) == sizeof(uint32_t), "sizeof(EDNS0Record) must match sizeof(uint32_t) AKA RR TTL size");
1076
  // copy out 4-byte "ttl" (really the EDNS0 record), after root label (1) + type (2) + class (2).
1077
0
  memcpy(&edns0, &packet.at(optStart + 5), sizeof edns0);
1078
0
  return true;
1079
0
}
1080
1081
bool setEDNSOption(DNSQuestion& dq, uint16_t ednsCode, const std::string& ednsData)
1082
0
{
1083
0
  std::string optRData;
1084
0
  generateEDNSOption(ednsCode, ednsData, optRData);
1085
1086
0
  if (dq.getHeader()->arcount) {
1087
0
    bool ednsAdded = false;
1088
0
    bool optionAdded = false;
1089
0
    PacketBuffer newContent;
1090
0
    newContent.reserve(dq.getData().size());
1091
1092
0
    if (!slowRewriteEDNSOptionInQueryWithRecords(dq.getData(), newContent, ednsAdded, ednsCode, optionAdded, true, optRData)) {
1093
0
      return false;
1094
0
    }
1095
1096
0
    if (newContent.size() > dq.getMaximumSize()) {
1097
0
      return false;
1098
0
    }
1099
1100
0
    dq.getMutableData() = std::move(newContent);
1101
0
    if (!dq.ids.ednsAdded && ednsAdded) {
1102
0
      dq.ids.ednsAdded = true;
1103
0
    }
1104
1105
0
    return true;
1106
0
  }
1107
1108
0
  auto& data = dq.getMutableData();
1109
0
  if (generateOptRR(optRData, data, dq.getMaximumSize(), g_EdnsUDPPayloadSize, 0, false)) {
1110
0
    dq.getHeader()->arcount = htons(1);
1111
    // make sure that any EDNS sent by the backend is removed before forwarding the response to the client
1112
0
    dq.ids.ednsAdded = true;
1113
0
  }
1114
1115
0
  return true;
1116
0
}
1117
1118
namespace dnsdist {
1119
bool setInternalQueryRCode(InternalQueryState& state, PacketBuffer& buffer,  uint8_t rcode, bool clearAnswers)
1120
0
{
1121
0
  const auto qnameLength = state.qname.wirelength();
1122
0
  if (buffer.size() < sizeof(dnsheader) + qnameLength + sizeof(uint16_t) + sizeof(uint16_t)) {
1123
0
    return false;
1124
0
  }
1125
1126
0
  EDNS0Record edns0;
1127
0
  bool hadEDNS = false;
1128
0
  if (clearAnswers) {
1129
0
    hadEDNS = getEDNS0Record(buffer, edns0);
1130
0
  }
1131
1132
0
  auto dh = reinterpret_cast<dnsheader*>(buffer.data());
1133
0
  dh->rcode = rcode;
1134
0
  dh->ad = false;
1135
0
  dh->aa = false;
1136
0
  dh->ra = dh->rd;
1137
0
  dh->qr = true;
1138
1139
0
  if (clearAnswers) {
1140
0
    dh->ancount = 0;
1141
0
    dh->nscount = 0;
1142
0
    dh->arcount = 0;
1143
0
    buffer.resize(sizeof(dnsheader) + qnameLength + sizeof(uint16_t) + sizeof(uint16_t));
1144
0
    if (hadEDNS) {
1145
0
      DNSQuestion dq(state, buffer);
1146
0
      if (!addEDNS(buffer, dq.getMaximumSize(), edns0.extFlags & htons(EDNS_HEADER_FLAG_DO), g_PayloadSizeSelfGenAnswers, 0)) {
1147
0
        return false;
1148
0
      }
1149
0
    }
1150
0
  }
1151
1152
0
  return true;
1153
0
}
1154
}