/src/crow/include/crow/websocket.h
Line | Count | Source |
1 | | #pragma once |
2 | | #include <array> |
3 | | #include <memory> |
4 | | #include <optional> |
5 | | #include <string> |
6 | | #include <thread> |
7 | | #include "crow/http_response.h" |
8 | | #include "crow/logging.h" |
9 | | #include "crow/socket_adaptors.h" |
10 | | #include "crow/http_request.h" |
11 | | #include "crow/TinySHA1.hpp" |
12 | | #include "crow/utility.h" |
13 | | |
14 | | namespace crow // NOTE: Already documented in "crow/app.h" |
15 | | { |
16 | | #ifdef CROW_USE_BOOST |
17 | | namespace asio = boost::asio; |
18 | | using error_code = boost::system::error_code; |
19 | | #else |
20 | | using error_code = asio::error_code; |
21 | | #endif |
22 | | |
23 | | /** |
24 | | * \namespace crow::websocket |
25 | | * \brief Namespace that includes the \ref Connection class |
26 | | * and \ref connection struct. Useful for WebSockets connection. |
27 | | * |
28 | | * Used specially in crow/websocket.h, crow/app.h and crow/routing.h |
29 | | */ |
30 | | namespace websocket |
31 | | { |
32 | | enum class WebSocketReadState |
33 | | { |
34 | | MiniHeader, |
35 | | Len16, |
36 | | Len64, |
37 | | Mask, |
38 | | Payload, |
39 | | }; |
40 | | |
41 | | // Codes taken from https://www.rfc-editor.org/rfc/rfc6455#section-7.4.1 |
42 | | enum CloseStatusCode : uint16_t { |
43 | | NormalClosure = 1000, |
44 | | EndpointGoingAway = 1001, |
45 | | ProtocolError = 1002, |
46 | | UnacceptableData = 1003, |
47 | | InconsistentData = 1007, |
48 | | PolicyViolated = 1008, |
49 | | MessageTooBig = 1009, |
50 | | ExtensionsNotNegotiated = 1010, |
51 | | UnexpectedCondition = 1011, |
52 | | |
53 | | // Reserved for applications only, should not send/receive these to/from clients |
54 | | NoStatusCodePresent = 1005, |
55 | | ClosedAbnormally = 1006, |
56 | | TLSHandshakeFailure = 1015, |
57 | | |
58 | | StartStatusCodesForLibraries = 3000, |
59 | | StartStatusCodesForPrivateUse = 4000, |
60 | | // Status code should be between 1000 and 4999 inclusive |
61 | | StartStatusCodes = NormalClosure, |
62 | | EndStatusCodes = 4999, |
63 | | }; |
64 | | |
65 | | /// A base class for websocket connection. |
66 | | struct connection |
67 | | { |
68 | | virtual void send_binary(std::string msg) = 0; |
69 | | virtual void send_text(std::string msg) = 0; |
70 | | virtual void send_ping(std::string msg) = 0; |
71 | | virtual void send_pong(std::string msg) = 0; |
72 | | virtual void close(std::string const& msg = "quit", uint16_t status_code = CloseStatusCode::NormalClosure) = 0; |
73 | | virtual std::string get_remote_ip() = 0; |
74 | | virtual std::string get_subprotocol() const = 0; |
75 | | virtual ~connection() = default; |
76 | | |
77 | 0 | void userdata(void* u) { userdata_ = u; } |
78 | 0 | void* userdata() { return userdata_; } |
79 | | |
80 | | private: |
81 | | void* userdata_; |
82 | | }; |
83 | | |
84 | | // Modified version of the illustration in RFC6455 Section-5.2 |
85 | | // |
86 | | // |
87 | | // 0 1 2 3 -byte |
88 | | // 0 1 2 3 4 5 6 7 0 1 2 3 4 5 6 7 0 1 2 3 4 5 6 7 0 1 2 3 4 5 6 7 -bit |
89 | | // +-+-+-+-+-------+-+-------------+-------------------------------+ |
90 | | // |F|R|R|R| opcode|M| Payload len | Extended payload length | |
91 | | // |I|S|S|S| (4) |A| (7) | (16/64) | |
92 | | // |N|V|V|V| |S| | (if payload len==126/127) | |
93 | | // | |1|2|3| |K| | | |
94 | | // +-+-+-+-+-------+-+-------------+ - - - - - - - - - - - - - - - + |
95 | | // | Extended payload length continued, if payload len == 127 | |
96 | | // + - - - - - - - - - - - - - - - +-------------------------------+ |
97 | | // | |Masking-key, if MASK set to 1 | |
98 | | // +-------------------------------+-------------------------------+ |
99 | | // | Masking-key (continued) | Payload Data | |
100 | | // +-------------------------------- - - - - - - - - - - - - - - - + |
101 | | // : Payload Data continued ... : |
102 | | // + - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - + |
103 | | // | Payload Data continued ... | |
104 | | // +---------------------------------------------------------------+ |
105 | | // |
106 | | |
107 | | /// A websocket connection. |
108 | | |
109 | | template<typename Adaptor, typename Handler> |
110 | | class Connection : public connection, public std::enable_shared_from_this<Connection<Adaptor, Handler>> |
111 | | { |
112 | | public: |
113 | | /// Factory for a connection. |
114 | | /// |
115 | | /// Requires a request with an "Upgrade: websocket" header.<br> |
116 | | /// Automatically handles the handshake. |
117 | | static void create(const crow::request& req, Adaptor adaptor, Handler* handler, |
118 | | uint64_t max_payload, const std::vector<std::string>& subprotocols, |
119 | | std::function<void(crow::websocket::connection&)> open_handler, |
120 | | std::function<void(crow::websocket::connection&, const std::string&, bool)> message_handler, |
121 | | std::function<void(crow::websocket::connection&, const std::string&, uint16_t)> close_handler, |
122 | | std::function<void(crow::websocket::connection&, const std::string&)> error_handler, |
123 | | std::function<void(const crow::request&, std::optional<crow::response>&, void**)> accept_handler, |
124 | | bool mirror_protocols) |
125 | | { |
126 | | auto conn = std::shared_ptr<Connection>(new Connection(std::move(adaptor), |
127 | | handler, max_payload, |
128 | | std::move(open_handler), |
129 | | std::move(message_handler), |
130 | | std::move(close_handler), |
131 | | std::move(error_handler), |
132 | | std::move(accept_handler))); |
133 | | |
134 | | // Perform handshake validation |
135 | | if (!utility::string_equals(req.get_header_value("upgrade"), "websocket")) |
136 | | { |
137 | | conn->adaptor_.close(); |
138 | | return; |
139 | | } |
140 | | |
141 | | std::string requested_subprotocols_header = req.get_header_value("Sec-WebSocket-Protocol"); |
142 | | if (!subprotocols.empty() || !requested_subprotocols_header.empty()) |
143 | | { |
144 | | auto requested_subprotocols = utility::split(requested_subprotocols_header, ", "); |
145 | | auto subprotocol = utility::find_first_of(subprotocols.begin(), subprotocols.end(), requested_subprotocols.begin(), requested_subprotocols.end()); |
146 | | if (subprotocol != subprotocols.end()) |
147 | | { |
148 | | conn->subprotocol_ = *subprotocol; |
149 | | } |
150 | | } |
151 | | |
152 | | if (mirror_protocols & !requested_subprotocols_header.empty()) |
153 | | { |
154 | | conn->subprotocol_ = requested_subprotocols_header; |
155 | | } |
156 | | |
157 | | if (conn->accept_handler_) |
158 | | { |
159 | | void* ud = nullptr; |
160 | | std::optional<crow::response> res; |
161 | | conn->accept_handler_(req, res, &ud); |
162 | | if (res) |
163 | | { |
164 | | std::vector<asio::const_buffer> buffers; |
165 | | auto server_name = ""; |
166 | | std::string content_length_buffer; |
167 | | res->write_header_into_buffer(buffers, content_length_buffer, req.keep_alive, server_name); |
168 | | buffers.emplace_back(res->body.data(), res->body.size()); |
169 | | error_code ec; |
170 | | asio::write(conn->adaptor_.socket(), buffers, ec); |
171 | | conn->adaptor_.close(); |
172 | | return; |
173 | | } |
174 | | conn->userdata(ud); |
175 | | } |
176 | | |
177 | | // Sec-WebSocket-Key: dGhlIHNhbXBsZSBub25jZQ== |
178 | | // Sec-WebSocket-Version: 13 |
179 | | std::string magic = req.get_header_value("Sec-WebSocket-Key") + "258EAFA5-E914-47DA-95CA-C5AB0DC85B11"; |
180 | | sha1::SHA1 s; |
181 | | s.processBytes(magic.data(), magic.size()); |
182 | | uint8_t digest[20]; |
183 | | s.getDigestBytes(digest); |
184 | | |
185 | | conn->handler_->add_websocket(conn); |
186 | | conn->start(crow::utility::base64encode((unsigned char*)digest, 20)); |
187 | | } |
188 | | |
189 | | ~Connection() noexcept override = default; |
190 | | |
191 | | template<typename Callable> |
192 | | struct WeakWrappedMessage |
193 | | { |
194 | | Callable callable; |
195 | | std::weak_ptr<void> watch; |
196 | | |
197 | | void operator()() |
198 | | { |
199 | | if (auto anchor = watch.lock()) |
200 | | { |
201 | | std::move(callable)(); |
202 | | } |
203 | | } |
204 | | }; |
205 | | |
206 | | /// Send data through the socket. |
207 | | template<typename CompletionHandler> |
208 | | void dispatch(CompletionHandler&& handler) |
209 | | { |
210 | | asio::dispatch(adaptor_.get_io_context(), |
211 | | WeakWrappedMessage<typename std::decay<CompletionHandler>::type>{ |
212 | | std::forward<CompletionHandler>(handler), anchor_}); |
213 | | } |
214 | | |
215 | | /// Send data through the socket and return immediately. |
216 | | template<typename CompletionHandler> |
217 | | void post(CompletionHandler&& handler) |
218 | | { |
219 | | asio::post(adaptor_.get_io_context(), |
220 | | WeakWrappedMessage<typename std::decay<CompletionHandler>::type>{ |
221 | | std::forward<CompletionHandler>(handler), anchor_}); |
222 | | } |
223 | | |
224 | | /// Send a "Ping" message. |
225 | | |
226 | | /// |
227 | | /// Usually invoked to check if the other point is still online. |
228 | | void send_ping(std::string msg) override |
229 | | { |
230 | | send_data(0x9, std::move(msg)); |
231 | | } |
232 | | |
233 | | /// Send a "Pong" message. |
234 | | |
235 | | /// |
236 | | /// Usually automatically invoked as a response to a "Ping" message. |
237 | | void send_pong(std::string msg) override |
238 | | { |
239 | | send_data(0xA, std::move(msg)); |
240 | | } |
241 | | |
242 | | /// Send a binary encoded message. |
243 | | void send_binary(std::string msg) override |
244 | | { |
245 | | send_data(0x2, std::move(msg)); |
246 | | } |
247 | | |
248 | | /// Send a plaintext message. |
249 | | void send_text(std::string msg) override |
250 | | { |
251 | | send_data(0x1, std::move(msg)); |
252 | | } |
253 | | |
254 | | /// Send a close signal. |
255 | | |
256 | | /// |
257 | | /// Sets a flag to destroy the object once the message is sent. |
258 | | void close(std::string const& msg, uint16_t status_code) override |
259 | | { |
260 | | dispatch([shared_this = this->shared_from_this(), msg, status_code]() mutable { |
261 | | shared_this->has_sent_close_ = true; |
262 | | if (shared_this->has_recv_close_ && !shared_this->is_close_handler_called_) |
263 | | { |
264 | | shared_this->is_close_handler_called_ = true; |
265 | | if (shared_this->close_handler_) |
266 | | shared_this->close_handler_(*shared_this, msg, status_code); |
267 | | } |
268 | | auto header = shared_this->build_header(0x8, msg.size() + 2); |
269 | | char status_buf[2]; |
270 | | *(uint16_t*)(status_buf) = htons(status_code); |
271 | | |
272 | | shared_this->write_buffers_.emplace_back(std::move(header)); |
273 | | shared_this->write_buffers_.emplace_back(std::string(status_buf, 2)); |
274 | | shared_this->write_buffers_.emplace_back(msg); |
275 | | shared_this->do_write(); |
276 | | }); |
277 | | } |
278 | | |
279 | | std::string get_remote_ip() override |
280 | | { |
281 | | return adaptor_.address(); |
282 | | } |
283 | | |
284 | | void set_max_payload_size(uint64_t payload) |
285 | | { |
286 | | max_payload_bytes_ = payload; |
287 | | } |
288 | | |
289 | | /// Returns the matching client/server subprotocol, empty string if none matched. |
290 | | std::string get_subprotocol() const override |
291 | | { |
292 | | return subprotocol_; |
293 | | } |
294 | | |
295 | | protected: |
296 | | /// Generate the websocket headers using an opcode and the message size (in bytes). |
297 | | std::string build_header(int opcode, size_t size) |
298 | | { |
299 | | char buf[2 + 8] = "\x80\x00"; |
300 | | buf[0] += opcode; |
301 | | if (size < 126) |
302 | | { |
303 | | buf[1] += static_cast<char>(size); |
304 | | return {buf, buf + 2}; |
305 | | } |
306 | | else if (size < 0x10000) |
307 | | { |
308 | | buf[1] += 126; |
309 | | *(uint16_t*)(buf + 2) = htons(static_cast<uint16_t>(size)); |
310 | | return {buf, buf + 4}; |
311 | | } |
312 | | else |
313 | | { |
314 | | buf[1] += 127; |
315 | | *reinterpret_cast<uint64_t*>(buf + 2) = ((1 == htonl(1)) ? static_cast<uint64_t>(size) : (static_cast<uint64_t>(htonl((size)&0xFFFFFFFF)) << 32) | htonl(static_cast<uint64_t>(size) >> 32)); |
316 | | return {buf, buf + 10}; |
317 | | } |
318 | | } |
319 | | |
320 | | /// Send the HTTP upgrade response. |
321 | | |
322 | | /// |
323 | | /// Finishes the handshake process, then starts reading messages from the socket. |
324 | | void start(std::string&& hello) |
325 | | { |
326 | | static const std::string header = |
327 | | "HTTP/1.1 101 Switching Protocols\r\n" |
328 | | "Upgrade: websocket\r\n" |
329 | | "Connection: Upgrade\r\n" |
330 | | "Sec-WebSocket-Accept: "; |
331 | | write_buffers_.emplace_back(header); |
332 | | write_buffers_.emplace_back(std::move(hello)); |
333 | | write_buffers_.emplace_back(crlf); |
334 | | if (!subprotocol_.empty()) |
335 | | { |
336 | | write_buffers_.emplace_back("Sec-WebSocket-Protocol: "); |
337 | | write_buffers_.emplace_back(subprotocol_); |
338 | | write_buffers_.emplace_back(crlf); |
339 | | } |
340 | | write_buffers_.emplace_back(crlf); |
341 | | do_write(); |
342 | | if (open_handler_) |
343 | | open_handler_(*this); |
344 | | do_read(); |
345 | | } |
346 | | |
347 | | /// Read a websocket message. |
348 | | |
349 | | /// |
350 | | /// Involves:<br> |
351 | | /// Handling headers (opcodes, size).<br> |
352 | | /// Unmasking the payload.<br> |
353 | | /// Reading the actual payload.<br> |
354 | | void do_read() |
355 | | { |
356 | | if (has_sent_close_ && has_recv_close_) |
357 | | { |
358 | | close_connection_ = true; |
359 | | adaptor_.shutdown_readwrite(); |
360 | | adaptor_.close(); |
361 | | check_destroy(); |
362 | | return; |
363 | | } |
364 | | |
365 | | is_reading = true; |
366 | | switch (state_) |
367 | | { |
368 | | case WebSocketReadState::MiniHeader: |
369 | | { |
370 | | mini_header_ = 0; |
371 | | //asio::async_read(adaptor_.socket(), asio::buffer(&mini_header_, 1), |
372 | | adaptor_.socket().async_read_some( |
373 | | asio::buffer(&mini_header_, 2), |
374 | | [shared_this = this->shared_from_this()](const error_code& ec, std::size_t |
375 | | #ifdef CROW_ENABLE_DEBUG |
376 | | bytes_transferred |
377 | | #endif |
378 | | ) |
379 | | |
380 | | { |
381 | | shared_this->is_reading = false; |
382 | | shared_this->mini_header_ = ntohs(shared_this->mini_header_); |
383 | | #ifdef CROW_ENABLE_DEBUG |
384 | | |
385 | | if (!ec && bytes_transferred != 2) |
386 | | { |
387 | | throw std::runtime_error("WebSocket:MiniHeader:async_read fail:asio bug?"); |
388 | | } |
389 | | #endif |
390 | | |
391 | | if (!ec) |
392 | | { |
393 | | if ((shared_this->mini_header_ & 0x80) == 0x80) |
394 | | shared_this->has_mask_ = true; |
395 | | else //if the websocket specification is enforced and the message isn't masked, terminate the connection |
396 | | { |
397 | | #ifndef CROW_ENFORCE_WS_SPEC |
398 | | shared_this->has_mask_ = false; |
399 | | #else |
400 | | shared_this->close_connection_ = true; |
401 | | shared_this->adaptor_.shutdown_readwrite(); |
402 | | shared_this->adaptor_.close(); |
403 | | if (shared_this->error_handler_) |
404 | | shared_this->error_handler_(*shared_this, "Client connection not masked."); |
405 | | shared_this->check_destroy(CloseStatusCode::UnacceptableData); |
406 | | #endif |
407 | | } |
408 | | |
409 | | if ((shared_this->mini_header_ & 0x7f) == 127) |
410 | | { |
411 | | shared_this->state_ = WebSocketReadState::Len64; |
412 | | } |
413 | | else if ((shared_this->mini_header_ & 0x7f) == 126) |
414 | | { |
415 | | shared_this->state_ = WebSocketReadState::Len16; |
416 | | } |
417 | | else |
418 | | { |
419 | | shared_this->remaining_length_ = shared_this->mini_header_ & 0x7f; |
420 | | shared_this->state_ = WebSocketReadState::Mask; |
421 | | } |
422 | | shared_this->do_read(); |
423 | | } |
424 | | else |
425 | | { |
426 | | shared_this->close_connection_ = true; |
427 | | shared_this->adaptor_.shutdown_readwrite(); |
428 | | shared_this->adaptor_.close(); |
429 | | if (shared_this->error_handler_) |
430 | | shared_this->error_handler_(*shared_this, ec.message()); |
431 | | shared_this->check_destroy(); |
432 | | } |
433 | | }); |
434 | | } |
435 | | break; |
436 | | case WebSocketReadState::Len16: |
437 | | { |
438 | | remaining_length_ = 0; |
439 | | remaining_length16_ = 0; |
440 | | asio::async_read( |
441 | | adaptor_.socket(), asio::buffer(&remaining_length16_, 2), |
442 | | [shared_this = this->shared_from_this()](const error_code& ec, std::size_t |
443 | | #ifdef CROW_ENABLE_DEBUG |
444 | | bytes_transferred |
445 | | #endif |
446 | | ) { |
447 | | shared_this->is_reading = false; |
448 | | shared_this->remaining_length16_ = ntohs(shared_this->remaining_length16_); |
449 | | shared_this->remaining_length_ = shared_this->remaining_length16_; |
450 | | #ifdef CROW_ENABLE_DEBUG |
451 | | if (!ec && bytes_transferred != 2) |
452 | | { |
453 | | throw std::runtime_error("WebSocket:Len16:async_read fail:asio bug?"); |
454 | | } |
455 | | #endif |
456 | | |
457 | | if (!ec) |
458 | | { |
459 | | shared_this->state_ = WebSocketReadState::Mask; |
460 | | shared_this->do_read(); |
461 | | } |
462 | | else |
463 | | { |
464 | | shared_this->close_connection_ = true; |
465 | | shared_this->adaptor_.shutdown_readwrite(); |
466 | | shared_this->adaptor_.close(); |
467 | | if (shared_this->error_handler_) |
468 | | shared_this->error_handler_(*shared_this, ec.message()); |
469 | | shared_this->check_destroy(); |
470 | | } |
471 | | }); |
472 | | } |
473 | | break; |
474 | | case WebSocketReadState::Len64: |
475 | | { |
476 | | asio::async_read( |
477 | | adaptor_.socket(), asio::buffer(&remaining_length_, 8), |
478 | | [shared_this = this->shared_from_this()](const error_code& ec, std::size_t |
479 | | #ifdef CROW_ENABLE_DEBUG |
480 | | bytes_transferred |
481 | | #endif |
482 | | ) { |
483 | | shared_this->is_reading = false; |
484 | | shared_this->remaining_length_ = ((1 == ntohl(1)) ? (shared_this->remaining_length_) : (static_cast<uint64_t>(ntohl((shared_this->remaining_length_)&0xFFFFFFFF)) << 32) | ntohl((shared_this->remaining_length_) >> 32)); |
485 | | #ifdef CROW_ENABLE_DEBUG |
486 | | if (!ec && bytes_transferred != 8) |
487 | | { |
488 | | throw std::runtime_error("WebSocket:Len16:async_read fail:asio bug?"); |
489 | | } |
490 | | #endif |
491 | | |
492 | | if (!ec) |
493 | | { |
494 | | shared_this->state_ = WebSocketReadState::Mask; |
495 | | shared_this->do_read(); |
496 | | } |
497 | | else |
498 | | { |
499 | | shared_this->close_connection_ = true; |
500 | | shared_this->adaptor_.shutdown_readwrite(); |
501 | | shared_this->adaptor_.close(); |
502 | | if (shared_this->error_handler_) |
503 | | shared_this->error_handler_(*shared_this, ec.message()); |
504 | | shared_this->check_destroy(); |
505 | | } |
506 | | }); |
507 | | } |
508 | | break; |
509 | | case WebSocketReadState::Mask: |
510 | | if (remaining_length_ > max_payload_bytes_) |
511 | | { |
512 | | close_connection_ = true; |
513 | | adaptor_.close(); |
514 | | if (error_handler_) |
515 | | error_handler_(*this, "Message length exceeds maximum payload."); |
516 | | check_destroy(MessageTooBig); |
517 | | } |
518 | | else if (has_mask_) |
519 | | { |
520 | | asio::async_read( |
521 | | adaptor_.socket(), asio::buffer((char*)&mask_, 4), |
522 | | [shared_this = this->shared_from_this()](const error_code& ec, std::size_t |
523 | | #ifdef CROW_ENABLE_DEBUG |
524 | | bytes_transferred |
525 | | #endif |
526 | | ) { |
527 | | shared_this->is_reading = false; |
528 | | #ifdef CROW_ENABLE_DEBUG |
529 | | if (!ec && bytes_transferred != 4) |
530 | | { |
531 | | throw std::runtime_error("WebSocket:Mask:async_read fail:asio bug?"); |
532 | | } |
533 | | #endif |
534 | | |
535 | | if (!ec) |
536 | | { |
537 | | shared_this->state_ = WebSocketReadState::Payload; |
538 | | shared_this->do_read(); |
539 | | } |
540 | | else |
541 | | { |
542 | | shared_this->close_connection_ = true; |
543 | | if (shared_this->error_handler_) |
544 | | shared_this->error_handler_(*shared_this, ec.message()); |
545 | | shared_this->adaptor_.shutdown_readwrite(); |
546 | | shared_this->adaptor_.close(); |
547 | | shared_this->check_destroy(); |
548 | | } |
549 | | }); |
550 | | } |
551 | | else |
552 | | { |
553 | | state_ = WebSocketReadState::Payload; |
554 | | do_read(); |
555 | | } |
556 | | break; |
557 | | case WebSocketReadState::Payload: |
558 | | { |
559 | | auto to_read = static_cast<std::uint64_t>(buffer_.size()); |
560 | | if (remaining_length_ < to_read) |
561 | | to_read = remaining_length_; |
562 | | adaptor_.socket().async_read_some( |
563 | | asio::buffer(buffer_, static_cast<std::size_t>(to_read)), |
564 | | [shared_this = this->shared_from_this()](const error_code& ec, std::size_t bytes_transferred) { |
565 | | shared_this->is_reading = false; |
566 | | |
567 | | if (!ec) |
568 | | { |
569 | | shared_this->fragment_.insert(shared_this->fragment_.end(), shared_this->buffer_.begin(), shared_this->buffer_.begin() + bytes_transferred); |
570 | | shared_this->remaining_length_ -= bytes_transferred; |
571 | | if (shared_this->remaining_length_ == 0) |
572 | | { |
573 | | if (shared_this->handle_fragment()) |
574 | | { |
575 | | shared_this->state_ = WebSocketReadState::MiniHeader; |
576 | | shared_this->do_read(); |
577 | | } |
578 | | } |
579 | | else |
580 | | shared_this->do_read(); |
581 | | } |
582 | | else |
583 | | { |
584 | | shared_this->close_connection_ = true; |
585 | | if (shared_this->error_handler_) |
586 | | shared_this->error_handler_(*shared_this, ec.message()); |
587 | | shared_this->adaptor_.shutdown_readwrite(); |
588 | | shared_this->adaptor_.close(); |
589 | | shared_this->check_destroy(); |
590 | | } |
591 | | }); |
592 | | } |
593 | | break; |
594 | | } |
595 | | } |
596 | | |
597 | | /// Check if the FIN bit is set. |
598 | | bool is_FIN() |
599 | | { |
600 | | return mini_header_ & 0x8000; |
601 | | } |
602 | | |
603 | | /// Extract the opcode from the header. |
604 | | int opcode() |
605 | | { |
606 | | return (mini_header_ & 0x0f00) >> 8; |
607 | | } |
608 | | |
609 | | /// Process the payload fragment. |
610 | | |
611 | | /// |
612 | | /// Unmasks the fragment, checks the opcode, merges fragments into 1 message body, and calls the appropriate handler. |
613 | | bool handle_fragment() |
614 | | { |
615 | | if (has_mask_) |
616 | | { |
617 | | for (decltype(fragment_.length()) i = 0; i < fragment_.length(); i++) |
618 | | { |
619 | | fragment_[i] ^= ((char*)&mask_)[i % 4]; |
620 | | } |
621 | | } |
622 | | switch (opcode()) |
623 | | { |
624 | | case 0: // Continuation |
625 | | { |
626 | | message_ += fragment_; |
627 | | if (is_FIN()) |
628 | | { |
629 | | if (message_handler_) |
630 | | message_handler_(*this, message_, is_binary_); |
631 | | message_.clear(); |
632 | | } |
633 | | } |
634 | | break; |
635 | | case 1: // Text |
636 | | { |
637 | | is_binary_ = false; |
638 | | message_ += fragment_; |
639 | | if (is_FIN()) |
640 | | { |
641 | | if (message_handler_) |
642 | | message_handler_(*this, message_, is_binary_); |
643 | | message_.clear(); |
644 | | } |
645 | | } |
646 | | break; |
647 | | case 2: // Binary |
648 | | { |
649 | | is_binary_ = true; |
650 | | message_ += fragment_; |
651 | | if (is_FIN()) |
652 | | { |
653 | | if (message_handler_) |
654 | | message_handler_(*this, message_, is_binary_); |
655 | | message_.clear(); |
656 | | } |
657 | | } |
658 | | break; |
659 | | case 0x8: // Close |
660 | | { |
661 | | has_recv_close_ = true; |
662 | | |
663 | | |
664 | | uint16_t status_code = NoStatusCodePresent; |
665 | | std::string::size_type message_start = 2; |
666 | | if (fragment_.size() >= 2) |
667 | | { |
668 | | status_code = ntohs(((uint16_t*)fragment_.data())[0]); |
669 | | } else { |
670 | | // no message will crash substr |
671 | | message_start = 0; |
672 | | } |
673 | | |
674 | | if (!has_sent_close_) |
675 | | { |
676 | | close(fragment_.substr(message_start), status_code); |
677 | | } |
678 | | else |
679 | | { |
680 | | |
681 | | close_connection_ = true; |
682 | | if (!is_close_handler_called_) |
683 | | { |
684 | | if (close_handler_) |
685 | | close_handler_(*this, fragment_.substr(message_start), status_code); |
686 | | is_close_handler_called_ = true; |
687 | | } |
688 | | adaptor_.shutdown_readwrite(); |
689 | | adaptor_.close(); |
690 | | |
691 | | // Close handler must have been called at this point so code does not matter |
692 | | check_destroy(); |
693 | | return false; |
694 | | } |
695 | | } |
696 | | break; |
697 | | case 0x9: // Ping |
698 | | { |
699 | | send_pong(fragment_); |
700 | | } |
701 | | break; |
702 | | case 0xA: // Pong |
703 | | { |
704 | | pong_received_ = true; |
705 | | } |
706 | | break; |
707 | | } |
708 | | |
709 | | fragment_.clear(); |
710 | | return true; |
711 | | } |
712 | | |
713 | | /// Send the buffers' data through the socket. |
714 | | |
715 | | /// |
716 | | /// Also destroys the object if the Close flag is set. |
717 | | void do_write() |
718 | | { |
719 | | if (sending_buffers_.empty()) { |
720 | | if (write_buffers_.empty()) return; |
721 | | |
722 | | sending_buffers_.swap(write_buffers_); |
723 | | std::vector<asio::const_buffer> buffers; |
724 | | buffers.reserve(sending_buffers_.size()); |
725 | | for (auto &s: sending_buffers_) |
726 | | { |
727 | | buffers.emplace_back(asio::buffer(s)); |
728 | | } |
729 | | auto watch = std::weak_ptr<void>{anchor_}; |
730 | | asio::async_write( |
731 | | adaptor_.socket(), buffers, |
732 | | [shared_this = this->shared_from_this(), watch](const error_code &ec, std::size_t /*bytes_transferred*/) { |
733 | | auto anchor = watch.lock(); |
734 | | if (anchor == nullptr) |
735 | | return; |
736 | | |
737 | | if (!ec && !shared_this->close_connection_) |
738 | | { |
739 | | shared_this->sending_buffers_.clear(); |
740 | | if (!shared_this->write_buffers_.empty()) |
741 | | shared_this->do_write(); |
742 | | if (shared_this->has_sent_close_) |
743 | | shared_this->close_connection_ = true; |
744 | | } |
745 | | else |
746 | | { |
747 | | shared_this->sending_buffers_.clear(); |
748 | | shared_this->close_connection_ = true; |
749 | | shared_this->check_destroy(); |
750 | | } |
751 | | }); |
752 | | } |
753 | | } |
754 | | |
755 | | /// Destroy the Connection. |
756 | | void check_destroy(websocket::CloseStatusCode code = CloseStatusCode::ClosedAbnormally) |
757 | | { |
758 | | // Note that if the close handler was not yet called at this point we did not receive a close packet (or send one) |
759 | | // and thus we use ClosedAbnormally unless instructed otherwise |
760 | | if (!is_close_handler_called_) |
761 | | { |
762 | | if (close_handler_) |
763 | | { |
764 | | close_handler_(*this, "uncleanly", code); |
765 | | } |
766 | | } |
767 | | |
768 | | handler_->remove_websocket(this->shared_from_this()); |
769 | | } |
770 | | |
771 | | |
772 | | struct SendMessageType |
773 | | { |
774 | | std::string payload; |
775 | | Connection* self; |
776 | | int opcode; |
777 | | |
778 | | void operator()() |
779 | | { |
780 | | self->send_data_impl(this); |
781 | | } |
782 | | }; |
783 | | |
784 | | void send_data_impl(SendMessageType* s) |
785 | | { |
786 | | auto header = build_header(s->opcode, s->payload.size()); |
787 | | write_buffers_.emplace_back(std::move(header)); |
788 | | write_buffers_.emplace_back(std::move(s->payload)); |
789 | | do_write(); |
790 | | } |
791 | | |
792 | | void send_data(int opcode, std::string&& msg) |
793 | | { |
794 | | SendMessageType event_arg{ |
795 | | std::move(msg), |
796 | | this, |
797 | | opcode}; |
798 | | |
799 | | post(std::move(event_arg)); |
800 | | } |
801 | | |
802 | | private: |
803 | | Connection(Adaptor&& adaptor, Handler* handler, uint64_t max_payload, |
804 | | std::function<void(crow::websocket::connection&)> open_handler, |
805 | | std::function<void(crow::websocket::connection&, const std::string&, bool)> message_handler, |
806 | | std::function<void(crow::websocket::connection&, const std::string&, uint16_t)> close_handler, |
807 | | std::function<void(crow::websocket::connection&, const std::string&)> error_handler, |
808 | | std::function<void(const crow::request&, std::optional<crow::response>&, void**)> accept_handler): |
809 | | adaptor_(std::move(adaptor)), |
810 | | handler_(handler), |
811 | | max_payload_bytes_(max_payload), |
812 | | open_handler_(std::move(open_handler)), |
813 | | message_handler_(std::move(message_handler)), |
814 | | close_handler_(std::move(close_handler)), |
815 | | error_handler_(std::move(error_handler)), |
816 | | accept_handler_(std::move(accept_handler)) |
817 | | {} |
818 | | |
819 | | Adaptor adaptor_; |
820 | | Handler* handler_; |
821 | | |
822 | | std::vector<std::string> sending_buffers_; |
823 | | std::vector<std::string> write_buffers_; |
824 | | |
825 | | std::array<char, 4096> buffer_; |
826 | | bool is_binary_; |
827 | | std::string message_; |
828 | | std::string fragment_; |
829 | | WebSocketReadState state_{WebSocketReadState::MiniHeader}; |
830 | | uint16_t remaining_length16_{0}; |
831 | | uint64_t remaining_length_{0}; |
832 | | uint64_t max_payload_bytes_{UINT64_MAX}; |
833 | | std::string subprotocol_; |
834 | | bool close_connection_{false}; |
835 | | bool is_reading{false}; |
836 | | bool has_mask_{false}; |
837 | | uint32_t mask_; |
838 | | uint16_t mini_header_; |
839 | | bool has_sent_close_{false}; |
840 | | bool has_recv_close_{false}; |
841 | | bool error_occurred_{false}; |
842 | | bool pong_received_{false}; |
843 | | bool is_close_handler_called_{false}; |
844 | | |
845 | | std::shared_ptr<void> anchor_ = std::make_shared<int>(); // Value is just for placeholding |
846 | | |
847 | | std::function<void(crow::websocket::connection&)> open_handler_; |
848 | | std::function<void(crow::websocket::connection&, const std::string&, bool)> message_handler_; |
849 | | std::function<void(crow::websocket::connection&, const std::string&, uint16_t status_code)> close_handler_; |
850 | | std::function<void(crow::websocket::connection&, const std::string&)> error_handler_; |
851 | | std::function<void(const crow::request&, std::optional<crow::response>&, void**)> accept_handler_; |
852 | | }; |
853 | | } // namespace websocket |
854 | | } // namespace crow |