Coverage Report

Created: 2026-03-07 06:10

next uncovered line (L), next uncovered region (R), next uncovered branch (B)
/src/pdns/pdns/dnsdistdist/tcpiohandler.hh
Line
Count
Source
1
2
#pragma once
3
#include <memory>
4
/* needed for proper TCP_FASTOPEN_CONNECT detection */
5
#include <netinet/tcp.h>
6
7
#include "iputils.hh"
8
#include "libssl.hh"
9
#include "misc.hh"
10
#include "noinitvector.hh"
11
12
/* Async is only returned for TLS connections, if OpenSSL's async mode has been enabled */
13
enum class IOState : uint8_t { Done, NeedRead, NeedWrite, Async };
14
15
class TLSSession
16
{
17
public:
18
  virtual ~TLSSession() = default;
19
};
20
21
class TLSConnection
22
{
23
public:
24
  virtual ~TLSConnection() = default;
25
  virtual void doHandshake() = 0;
26
  virtual IOState tryConnect(bool fastOpen, const ComboAddress& remote) = 0;
27
  virtual void connect(bool fastOpen, const ComboAddress& remote, const struct timeval& timeout) = 0;
28
  virtual IOState tryHandshake() = 0;
29
  virtual size_t read(void* buffer, size_t bufferSize, const struct timeval& readTimeout, const struct timeval& totalTimeout={0,0}, bool allowIncomplete=false) = 0;
30
  virtual size_t write(const void* buffer, size_t bufferSize, const struct timeval& writeTimeout) = 0;
31
  virtual IOState tryWrite(const PacketBuffer& buffer, size_t& pos, size_t toWrite) = 0;
32
  virtual IOState tryRead(PacketBuffer& buffer, size_t& pos, size_t toRead, bool allowIncomplete=false) = 0;
33
  virtual std::string getServerNameIndication() const = 0;
34
  virtual std::vector<uint8_t> getNextProtocol() const = 0;
35
  virtual LibsslTLSVersion getTLSVersion() const = 0;
36
  virtual bool hasSessionBeenResumed() const = 0;
37
  virtual std::vector<std::unique_ptr<TLSSession>> getSessions() = 0;
38
  virtual void setSession(std::unique_ptr<TLSSession>& session) = 0;
39
  virtual bool isUsable() const = 0;
40
  virtual std::vector<int> getAsyncFDs() = 0;
41
  virtual void close() = 0;
42
  [[nodiscard]] virtual std::pair<long, std::string> getVerifyResult() const = 0;
43
44
  void setUnknownTicketKey()
45
0
  {
46
0
    d_unknownTicketKey = true;
47
0
  }
48
49
  bool getUnknownTicketKey() const
50
0
  {
51
0
    return d_unknownTicketKey;
52
0
  }
53
54
  void setResumedFromInactiveTicketKey()
55
0
  {
56
0
    d_resumedFromInactiveTicketKey = true;
57
0
  }
58
59
  bool getResumedFromInactiveTicketKey() const
60
0
  {
61
0
    return d_resumedFromInactiveTicketKey;
62
0
  }
63
64
protected:
65
  int d_socket{-1};
66
  bool d_unknownTicketKey{false};
67
  bool d_resumedFromInactiveTicketKey{false};
68
};
69
70
class TLSCtx
71
{
72
public:
73
  TLSCtx()
74
0
  {
75
0
    d_rotatingTicketsKey.clear();
76
0
  }
77
  virtual ~TLSCtx() = default;
78
  virtual std::unique_ptr<TLSConnection> getConnection(int socket, const struct timeval& timeout, time_t now) = 0;
79
  virtual std::unique_ptr<TLSConnection> getClientConnection(const std::string& host, bool hostIsAddr, int socket, const struct timeval& timeout) = 0;
80
  virtual void rotateTicketsKey(time_t now) = 0;
81
  virtual void loadTicketsKeys(const std::string& /* file */)
82
0
  {
83
0
    throw std::runtime_error("This TLS backend does not have the capability to load a tickets key from a file");
84
0
  }
85
  virtual void loadTicketsKey(const std::string& /* key */)
86
0
  {
87
0
    throw std::runtime_error("This TLS backend does not have the capability to load a ticket key");
88
0
  }
89
  void handleTicketsKeyRotation(time_t now)
90
0
  {
91
0
    if (d_ticketsKeyRotationDelay != 0 && now > d_ticketsKeyNextRotation) {
92
0
      if (d_rotatingTicketsKey.test_and_set()) {
93
0
        /* someone is already rotating */
94
0
        return;
95
0
      }
96
0
      try {
97
0
        rotateTicketsKey(now);
98
0
        d_rotatingTicketsKey.clear();
99
0
      }
100
0
      catch(const std::runtime_error& e) {
101
0
        d_rotatingTicketsKey.clear();
102
0
        throw std::runtime_error(std::string("Error generating a new tickets key for TLS context:") + e.what());
103
0
      }
104
0
      catch(...) {
105
0
        d_rotatingTicketsKey.clear();
106
0
        throw;
107
0
      }
108
0
    }
109
0
  }
110
111
  time_t getNextTicketsKeyRotation() const
112
0
  {
113
0
    return d_ticketsKeyNextRotation;
114
0
  }
115
116
  virtual size_t getTicketsKeysCount() = 0;
117
  virtual std::string getName() const = 0;
118
119
  using tickets_key_added_hook = std::function<void(const std::string& key)>;
120
121
  static void setTicketsKeyAddedHook(const tickets_key_added_hook& hook)
122
0
  {
123
0
    TLSCtx::s_ticketsKeyAddedHook = hook;
124
0
  }
125
  static const tickets_key_added_hook& getTicketsKeyAddedHook()
126
0
  {
127
0
    return TLSCtx::s_ticketsKeyAddedHook;
128
0
  }
129
  static bool hasTicketsKeyAddedHook()
130
0
  {
131
0
    return TLSCtx::s_ticketsKeyAddedHook != nullptr;
132
0
  }
133
protected:
134
  std::atomic_flag d_rotatingTicketsKey;
135
  std::atomic<time_t> d_ticketsKeyNextRotation{0};
136
  time_t d_ticketsKeyRotationDelay{0};
137
138
private:
139
  static tickets_key_added_hook s_ticketsKeyAddedHook;
140
};
141
142
class TLSFrontend
143
{
144
public:
145
  enum class ALPN : uint8_t { Unset, DoT, DoH };
146
147
  TLSFrontend(ALPN alpn): d_alpn(alpn)
148
0
  {
149
0
  }
150
151
  TLSFrontend(std::shared_ptr<TLSCtx> ctx): d_ctx(std::move(ctx))
152
0
  {
153
0
  }
154
155
  bool setupTLS();
156
157
  void rotateTicketsKey(time_t now)
158
0
  {
159
0
    if (d_ctx != nullptr && d_parentFrontend == nullptr) {
160
0
      d_ctx->rotateTicketsKey(now);
161
0
    }
162
0
  }
163
164
  void loadTicketsKeys(const std::string& file)
165
0
  {
166
0
    if (d_ctx != nullptr && d_parentFrontend == nullptr) {
167
0
      d_ctx->loadTicketsKeys(file);
168
0
    }
169
0
  }
170
171
  void loadTicketsKey(const std::string& key)
172
0
  {
173
0
    if (d_ctx != nullptr && d_parentFrontend == nullptr) {
174
0
      d_ctx->loadTicketsKey(key);
175
0
    }
176
0
  }
177
178
  std::shared_ptr<TLSCtx> getContext() const
179
0
  {
180
0
    return std::atomic_load_explicit(&d_ctx, std::memory_order_acquire);
181
0
  }
182
183
  void setParent(std::shared_ptr<const TLSFrontend> parent)
184
0
  {
185
0
    std::atomic_store_explicit(&d_parentFrontend, std::move(parent), std::memory_order_release);
186
0
  }
187
188
  void cleanup()
189
0
  {
190
0
    d_ctx.reset();
191
0
  }
192
193
  size_t getTicketsKeysCount()
194
0
  {
195
0
    if (d_ctx != nullptr) {
196
0
      return d_ctx->getTicketsKeysCount();
197
0
    }
198
0
199
0
    return 0;
200
0
  }
201
202
  static std::string timeToString(time_t rotationTime)
203
0
  {
204
0
    char buf[20];
205
0
    struct tm date_tm;
206
0
207
0
    localtime_r(&rotationTime, &date_tm);
208
0
    strftime(buf, sizeof(buf), "%Y-%m-%d %H:%M:%S", &date_tm);
209
0
210
0
    return std::string(buf);
211
0
  }
212
213
  time_t getTicketsKeyRotationDelay() const
214
0
  {
215
0
    return d_tlsConfig.d_ticketsKeyRotationDelay;
216
0
  }
217
218
  std::string getNextTicketsKeyRotation() const
219
0
  {
220
0
    std::string res;
221
0
222
0
    if (d_ctx != nullptr) {
223
0
      res = timeToString(d_ctx->getNextTicketsKeyRotation());
224
0
    }
225
0
226
0
    return res;
227
0
  }
228
229
  std::string getRequestedProvider() const
230
0
  {
231
0
    return d_provider;
232
0
  }
233
234
  std::string getEffectiveProvider() const
235
0
  {
236
0
    if (d_ctx) {
237
0
      return d_ctx->getName();
238
0
    }
239
0
    return "";
240
0
  }
241
242
  TLSConfig d_tlsConfig;
243
  TLSErrorCounters d_tlsCounters;
244
  ComboAddress d_addr;
245
  std::string d_provider;
246
  ALPN d_alpn{ALPN::Unset};
247
  /* whether the proxy protocol is inside or outside the TLS layer */
248
  bool d_proxyProtocolOutsideTLS{false};
249
protected:
250
  std::shared_ptr<TLSCtx> d_ctx{nullptr};
251
  std::shared_ptr<const TLSFrontend> d_parentFrontend{nullptr};
252
};
253
254
class TCPIOHandler
255
{
256
public:
257
  TCPIOHandler(const std::string& host, bool hostIsAddr, int socket, const struct timeval& timeout, const std::shared_ptr<TLSCtx>& ctx) :
258
    d_socket(socket)
259
0
  {
260
0
    if (ctx) {
261
0
      d_conn = ctx->getClientConnection(host, hostIsAddr, d_socket, timeout);
262
0
    }
263
0
  }
264
265
  TCPIOHandler(int socket, const struct timeval& timeout, const std::shared_ptr<TLSCtx>& ctx, time_t now) :
266
    d_socket(socket)
267
0
  {
268
0
    if (ctx) {
269
0
      d_conn = ctx->getConnection(d_socket, timeout, now);
270
0
    }
271
0
  }
272
273
  ~TCPIOHandler()
274
0
  {
275
0
    close();
276
0
  }
277
278
  void close()
279
0
  {
280
0
    if (d_conn) {
281
0
      d_conn->close();
282
0
      d_conn.reset();
283
0
    }
284
0
285
0
    if (d_socket != -1) {
286
0
      shutdown(d_socket, SHUT_RDWR);
287
0
      ::close(d_socket);
288
0
      d_socket = -1;
289
0
    }
290
0
  }
291
292
  int getDescriptor() const
293
0
  {
294
0
    return d_socket;
295
0
  }
296
297
  IOState tryConnect(bool fastOpen, const ComboAddress& remote)
298
0
  {
299
0
    d_remote = remote;
300
0
301
0
#ifdef TCP_FASTOPEN_CONNECT /* Linux >= 4.11 */
302
0
    if (fastOpen) {
303
0
      int value = 1;
304
0
      int res = setsockopt(d_socket, IPPROTO_TCP, TCP_FASTOPEN_CONNECT, &value, sizeof(value));
305
0
      if (res == 0) {
306
0
        fastOpen = false;
307
0
      }
308
0
    }
309
0
#endif /* TCP_FASTOPEN_CONNECT */
310
0
311
0
#ifdef MSG_FASTOPEN
312
0
    if (!d_conn && fastOpen) {
313
0
      d_fastOpen = true;
314
0
    }
315
0
    else {
316
0
      if (!s_disableConnectForUnitTests) {
317
0
        SConnectWithTimeout(d_socket, fastOpen, remote, /* no timeout, we will handle it ourselves */ timeval{0,0});
318
0
      }
319
0
    }
320
0
#else
321
0
    if (!s_disableConnectForUnitTests) {
322
0
      SConnectWithTimeout(d_socket, fastOpen, remote, /* no timeout, we will handle it ourselves */ timeval{0,0});
323
0
    }
324
0
#endif /* MSG_FASTOPEN */
325
0
326
0
    if (d_conn) {
327
0
      return d_conn->tryConnect(fastOpen, remote);
328
0
    }
329
0
330
0
    return IOState::Done;
331
0
  }
332
333
  void connect(bool fastOpen, const ComboAddress& remote, const struct timeval& timeout)
334
0
  {
335
0
    d_remote = remote;
336
0
337
0
#ifdef TCP_FASTOPEN_CONNECT /* Linux >= 4.11 */
338
0
    if (fastOpen) {
339
0
      int value = 1;
340
0
      int res = setsockopt(d_socket, IPPROTO_TCP, TCP_FASTOPEN_CONNECT, &value, sizeof(value));
341
0
      if (res == 0) {
342
0
        fastOpen = false;
343
0
      }
344
0
    }
345
0
#endif /* TCP_FASTOPEN_CONNECT */
346
0
347
0
#ifdef MSG_FASTOPEN
348
0
    if (!d_conn && fastOpen) {
349
0
      d_fastOpen = true;
350
0
    }
351
0
    else {
352
0
      if (!s_disableConnectForUnitTests) {
353
0
        SConnectWithTimeout(d_socket, fastOpen, remote, timeout);
354
0
      }
355
0
    }
356
0
#else
357
0
    if (!s_disableConnectForUnitTests) {
358
0
      SConnectWithTimeout(d_socket, fastOpen, remote, timeout);
359
0
    }
360
0
#endif /* MSG_FASTOPEN */
361
0
362
0
    if (d_conn) {
363
0
      d_conn->connect(fastOpen, remote, timeout);
364
0
    }
365
0
  }
366
367
  IOState tryHandshake()
368
0
  {
369
0
    if (d_conn) {
370
0
      return d_conn->tryHandshake();
371
0
    }
372
0
    return IOState::Done;
373
0
  }
374
375
  size_t read(void* buffer, size_t bufferSize, const struct timeval& readTimeout, const struct timeval& totalTimeout = {0,0}, bool allowIncomplete=false)
376
0
  {
377
0
    if (d_conn) {
378
0
      return d_conn->read(buffer, bufferSize, readTimeout, totalTimeout, allowIncomplete);
379
0
    } else {
380
0
      return readn2WithTimeout(d_socket, buffer, bufferSize, readTimeout, totalTimeout, allowIncomplete);
381
0
    }
382
0
  }
383
384
  /* Tries to read exactly toRead - pos bytes into the buffer, starting at position pos.
385
     Updates pos every time a successful read occurs,
386
     throws an std::runtime_error in case of IO error,
387
     return Done when toRead bytes have been read, needRead or needWrite if the IO operation
388
     would block.
389
  */
390
  IOState tryRead(PacketBuffer& buffer, size_t& pos, size_t toRead, bool allowIncomplete=false, bool bypassFilters=false)
391
0
  {
392
0
    if (buffer.size() < toRead || pos >= toRead) {
393
0
      throw std::out_of_range("Calling tryRead() with a too small buffer (" + std::to_string(buffer.size()) + ") for a read of " + std::to_string(toRead - pos) + " bytes starting at " + std::to_string(pos));
394
0
    }
395
0
396
0
    if (!bypassFilters && d_conn) {
397
0
      return d_conn->tryRead(buffer, pos, toRead, allowIncomplete);
398
0
    }
399
0
400
0
    do {
401
0
      ssize_t res = ::read(d_socket, reinterpret_cast<char*>(&buffer.at(pos)), toRead - pos);
402
0
      if (res == 0) {
403
0
        throw runtime_error("EOF while reading message");
404
0
      }
405
0
      if (res < 0) {
406
0
        if (errno == EAGAIN || errno == EWOULDBLOCK || errno == ENOTCONN) {
407
0
          return IOState::NeedRead;
408
0
        }
409
0
        else {
410
0
          throw std::runtime_error("Error while reading message: " + stringerror());
411
0
        }
412
0
      }
413
0
414
0
      pos += static_cast<size_t>(res);
415
0
      if (allowIncomplete) {
416
0
        break;
417
0
      }
418
0
    }
419
0
    while (pos < toRead);
420
0
421
0
    return IOState::Done;
422
0
  }
423
424
  /* Tries to write exactly toWrite - pos bytes from the buffer, starting at position pos.
425
     Updates pos every time a successful write occurs,
426
     throws an std::runtime_error in case of IO error,
427
     return Done when toWrite bytes have been written, needRead or needWrite if the IO operation
428
     would block.
429
  */
430
  IOState tryWrite(const PacketBuffer& buffer, size_t& pos, size_t toWrite)
431
0
  {
432
0
    if (buffer.size() < toWrite || pos >= toWrite) {
433
0
      throw std::out_of_range("Calling tryWrite() with a too small buffer (" + std::to_string(buffer.size()) + ") for a write of " + std::to_string(toWrite - pos) + " bytes starting at " + std::to_string(pos));
434
0
    }
435
0
    if (d_conn) {
436
0
      return d_conn->tryWrite(buffer, pos, toWrite);
437
0
    }
438
0
439
0
#ifdef MSG_FASTOPEN
440
0
    if (d_fastOpen) {
441
0
      int socketFlags = MSG_FASTOPEN;
442
0
      size_t sent = sendMsgWithOptions(d_socket, reinterpret_cast<const char *>(&buffer.at(pos)), toWrite - pos, &d_remote, nullptr, 0, socketFlags);
443
0
      if (sent > 0) {
444
0
        d_fastOpen = false;
445
0
        pos += sent;
446
0
      }
447
0
448
0
      if (pos < toWrite) {
449
0
        return IOState::NeedWrite;
450
0
      }
451
0
452
0
      return IOState::Done;
453
0
    }
454
0
#endif /* MSG_FASTOPEN */
455
0
456
0
    do {
457
0
      ssize_t res = ::write(d_socket, reinterpret_cast<const char*>(&buffer.at(pos)), toWrite - pos);
458
0
459
0
      if (res == 0) {
460
0
        throw runtime_error("EOF while sending message");
461
0
      }
462
0
      if (res < 0) {
463
0
        if (errno == EAGAIN || errno == EWOULDBLOCK || errno == ENOTCONN) {
464
0
          return IOState::NeedWrite;
465
0
        }
466
0
        else {
467
0
          throw std::runtime_error("Error while writing message: " + stringerror());
468
0
        }
469
0
      }
470
0
471
0
      pos += static_cast<size_t>(res);
472
0
    }
473
0
    while (pos < toWrite);
474
0
475
0
    return IOState::Done;
476
0
  }
477
478
  size_t write(const void* buffer, size_t bufferSize, const struct timeval& writeTimeout)
479
0
  {
480
0
    if (d_conn) {
481
0
      return d_conn->write(buffer, bufferSize, writeTimeout);
482
0
    }
483
0
484
0
#ifdef MSG_FASTOPEN
485
0
    if (d_fastOpen) {
486
0
      int socketFlags = MSG_FASTOPEN;
487
0
      size_t sent = sendMsgWithOptions(d_socket, reinterpret_cast<const char *>(buffer), bufferSize, &d_remote, nullptr, 0, socketFlags);
488
0
      if (sent > 0) {
489
0
        d_fastOpen = false;
490
0
      }
491
0
492
0
      return sent;
493
0
    }
494
0
#endif /* MSG_FASTOPEN */
495
0
496
0
    return writen2WithTimeout(d_socket, buffer, bufferSize, writeTimeout);
497
0
  }
498
499
  std::string getServerNameIndication() const
500
0
  {
501
0
    if (d_conn) {
502
0
      return d_conn->getServerNameIndication();
503
0
    }
504
0
    return std::string();
505
0
  }
506
507
  std::vector<uint8_t> getNextProtocol() const
508
0
  {
509
0
    if (d_conn) {
510
0
      return d_conn->getNextProtocol();
511
0
    }
512
0
    return std::vector<uint8_t>();
513
0
  }
514
515
  LibsslTLSVersion getTLSVersion() const
516
0
  {
517
0
    if (d_conn) {
518
0
      return d_conn->getTLSVersion();
519
0
    }
520
0
    return LibsslTLSVersion::Unknown;
521
0
  }
522
523
  bool isTLS() const
524
0
  {
525
0
    return d_conn != nullptr;
526
0
  }
527
528
  [[nodiscard]] std::pair<long, std::string> getVerifyResult() const
529
0
  {
530
0
    if (d_conn) {
531
0
      return d_conn->getVerifyResult();
532
0
    }
533
0
    return {0, ""};
534
0
  }
535
536
  bool hasTLSSessionBeenResumed() const
537
0
  {
538
0
    return d_conn && d_conn->hasSessionBeenResumed();
539
0
  }
540
541
  bool getResumedFromInactiveTicketKey() const
542
0
  {
543
0
    return d_conn && d_conn->getResumedFromInactiveTicketKey();
544
0
  }
545
546
  bool getUnknownTicketKey() const
547
0
  {
548
0
    return d_conn && d_conn->getUnknownTicketKey();
549
0
  }
550
551
  void setTLSSession(std::unique_ptr<TLSSession>& session)
552
0
  {
553
0
    if (d_conn != nullptr) {
554
0
      d_conn->setSession(session);
555
0
    }
556
0
  }
557
558
  std::vector<std::unique_ptr<TLSSession>> getTLSSessions()
559
0
  {
560
0
    if (!d_conn) {
561
0
      throw std::runtime_error("Trying to get TLS sessions from a non-TLS handler");
562
0
    }
563
0
564
0
    return d_conn->getSessions();
565
0
  }
566
567
  bool isUsable() const
568
0
  {
569
0
    if (!d_conn) {
570
0
      return isTCPSocketUsable(d_socket);
571
0
    }
572
0
    return d_conn->isUsable();
573
0
  }
574
575
  std::vector<int> getAsyncFDs()
576
0
  {
577
0
    if (!d_conn) {
578
0
      return {};
579
0
    }
580
0
    return d_conn->getAsyncFDs();
581
0
  }
582
583
  static const bool s_disableConnectForUnitTests;
584
585
private:
586
  std::unique_ptr<TLSConnection> d_conn{nullptr};
587
  ComboAddress d_remote;
588
  int d_socket{-1};
589
#ifdef MSG_FASTOPEN
590
  bool d_fastOpen{false};
591
#endif
592
};
593
594
struct TLSContextParameters
595
{
596
  std::string d_provider;
597
  std::string d_ciphers;
598
  std::string d_ciphers13;
599
  std::string d_caStore;
600
  std::string d_keyLogFile;
601
  std::string d_client_certificate;
602
  std::string d_client_certificate_key;
603
  std::string d_client_certificate_password;
604
  TLSFrontend::ALPN d_alpn{TLSFrontend::ALPN::Unset};
605
  bool d_validateCertificates{true};
606
  bool d_releaseBuffers{true};
607
  bool d_enableRenegotiation{false};
608
  bool d_ktls{false};
609
};
610
611
std::shared_ptr<TLSCtx> getTLSContext(const TLSContextParameters& params);
612
bool setupDoTProtocolNegotiation(std::shared_ptr<TLSCtx>& ctx);
613
bool setupDoHProtocolNegotiation(std::shared_ptr<TLSCtx>& ctx);