Coverage Report

Created: 2026-05-30 06:31

next uncovered line (L), next uncovered region (R), next uncovered branch (B)
/src/pdns/pdns/sstuff.hh
Line
Count
Source
1
/*
2
 * This file is part of PowerDNS or dnsdist.
3
 * Copyright -- PowerDNS.COM B.V. and its contributors
4
 *
5
 * This program is free software; you can redistribute it and/or modify
6
 * it under the terms of version 2 of the GNU General Public License as
7
 * published by the Free Software Foundation.
8
 *
9
 * In addition, for the avoidance of any doubt, permission is granted to
10
 * link this program with OpenSSL and to (re)distribute the binaries
11
 * produced as the result of such linking.
12
 *
13
 * This program is distributed in the hope that it will be useful,
14
 * but WITHOUT ANY WARRANTY; without even the implied warranty of
15
 * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the
16
 * GNU General Public License for more details.
17
 *
18
 * You should have received a copy of the GNU General Public License
19
 * along with this program; if not, write to the Free Software
20
 * Foundation, Inc., 51 Franklin Street, Fifth Floor, Boston, MA 02110-1301 USA.
21
 */
22
#pragma once
23
#include <string>
24
#include <sstream>
25
#include <iostream>
26
#include "iputils.hh"
27
#include <cerrno>
28
#include <sys/types.h>
29
#include <unistd.h>
30
#include <sys/socket.h>
31
#include <netinet/in.h>
32
#include <netinet/tcp.h>
33
#include <arpa/inet.h>
34
#include <sys/select.h>
35
#include <fcntl.h>
36
#include <stdexcept>
37
38
#include <csignal>
39
#include "namespaces.hh"
40
#include "noinitvector.hh"
41
42
using ProtocolType = int; //!< Supported protocol types
43
44
//! Representation of a Socket and many of the Berkeley functions available
45
class Socket
46
{
47
public:
48
  Socket(const Socket&) = delete;
49
  Socket& operator=(const Socket&) = delete;
50
51
  Socket(int socketDesc) :
52
    d_socket(socketDesc)
53
0
  {
54
0
  }
55
56
  //! Construct a socket of specified address family and socket type.
57
  Socket(int addressFamily, int socketType, ProtocolType protocolType = 0) :
58
    d_socket(socket(addressFamily, socketType, protocolType))
59
0
  {
60
0
    if (d_socket < 0) {
61
0
      throw NetworkError(stringerror());
62
0
    }
63
0
    setCloseOnExec(d_socket);
64
0
  }
65
66
  Socket(Socket&& rhs) noexcept :
67
    d_buffer(std::move(rhs.d_buffer)), d_socket(rhs.d_socket)
68
0
  {
69
0
    rhs.d_socket = -1;
70
0
  }
71
72
  Socket& operator=(Socket&& rhs) noexcept
73
0
  {
74
0
    if (d_socket != -1) {
75
0
      close(d_socket);
76
0
    }
77
0
    d_socket = rhs.d_socket;
78
0
    rhs.d_socket = -1;
79
0
    d_buffer = std::move(rhs.d_buffer);
80
0
    return *this;
81
0
  }
82
83
  ~Socket()
84
0
  {
85
0
    try {
86
0
      if (d_socket != -1) {
87
0
        closesocket(d_socket);
88
0
      }
89
0
    }
90
0
    catch (const PDNSException& e) {
91
0
    }
92
0
  }
93
94
  //! If the socket is capable of doing so, this function will wait for a connection
95
  [[nodiscard]] std::unique_ptr<Socket> accept() const
96
0
  {
97
0
    sockaddr_in remote{};
98
0
    socklen_t remlen = sizeof(remote);
99
0
    memset(&remote, 0, sizeof(remote));
100
0
    int sock = ::accept(d_socket, reinterpret_cast<sockaddr*>(&remote), &remlen); // NOLINT(cppcoreguidelines-pro-type-reinterpret-cast): it's the API
101
0
    if (sock < 0) {
102
0
      if (errno == EAGAIN) {
103
0
        return nullptr;
104
0
      }
105
0
106
0
      throw NetworkError("Accepting a connection: " + stringerror());
107
0
    }
108
0
109
0
    return std::make_unique<Socket>(sock);
110
0
  }
111
112
  //! Get remote address
113
  bool getRemote(ComboAddress& remote) const
114
0
  {
115
0
    socklen_t remotelen = sizeof(remote);
116
0
    return getpeername(d_socket, reinterpret_cast<struct sockaddr*>(&remote), &remotelen) >= 0; // NOLINT(cppcoreguidelines-pro-type-reinterpret-cast): it's the API
117
0
  }
118
119
  //! Check remote address against netmaskgroup ng
120
  [[nodiscard]] bool acl(const NetmaskGroup& netmaskGroup) const
121
0
  {
122
0
    ComboAddress remote;
123
0
    if (getRemote(remote)) {
124
0
      if (netmaskGroup.match(remote)) {
125
0
        return true;
126
0
      }
127
0
128
0
      if (remote.isMappedIPv4()) {
129
0
        return netmaskGroup.match(remote.mapToIPv4());
130
0
      }
131
0
    }
132
0
133
0
    return false;
134
0
  }
135
136
  //! Set the socket to non-blocking
137
  void setNonBlocking() const
138
0
  {
139
0
    ::setNonBlocking(d_socket);
140
0
  }
141
142
  //! Set the socket to blocking
143
  void setBlocking() const
144
0
  {
145
0
    ::setBlocking(d_socket);
146
0
  }
147
148
  void setReuseAddr() const
149
0
  {
150
0
    try {
151
0
      ::setReuseAddr(d_socket);
152
0
    }
153
0
    catch (const PDNSException& e) {
154
0
      throw NetworkError(e.reason);
155
0
    }
156
0
  }
157
158
  void setFastOpenConnect()
159
0
  {
160
0
#ifdef TCP_FASTOPEN_CONNECT
161
0
    int on = 1;
162
0
    if (setsockopt(d_socket, IPPROTO_TCP, TCP_FASTOPEN_CONNECT, &on, sizeof(on)) < 0) {
163
0
      throw NetworkError("While setting TCP_FASTOPEN_CONNECT: " + stringerror());
164
0
    }
165
0
#else
166
0
    throw NetworkError("While setting TCP_FASTOPEN_CONNECT: not compiled in");
167
0
#endif
168
0
  }
169
170
  //! Bind the socket to a specified endpoint
171
  template <typename T>
172
  void bind(const T& local, bool reuseaddr = true) const
173
  {
174
    int tmp = 1;
175
    // NOLINTNEXTLINE(cppcoreguidelines-pro-type-reinterpret-cast)
176
    if (reuseaddr && setsockopt(d_socket, SOL_SOCKET, SO_REUSEADDR, reinterpret_cast<char*>(&tmp), sizeof tmp) < 0) {
177
      throw NetworkError("Setsockopt failed: " + stringerror());
178
    }
179
    // NOLINTNEXTLINE(cppcoreguidelines-pro-type-reinterpret-cast)
180
    if (::bind(d_socket, reinterpret_cast<const struct sockaddr*>(&local), local.getSocklen()) < 0) {
181
      throw NetworkError("While binding: " + stringerror());
182
    }
183
  }
184
185
  //! Connect the socket to a specified endpoint
186
  void connect(const ComboAddress& address, int timeout = 0) const
187
0
  {
188
0
    SConnectWithTimeout(d_socket, false, address, timeval{timeout, 0});
189
0
  }
190
191
  //! For datagram sockets, receive a datagram and learn where it came from
192
  /** For datagram sockets, receive a datagram and learn where it came from
193
      \param dgram Will be filled with the datagram
194
      \param ep Will be filled with the origin of the datagram */
195
  void recvFrom(string& dgram, ComboAddress& remote) const
196
0
  {
197
0
    socklen_t remlen = sizeof(remote);
198
0
    if (dgram.size() < s_buflen) {
199
0
      dgram.resize(s_buflen);
200
0
    }
201
0
    // NOLINTNEXTLINE(cppcoreguidelines-pro-type-reinterpret-cast)
202
0
    auto bytes = recvfrom(d_socket, dgram.data(), dgram.size(), 0, reinterpret_cast<sockaddr*>(&remote), &remlen);
203
0
    if (bytes < 0) {
204
0
      throw NetworkError("After recvfrom: " + stringerror());
205
0
    }
206
0
    dgram.resize(static_cast<size_t>(bytes));
207
0
  }
208
209
  bool recvFromAsync(PacketBuffer& dgram, ComboAddress& remote) const
210
0
  {
211
0
    socklen_t remlen = sizeof(remote);
212
0
    if (dgram.size() < s_buflen) {
213
0
      dgram.resize(s_buflen);
214
0
    }
215
0
    // NOLINTNEXTLINE(cppcoreguidelines-pro-type-reinterpret-cast)
216
0
    auto bytes = recvfrom(d_socket, dgram.data(), dgram.size(), 0, reinterpret_cast<sockaddr*>(&remote), &remlen);
217
0
    if (bytes < 0) {
218
0
      if (errno != EAGAIN) {
219
0
        throw NetworkError("After async recvfrom: " + stringerror());
220
0
      }
221
0
      return false;
222
0
    }
223
0
    dgram.resize(static_cast<size_t>(bytes));
224
0
    return true;
225
0
  }
226
227
  //! For datagram sockets, send a datagram to a destination
228
  void sendTo(const char* msg, size_t len, const ComboAddress& remote) const
229
0
  {
230
0
    // NOLINTNEXTLINE(cppcoreguidelines-pro-type-reinterpret-cast)
231
0
    if (sendto(d_socket, msg, len, 0, reinterpret_cast<const sockaddr*>(&remote), remote.getSocklen()) < 0) {
232
0
      throw NetworkError("After sendto: " + stringerror());
233
0
    }
234
0
  }
235
236
  //! For connected datagram sockets, send a datagram
237
  void send(const std::string& msg) const
238
0
  {
239
0
    if (::send(d_socket, msg.data(), msg.size(), 0) < 0) {
240
0
      throw NetworkError("After send: " + stringerror());
241
0
    }
242
0
  }
243
244
  /** For datagram sockets, send a datagram to a destination
245
      \param dgram The datagram
246
      \param remote The intended destination of the datagram */
247
  void sendTo(const string& dgram, const ComboAddress& remote) const
248
0
  {
249
0
    sendTo(dgram.data(), dgram.length(), remote);
250
0
  }
251
252
  //! Write this data to the socket, taking care that all bytes are written out
253
  void writen(const string& data) const
254
0
  {
255
0
    if (data.empty()) {
256
0
      return;
257
0
    }
258
0
259
0
    size_t toWrite = data.length();
260
0
    const char* ptr = data.data();
261
0
262
0
    do {
263
0
      auto res = ::send(d_socket, ptr, toWrite, 0);
264
0
      if (res < 0) {
265
0
        throw NetworkError("Writing to a socket: " + stringerror());
266
0
      }
267
0
      if (res == 0) {
268
0
        throw NetworkError("EOF on socket");
269
0
      }
270
0
      toWrite -= static_cast<size_t>(res);
271
0
      ptr += static_cast<size_t>(res); // NOLINT(cppcoreguidelines-pro-bounds-pointer-arithmetic)
272
0
    } while (toWrite > 0);
273
0
  }
274
275
  //! tries to write toWrite bytes from ptr to the socket
276
  /** tries to write toWrite bytes from ptr to the socket, but does not make sure they al get written out
277
      \param ptr Location to write from
278
      \param toWrite number of bytes to try
279
  */
280
  size_t tryWrite(const char* ptr, size_t toWrite) const
281
0
  {
282
0
    auto res = ::send(d_socket, ptr, toWrite, 0);
283
0
    if (res == 0) {
284
0
      throw NetworkError("EOF on writing to a socket");
285
0
    }
286
0
    if (res > 0) {
287
0
      return res;
288
0
    }
289
0
290
0
    if (errno == EAGAIN) {
291
0
      return 0;
292
0
    }
293
0
294
0
    throw NetworkError("Writing to a socket: " + stringerror());
295
0
  }
296
297
  //! Writes toWrite bytes from ptr to the socket
298
  /** Writes toWrite bytes from ptr to the socket. Returns how many bytes were written */
299
  size_t write(const char* ptr, size_t toWrite) const
300
0
  {
301
0
    auto res = ::send(d_socket, ptr, toWrite, 0);
302
0
    if (res < 0) {
303
0
      throw NetworkError("Writing to a socket: " + stringerror());
304
0
    }
305
0
    return res;
306
0
  }
307
308
  void writenWithTimeout(const void* buffer, size_t n, int timeout) const
309
0
  {
310
0
    size_t bytes = n;
311
0
    const char* ptr = static_cast<const char*>(buffer);
312
0
313
0
    while (bytes > 0) {
314
0
      auto ret = ::write(d_socket, ptr, bytes);
315
0
      if (ret < 0) {
316
0
        if (errno == EAGAIN) {
317
0
          ret = waitForRWData(d_socket, false, timeout, 0);
318
0
          if (ret < 0) {
319
0
            throw NetworkError("Waiting for data write");
320
0
          }
321
0
          if (ret == 0) {
322
0
            throw NetworkError("Timeout writing data");
323
0
          }
324
0
          continue;
325
0
        }
326
0
        throw NetworkError("Writing data: " + stringerror());
327
0
      }
328
0
      if (ret == 0) {
329
0
        throw NetworkError("Did not fulfill TCP write due to EOF");
330
0
      }
331
0
332
0
      ptr += static_cast<size_t>(ret); // NOLINT(cppcoreguidelines-pro-bounds-pointer-arithmetic)
333
0
      bytes -= static_cast<size_t>(ret);
334
0
    }
335
0
  }
336
337
  //! Reads a block of data from the socket to a string
338
  void read(string& data)
339
0
  {
340
0
    d_buffer.resize(s_buflen);
341
0
    ssize_t res = ::recv(d_socket, d_buffer.data(), s_buflen, 0);
342
0
    if (res < 0) {
343
0
      throw NetworkError("Reading from a socket: " + stringerror());
344
0
    }
345
0
    data.assign(d_buffer, 0, static_cast<size_t>(res));
346
0
  }
347
348
  //! Reads a block of data from the socket to a block of memory
349
  size_t read(char* buffer, size_t bytes) const
350
0
  {
351
0
    auto res = ::recv(d_socket, buffer, bytes, 0);
352
0
    if (res < 0) {
353
0
      throw NetworkError("Reading from a socket: " + stringerror());
354
0
    }
355
0
    return static_cast<size_t>(res);
356
0
  }
357
358
  /** Read a bock of data from the socket to a block of memory,
359
   *   waiting at most 'timeout' seconds for the data to become
360
   *   available. Be aware that this does _NOT_ handle partial reads
361
   *   for you.
362
   */
363
  size_t readWithTimeout(char* buffer, size_t n, int timeout) const
364
0
  {
365
0
    int err = waitForRWData(d_socket, true, timeout, 0);
366
0
367
0
    if (err == 0) {
368
0
      throw NetworkError("timeout reading");
369
0
    }
370
0
    if (err < 0) {
371
0
      throw NetworkError("nonblocking read failed: " + stringerror());
372
0
    }
373
0
374
0
    return read(buffer, n);
375
0
  }
376
377
  //! Sets the socket to listen with a default listen backlog of 10 pending connections
378
  void listen(int length = 10) const
379
0
  {
380
0
    if (::listen(d_socket, length) < 0) {
381
0
      throw NetworkError("Setting socket to listen: " + stringerror());
382
0
    }
383
0
  }
384
385
  //! Returns the internal file descriptor of the socket
386
  [[nodiscard]] int getHandle() const
387
0
  {
388
0
    return d_socket;
389
0
  }
390
391
  int releaseHandle()
392
0
  {
393
0
    int ret = d_socket;
394
0
    d_socket = -1;
395
0
    return ret;
396
0
  }
397
398
private:
399
  static constexpr size_t s_buflen{4096};
400
  std::string d_buffer;
401
  int d_socket;
402
};