Line | Count | Source (jump to first uncovered line) |
1 | | /* |
2 | | * This file is part of PowerDNS or dnsdist. |
3 | | * Copyright -- PowerDNS.COM B.V. and its contributors |
4 | | * |
5 | | * This program is free software; you can redistribute it and/or modify |
6 | | * it under the terms of version 2 of the GNU General Public License as |
7 | | * published by the Free Software Foundation. |
8 | | * |
9 | | * In addition, for the avoidance of any doubt, permission is granted to |
10 | | * link this program with OpenSSL and to (re)distribute the binaries |
11 | | * produced as the result of such linking. |
12 | | * |
13 | | * This program is distributed in the hope that it will be useful, |
14 | | * but WITHOUT ANY WARRANTY; without even the implied warranty of |
15 | | * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the |
16 | | * GNU General Public License for more details. |
17 | | * |
18 | | * You should have received a copy of the GNU General Public License |
19 | | * along with this program; if not, write to the Free Software |
20 | | * Foundation, Inc., 51 Franklin Street, Fifth Floor, Boston, MA 02110-1301 USA. |
21 | | */ |
22 | | #pragma once |
23 | | #include <string> |
24 | | #include <sstream> |
25 | | #include <iostream> |
26 | | #include "iputils.hh" |
27 | | #include <cerrno> |
28 | | #include <sys/types.h> |
29 | | #include <unistd.h> |
30 | | #include <sys/socket.h> |
31 | | #include <netinet/in.h> |
32 | | #include <netinet/tcp.h> |
33 | | #include <arpa/inet.h> |
34 | | #include <sys/select.h> |
35 | | #include <fcntl.h> |
36 | | #include <stdexcept> |
37 | | |
38 | | #include <csignal> |
39 | | #include "namespaces.hh" |
40 | | #include "noinitvector.hh" |
41 | | |
42 | | using ProtocolType = int; //!< Supported protocol types |
43 | | |
44 | | //! Representation of a Socket and many of the Berkeley functions available |
45 | | class Socket |
46 | | { |
47 | | public: |
48 | | Socket(const Socket&) = delete; |
49 | | Socket& operator=(const Socket&) = delete; |
50 | | |
51 | | Socket(int socketDesc) : |
52 | | d_socket(socketDesc) |
53 | 0 | { |
54 | 0 | } |
55 | | |
56 | | //! Construct a socket of specified address family and socket type. |
57 | | Socket(int addressFamily, int socketType, ProtocolType protocolType = 0) : |
58 | | d_socket(socket(addressFamily, socketType, protocolType)) |
59 | 0 | { |
60 | 0 | if (d_socket < 0) { |
61 | 0 | throw NetworkError(stringerror()); |
62 | 0 | } |
63 | 0 | setCloseOnExec(d_socket); |
64 | 0 | } |
65 | | |
66 | | Socket(Socket&& rhs) noexcept : |
67 | | d_buffer(std::move(rhs.d_buffer)), d_socket(rhs.d_socket) |
68 | 0 | { |
69 | 0 | rhs.d_socket = -1; |
70 | 0 | } |
71 | | |
72 | | Socket& operator=(Socket&& rhs) noexcept |
73 | 0 | { |
74 | 0 | if (d_socket != -1) { |
75 | 0 | close(d_socket); |
76 | 0 | } |
77 | 0 | d_socket = rhs.d_socket; |
78 | 0 | rhs.d_socket = -1; |
79 | 0 | d_buffer = std::move(rhs.d_buffer); |
80 | 0 | return *this; |
81 | 0 | } |
82 | | |
83 | | ~Socket() |
84 | 0 | { |
85 | 0 | try { |
86 | 0 | if (d_socket != -1) { |
87 | 0 | closesocket(d_socket); |
88 | 0 | } |
89 | 0 | } |
90 | 0 | catch (const PDNSException& e) { |
91 | 0 | } |
92 | 0 | } |
93 | | |
94 | | //! If the socket is capable of doing so, this function will wait for a connection |
95 | | [[nodiscard]] std::unique_ptr<Socket> accept() const |
96 | 0 | { |
97 | 0 | sockaddr_in remote{}; |
98 | 0 | socklen_t remlen = sizeof(remote); |
99 | 0 | memset(&remote, 0, sizeof(remote)); |
100 | 0 | int sock = ::accept(d_socket, reinterpret_cast<sockaddr*>(&remote), &remlen); // NOLINT(cppcoreguidelines-pro-type-reinterpret-cast): it's the API |
101 | 0 | if (sock < 0) { |
102 | 0 | if (errno == EAGAIN) { |
103 | 0 | return nullptr; |
104 | 0 | } |
105 | 0 |
|
106 | 0 | throw NetworkError("Accepting a connection: " + stringerror()); |
107 | 0 | } |
108 | 0 |
|
109 | 0 | return std::make_unique<Socket>(sock); |
110 | 0 | } |
111 | | |
112 | | //! Get remote address |
113 | | bool getRemote(ComboAddress& remote) const |
114 | 0 | { |
115 | 0 | socklen_t remotelen = sizeof(remote); |
116 | 0 | return getpeername(d_socket, reinterpret_cast<struct sockaddr*>(&remote), &remotelen) >= 0; // NOLINT(cppcoreguidelines-pro-type-reinterpret-cast): it's the API |
117 | 0 | } |
118 | | |
119 | | //! Check remote address against netmaskgroup ng |
120 | | [[nodiscard]] bool acl(const NetmaskGroup& netmaskGroup) const |
121 | 0 | { |
122 | 0 | ComboAddress remote; |
123 | 0 | if (getRemote(remote)) { |
124 | 0 | return netmaskGroup.match(remote); |
125 | 0 | } |
126 | 0 |
|
127 | 0 | return false; |
128 | 0 | } |
129 | | |
130 | | //! Set the socket to non-blocking |
131 | | void setNonBlocking() const |
132 | 0 | { |
133 | 0 | ::setNonBlocking(d_socket); |
134 | 0 | } |
135 | | |
136 | | //! Set the socket to blocking |
137 | | void setBlocking() const |
138 | 0 | { |
139 | 0 | ::setBlocking(d_socket); |
140 | 0 | } |
141 | | |
142 | | void setReuseAddr() const |
143 | 0 | { |
144 | 0 | try { |
145 | 0 | ::setReuseAddr(d_socket); |
146 | 0 | } |
147 | 0 | catch (const PDNSException& e) { |
148 | 0 | throw NetworkError(e.reason); |
149 | 0 | } |
150 | 0 | } |
151 | | |
152 | | void setFastOpenConnect() |
153 | 0 | { |
154 | 0 | #ifdef TCP_FASTOPEN_CONNECT |
155 | 0 | int on = 1; |
156 | 0 | if (setsockopt(d_socket, IPPROTO_TCP, TCP_FASTOPEN_CONNECT, &on, sizeof(on)) < 0) { |
157 | 0 | throw NetworkError("While setting TCP_FASTOPEN_CONNECT: " + stringerror()); |
158 | 0 | } |
159 | 0 | #else |
160 | 0 | throw NetworkError("While setting TCP_FASTOPEN_CONNECT: not compiled in"); |
161 | 0 | #endif |
162 | 0 | } |
163 | | |
164 | | //! Bind the socket to a specified endpoint |
165 | | template <typename T> |
166 | | void bind(const T& local, bool reuseaddr = true) const |
167 | | { |
168 | | int tmp = 1; |
169 | | // NOLINTNEXTLINE(cppcoreguidelines-pro-type-reinterpret-cast) |
170 | | if (reuseaddr && setsockopt(d_socket, SOL_SOCKET, SO_REUSEADDR, reinterpret_cast<char*>(&tmp), sizeof tmp) < 0) { |
171 | | throw NetworkError("Setsockopt failed: " + stringerror()); |
172 | | } |
173 | | // NOLINTNEXTLINE(cppcoreguidelines-pro-type-reinterpret-cast) |
174 | | if (::bind(d_socket, reinterpret_cast<const struct sockaddr*>(&local), local.getSocklen()) < 0) { |
175 | | throw NetworkError("While binding: " + stringerror()); |
176 | | } |
177 | | } |
178 | | |
179 | | //! Connect the socket to a specified endpoint |
180 | | void connect(const ComboAddress& address, int timeout = 0) const |
181 | 0 | { |
182 | 0 | SConnectWithTimeout(d_socket, address, timeval{timeout, 0}); |
183 | 0 | } |
184 | | |
185 | | //! For datagram sockets, receive a datagram and learn where it came from |
186 | | /** For datagram sockets, receive a datagram and learn where it came from |
187 | | \param dgram Will be filled with the datagram |
188 | | \param ep Will be filled with the origin of the datagram */ |
189 | | void recvFrom(string& dgram, ComboAddress& remote) const |
190 | 0 | { |
191 | 0 | socklen_t remlen = sizeof(remote); |
192 | 0 | if (dgram.size() < s_buflen) { |
193 | 0 | dgram.resize(s_buflen); |
194 | 0 | } |
195 | 0 | // NOLINTNEXTLINE(cppcoreguidelines-pro-type-reinterpret-cast) |
196 | 0 | auto bytes = recvfrom(d_socket, dgram.data(), dgram.size(), 0, reinterpret_cast<sockaddr*>(&remote), &remlen); |
197 | 0 | if (bytes < 0) { |
198 | 0 | throw NetworkError("After recvfrom: " + stringerror()); |
199 | 0 | } |
200 | 0 | dgram.resize(static_cast<size_t>(bytes)); |
201 | 0 | } |
202 | | |
203 | | bool recvFromAsync(PacketBuffer& dgram, ComboAddress& remote) const |
204 | 0 | { |
205 | 0 | socklen_t remlen = sizeof(remote); |
206 | 0 | if (dgram.size() < s_buflen) { |
207 | 0 | dgram.resize(s_buflen); |
208 | 0 | } |
209 | 0 | // NOLINTNEXTLINE(cppcoreguidelines-pro-type-reinterpret-cast) |
210 | 0 | auto bytes = recvfrom(d_socket, dgram.data(), dgram.size(), 0, reinterpret_cast<sockaddr*>(&remote), &remlen); |
211 | 0 | if (bytes < 0) { |
212 | 0 | if (errno != EAGAIN) { |
213 | 0 | throw NetworkError("After async recvfrom: " + stringerror()); |
214 | 0 | } |
215 | 0 | return false; |
216 | 0 | } |
217 | 0 | dgram.resize(static_cast<size_t>(bytes)); |
218 | 0 | return true; |
219 | 0 | } |
220 | | |
221 | | //! For datagram sockets, send a datagram to a destination |
222 | | void sendTo(const char* msg, size_t len, const ComboAddress& remote) const |
223 | 0 | { |
224 | 0 | // NOLINTNEXTLINE(cppcoreguidelines-pro-type-reinterpret-cast) |
225 | 0 | if (sendto(d_socket, msg, len, 0, reinterpret_cast<const sockaddr*>(&remote), remote.getSocklen()) < 0) { |
226 | 0 | throw NetworkError("After sendto: " + stringerror()); |
227 | 0 | } |
228 | 0 | } |
229 | | |
230 | | //! For connected datagram sockets, send a datagram |
231 | | void send(const std::string& msg) const |
232 | 0 | { |
233 | 0 | if (::send(d_socket, msg.data(), msg.size(), 0) < 0) { |
234 | 0 | throw NetworkError("After send: " + stringerror()); |
235 | 0 | } |
236 | 0 | } |
237 | | |
238 | | /** For datagram sockets, send a datagram to a destination |
239 | | \param dgram The datagram |
240 | | \param remote The intended destination of the datagram */ |
241 | | void sendTo(const string& dgram, const ComboAddress& remote) const |
242 | 0 | { |
243 | 0 | sendTo(dgram.data(), dgram.length(), remote); |
244 | 0 | } |
245 | | |
246 | | //! Write this data to the socket, taking care that all bytes are written out |
247 | | void writen(const string& data) const |
248 | 0 | { |
249 | 0 | if (data.empty()) { |
250 | 0 | return; |
251 | 0 | } |
252 | 0 |
|
253 | 0 | size_t toWrite = data.length(); |
254 | 0 | const char* ptr = data.data(); |
255 | 0 |
|
256 | 0 | do { |
257 | 0 | auto res = ::send(d_socket, ptr, toWrite, 0); |
258 | 0 | if (res < 0) { |
259 | 0 | throw NetworkError("Writing to a socket: " + stringerror()); |
260 | 0 | } |
261 | 0 | if (res == 0) { |
262 | 0 | throw NetworkError("EOF on socket"); |
263 | 0 | } |
264 | 0 | toWrite -= static_cast<size_t>(res); |
265 | 0 | ptr += static_cast<size_t>(res); // NOLINT(cppcoreguidelines-pro-bounds-pointer-arithmetic) |
266 | 0 | } while (toWrite > 0); |
267 | 0 | } |
268 | | |
269 | | //! tries to write toWrite bytes from ptr to the socket |
270 | | /** tries to write toWrite bytes from ptr to the socket, but does not make sure they al get written out |
271 | | \param ptr Location to write from |
272 | | \param toWrite number of bytes to try |
273 | | */ |
274 | | size_t tryWrite(const char* ptr, size_t toWrite) const |
275 | 0 | { |
276 | 0 | auto res = ::send(d_socket, ptr, toWrite, 0); |
277 | 0 | if (res == 0) { |
278 | 0 | throw NetworkError("EOF on writing to a socket"); |
279 | 0 | } |
280 | 0 | if (res > 0) { |
281 | 0 | return res; |
282 | 0 | } |
283 | 0 |
|
284 | 0 | if (errno == EAGAIN) { |
285 | 0 | return 0; |
286 | 0 | } |
287 | 0 |
|
288 | 0 | throw NetworkError("Writing to a socket: " + stringerror()); |
289 | 0 | } |
290 | | |
291 | | //! Writes toWrite bytes from ptr to the socket |
292 | | /** Writes toWrite bytes from ptr to the socket. Returns how many bytes were written */ |
293 | | size_t write(const char* ptr, size_t toWrite) const |
294 | 0 | { |
295 | 0 | auto res = ::send(d_socket, ptr, toWrite, 0); |
296 | 0 | if (res < 0) { |
297 | 0 | throw NetworkError("Writing to a socket: " + stringerror()); |
298 | 0 | } |
299 | 0 | return res; |
300 | 0 | } |
301 | | |
302 | | void writenWithTimeout(const void* buffer, size_t n, int timeout) const |
303 | 0 | { |
304 | 0 | size_t bytes = n; |
305 | 0 | const char* ptr = static_cast<const char*>(buffer); |
306 | 0 |
|
307 | 0 | while (bytes > 0) { |
308 | 0 | auto ret = ::write(d_socket, ptr, bytes); |
309 | 0 | if (ret < 0) { |
310 | 0 | if (errno == EAGAIN) { |
311 | 0 | ret = waitForRWData(d_socket, false, timeout, 0); |
312 | 0 | if (ret < 0) { |
313 | 0 | throw NetworkError("Waiting for data write"); |
314 | 0 | } |
315 | 0 | if (ret == 0) { |
316 | 0 | throw NetworkError("Timeout writing data"); |
317 | 0 | } |
318 | 0 | continue; |
319 | 0 | } |
320 | 0 | throw NetworkError("Writing data: " + stringerror()); |
321 | 0 | } |
322 | 0 | if (ret == 0) { |
323 | 0 | throw NetworkError("Did not fulfill TCP write due to EOF"); |
324 | 0 | } |
325 | 0 |
|
326 | 0 | ptr += static_cast<size_t>(ret); // NOLINT(cppcoreguidelines-pro-bounds-pointer-arithmetic) |
327 | 0 | bytes -= static_cast<size_t>(ret); |
328 | 0 | } |
329 | 0 | } |
330 | | |
331 | | //! reads one character from the socket |
332 | | [[nodiscard]] int getChar() const |
333 | 0 | { |
334 | 0 | char character{}; |
335 | 0 |
|
336 | 0 | ssize_t res = ::recv(d_socket, &character, 1, 0); |
337 | 0 | if (res == 0) { |
338 | 0 | return character; |
339 | 0 | } |
340 | 0 | return -1; |
341 | 0 | } |
342 | | |
343 | | void getline(string& data) const |
344 | 0 | { |
345 | 0 | data.clear(); |
346 | 0 | while (true) { |
347 | 0 | int character = getChar(); |
348 | 0 | if (character == -1) { |
349 | 0 | break; |
350 | 0 | } |
351 | 0 | data += (char)character; |
352 | 0 | if (character == '\n') { |
353 | 0 | break; |
354 | 0 | } |
355 | 0 | } |
356 | 0 | } |
357 | | |
358 | | //! Reads a block of data from the socket to a string |
359 | | void read(string& data) |
360 | 0 | { |
361 | 0 | d_buffer.resize(s_buflen); |
362 | 0 | ssize_t res = ::recv(d_socket, d_buffer.data(), s_buflen, 0); |
363 | 0 | if (res < 0) { |
364 | 0 | throw NetworkError("Reading from a socket: " + stringerror()); |
365 | 0 | } |
366 | 0 | data.assign(d_buffer, 0, static_cast<size_t>(res)); |
367 | 0 | } |
368 | | |
369 | | //! Reads a block of data from the socket to a block of memory |
370 | | size_t read(char* buffer, size_t bytes) const |
371 | 0 | { |
372 | 0 | auto res = ::recv(d_socket, buffer, bytes, 0); |
373 | 0 | if (res < 0) { |
374 | 0 | throw NetworkError("Reading from a socket: " + stringerror()); |
375 | 0 | } |
376 | 0 | return static_cast<size_t>(res); |
377 | 0 | } |
378 | | |
379 | | /** Read a bock of data from the socket to a block of memory, |
380 | | * waiting at most 'timeout' seconds for the data to become |
381 | | * available. Be aware that this does _NOT_ handle partial reads |
382 | | * for you. |
383 | | */ |
384 | | size_t readWithTimeout(char* buffer, size_t n, int timeout) const |
385 | 0 | { |
386 | 0 | int err = waitForRWData(d_socket, true, timeout, 0); |
387 | 0 |
|
388 | 0 | if (err == 0) { |
389 | 0 | throw NetworkError("timeout reading"); |
390 | 0 | } |
391 | 0 | if (err < 0) { |
392 | 0 | throw NetworkError("nonblocking read failed: " + stringerror()); |
393 | 0 | } |
394 | 0 |
|
395 | 0 | return read(buffer, n); |
396 | 0 | } |
397 | | |
398 | | //! Sets the socket to listen with a default listen backlog of 10 pending connections |
399 | | void listen(int length = 10) const |
400 | 0 | { |
401 | 0 | if (::listen(d_socket, length) < 0) { |
402 | 0 | throw NetworkError("Setting socket to listen: " + stringerror()); |
403 | 0 | } |
404 | 0 | } |
405 | | |
406 | | //! Returns the internal file descriptor of the socket |
407 | | [[nodiscard]] int getHandle() const |
408 | 0 | { |
409 | 0 | return d_socket; |
410 | 0 | } |
411 | | |
412 | | int releaseHandle() |
413 | 0 | { |
414 | 0 | int ret = d_socket; |
415 | 0 | d_socket = -1; |
416 | 0 | return ret; |
417 | 0 | } |
418 | | |
419 | | private: |
420 | | static constexpr size_t s_buflen{4096}; |
421 | | std::string d_buffer; |
422 | | int d_socket; |
423 | | }; |