Coverage Report

Created: 2025-09-05 06:36

/src/pdns/pdns/sstuff.hh
Line
Count
Source (jump to first uncovered line)
1
/*
2
 * This file is part of PowerDNS or dnsdist.
3
 * Copyright -- PowerDNS.COM B.V. and its contributors
4
 *
5
 * This program is free software; you can redistribute it and/or modify
6
 * it under the terms of version 2 of the GNU General Public License as
7
 * published by the Free Software Foundation.
8
 *
9
 * In addition, for the avoidance of any doubt, permission is granted to
10
 * link this program with OpenSSL and to (re)distribute the binaries
11
 * produced as the result of such linking.
12
 *
13
 * This program is distributed in the hope that it will be useful,
14
 * but WITHOUT ANY WARRANTY; without even the implied warranty of
15
 * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the
16
 * GNU General Public License for more details.
17
 *
18
 * You should have received a copy of the GNU General Public License
19
 * along with this program; if not, write to the Free Software
20
 * Foundation, Inc., 51 Franklin Street, Fifth Floor, Boston, MA 02110-1301 USA.
21
 */
22
#pragma once
23
#include <string>
24
#include <sstream>
25
#include <iostream>
26
#include "iputils.hh"
27
#include <cerrno>
28
#include <sys/types.h>
29
#include <unistd.h>
30
#include <sys/socket.h>
31
#include <netinet/in.h>
32
#include <netinet/tcp.h>
33
#include <arpa/inet.h>
34
#include <sys/select.h>
35
#include <fcntl.h>
36
#include <stdexcept>
37
38
#include <csignal>
39
#include "namespaces.hh"
40
#include "noinitvector.hh"
41
42
using ProtocolType = int; //!< Supported protocol types
43
44
//! Representation of a Socket and many of the Berkeley functions available
45
class Socket
46
{
47
public:
48
  Socket(const Socket&) = delete;
49
  Socket& operator=(const Socket&) = delete;
50
51
  Socket(int socketDesc) :
52
    d_socket(socketDesc)
53
0
  {
54
0
  }
55
56
  //! Construct a socket of specified address family and socket type.
57
  Socket(int addressFamily, int socketType, ProtocolType protocolType = 0) :
58
    d_socket(socket(addressFamily, socketType, protocolType))
59
0
  {
60
0
    if (d_socket < 0) {
61
0
      throw NetworkError(stringerror());
62
0
    }
63
0
    setCloseOnExec(d_socket);
64
0
  }
65
66
  Socket(Socket&& rhs) noexcept :
67
    d_buffer(std::move(rhs.d_buffer)), d_socket(rhs.d_socket)
68
0
  {
69
0
    rhs.d_socket = -1;
70
0
  }
71
72
  Socket& operator=(Socket&& rhs) noexcept
73
0
  {
74
0
    if (d_socket != -1) {
75
0
      close(d_socket);
76
0
    }
77
0
    d_socket = rhs.d_socket;
78
0
    rhs.d_socket = -1;
79
0
    d_buffer = std::move(rhs.d_buffer);
80
0
    return *this;
81
0
  }
82
83
  ~Socket()
84
0
  {
85
0
    try {
86
0
      if (d_socket != -1) {
87
0
        closesocket(d_socket);
88
0
      }
89
0
    }
90
0
    catch (const PDNSException& e) {
91
0
    }
92
0
  }
93
94
  //! If the socket is capable of doing so, this function will wait for a connection
95
  [[nodiscard]] std::unique_ptr<Socket> accept() const
96
0
  {
97
0
    sockaddr_in remote{};
98
0
    socklen_t remlen = sizeof(remote);
99
0
    memset(&remote, 0, sizeof(remote));
100
0
    int sock = ::accept(d_socket, reinterpret_cast<sockaddr*>(&remote), &remlen); // NOLINT(cppcoreguidelines-pro-type-reinterpret-cast): it's the API
101
0
    if (sock < 0) {
102
0
      if (errno == EAGAIN) {
103
0
        return nullptr;
104
0
      }
105
0
106
0
      throw NetworkError("Accepting a connection: " + stringerror());
107
0
    }
108
0
109
0
    return std::make_unique<Socket>(sock);
110
0
  }
111
112
  //! Get remote address
113
  bool getRemote(ComboAddress& remote) const
114
0
  {
115
0
    socklen_t remotelen = sizeof(remote);
116
0
    return getpeername(d_socket, reinterpret_cast<struct sockaddr*>(&remote), &remotelen) >= 0; // NOLINT(cppcoreguidelines-pro-type-reinterpret-cast): it's the API
117
0
  }
118
119
  //! Check remote address against netmaskgroup ng
120
  [[nodiscard]] bool acl(const NetmaskGroup& netmaskGroup) const
121
0
  {
122
0
    ComboAddress remote;
123
0
    if (getRemote(remote)) {
124
0
      return netmaskGroup.match(remote);
125
0
    }
126
0
127
0
    return false;
128
0
  }
129
130
  //! Set the socket to non-blocking
131
  void setNonBlocking() const
132
0
  {
133
0
    ::setNonBlocking(d_socket);
134
0
  }
135
136
  //! Set the socket to blocking
137
  void setBlocking() const
138
0
  {
139
0
    ::setBlocking(d_socket);
140
0
  }
141
142
  void setReuseAddr() const
143
0
  {
144
0
    try {
145
0
      ::setReuseAddr(d_socket);
146
0
    }
147
0
    catch (const PDNSException& e) {
148
0
      throw NetworkError(e.reason);
149
0
    }
150
0
  }
151
152
  void setFastOpenConnect()
153
0
  {
154
0
#ifdef TCP_FASTOPEN_CONNECT
155
0
    int on = 1;
156
0
    if (setsockopt(d_socket, IPPROTO_TCP, TCP_FASTOPEN_CONNECT, &on, sizeof(on)) < 0) {
157
0
      throw NetworkError("While setting TCP_FASTOPEN_CONNECT: " + stringerror());
158
0
    }
159
0
#else
160
0
    throw NetworkError("While setting TCP_FASTOPEN_CONNECT: not compiled in");
161
0
#endif
162
0
  }
163
164
  //! Bind the socket to a specified endpoint
165
  template <typename T>
166
  void bind(const T& local, bool reuseaddr = true) const
167
  {
168
    int tmp = 1;
169
    // NOLINTNEXTLINE(cppcoreguidelines-pro-type-reinterpret-cast)
170
    if (reuseaddr && setsockopt(d_socket, SOL_SOCKET, SO_REUSEADDR, reinterpret_cast<char*>(&tmp), sizeof tmp) < 0) {
171
      throw NetworkError("Setsockopt failed: " + stringerror());
172
    }
173
    // NOLINTNEXTLINE(cppcoreguidelines-pro-type-reinterpret-cast)
174
    if (::bind(d_socket, reinterpret_cast<const struct sockaddr*>(&local), local.getSocklen()) < 0) {
175
      throw NetworkError("While binding: " + stringerror());
176
    }
177
  }
178
179
  //! Connect the socket to a specified endpoint
180
  void connect(const ComboAddress& address, int timeout = 0) const
181
0
  {
182
0
    SConnectWithTimeout(d_socket, address, timeval{timeout, 0});
183
0
  }
184
185
  //! For datagram sockets, receive a datagram and learn where it came from
186
  /** For datagram sockets, receive a datagram and learn where it came from
187
      \param dgram Will be filled with the datagram
188
      \param ep Will be filled with the origin of the datagram */
189
  void recvFrom(string& dgram, ComboAddress& remote) const
190
0
  {
191
0
    socklen_t remlen = sizeof(remote);
192
0
    if (dgram.size() < s_buflen) {
193
0
      dgram.resize(s_buflen);
194
0
    }
195
0
    // NOLINTNEXTLINE(cppcoreguidelines-pro-type-reinterpret-cast)
196
0
    auto bytes = recvfrom(d_socket, dgram.data(), dgram.size(), 0, reinterpret_cast<sockaddr*>(&remote), &remlen);
197
0
    if (bytes < 0) {
198
0
      throw NetworkError("After recvfrom: " + stringerror());
199
0
    }
200
0
    dgram.resize(static_cast<size_t>(bytes));
201
0
  }
202
203
  bool recvFromAsync(PacketBuffer& dgram, ComboAddress& remote) const
204
0
  {
205
0
    socklen_t remlen = sizeof(remote);
206
0
    if (dgram.size() < s_buflen) {
207
0
      dgram.resize(s_buflen);
208
0
    }
209
0
    // NOLINTNEXTLINE(cppcoreguidelines-pro-type-reinterpret-cast)
210
0
    auto bytes = recvfrom(d_socket, dgram.data(), dgram.size(), 0, reinterpret_cast<sockaddr*>(&remote), &remlen);
211
0
    if (bytes < 0) {
212
0
      if (errno != EAGAIN) {
213
0
        throw NetworkError("After async recvfrom: " + stringerror());
214
0
      }
215
0
      return false;
216
0
    }
217
0
    dgram.resize(static_cast<size_t>(bytes));
218
0
    return true;
219
0
  }
220
221
  //! For datagram sockets, send a datagram to a destination
222
  void sendTo(const char* msg, size_t len, const ComboAddress& remote) const
223
0
  {
224
0
    // NOLINTNEXTLINE(cppcoreguidelines-pro-type-reinterpret-cast)
225
0
    if (sendto(d_socket, msg, len, 0, reinterpret_cast<const sockaddr*>(&remote), remote.getSocklen()) < 0) {
226
0
      throw NetworkError("After sendto: " + stringerror());
227
0
    }
228
0
  }
229
230
  //! For connected datagram sockets, send a datagram
231
  void send(const std::string& msg) const
232
0
  {
233
0
    if (::send(d_socket, msg.data(), msg.size(), 0) < 0) {
234
0
      throw NetworkError("After send: " + stringerror());
235
0
    }
236
0
  }
237
238
  /** For datagram sockets, send a datagram to a destination
239
      \param dgram The datagram
240
      \param remote The intended destination of the datagram */
241
  void sendTo(const string& dgram, const ComboAddress& remote) const
242
0
  {
243
0
    sendTo(dgram.data(), dgram.length(), remote);
244
0
  }
245
246
  //! Write this data to the socket, taking care that all bytes are written out
247
  void writen(const string& data) const
248
0
  {
249
0
    if (data.empty()) {
250
0
      return;
251
0
    }
252
0
253
0
    size_t toWrite = data.length();
254
0
    const char* ptr = data.data();
255
0
256
0
    do {
257
0
      auto res = ::send(d_socket, ptr, toWrite, 0);
258
0
      if (res < 0) {
259
0
        throw NetworkError("Writing to a socket: " + stringerror());
260
0
      }
261
0
      if (res == 0) {
262
0
        throw NetworkError("EOF on socket");
263
0
      }
264
0
      toWrite -= static_cast<size_t>(res);
265
0
      ptr += static_cast<size_t>(res); // NOLINT(cppcoreguidelines-pro-bounds-pointer-arithmetic)
266
0
    } while (toWrite > 0);
267
0
  }
268
269
  //! tries to write toWrite bytes from ptr to the socket
270
  /** tries to write toWrite bytes from ptr to the socket, but does not make sure they al get written out
271
      \param ptr Location to write from
272
      \param toWrite number of bytes to try
273
  */
274
  size_t tryWrite(const char* ptr, size_t toWrite) const
275
0
  {
276
0
    auto res = ::send(d_socket, ptr, toWrite, 0);
277
0
    if (res == 0) {
278
0
      throw NetworkError("EOF on writing to a socket");
279
0
    }
280
0
    if (res > 0) {
281
0
      return res;
282
0
    }
283
0
284
0
    if (errno == EAGAIN) {
285
0
      return 0;
286
0
    }
287
0
288
0
    throw NetworkError("Writing to a socket: " + stringerror());
289
0
  }
290
291
  //! Writes toWrite bytes from ptr to the socket
292
  /** Writes toWrite bytes from ptr to the socket. Returns how many bytes were written */
293
  size_t write(const char* ptr, size_t toWrite) const
294
0
  {
295
0
    auto res = ::send(d_socket, ptr, toWrite, 0);
296
0
    if (res < 0) {
297
0
      throw NetworkError("Writing to a socket: " + stringerror());
298
0
    }
299
0
    return res;
300
0
  }
301
302
  void writenWithTimeout(const void* buffer, size_t n, int timeout) const
303
0
  {
304
0
    size_t bytes = n;
305
0
    const char* ptr = static_cast<const char*>(buffer);
306
0
307
0
    while (bytes > 0) {
308
0
      auto ret = ::write(d_socket, ptr, bytes);
309
0
      if (ret < 0) {
310
0
        if (errno == EAGAIN) {
311
0
          ret = waitForRWData(d_socket, false, timeout, 0);
312
0
          if (ret < 0) {
313
0
            throw NetworkError("Waiting for data write");
314
0
          }
315
0
          if (ret == 0) {
316
0
            throw NetworkError("Timeout writing data");
317
0
          }
318
0
          continue;
319
0
        }
320
0
        throw NetworkError("Writing data: " + stringerror());
321
0
      }
322
0
      if (ret == 0) {
323
0
        throw NetworkError("Did not fulfill TCP write due to EOF");
324
0
      }
325
0
326
0
      ptr += static_cast<size_t>(ret); // NOLINT(cppcoreguidelines-pro-bounds-pointer-arithmetic)
327
0
      bytes -= static_cast<size_t>(ret);
328
0
    }
329
0
  }
330
331
  //! reads one character from the socket
332
  [[nodiscard]] int getChar() const
333
0
  {
334
0
    char character{};
335
0
336
0
    ssize_t res = ::recv(d_socket, &character, 1, 0);
337
0
    if (res == 0) {
338
0
      return character;
339
0
    }
340
0
    return -1;
341
0
  }
342
343
  void getline(string& data) const
344
0
  {
345
0
    data.clear();
346
0
    while (true) {
347
0
      int character = getChar();
348
0
      if (character == -1) {
349
0
        break;
350
0
      }
351
0
      data += (char)character;
352
0
      if (character == '\n') {
353
0
        break;
354
0
      }
355
0
    }
356
0
  }
357
358
  //! Reads a block of data from the socket to a string
359
  void read(string& data)
360
0
  {
361
0
    d_buffer.resize(s_buflen);
362
0
    ssize_t res = ::recv(d_socket, d_buffer.data(), s_buflen, 0);
363
0
    if (res < 0) {
364
0
      throw NetworkError("Reading from a socket: " + stringerror());
365
0
    }
366
0
    data.assign(d_buffer, 0, static_cast<size_t>(res));
367
0
  }
368
369
  //! Reads a block of data from the socket to a block of memory
370
  size_t read(char* buffer, size_t bytes) const
371
0
  {
372
0
    auto res = ::recv(d_socket, buffer, bytes, 0);
373
0
    if (res < 0) {
374
0
      throw NetworkError("Reading from a socket: " + stringerror());
375
0
    }
376
0
    return static_cast<size_t>(res);
377
0
  }
378
379
  /** Read a bock of data from the socket to a block of memory,
380
   *   waiting at most 'timeout' seconds for the data to become
381
   *   available. Be aware that this does _NOT_ handle partial reads
382
   *   for you.
383
   */
384
  size_t readWithTimeout(char* buffer, size_t n, int timeout) const
385
0
  {
386
0
    int err = waitForRWData(d_socket, true, timeout, 0);
387
0
388
0
    if (err == 0) {
389
0
      throw NetworkError("timeout reading");
390
0
    }
391
0
    if (err < 0) {
392
0
      throw NetworkError("nonblocking read failed: " + stringerror());
393
0
    }
394
0
395
0
    return read(buffer, n);
396
0
  }
397
398
  //! Sets the socket to listen with a default listen backlog of 10 pending connections
399
  void listen(int length = 10) const
400
0
  {
401
0
    if (::listen(d_socket, length) < 0) {
402
0
      throw NetworkError("Setting socket to listen: " + stringerror());
403
0
    }
404
0
  }
405
406
  //! Returns the internal file descriptor of the socket
407
  [[nodiscard]] int getHandle() const
408
0
  {
409
0
    return d_socket;
410
0
  }
411
412
  int releaseHandle()
413
0
  {
414
0
    int ret = d_socket;
415
0
    d_socket = -1;
416
0
    return ret;
417
0
  }
418
419
private:
420
  static constexpr size_t s_buflen{4096};
421
  std::string d_buffer;
422
  int d_socket;
423
};