Coverage Report

Created: 2023-09-25 07:18

/src/uWebSockets/src/WebSocketContext.h
Line
Count
Source (jump to first uncovered line)
1
/*
2
 * Authored by Alex Hultman, 2018-2020.
3
 * Intellectual property of third-party.
4
5
 * Licensed under the Apache License, Version 2.0 (the "License");
6
 * you may not use this file except in compliance with the License.
7
 * You may obtain a copy of the License at
8
9
 *     http://www.apache.org/licenses/LICENSE-2.0
10
11
 * Unless required by applicable law or agreed to in writing, software
12
 * distributed under the License is distributed on an "AS IS" BASIS,
13
 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14
 * See the License for the specific language governing permissions and
15
 * limitations under the License.
16
 */
17
18
#ifndef UWS_WEBSOCKETCONTEXT_H
19
#define UWS_WEBSOCKETCONTEXT_H
20
21
#include "WebSocketContextData.h"
22
#include "WebSocketProtocol.h"
23
#include "WebSocketData.h"
24
#include "WebSocket.h"
25
26
namespace uWS {
27
28
template <bool SSL, bool isServer, typename USERDATA>
29
struct WebSocketContext {
30
    template <bool> friend struct TemplatedApp;
31
    template <bool, typename> friend struct WebSocketProtocol;
32
private:
33
    WebSocketContext() = delete;
34
35
25.2k
    us_socket_context_t *getSocketContext() {
36
25.2k
        return (us_socket_context_t *) this;
37
25.2k
    }
38
39
63.1k
    WebSocketContextData<SSL, USERDATA> *getExt() {
40
63.1k
        return (WebSocketContextData<SSL, USERDATA> *) us_socket_context_ext(SSL, (us_socket_context_t *) this);
41
63.1k
    }
42
43
    /* If we have negotiated compression, set this frame compressed */
44
10.4k
    static bool setCompressed(WebSocketState<isServer> */*wState*/, void *s) {
45
10.4k
        WebSocketData *webSocketData = (WebSocketData *) us_socket_ext(SSL, (us_socket_t *) s);
46
47
10.4k
        if (webSocketData->compressionStatus == WebSocketData::CompressionStatus::ENABLED) {
48
0
            webSocketData->compressionStatus = WebSocketData::CompressionStatus::COMPRESSED_FRAME;
49
0
            return true;
50
10.4k
        } else {
51
10.4k
            return false;
52
10.4k
        }
53
10.4k
    }
54
55
36.0k
    static void forceClose(WebSocketState<isServer> */*wState*/, void *s, std::string_view reason = {}) {
56
36.0k
        us_socket_close(SSL, (us_socket_t *) s, (int) reason.length(), (void *) reason.data());
57
36.0k
    }
58
59
    /* Returns true on breakage */
60
652k
    static bool handleFragment(char *data, size_t length, unsigned int remainingBytes, int opCode, bool fin, WebSocketState<isServer> *webSocketState, void *s) {
61
        /* WebSocketData and WebSocketContextData */
62
652k
        WebSocketContextData<SSL, USERDATA> *webSocketContextData = (WebSocketContextData<SSL, USERDATA> *) us_socket_context_ext(SSL, us_socket_context(SSL, (us_socket_t *) s));
63
652k
        WebSocketData *webSocketData = (WebSocketData *) us_socket_ext(SSL, (us_socket_t *) s);
64
65
        /* Is this a non-control frame? */
66
652k
        if (opCode < 3) {
67
            /* Did we get everything in one go? */
68
633k
            if (!remainingBytes && fin && !webSocketData->fragmentBuffer.length()) {
69
70
                /* Handle compressed frame */
71
606k
                if (webSocketData->compressionStatus == WebSocketData::CompressionStatus::COMPRESSED_FRAME) {
72
0
                        webSocketData->compressionStatus = WebSocketData::CompressionStatus::ENABLED;
73
74
0
                        LoopData *loopData = (LoopData *) us_loop_ext(us_socket_context_loop(SSL, us_socket_context(SSL, (us_socket_t *) s)));
75
                        /* Decompress using shared or dedicated decompressor */
76
0
                        std::optional<std::string_view> inflatedFrame;
77
0
                        if (webSocketData->inflationStream) {
78
0
                            inflatedFrame = webSocketData->inflationStream->inflate(loopData->zlibContext, {data, length}, webSocketContextData->maxPayloadLength, false);
79
0
                        } else {
80
0
                            inflatedFrame = loopData->inflationStream->inflate(loopData->zlibContext, {data, length}, webSocketContextData->maxPayloadLength, true);
81
0
                        }
82
83
0
                        if (!inflatedFrame.has_value()) {
84
0
                            forceClose(webSocketState, s, ERR_TOO_BIG_MESSAGE_INFLATION);
85
0
                            return true;
86
0
                        } else {
87
0
                            data = (char *) inflatedFrame->data();
88
0
                            length = inflatedFrame->length();
89
0
                        }
90
0
                }
91
92
                /* Check text messages for Utf-8 validity */
93
606k
                if (opCode == 1 && !protocol::isValidUtf8((unsigned char *) data, length)) {
94
9.40k
                    forceClose(webSocketState, s, ERR_INVALID_TEXT);
95
9.40k
                    return true;
96
9.40k
                }
97
98
                /* Emit message event & break if we are closed or shut down when returning */
99
596k
                if (webSocketContextData->messageHandler) {
100
596k
                    webSocketContextData->messageHandler((WebSocket<SSL, isServer, USERDATA> *) s, std::string_view(data, length), (OpCode) opCode);
101
596k
                    if (us_socket_is_closed(SSL, (us_socket_t *) s) || webSocketData->isShuttingDown) {
102
0
                        return true;
103
0
                    }
104
596k
                }
105
596k
            } else {
106
                /* Allocate fragment buffer up front first time */
107
27.3k
                if (!webSocketData->fragmentBuffer.length()) {
108
17.2k
                    webSocketData->fragmentBuffer.reserve(length + remainingBytes);
109
17.2k
                }
110
                /* Fragments forming a big message are not caught until appending them */
111
27.3k
                if (refusePayloadLength(length + webSocketData->fragmentBuffer.length(), webSocketState, s)) {
112
209
                    forceClose(webSocketState, s, ERR_TOO_BIG_MESSAGE);
113
209
                    return true;
114
209
                }
115
27.1k
                webSocketData->fragmentBuffer.append(data, length);
116
117
                /* Are we done now? */
118
                // todo: what if we don't have any remaining bytes yet we are not fin? forceclose!
119
27.1k
                if (!remainingBytes && fin) {
120
121
                    /* Handle compression */
122
4.34k
                    if (webSocketData->compressionStatus == WebSocketData::CompressionStatus::COMPRESSED_FRAME) {
123
0
                            webSocketData->compressionStatus = WebSocketData::CompressionStatus::ENABLED;
124
125
                            /* 9 bytes of padding for libdeflate, 4 for zlib */
126
0
                            webSocketData->fragmentBuffer.append("123456789");
127
128
0
                            LoopData *loopData = (LoopData *) us_loop_ext(
129
0
                                us_socket_context_loop(SSL,
130
0
                                    us_socket_context(SSL, (us_socket_t *) s)
131
0
                                )
132
0
                            );
133
134
                            /* Decompress using shared or dedicated decompressor */
135
0
                            std::optional<std::string_view> inflatedFrame;
136
0
                            if (webSocketData->inflationStream) {
137
0
                                inflatedFrame = webSocketData->inflationStream->inflate(loopData->zlibContext, {webSocketData->fragmentBuffer.data(), webSocketData->fragmentBuffer.length() - 9}, webSocketContextData->maxPayloadLength, false);
138
0
                            } else {
139
0
                                inflatedFrame = loopData->inflationStream->inflate(loopData->zlibContext, {webSocketData->fragmentBuffer.data(), webSocketData->fragmentBuffer.length() - 9}, webSocketContextData->maxPayloadLength, true);
140
0
                            }
141
142
0
                            if (!inflatedFrame.has_value()) {
143
0
                                forceClose(webSocketState, s, ERR_TOO_BIG_MESSAGE_INFLATION);
144
0
                                return true;
145
0
                            } else {
146
0
                                data = (char *) inflatedFrame->data();
147
0
                                length = inflatedFrame->length();
148
0
                            }
149
150
151
4.34k
                    } else {
152
                        // reset length and data ptrs
153
4.34k
                        length = webSocketData->fragmentBuffer.length();
154
4.34k
                        data = webSocketData->fragmentBuffer.data();
155
4.34k
                    }
156
157
                    /* Check text messages for Utf-8 validity */
158
4.34k
                    if (opCode == 1 && !protocol::isValidUtf8((unsigned char *) data, length)) {
159
517
                        forceClose(webSocketState, s, ERR_INVALID_TEXT);
160
517
                        return true;
161
517
                    }
162
163
                    /* Emit message and check for shutdown or close */
164
3.83k
                    if (webSocketContextData->messageHandler) {
165
3.83k
                        webSocketContextData->messageHandler((WebSocket<SSL, isServer, USERDATA> *) s, std::string_view(data, length), (OpCode) opCode);
166
3.83k
                        if (us_socket_is_closed(SSL, (us_socket_t *) s) || webSocketData->isShuttingDown) {
167
0
                            return true;
168
0
                        }
169
3.83k
                    }
170
171
                    /* If we shutdown or closed, this will be taken care of elsewhere */
172
3.83k
                    webSocketData->fragmentBuffer.clear();
173
3.83k
                }
174
27.1k
            }
175
633k
        } else {
176
            /* Control frames need the websocket to send pings, pongs and close */
177
18.8k
            WebSocket<SSL, isServer, USERDATA> *webSocket = (WebSocket<SSL, isServer, USERDATA> *) s;
178
179
18.8k
            if (!remainingBytes && fin && !webSocketData->controlTipLength) {
180
11.1k
                if (opCode == CLOSE) {
181
5.94k
                    auto closeFrame = protocol::parseClosePayload(data, length);
182
5.94k
                    webSocket->end(closeFrame.code, std::string_view(closeFrame.message, closeFrame.length));
183
5.94k
                    return true;
184
5.94k
                } else {
185
5.25k
                    if (opCode == PING) {
186
4.83k
                        webSocket->send(std::string_view(data, length), (OpCode) OpCode::PONG);
187
4.83k
                        if (webSocketContextData->pingHandler) {
188
4.83k
                            webSocketContextData->pingHandler(webSocket, {data, length});
189
4.83k
                            if (us_socket_is_closed(SSL, (us_socket_t *) s) || webSocketData->isShuttingDown) {
190
0
                                return true;
191
0
                            }
192
4.83k
                        }
193
4.83k
                    } else if (opCode == PONG) {
194
414
                        if (webSocketContextData->pongHandler) {
195
414
                            webSocketContextData->pongHandler(webSocket, {data, length});
196
414
                            if (us_socket_is_closed(SSL, (us_socket_t *) s) || webSocketData->isShuttingDown) {
197
0
                                return true;
198
0
                            }
199
414
                        }
200
414
                    }
201
5.25k
                }
202
11.1k
            } else {
203
                /* Here we never mind any size optimizations as we are in the worst possible path */
204
7.64k
                webSocketData->fragmentBuffer.append(data, length);
205
7.64k
                webSocketData->controlTipLength += (unsigned int) length;
206
207
7.64k
                if (!remainingBytes && fin) {
208
3.40k
                    char *controlBuffer = (char *) webSocketData->fragmentBuffer.data() + webSocketData->fragmentBuffer.length() - webSocketData->controlTipLength;
209
3.40k
                    if (opCode == CLOSE) {
210
2.76k
                        protocol::CloseFrame closeFrame = protocol::parseClosePayload(controlBuffer, webSocketData->controlTipLength);
211
2.76k
                        webSocket->end(closeFrame.code, std::string_view(closeFrame.message, closeFrame.length));
212
2.76k
                        return true;
213
2.76k
                    } else {
214
640
                        if (opCode == PING) {
215
407
                            webSocket->send(std::string_view(controlBuffer, webSocketData->controlTipLength), (OpCode) OpCode::PONG);
216
407
                            if (webSocketContextData->pingHandler) {
217
407
                                webSocketContextData->pingHandler(webSocket, std::string_view(controlBuffer, webSocketData->controlTipLength));
218
407
                                if (us_socket_is_closed(SSL, (us_socket_t *) s) || webSocketData->isShuttingDown) {
219
0
                                    return true;
220
0
                                }
221
407
                            }
222
407
                        } else if (opCode == PONG) {
223
233
                            if (webSocketContextData->pongHandler) {
224
233
                                webSocketContextData->pongHandler(webSocket, std::string_view(controlBuffer, webSocketData->controlTipLength));
225
233
                                if (us_socket_is_closed(SSL, (us_socket_t *) s) || webSocketData->isShuttingDown) {
226
0
                                    return true;
227
0
                                }
228
233
                            }
229
233
                        }
230
640
                    }
231
232
                    /* Same here, we do not care for any particular smart allocation scheme */
233
640
                    webSocketData->fragmentBuffer.resize((unsigned int) webSocketData->fragmentBuffer.length() - webSocketData->controlTipLength);
234
640
                    webSocketData->controlTipLength = 0;
235
640
                }
236
7.64k
            }
237
18.8k
        }
238
633k
        return false;
239
652k
    }
240
241
671k
    static bool refusePayloadLength(uint64_t length, WebSocketState<isServer> */*wState*/, void *s) {
242
671k
        auto *webSocketContextData = (WebSocketContextData<SSL, USERDATA> *) us_socket_context_ext(SSL, us_socket_context(SSL, (us_socket_t *) s));
243
244
        /* Return true for refuse, false for accept */
245
671k
        return webSocketContextData->maxPayloadLength < length;
246
671k
    }
247
248
4.21k
    WebSocketContext<SSL, isServer, USERDATA> *init() {
249
        /* Adopting a socket does not trigger open event.
250
         * We arreive as WebSocket with timeout set and
251
         * any backpressure from HTTP state kept. */
252
253
        /* Handle socket disconnections */
254
79.5k
        us_socket_context_on_close(SSL, getSocketContext(), [](auto *s, int code, void *reason) {
255
            /* For whatever reason, if we already have emitted close event, do not emit it again */
256
79.5k
            WebSocketData *webSocketData = (WebSocketData *) (us_socket_ext(SSL, s));
257
79.5k
            if (!webSocketData->isShuttingDown) {
258
                /* Emit close event */
259
70.8k
                auto *webSocketContextData = (WebSocketContextData<SSL, USERDATA> *) us_socket_context_ext(SSL, us_socket_context(SSL, (us_socket_t *) s));
260
261
                /* At this point we iterate all currently held subscriptions and emit an event for all of them */
262
70.8k
                if (webSocketData->subscriber && webSocketContextData->subscriptionHandler) {
263
0
                    for (Topic *t : webSocketData->subscriber->topics) {
264
0
                        webSocketContextData->subscriptionHandler((WebSocket<SSL, isServer, USERDATA> *) s, t->name, (int) t->size() - 1, (int) t->size());
265
0
                    }
266
0
                }
267
268
                /* Make sure to unsubscribe from any pub/sub node at exit */
269
70.8k
                webSocketContextData->topicTree->freeSubscriber(webSocketData->subscriber);
270
70.8k
                webSocketData->subscriber = nullptr;
271
272
70.8k
                if (webSocketContextData->closeHandler) {
273
70.8k
                    webSocketContextData->closeHandler((WebSocket<SSL, isServer, USERDATA> *) s, 1006, {(char *) reason, (size_t) code});
274
70.8k
                }
275
70.8k
            }
276
277
            /* Destruct in-placed data struct */
278
79.5k
            webSocketData->~WebSocketData();
279
280
79.5k
            return s;
281
79.5k
        });
282
283
        /* Handle WebSocket data streams */
284
70.1k
        us_socket_context_on_data(SSL, getSocketContext(), [](auto *s, char *data, int length) {
285
286
            /* We need the websocket data */
287
70.1k
            WebSocketData *webSocketData = (WebSocketData *) (us_socket_ext(SSL, s));
288
289
            /* When in websocket shutdown mode, we do not care for ANY message, whether responding close frame or not.
290
             * We only care for the TCP FIN really, not emitting any message after closing is key */
291
70.1k
            if (webSocketData->isShuttingDown) {
292
1.22k
                return s;
293
1.22k
            }
294
295
68.9k
            auto *webSocketContextData = (WebSocketContextData<SSL, USERDATA> *) us_socket_context_ext(SSL, us_socket_context(SSL, (us_socket_t *) s));
296
68.9k
            auto *asyncSocket = (AsyncSocket<SSL> *) s;
297
298
            /* Every time we get data and not in shutdown state we simply reset the timeout */
299
68.9k
            asyncSocket->timeout(webSocketContextData->idleTimeoutComponents.first);
300
68.9k
            webSocketData->hasTimedOut = false;
301
302
            /* We always cork on data */
303
68.9k
            asyncSocket->cork();
304
305
            /* This parser has virtually no overhead */
306
68.9k
            WebSocketProtocol<isServer, WebSocketContext<SSL, isServer, USERDATA>>::consume(data, (unsigned int) length, (WebSocketState<isServer> *) webSocketData, s);
307
308
            /* Uncorking a closed socekt is fine, in fact it is needed */
309
68.9k
            asyncSocket->uncork();
310
311
            /* If uncorking was successful and we are in shutdown state then send TCP FIN */
312
68.9k
            if (asyncSocket->getBufferedAmount() == 0) {
313
                /* We can now be in shutdown state */
314
5.40k
                if (webSocketData->isShuttingDown) {
315
                    /* Shutting down a closed socket is handled by uSockets and just fine */
316
1.21k
                    asyncSocket->shutdown();
317
1.21k
                }
318
5.40k
            }
319
320
68.9k
            return s;
321
70.1k
        });
322
323
        /* Handle HTTP write out (note: SSL_read may trigger this spuriously, the app need to handle spurious calls) */
324
15.7k
        us_socket_context_on_writable(SSL, getSocketContext(), [](auto *s) {
325
326
            /* NOTE: Are we called here corked? If so, the below write code is broken, since
327
             * we will have 0 as getBufferedAmount due to writing to cork buffer, then sending TCP FIN before
328
             * we actually uncorked and sent off things */
329
330
            /* It makes sense to check for us_is_shut_down here and return if so, to avoid shutting down twice */
331
15.7k
            if (us_socket_is_shut_down(SSL, (us_socket_t *) s)) {
332
0
                return s;
333
0
            }
334
335
15.7k
            AsyncSocket<SSL> *asyncSocket = (AsyncSocket<SSL> *) s;
336
15.7k
            WebSocketData *webSocketData = (WebSocketData *)(us_socket_ext(SSL, s));
337
338
            /* We store old backpressure since it is unclear whether write drained anything,
339
             * however, in case of coming here with 0 backpressure we still need to emit drain event */
340
15.7k
            unsigned int backpressure = asyncSocket->getBufferedAmount();
341
342
            /* Drain as much as possible */
343
15.7k
            asyncSocket->write(nullptr, 0);
344
345
            /* Behavior: if we actively drain backpressure, always reset timeout (even if we are in shutdown) */
346
            /* Also reset timeout if we came here with 0 backpressure */
347
15.7k
            if (!backpressure || backpressure > asyncSocket->getBufferedAmount()) {
348
10.3k
                auto *webSocketContextData = (WebSocketContextData<SSL, USERDATA> *) us_socket_context_ext(SSL, us_socket_context(SSL, (us_socket_t *) s));
349
10.3k
                asyncSocket->timeout(webSocketContextData->idleTimeoutComponents.first);
350
10.3k
                webSocketData->hasTimedOut = false;
351
10.3k
            }
352
353
            /* Are we in (WebSocket) shutdown mode? */
354
15.7k
            if (webSocketData->isShuttingDown) {
355
                /* Check if we just now drained completely */
356
2.85k
                if (asyncSocket->getBufferedAmount() == 0) {
357
                    /* Now perform the actual TCP/TLS shutdown which was postponed due to backpressure */
358
613
                    asyncSocket->shutdown();
359
613
                }
360
12.8k
            } else if (!backpressure || backpressure > asyncSocket->getBufferedAmount()) {
361
                /* Only call drain if we actually drained backpressure or if we came here with 0 backpressure */
362
8.91k
                auto *webSocketContextData = (WebSocketContextData<SSL, USERDATA> *) us_socket_context_ext(SSL, us_socket_context(SSL, (us_socket_t *) s));
363
8.91k
                if (webSocketContextData->drainHandler) {
364
8.91k
                    webSocketContextData->drainHandler((WebSocket<SSL, isServer, USERDATA> *) s);
365
8.91k
                }
366
                /* No need to check for closed here as we leave the handler immediately*/
367
8.91k
            }
368
369
15.7k
            return s;
370
15.7k
        });
371
372
        /* Handle FIN, HTTP does not support half-closed sockets, so simply close */
373
4.21k
        us_socket_context_on_end(SSL, getSocketContext(), [](auto *s) {
374
375
            /* If we get a fin, we just close I guess */
376
1.67k
            us_socket_close(SSL, (us_socket_t *) s, 0, nullptr);
377
378
1.67k
            return s;
379
1.67k
        });
380
381
4.21k
        us_socket_context_on_long_timeout(SSL, getSocketContext(), [](auto *s) {
382
0
            ((WebSocket<SSL, isServer, USERDATA> *) s)->end(1000, "please reconnect");
383
384
0
            return s;
385
0
        });
386
387
        /* Handle socket timeouts, simply close them so to not confuse client with FIN */
388
4.21k
        us_socket_context_on_timeout(SSL, getSocketContext(), [](auto *s) {
389
390
210
            auto *webSocketData = (WebSocketData *)(us_socket_ext(SSL, s));
391
210
            auto *webSocketContextData = (WebSocketContextData<SSL, USERDATA> *) us_socket_context_ext(SSL, us_socket_context(SSL, (us_socket_t *) s));
392
393
210
            if (webSocketContextData->sendPingsAutomatically && !webSocketData->isShuttingDown && !webSocketData->hasTimedOut) {
394
0
                webSocketData->hasTimedOut = true;
395
0
                us_socket_timeout(SSL, s, webSocketContextData->idleTimeoutComponents.second);
396
                /* Send ping without being corked */
397
0
                ((AsyncSocket<SSL> *) s)->write("\x89\x00", 2);
398
0
                return s;
399
0
            }
400
401
            /* Timeout is very simple; we just close it */
402
            /* Warning: we happen to know forceClose will not use first parameter so pass nullptr here */
403
210
            forceClose(nullptr, s, ERR_WEBSOCKET_TIMEOUT);
404
405
210
            return s;
406
210
        });
407
408
4.21k
        return this;
409
4.21k
    }
410
411
4.21k
    void free() {
412
4.21k
        WebSocketContextData<SSL, USERDATA> *webSocketContextData = (WebSocketContextData<SSL, USERDATA> *) us_socket_context_ext(SSL, (us_socket_context_t *) this);
413
4.21k
        webSocketContextData->~WebSocketContextData();
414
415
4.21k
        us_socket_context_free(SSL, (us_socket_context_t *) this);
416
4.21k
    }
417
418
public:
419
    /* WebSocket contexts are always child contexts to a HTTP context so no SSL options are needed as they are inherited */
420
4.21k
    static WebSocketContext *create(Loop */*loop*/, us_socket_context_t *parentSocketContext, TopicTree<TopicTreeMessage, TopicTreeBigMessage> *topicTree) {
421
4.21k
        WebSocketContext *webSocketContext = (WebSocketContext *) us_create_child_socket_context(SSL, parentSocketContext, sizeof(WebSocketContextData<SSL, USERDATA>));
422
4.21k
        if (!webSocketContext) {
423
0
            return nullptr;
424
0
        }
425
426
        /* Init socket context data */
427
4.21k
        new ((WebSocketContextData<SSL, USERDATA> *) us_socket_context_ext(SSL, (us_socket_context_t *)webSocketContext)) WebSocketContextData<SSL, USERDATA>(topicTree);
428
4.21k
        return webSocketContext->init();
429
4.21k
    }
430
};
431
432
}
433
434
#endif // UWS_WEBSOCKETCONTEXT_H