Coverage Report

Created: 2025-06-13 06:28

/src/pdns/pdns/dnsdistdist/tcpiohandler.hh
Line
Count
Source (jump to first uncovered line)
1
2
#pragma once
3
#include <memory>
4
/* needed for proper TCP_FASTOPEN_CONNECT detection */
5
#include <netinet/tcp.h>
6
7
#include "iputils.hh"
8
#include "libssl.hh"
9
#include "misc.hh"
10
#include "noinitvector.hh"
11
12
/* Async is only returned for TLS connections, if OpenSSL's async mode has been enabled */
13
enum class IOState : uint8_t { Done, NeedRead, NeedWrite, Async };
14
15
class TLSSession
16
{
17
public:
18
  virtual ~TLSSession() = default;
19
};
20
21
class TLSConnection
22
{
23
public:
24
  virtual ~TLSConnection() = default;
25
  virtual void doHandshake() = 0;
26
  virtual IOState tryConnect(bool fastOpen, const ComboAddress& remote) = 0;
27
  virtual void connect(bool fastOpen, const ComboAddress& remote, const struct timeval& timeout) = 0;
28
  virtual IOState tryHandshake() = 0;
29
  virtual size_t read(void* buffer, size_t bufferSize, const struct timeval& readTimeout, const struct timeval& totalTimeout={0,0}, bool allowIncomplete=false) = 0;
30
  virtual size_t write(const void* buffer, size_t bufferSize, const struct timeval& writeTimeout) = 0;
31
  virtual IOState tryWrite(const PacketBuffer& buffer, size_t& pos, size_t toWrite) = 0;
32
  virtual IOState tryRead(PacketBuffer& buffer, size_t& pos, size_t toRead, bool allowIncomplete=false) = 0;
33
  virtual std::string getServerNameIndication() const = 0;
34
  virtual std::vector<uint8_t> getNextProtocol() const = 0;
35
  virtual LibsslTLSVersion getTLSVersion() const = 0;
36
  virtual bool hasSessionBeenResumed() const = 0;
37
  virtual std::vector<std::unique_ptr<TLSSession>> getSessions() = 0;
38
  virtual void setSession(std::unique_ptr<TLSSession>& session) = 0;
39
  virtual bool isUsable() const = 0;
40
  virtual std::vector<int> getAsyncFDs() = 0;
41
  virtual void close() = 0;
42
43
  void setUnknownTicketKey()
44
0
  {
45
0
    d_unknownTicketKey = true;
46
0
  }
47
48
  bool getUnknownTicketKey() const
49
0
  {
50
0
    return d_unknownTicketKey;
51
0
  }
52
53
  void setResumedFromInactiveTicketKey()
54
0
  {
55
0
    d_resumedFromInactiveTicketKey = true;
56
0
  }
57
58
  bool getResumedFromInactiveTicketKey() const
59
0
  {
60
0
    return d_resumedFromInactiveTicketKey;
61
0
  }
62
63
protected:
64
  int d_socket{-1};
65
  bool d_unknownTicketKey{false};
66
  bool d_resumedFromInactiveTicketKey{false};
67
};
68
69
class TLSCtx
70
{
71
public:
72
  TLSCtx()
73
0
  {
74
0
    d_rotatingTicketsKey.clear();
75
0
  }
76
  virtual ~TLSCtx() = default;
77
  virtual std::unique_ptr<TLSConnection> getConnection(int socket, const struct timeval& timeout, time_t now) = 0;
78
  virtual std::unique_ptr<TLSConnection> getClientConnection(const std::string& host, bool hostIsAddr, int socket, const struct timeval& timeout) = 0;
79
  virtual void rotateTicketsKey(time_t now) = 0;
80
  virtual void loadTicketsKeys(const std::string& /* file */)
81
0
  {
82
0
    throw std::runtime_error("This TLS backend does not have the capability to load a tickets key from a file");
83
0
  }
84
  virtual void loadTicketsKey(const std::string& /* key */)
85
0
  {
86
0
    throw std::runtime_error("This TLS backend does not have the capability to load a ticket key");
87
0
  }
88
  void handleTicketsKeyRotation(time_t now)
89
0
  {
90
0
    if (d_ticketsKeyRotationDelay != 0 && now > d_ticketsKeyNextRotation) {
91
0
      if (d_rotatingTicketsKey.test_and_set()) {
92
0
        /* someone is already rotating */
93
0
        return;
94
0
      }
95
0
      try {
96
0
        rotateTicketsKey(now);
97
0
        d_rotatingTicketsKey.clear();
98
0
      }
99
0
      catch(const std::runtime_error& e) {
100
0
        d_rotatingTicketsKey.clear();
101
0
        throw std::runtime_error(std::string("Error generating a new tickets key for TLS context:") + e.what());
102
0
      }
103
0
      catch(...) {
104
0
        d_rotatingTicketsKey.clear();
105
0
        throw;
106
0
      }
107
0
    }
108
0
  }
109
110
  time_t getNextTicketsKeyRotation() const
111
0
  {
112
0
    return d_ticketsKeyNextRotation;
113
0
  }
114
115
  virtual size_t getTicketsKeysCount() = 0;
116
  virtual std::string getName() const = 0;
117
118
  using tickets_key_added_hook = std::function<void(const std::string& key)>;
119
120
  static void setTicketsKeyAddedHook(const tickets_key_added_hook& hook)
121
0
  {
122
0
    TLSCtx::s_ticketsKeyAddedHook = hook;
123
0
  }
124
  static const tickets_key_added_hook& getTicketsKeyAddedHook()
125
0
  {
126
0
    return TLSCtx::s_ticketsKeyAddedHook;
127
0
  }
128
  static bool hasTicketsKeyAddedHook()
129
0
  {
130
0
    return TLSCtx::s_ticketsKeyAddedHook != nullptr;
131
0
  }
132
protected:
133
  std::atomic_flag d_rotatingTicketsKey;
134
  std::atomic<time_t> d_ticketsKeyNextRotation{0};
135
  time_t d_ticketsKeyRotationDelay{0};
136
137
private:
138
  static tickets_key_added_hook s_ticketsKeyAddedHook;
139
};
140
141
class TLSFrontend
142
{
143
public:
144
  enum class ALPN : uint8_t { Unset, DoT, DoH };
145
146
  TLSFrontend(ALPN alpn): d_alpn(alpn)
147
0
  {
148
0
  }
149
150
  TLSFrontend(std::shared_ptr<TLSCtx> ctx): d_ctx(std::move(ctx))
151
0
  {
152
0
  }
153
154
  bool setupTLS();
155
156
  void rotateTicketsKey(time_t now)
157
0
  {
158
0
    if (d_ctx != nullptr && d_parentFrontend == nullptr) {
159
0
      d_ctx->rotateTicketsKey(now);
160
0
    }
161
0
  }
162
163
  void loadTicketsKeys(const std::string& file)
164
0
  {
165
0
    if (d_ctx != nullptr && d_parentFrontend == nullptr) {
166
0
      d_ctx->loadTicketsKeys(file);
167
0
    }
168
0
  }
169
170
  void loadTicketsKey(const std::string& key)
171
0
  {
172
0
    if (d_ctx != nullptr && d_parentFrontend == nullptr) {
173
0
      d_ctx->loadTicketsKey(key);
174
0
    }
175
0
  }
176
177
  std::shared_ptr<TLSCtx> getContext() const
178
0
  {
179
0
    return std::atomic_load_explicit(&d_ctx, std::memory_order_acquire);
180
0
  }
181
182
  void setParent(std::shared_ptr<const TLSFrontend> parent)
183
0
  {
184
0
    std::atomic_store_explicit(&d_parentFrontend, std::move(parent), std::memory_order_release);
185
0
  }
186
187
  void cleanup()
188
0
  {
189
0
    d_ctx.reset();
190
0
  }
191
192
  size_t getTicketsKeysCount()
193
0
  {
194
0
    if (d_ctx != nullptr) {
195
0
      return d_ctx->getTicketsKeysCount();
196
0
    }
197
0
198
0
    return 0;
199
0
  }
200
201
  static std::string timeToString(time_t rotationTime)
202
0
  {
203
0
    char buf[20];
204
0
    struct tm date_tm;
205
0
206
0
    localtime_r(&rotationTime, &date_tm);
207
0
    strftime(buf, sizeof(buf), "%Y-%m-%d %H:%M:%S", &date_tm);
208
0
209
0
    return std::string(buf);
210
0
  }
211
212
  time_t getTicketsKeyRotationDelay() const
213
0
  {
214
0
    return d_tlsConfig.d_ticketsKeyRotationDelay;
215
0
  }
216
217
  std::string getNextTicketsKeyRotation() const
218
0
  {
219
0
    std::string res;
220
0
221
0
    if (d_ctx != nullptr) {
222
0
      res = timeToString(d_ctx->getNextTicketsKeyRotation());
223
0
    }
224
0
225
0
    return res;
226
0
  }
227
228
  std::string getRequestedProvider() const
229
0
  {
230
0
    return d_provider;
231
0
  }
232
233
  std::string getEffectiveProvider() const
234
0
  {
235
0
    if (d_ctx) {
236
0
      return d_ctx->getName();
237
0
    }
238
0
    return "";
239
0
  }
240
241
  TLSConfig d_tlsConfig;
242
  TLSErrorCounters d_tlsCounters;
243
  ComboAddress d_addr;
244
  std::string d_provider;
245
  ALPN d_alpn{ALPN::Unset};
246
  /* whether the proxy protocol is inside or outside the TLS layer */
247
  bool d_proxyProtocolOutsideTLS{false};
248
protected:
249
  std::shared_ptr<TLSCtx> d_ctx{nullptr};
250
  std::shared_ptr<const TLSFrontend> d_parentFrontend{nullptr};
251
};
252
253
class TCPIOHandler
254
{
255
public:
256
  TCPIOHandler(const std::string& host, bool hostIsAddr, int socket, const struct timeval& timeout, const std::shared_ptr<TLSCtx>& ctx) :
257
    d_socket(socket)
258
0
  {
259
0
    if (ctx) {
260
0
      d_conn = ctx->getClientConnection(host, hostIsAddr, d_socket, timeout);
261
0
    }
262
0
  }
263
264
  TCPIOHandler(int socket, const struct timeval& timeout, const std::shared_ptr<TLSCtx>& ctx, time_t now) :
265
    d_socket(socket)
266
0
  {
267
0
    if (ctx) {
268
0
      d_conn = ctx->getConnection(d_socket, timeout, now);
269
0
    }
270
0
  }
271
272
  ~TCPIOHandler()
273
0
  {
274
0
    close();
275
0
  }
276
277
  void close()
278
0
  {
279
0
    if (d_conn) {
280
0
      d_conn->close();
281
0
      d_conn.reset();
282
0
    }
283
0
284
0
    if (d_socket != -1) {
285
0
      shutdown(d_socket, SHUT_RDWR);
286
0
      ::close(d_socket);
287
0
      d_socket = -1;
288
0
    }
289
0
  }
290
291
  int getDescriptor() const
292
0
  {
293
0
    return d_socket;
294
0
  }
295
296
  IOState tryConnect(bool fastOpen, const ComboAddress& remote)
297
0
  {
298
0
    d_remote = remote;
299
0
300
0
#ifdef TCP_FASTOPEN_CONNECT /* Linux >= 4.11 */
301
0
    if (fastOpen) {
302
0
      int value = 1;
303
0
      int res = setsockopt(d_socket, IPPROTO_TCP, TCP_FASTOPEN_CONNECT, &value, sizeof(value));
304
0
      if (res == 0) {
305
0
        fastOpen = false;
306
0
      }
307
0
    }
308
0
#endif /* TCP_FASTOPEN_CONNECT */
309
0
310
0
#ifdef MSG_FASTOPEN
311
0
    if (!d_conn && fastOpen) {
312
0
      d_fastOpen = true;
313
0
    }
314
0
    else {
315
0
      if (!s_disableConnectForUnitTests) {
316
0
        SConnectWithTimeout(d_socket, remote, /* no timeout, we will handle it ourselves */ timeval{0,0});
317
0
      }
318
0
    }
319
0
#else
320
0
    if (!s_disableConnectForUnitTests) {
321
0
      SConnectWithTimeout(d_socket, remote, /* no timeout, we will handle it ourselves */ timeval{0,0});
322
0
    }
323
0
#endif /* MSG_FASTOPEN */
324
0
325
0
    if (d_conn) {
326
0
      return d_conn->tryConnect(fastOpen, remote);
327
0
    }
328
0
329
0
    return IOState::Done;
330
0
  }
331
332
  void connect(bool fastOpen, const ComboAddress& remote, const struct timeval& timeout)
333
0
  {
334
0
    d_remote = remote;
335
0
336
0
#ifdef TCP_FASTOPEN_CONNECT /* Linux >= 4.11 */
337
0
    if (fastOpen) {
338
0
      int value = 1;
339
0
      int res = setsockopt(d_socket, IPPROTO_TCP, TCP_FASTOPEN_CONNECT, &value, sizeof(value));
340
0
      if (res == 0) {
341
0
        fastOpen = false;
342
0
      }
343
0
    }
344
0
#endif /* TCP_FASTOPEN_CONNECT */
345
0
346
0
#ifdef MSG_FASTOPEN
347
0
    if (!d_conn && fastOpen) {
348
0
      d_fastOpen = true;
349
0
    }
350
0
    else {
351
0
      if (!s_disableConnectForUnitTests) {
352
0
        SConnectWithTimeout(d_socket, remote, timeout);
353
0
      }
354
0
    }
355
0
#else
356
0
    if (!s_disableConnectForUnitTests) {
357
0
      SConnectWithTimeout(d_socket, remote, timeout);
358
0
    }
359
0
#endif /* MSG_FASTOPEN */
360
0
361
0
    if (d_conn) {
362
0
      d_conn->connect(fastOpen, remote, timeout);
363
0
    }
364
0
  }
365
366
  IOState tryHandshake()
367
0
  {
368
0
    if (d_conn) {
369
0
      return d_conn->tryHandshake();
370
0
    }
371
0
    return IOState::Done;
372
0
  }
373
374
  size_t read(void* buffer, size_t bufferSize, const struct timeval& readTimeout, const struct timeval& totalTimeout = {0,0}, bool allowIncomplete=false)
375
0
  {
376
0
    if (d_conn) {
377
0
      return d_conn->read(buffer, bufferSize, readTimeout, totalTimeout, allowIncomplete);
378
0
    } else {
379
0
      return readn2WithTimeout(d_socket, buffer, bufferSize, readTimeout, totalTimeout, allowIncomplete);
380
0
    }
381
0
  }
382
383
  /* Tries to read exactly toRead - pos bytes into the buffer, starting at position pos.
384
     Updates pos everytime a successful read occurs,
385
     throws an std::runtime_error in case of IO error,
386
     return Done when toRead bytes have been read, needRead or needWrite if the IO operation
387
     would block.
388
  */
389
  IOState tryRead(PacketBuffer& buffer, size_t& pos, size_t toRead, bool allowIncomplete=false, bool bypassFilters=false)
390
0
  {
391
0
    if (buffer.size() < toRead || pos >= toRead) {
392
0
      throw std::out_of_range("Calling tryRead() with a too small buffer (" + std::to_string(buffer.size()) + ") for a read of " + std::to_string(toRead - pos) + " bytes starting at " + std::to_string(pos));
393
0
    }
394
0
395
0
    if (!bypassFilters && d_conn) {
396
0
      return d_conn->tryRead(buffer, pos, toRead, allowIncomplete);
397
0
    }
398
0
399
0
    do {
400
0
      ssize_t res = ::read(d_socket, reinterpret_cast<char*>(&buffer.at(pos)), toRead - pos);
401
0
      if (res == 0) {
402
0
        throw runtime_error("EOF while reading message");
403
0
      }
404
0
      if (res < 0) {
405
0
        if (errno == EAGAIN || errno == EWOULDBLOCK || errno == ENOTCONN) {
406
0
          return IOState::NeedRead;
407
0
        }
408
0
        else {
409
0
          throw std::runtime_error("Error while reading message: " + stringerror());
410
0
        }
411
0
      }
412
0
413
0
      pos += static_cast<size_t>(res);
414
0
      if (allowIncomplete) {
415
0
        break;
416
0
      }
417
0
    }
418
0
    while (pos < toRead);
419
0
420
0
    return IOState::Done;
421
0
  }
422
423
  /* Tries to write exactly toWrite - pos bytes from the buffer, starting at position pos.
424
     Updates pos everytime a successful write occurs,
425
     throws an std::runtime_error in case of IO error,
426
     return Done when toWrite bytes have been written, needRead or needWrite if the IO operation
427
     would block.
428
  */
429
  IOState tryWrite(const PacketBuffer& buffer, size_t& pos, size_t toWrite)
430
0
  {
431
0
    if (buffer.size() < toWrite || pos >= toWrite) {
432
0
      throw std::out_of_range("Calling tryWrite() with a too small buffer (" + std::to_string(buffer.size()) + ") for a write of " + std::to_string(toWrite - pos) + " bytes starting at " + std::to_string(pos));
433
0
    }
434
0
    if (d_conn) {
435
0
      return d_conn->tryWrite(buffer, pos, toWrite);
436
0
    }
437
0
438
0
#ifdef MSG_FASTOPEN
439
0
    if (d_fastOpen) {
440
0
      int socketFlags = MSG_FASTOPEN;
441
0
      size_t sent = sendMsgWithOptions(d_socket, reinterpret_cast<const char *>(&buffer.at(pos)), toWrite - pos, &d_remote, nullptr, 0, socketFlags);
442
0
      if (sent > 0) {
443
0
        d_fastOpen = false;
444
0
        pos += sent;
445
0
      }
446
0
447
0
      if (pos < toWrite) {
448
0
        return IOState::NeedWrite;
449
0
      }
450
0
451
0
      return IOState::Done;
452
0
    }
453
0
#endif /* MSG_FASTOPEN */
454
0
455
0
    do {
456
0
      ssize_t res = ::write(d_socket, reinterpret_cast<const char*>(&buffer.at(pos)), toWrite - pos);
457
0
458
0
      if (res == 0) {
459
0
        throw runtime_error("EOF while sending message");
460
0
      }
461
0
      if (res < 0) {
462
0
        if (errno == EAGAIN || errno == EWOULDBLOCK || errno == ENOTCONN) {
463
0
          return IOState::NeedWrite;
464
0
        }
465
0
        else {
466
0
          throw std::runtime_error("Error while writing message: " + stringerror());
467
0
        }
468
0
      }
469
0
470
0
      pos += static_cast<size_t>(res);
471
0
    }
472
0
    while (pos < toWrite);
473
0
474
0
    return IOState::Done;
475
0
  }
476
477
  size_t write(const void* buffer, size_t bufferSize, const struct timeval& writeTimeout)
478
0
  {
479
0
    if (d_conn) {
480
0
      return d_conn->write(buffer, bufferSize, writeTimeout);
481
0
    }
482
0
483
0
#ifdef MSG_FASTOPEN
484
0
    if (d_fastOpen) {
485
0
      int socketFlags = MSG_FASTOPEN;
486
0
      size_t sent = sendMsgWithOptions(d_socket, reinterpret_cast<const char *>(buffer), bufferSize, &d_remote, nullptr, 0, socketFlags);
487
0
      if (sent > 0) {
488
0
        d_fastOpen = false;
489
0
      }
490
0
491
0
      return sent;
492
0
    }
493
0
#endif /* MSG_FASTOPEN */
494
0
495
0
    return writen2WithTimeout(d_socket, buffer, bufferSize, writeTimeout);
496
0
  }
497
498
  std::string getServerNameIndication() const
499
0
  {
500
0
    if (d_conn) {
501
0
      return d_conn->getServerNameIndication();
502
0
    }
503
0
    return std::string();
504
0
  }
505
506
  std::vector<uint8_t> getNextProtocol() const
507
0
  {
508
0
    if (d_conn) {
509
0
      return d_conn->getNextProtocol();
510
0
    }
511
0
    return std::vector<uint8_t>();
512
0
  }
513
514
  LibsslTLSVersion getTLSVersion() const
515
0
  {
516
0
    if (d_conn) {
517
0
      return d_conn->getTLSVersion();
518
0
    }
519
0
    return LibsslTLSVersion::Unknown;
520
0
  }
521
522
  bool isTLS() const
523
0
  {
524
0
    return d_conn != nullptr;
525
0
  }
526
527
  bool hasTLSSessionBeenResumed() const
528
0
  {
529
0
    return d_conn && d_conn->hasSessionBeenResumed();
530
0
  }
531
532
  bool getResumedFromInactiveTicketKey() const
533
0
  {
534
0
    return d_conn && d_conn->getResumedFromInactiveTicketKey();
535
0
  }
536
537
  bool getUnknownTicketKey() const
538
0
  {
539
0
    return d_conn && d_conn->getUnknownTicketKey();
540
0
  }
541
542
  void setTLSSession(std::unique_ptr<TLSSession>& session)
543
0
  {
544
0
    if (d_conn != nullptr) {
545
0
      d_conn->setSession(session);
546
0
    }
547
0
  }
548
549
  std::vector<std::unique_ptr<TLSSession>> getTLSSessions()
550
0
  {
551
0
    if (!d_conn) {
552
0
      throw std::runtime_error("Trying to get TLS sessions from a non-TLS handler");
553
0
    }
554
0
555
0
    return d_conn->getSessions();
556
0
  }
557
558
  bool isUsable() const
559
0
  {
560
0
    if (!d_conn) {
561
0
      return isTCPSocketUsable(d_socket);
562
0
    }
563
0
    return d_conn->isUsable();
564
0
  }
565
566
  std::vector<int> getAsyncFDs()
567
0
  {
568
0
    if (!d_conn) {
569
0
      return {};
570
0
    }
571
0
    return d_conn->getAsyncFDs();
572
0
  }
573
574
  static const bool s_disableConnectForUnitTests;
575
576
private:
577
  std::unique_ptr<TLSConnection> d_conn{nullptr};
578
  ComboAddress d_remote;
579
  int d_socket{-1};
580
#ifdef MSG_FASTOPEN
581
  bool d_fastOpen{false};
582
#endif
583
};
584
585
struct TLSContextParameters
586
{
587
  std::string d_provider;
588
  std::string d_ciphers;
589
  std::string d_ciphers13;
590
  std::string d_caStore;
591
  std::string d_keyLogFile;
592
  TLSFrontend::ALPN d_alpn{TLSFrontend::ALPN::Unset};
593
  bool d_validateCertificates{true};
594
  bool d_releaseBuffers{true};
595
  bool d_enableRenegotiation{false};
596
  bool d_ktls{false};
597
};
598
599
std::shared_ptr<TLSCtx> getTLSContext(const TLSContextParameters& params);
600
bool setupDoTProtocolNegotiation(std::shared_ptr<TLSCtx>& ctx);
601
bool setupDoHProtocolNegotiation(std::shared_ptr<TLSCtx>& ctx);