Line | Count | Source |
1 | | /* |
2 | | * This file is part of PowerDNS or dnsdist. |
3 | | * Copyright -- PowerDNS.COM B.V. and its contributors |
4 | | * |
5 | | * This program is free software; you can redistribute it and/or modify |
6 | | * it under the terms of version 2 of the GNU General Public License as |
7 | | * published by the Free Software Foundation. |
8 | | * |
9 | | * In addition, for the avoidance of any doubt, permission is granted to |
10 | | * link this program with OpenSSL and to (re)distribute the binaries |
11 | | * produced as the result of such linking. |
12 | | * |
13 | | * This program is distributed in the hope that it will be useful, |
14 | | * but WITHOUT ANY WARRANTY; without even the implied warranty of |
15 | | * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the |
16 | | * GNU General Public License for more details. |
17 | | * |
18 | | * You should have received a copy of the GNU General Public License |
19 | | * along with this program; if not, write to the Free Software |
20 | | * Foundation, Inc., 51 Franklin Street, Fifth Floor, Boston, MA 02110-1301 USA. |
21 | | */ |
22 | | #pragma once |
23 | | #include <string> |
24 | | #include <sstream> |
25 | | #include <iostream> |
26 | | #include "iputils.hh" |
27 | | #include <cerrno> |
28 | | #include <sys/types.h> |
29 | | #include <unistd.h> |
30 | | #include <sys/socket.h> |
31 | | #include <netinet/in.h> |
32 | | #include <netinet/tcp.h> |
33 | | #include <arpa/inet.h> |
34 | | #include <sys/select.h> |
35 | | #include <fcntl.h> |
36 | | #include <stdexcept> |
37 | | |
38 | | #include <csignal> |
39 | | #include "namespaces.hh" |
40 | | #include "noinitvector.hh" |
41 | | |
42 | | using ProtocolType = int; //!< Supported protocol types |
43 | | |
44 | | //! Representation of a Socket and many of the Berkeley functions available |
45 | | class Socket |
46 | | { |
47 | | public: |
48 | | Socket(const Socket&) = delete; |
49 | | Socket& operator=(const Socket&) = delete; |
50 | | |
51 | | Socket(int socketDesc) : |
52 | | d_socket(socketDesc) |
53 | 0 | { |
54 | 0 | } |
55 | | |
56 | | //! Construct a socket of specified address family and socket type. |
57 | | Socket(int addressFamily, int socketType, ProtocolType protocolType = 0) : |
58 | | d_socket(socket(addressFamily, socketType, protocolType)) |
59 | 0 | { |
60 | 0 | if (d_socket < 0) { |
61 | 0 | throw NetworkError(stringerror()); |
62 | 0 | } |
63 | 0 | setCloseOnExec(d_socket); |
64 | 0 | } |
65 | | |
66 | | Socket(Socket&& rhs) noexcept : |
67 | | d_buffer(std::move(rhs.d_buffer)), d_socket(rhs.d_socket) |
68 | 0 | { |
69 | 0 | rhs.d_socket = -1; |
70 | 0 | } |
71 | | |
72 | | Socket& operator=(Socket&& rhs) noexcept |
73 | 0 | { |
74 | 0 | if (d_socket != -1) { |
75 | 0 | close(d_socket); |
76 | 0 | } |
77 | 0 | d_socket = rhs.d_socket; |
78 | 0 | rhs.d_socket = -1; |
79 | 0 | d_buffer = std::move(rhs.d_buffer); |
80 | 0 | return *this; |
81 | 0 | } |
82 | | |
83 | | ~Socket() |
84 | 0 | { |
85 | 0 | try { |
86 | 0 | if (d_socket != -1) { |
87 | 0 | closesocket(d_socket); |
88 | 0 | } |
89 | 0 | } |
90 | 0 | catch (const PDNSException& e) { |
91 | 0 | } |
92 | 0 | } |
93 | | |
94 | | //! If the socket is capable of doing so, this function will wait for a connection |
95 | | [[nodiscard]] std::unique_ptr<Socket> accept() const |
96 | 0 | { |
97 | 0 | sockaddr_in remote{}; |
98 | 0 | socklen_t remlen = sizeof(remote); |
99 | 0 | memset(&remote, 0, sizeof(remote)); |
100 | 0 | int sock = ::accept(d_socket, reinterpret_cast<sockaddr*>(&remote), &remlen); // NOLINT(cppcoreguidelines-pro-type-reinterpret-cast): it's the API |
101 | 0 | if (sock < 0) { |
102 | 0 | if (errno == EAGAIN) { |
103 | 0 | return nullptr; |
104 | 0 | } |
105 | 0 |
|
106 | 0 | throw NetworkError("Accepting a connection: " + stringerror()); |
107 | 0 | } |
108 | 0 |
|
109 | 0 | return std::make_unique<Socket>(sock); |
110 | 0 | } |
111 | | |
112 | | //! Get remote address |
113 | | bool getRemote(ComboAddress& remote) const |
114 | 0 | { |
115 | 0 | socklen_t remotelen = sizeof(remote); |
116 | 0 | return getpeername(d_socket, reinterpret_cast<struct sockaddr*>(&remote), &remotelen) >= 0; // NOLINT(cppcoreguidelines-pro-type-reinterpret-cast): it's the API |
117 | 0 | } |
118 | | |
119 | | //! Check remote address against netmaskgroup ng |
120 | | [[nodiscard]] bool acl(const NetmaskGroup& netmaskGroup) const |
121 | 0 | { |
122 | 0 | ComboAddress remote; |
123 | 0 | if (getRemote(remote)) { |
124 | 0 | if (netmaskGroup.match(remote)) { |
125 | 0 | return true; |
126 | 0 | } |
127 | 0 |
|
128 | 0 | if (remote.isMappedIPv4()) { |
129 | 0 | return netmaskGroup.match(remote.mapToIPv4()); |
130 | 0 | } |
131 | 0 | } |
132 | 0 |
|
133 | 0 | return false; |
134 | 0 | } |
135 | | |
136 | | //! Set the socket to non-blocking |
137 | | void setNonBlocking() const |
138 | 0 | { |
139 | 0 | ::setNonBlocking(d_socket); |
140 | 0 | } |
141 | | |
142 | | //! Set the socket to blocking |
143 | | void setBlocking() const |
144 | 0 | { |
145 | 0 | ::setBlocking(d_socket); |
146 | 0 | } |
147 | | |
148 | | void setReuseAddr() const |
149 | 0 | { |
150 | 0 | try { |
151 | 0 | ::setReuseAddr(d_socket); |
152 | 0 | } |
153 | 0 | catch (const PDNSException& e) { |
154 | 0 | throw NetworkError(e.reason); |
155 | 0 | } |
156 | 0 | } |
157 | | |
158 | | void setFastOpenConnect() |
159 | 0 | { |
160 | 0 | #ifdef TCP_FASTOPEN_CONNECT |
161 | 0 | int on = 1; |
162 | 0 | if (setsockopt(d_socket, IPPROTO_TCP, TCP_FASTOPEN_CONNECT, &on, sizeof(on)) < 0) { |
163 | 0 | throw NetworkError("While setting TCP_FASTOPEN_CONNECT: " + stringerror()); |
164 | 0 | } |
165 | 0 | #else |
166 | 0 | throw NetworkError("While setting TCP_FASTOPEN_CONNECT: not compiled in"); |
167 | 0 | #endif |
168 | 0 | } |
169 | | |
170 | | //! Bind the socket to a specified endpoint |
171 | | template <typename T> |
172 | | void bind(const T& local, bool reuseaddr = true) const |
173 | | { |
174 | | int tmp = 1; |
175 | | // NOLINTNEXTLINE(cppcoreguidelines-pro-type-reinterpret-cast) |
176 | | if (reuseaddr && setsockopt(d_socket, SOL_SOCKET, SO_REUSEADDR, reinterpret_cast<char*>(&tmp), sizeof tmp) < 0) { |
177 | | throw NetworkError("Setsockopt failed: " + stringerror()); |
178 | | } |
179 | | // NOLINTNEXTLINE(cppcoreguidelines-pro-type-reinterpret-cast) |
180 | | if (::bind(d_socket, reinterpret_cast<const struct sockaddr*>(&local), local.getSocklen()) < 0) { |
181 | | throw NetworkError("While binding: " + stringerror()); |
182 | | } |
183 | | } |
184 | | |
185 | | //! Connect the socket to a specified endpoint |
186 | | void connect(const ComboAddress& address, int timeout = 0) const |
187 | 0 | { |
188 | 0 | SConnectWithTimeout(d_socket, false, address, timeval{timeout, 0}); |
189 | 0 | } |
190 | | |
191 | | //! For datagram sockets, receive a datagram and learn where it came from |
192 | | /** For datagram sockets, receive a datagram and learn where it came from |
193 | | \param dgram Will be filled with the datagram |
194 | | \param ep Will be filled with the origin of the datagram */ |
195 | | void recvFrom(string& dgram, ComboAddress& remote) const |
196 | 0 | { |
197 | 0 | socklen_t remlen = sizeof(remote); |
198 | 0 | if (dgram.size() < s_buflen) { |
199 | 0 | dgram.resize(s_buflen); |
200 | 0 | } |
201 | 0 | // NOLINTNEXTLINE(cppcoreguidelines-pro-type-reinterpret-cast) |
202 | 0 | auto bytes = recvfrom(d_socket, dgram.data(), dgram.size(), 0, reinterpret_cast<sockaddr*>(&remote), &remlen); |
203 | 0 | if (bytes < 0) { |
204 | 0 | throw NetworkError("After recvfrom: " + stringerror()); |
205 | 0 | } |
206 | 0 | dgram.resize(static_cast<size_t>(bytes)); |
207 | 0 | } |
208 | | |
209 | | bool recvFromAsync(PacketBuffer& dgram, ComboAddress& remote) const |
210 | 0 | { |
211 | 0 | socklen_t remlen = sizeof(remote); |
212 | 0 | if (dgram.size() < s_buflen) { |
213 | 0 | dgram.resize(s_buflen); |
214 | 0 | } |
215 | 0 | // NOLINTNEXTLINE(cppcoreguidelines-pro-type-reinterpret-cast) |
216 | 0 | auto bytes = recvfrom(d_socket, dgram.data(), dgram.size(), 0, reinterpret_cast<sockaddr*>(&remote), &remlen); |
217 | 0 | if (bytes < 0) { |
218 | 0 | if (errno != EAGAIN) { |
219 | 0 | throw NetworkError("After async recvfrom: " + stringerror()); |
220 | 0 | } |
221 | 0 | return false; |
222 | 0 | } |
223 | 0 | dgram.resize(static_cast<size_t>(bytes)); |
224 | 0 | return true; |
225 | 0 | } |
226 | | |
227 | | //! For datagram sockets, send a datagram to a destination |
228 | | void sendTo(const char* msg, size_t len, const ComboAddress& remote) const |
229 | 0 | { |
230 | 0 | // NOLINTNEXTLINE(cppcoreguidelines-pro-type-reinterpret-cast) |
231 | 0 | if (sendto(d_socket, msg, len, 0, reinterpret_cast<const sockaddr*>(&remote), remote.getSocklen()) < 0) { |
232 | 0 | throw NetworkError("After sendto: " + stringerror()); |
233 | 0 | } |
234 | 0 | } |
235 | | |
236 | | //! For connected datagram sockets, send a datagram |
237 | | void send(const std::string& msg) const |
238 | 0 | { |
239 | 0 | if (::send(d_socket, msg.data(), msg.size(), 0) < 0) { |
240 | 0 | throw NetworkError("After send: " + stringerror()); |
241 | 0 | } |
242 | 0 | } |
243 | | |
244 | | /** For datagram sockets, send a datagram to a destination |
245 | | \param dgram The datagram |
246 | | \param remote The intended destination of the datagram */ |
247 | | void sendTo(const string& dgram, const ComboAddress& remote) const |
248 | 0 | { |
249 | 0 | sendTo(dgram.data(), dgram.length(), remote); |
250 | 0 | } |
251 | | |
252 | | //! Write this data to the socket, taking care that all bytes are written out |
253 | | void writen(const string& data) const |
254 | 0 | { |
255 | 0 | if (data.empty()) { |
256 | 0 | return; |
257 | 0 | } |
258 | 0 |
|
259 | 0 | size_t toWrite = data.length(); |
260 | 0 | const char* ptr = data.data(); |
261 | 0 |
|
262 | 0 | do { |
263 | 0 | auto res = ::send(d_socket, ptr, toWrite, 0); |
264 | 0 | if (res < 0) { |
265 | 0 | throw NetworkError("Writing to a socket: " + stringerror()); |
266 | 0 | } |
267 | 0 | if (res == 0) { |
268 | 0 | throw NetworkError("EOF on socket"); |
269 | 0 | } |
270 | 0 | toWrite -= static_cast<size_t>(res); |
271 | 0 | ptr += static_cast<size_t>(res); // NOLINT(cppcoreguidelines-pro-bounds-pointer-arithmetic) |
272 | 0 | } while (toWrite > 0); |
273 | 0 | } |
274 | | |
275 | | //! tries to write toWrite bytes from ptr to the socket |
276 | | /** tries to write toWrite bytes from ptr to the socket, but does not make sure they al get written out |
277 | | \param ptr Location to write from |
278 | | \param toWrite number of bytes to try |
279 | | */ |
280 | | size_t tryWrite(const char* ptr, size_t toWrite) const |
281 | 0 | { |
282 | 0 | auto res = ::send(d_socket, ptr, toWrite, 0); |
283 | 0 | if (res == 0) { |
284 | 0 | throw NetworkError("EOF on writing to a socket"); |
285 | 0 | } |
286 | 0 | if (res > 0) { |
287 | 0 | return res; |
288 | 0 | } |
289 | 0 |
|
290 | 0 | if (errno == EAGAIN) { |
291 | 0 | return 0; |
292 | 0 | } |
293 | 0 |
|
294 | 0 | throw NetworkError("Writing to a socket: " + stringerror()); |
295 | 0 | } |
296 | | |
297 | | //! Writes toWrite bytes from ptr to the socket |
298 | | /** Writes toWrite bytes from ptr to the socket. Returns how many bytes were written */ |
299 | | size_t write(const char* ptr, size_t toWrite) const |
300 | 0 | { |
301 | 0 | auto res = ::send(d_socket, ptr, toWrite, 0); |
302 | 0 | if (res < 0) { |
303 | 0 | throw NetworkError("Writing to a socket: " + stringerror()); |
304 | 0 | } |
305 | 0 | return res; |
306 | 0 | } |
307 | | |
308 | | void writenWithTimeout(const void* buffer, size_t n, int timeout) const |
309 | 0 | { |
310 | 0 | size_t bytes = n; |
311 | 0 | const char* ptr = static_cast<const char*>(buffer); |
312 | 0 |
|
313 | 0 | while (bytes > 0) { |
314 | 0 | auto ret = ::write(d_socket, ptr, bytes); |
315 | 0 | if (ret < 0) { |
316 | 0 | if (errno == EAGAIN) { |
317 | 0 | ret = waitForRWData(d_socket, false, timeout, 0); |
318 | 0 | if (ret < 0) { |
319 | 0 | throw NetworkError("Waiting for data write"); |
320 | 0 | } |
321 | 0 | if (ret == 0) { |
322 | 0 | throw NetworkError("Timeout writing data"); |
323 | 0 | } |
324 | 0 | continue; |
325 | 0 | } |
326 | 0 | throw NetworkError("Writing data: " + stringerror()); |
327 | 0 | } |
328 | 0 | if (ret == 0) { |
329 | 0 | throw NetworkError("Did not fulfill TCP write due to EOF"); |
330 | 0 | } |
331 | 0 |
|
332 | 0 | ptr += static_cast<size_t>(ret); // NOLINT(cppcoreguidelines-pro-bounds-pointer-arithmetic) |
333 | 0 | bytes -= static_cast<size_t>(ret); |
334 | 0 | } |
335 | 0 | } |
336 | | |
337 | | //! Reads a block of data from the socket to a string |
338 | | void read(string& data) |
339 | 0 | { |
340 | 0 | d_buffer.resize(s_buflen); |
341 | 0 | ssize_t res = ::recv(d_socket, d_buffer.data(), s_buflen, 0); |
342 | 0 | if (res < 0) { |
343 | 0 | throw NetworkError("Reading from a socket: " + stringerror()); |
344 | 0 | } |
345 | 0 | data.assign(d_buffer, 0, static_cast<size_t>(res)); |
346 | 0 | } |
347 | | |
348 | | //! Reads a block of data from the socket to a block of memory |
349 | | size_t read(char* buffer, size_t bytes) const |
350 | 0 | { |
351 | 0 | auto res = ::recv(d_socket, buffer, bytes, 0); |
352 | 0 | if (res < 0) { |
353 | 0 | throw NetworkError("Reading from a socket: " + stringerror()); |
354 | 0 | } |
355 | 0 | return static_cast<size_t>(res); |
356 | 0 | } |
357 | | |
358 | | /** Read a bock of data from the socket to a block of memory, |
359 | | * waiting at most 'timeout' seconds for the data to become |
360 | | * available. Be aware that this does _NOT_ handle partial reads |
361 | | * for you. |
362 | | */ |
363 | | size_t readWithTimeout(char* buffer, size_t n, int timeout) const |
364 | 0 | { |
365 | 0 | int err = waitForRWData(d_socket, true, timeout, 0); |
366 | 0 |
|
367 | 0 | if (err == 0) { |
368 | 0 | throw NetworkError("timeout reading"); |
369 | 0 | } |
370 | 0 | if (err < 0) { |
371 | 0 | throw NetworkError("nonblocking read failed: " + stringerror()); |
372 | 0 | } |
373 | 0 |
|
374 | 0 | return read(buffer, n); |
375 | 0 | } |
376 | | |
377 | | //! Sets the socket to listen with a default listen backlog of 10 pending connections |
378 | | void listen(int length = 10) const |
379 | 0 | { |
380 | 0 | if (::listen(d_socket, length) < 0) { |
381 | 0 | throw NetworkError("Setting socket to listen: " + stringerror()); |
382 | 0 | } |
383 | 0 | } |
384 | | |
385 | | //! Returns the internal file descriptor of the socket |
386 | | [[nodiscard]] int getHandle() const |
387 | 0 | { |
388 | 0 | return d_socket; |
389 | 0 | } |
390 | | |
391 | | int releaseHandle() |
392 | 0 | { |
393 | 0 | int ret = d_socket; |
394 | 0 | d_socket = -1; |
395 | 0 | return ret; |
396 | 0 | } |
397 | | |
398 | | private: |
399 | | static constexpr size_t s_buflen{4096}; |
400 | | std::string d_buffer; |
401 | | int d_socket; |
402 | | }; |