Coverage Report

Created: 2023-06-06 06:17

/src/uWebSockets/src/WebSocketProtocol.h
Line
Count
Source (jump to first uncovered line)
1
/*
2
 * Authored by Alex Hultman, 2018-2020.
3
 * Intellectual property of third-party.
4
5
 * Licensed under the Apache License, Version 2.0 (the "License");
6
 * you may not use this file except in compliance with the License.
7
 * You may obtain a copy of the License at
8
9
 *     http://www.apache.org/licenses/LICENSE-2.0
10
11
 * Unless required by applicable law or agreed to in writing, software
12
 * distributed under the License is distributed on an "AS IS" BASIS,
13
 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14
 * See the License for the specific language governing permissions and
15
 * limitations under the License.
16
 */
17
18
#ifndef UWS_WEBSOCKETPROTOCOL_H
19
#define UWS_WEBSOCKETPROTOCOL_H
20
21
#include <libusockets.h>
22
23
#include <cstdint>
24
#include <cstring>
25
#include <cstdlib>
26
#include <string_view>
27
28
namespace uWS {
29
30
/* We should not overcomplicate these */
31
const std::string_view ERR_TOO_BIG_MESSAGE("Received too big message");
32
const std::string_view ERR_WEBSOCKET_TIMEOUT("WebSocket timed out from inactivity");
33
const std::string_view ERR_INVALID_TEXT("Received invalid UTF-8");
34
const std::string_view ERR_TOO_BIG_MESSAGE_INFLATION("Received too big message, or other inflation error");
35
const std::string_view ERR_INVALID_CLOSE_PAYLOAD("Received invalid close payload");
36
37
enum OpCode : unsigned char {
38
    CONTINUATION = 0,
39
    TEXT = 1,
40
    BINARY = 2,
41
    CLOSE = 8,
42
    PING = 9,
43
    PONG = 10
44
};
45
46
enum {
47
    CLIENT,
48
    SERVER
49
};
50
51
// 24 bytes perfectly
52
template <bool isServer>
53
struct WebSocketState {
54
public:
55
    static const unsigned int SHORT_MESSAGE_HEADER = isServer ? 6 : 2;
56
    static const unsigned int MEDIUM_MESSAGE_HEADER = isServer ? 8 : 4;
57
    static const unsigned int LONG_MESSAGE_HEADER = isServer ? 14 : 10;
58
59
    // 16 bytes
60
    struct State {
61
        unsigned int wantsHead : 1;
62
        unsigned int spillLength : 4;
63
        signed int opStack : 2; // -1, 0, 1
64
        unsigned int lastFin : 1;
65
66
        // 15 bytes
67
        unsigned char spill[LONG_MESSAGE_HEADER - 1];
68
        OpCode opCode[2];
69
70
1.49k
        State() {
71
1.49k
            wantsHead = true;
72
1.49k
            spillLength = 0;
73
1.49k
            opStack = -1;
74
1.49k
            lastFin = true;
75
1.49k
        }
76
77
    } state;
78
79
    // 8 bytes
80
    unsigned int remainingBytes = 0;
81
    char mask[isServer ? 4 : 1];
82
};
83
84
namespace protocol {
85
86
template <typename T>
87
8.04k
T bit_cast(char *c) {
88
8.04k
    T val;
89
8.04k
    memcpy(&val, c, sizeof(T));
90
8.04k
    return val;
91
8.04k
}
unsigned short uWS::protocol::bit_cast<unsigned short>(char*)
Line
Count
Source
87
3.90k
T bit_cast(char *c) {
88
3.90k
    T val;
89
3.90k
    memcpy(&val, c, sizeof(T));
90
3.90k
    return val;
91
3.90k
}
unsigned long uWS::protocol::bit_cast<unsigned long>(char*)
Line
Count
Source
87
4.14k
T bit_cast(char *c) {
88
4.14k
    T val;
89
4.14k
    memcpy(&val, c, sizeof(T));
90
4.14k
    return val;
91
4.14k
}
92
93
/* Byte swap for little-endian systems */
94
template <typename T>
95
8.83k
T cond_byte_swap(T value) {
96
8.83k
    uint32_t endian_test = 1;
97
8.83k
    if (*((char *)&endian_test)) {
98
8.83k
        union {
99
8.83k
            T i;
100
8.83k
            uint8_t b[sizeof(T)];
101
8.83k
        } src = { value }, dst;
102
103
51.3k
        for (unsigned int i = 0; i < sizeof(value); i++) {
104
42.5k
            dst.b[i] = src.b[sizeof(value) - 1 - i];
105
42.5k
        }
106
107
8.83k
        return dst.i;
108
8.83k
    }
109
0
    return value;
110
8.83k
}
unsigned short uWS::protocol::cond_byte_swap<unsigned short>(unsigned short)
Line
Count
Source
95
4.69k
T cond_byte_swap(T value) {
96
4.69k
    uint32_t endian_test = 1;
97
4.69k
    if (*((char *)&endian_test)) {
98
4.69k
        union {
99
4.69k
            T i;
100
4.69k
            uint8_t b[sizeof(T)];
101
4.69k
        } src = { value }, dst;
102
103
14.0k
        for (unsigned int i = 0; i < sizeof(value); i++) {
104
9.39k
            dst.b[i] = src.b[sizeof(value) - 1 - i];
105
9.39k
        }
106
107
4.69k
        return dst.i;
108
4.69k
    }
109
0
    return value;
110
4.69k
}
unsigned long uWS::protocol::cond_byte_swap<unsigned long>(unsigned long)
Line
Count
Source
95
4.14k
T cond_byte_swap(T value) {
96
4.14k
    uint32_t endian_test = 1;
97
4.14k
    if (*((char *)&endian_test)) {
98
4.14k
        union {
99
4.14k
            T i;
100
4.14k
            uint8_t b[sizeof(T)];
101
4.14k
        } src = { value }, dst;
102
103
37.2k
        for (unsigned int i = 0; i < sizeof(value); i++) {
104
33.1k
            dst.b[i] = src.b[sizeof(value) - 1 - i];
105
33.1k
        }
106
107
4.14k
        return dst.i;
108
4.14k
    }
109
0
    return value;
110
4.14k
}
111
112
// Based on utf8_check.c by Markus Kuhn, 2005
113
// https://www.cl.cam.ac.uk/~mgk25/ucs/utf8_check.c
114
// Optimized for predominantly 7-bit content by Alex Hultman, 2016
115
// Licensed as Zlib, like the rest of this project
116
static bool isValidUtf8(unsigned char *s, size_t length)
117
14.2k
{
118
24.0k
    for (unsigned char *e = s + length; s != e; ) {
119
20.0k
        if (s + 4 <= e) {
120
13.9k
            uint32_t tmp;
121
13.9k
            memcpy(&tmp, s, 4);
122
13.9k
            if ((tmp & 0x80808080) == 0) {
123
7.79k
                s += 4;
124
7.79k
                continue;
125
7.79k
            }
126
13.9k
        }
127
128
16.0k
        while (!(*s & 0x80)) {
129
6.09k
            if (++s == e) {
130
2.29k
                return true;
131
2.29k
            }
132
6.09k
        }
133
134
9.92k
        if ((s[0] & 0x60) == 0x40) {
135
2.45k
            if (s + 1 >= e || (s[1] & 0xc0) != 0x80 || (s[0] & 0xfe) == 0xc0) {
136
1.76k
                return false;
137
1.76k
            }
138
686
            s += 2;
139
7.47k
        } else if ((s[0] & 0xf0) == 0xe0) {
140
2.16k
            if (s + 2 >= e || (s[1] & 0xc0) != 0x80 || (s[2] & 0xc0) != 0x80 ||
141
2.16k
                    (s[0] == 0xe0 && (s[1] & 0xe0) == 0x80) || (s[0] == 0xed && (s[1] & 0xe0) == 0xa0)) {
142
1.54k
                return false;
143
1.54k
            }
144
623
            s += 3;
145
5.31k
        } else if ((s[0] & 0xf8) == 0xf0) {
146
2.55k
            if (s + 3 >= e || (s[1] & 0xc0) != 0x80 || (s[2] & 0xc0) != 0x80 || (s[3] & 0xc0) != 0x80 ||
147
2.55k
                    (s[0] == 0xf0 && (s[1] & 0xf0) == 0x80) || (s[0] == 0xf4 && s[1] > 0x8f) || s[0] > 0xf4) {
148
1.84k
                return false;
149
1.84k
            }
150
711
            s += 4;
151
2.75k
        } else {
152
2.75k
            return false;
153
2.75k
        }
154
9.92k
    }
155
4.03k
    return true;
156
14.2k
}
157
158
struct CloseFrame {
159
    uint16_t code;
160
    char *message;
161
    size_t length;
162
};
163
164
1.45k
static inline CloseFrame parseClosePayload(char *src, size_t length) {
165
    /* If we get no code or message, default to reporting 1005 no status code present */
166
1.45k
    CloseFrame cf = {1005, nullptr, 0};
167
1.45k
    if (length >= 2) {
168
791
        memcpy(&cf.code, src, 2);
169
791
        cf = {cond_byte_swap<uint16_t>(cf.code), src + 2, length - 2};
170
791
        if (cf.code < 1000 || cf.code > 4999 || (cf.code > 1011 && cf.code < 4000) ||
171
791
            (cf.code >= 1004 && cf.code <= 1006) || !isValidUtf8((unsigned char *) cf.message, cf.length)) {
172
            /* Even though we got a WebSocket close frame, it in itself is abnormal */
173
784
            return {1006, nullptr, 0};
174
784
        }
175
791
    }
176
667
    return cf;
177
1.45k
}
178
179
0
static inline size_t formatClosePayload(char *dst, uint16_t code, const char *message, size_t length) {
180
0
    /* We could have more strict checks here, but never append code 0 or 1005 or 1006 */
181
0
    if (code && code != 1005 && code != 1006) {
182
0
        code = cond_byte_swap<uint16_t>(code);
183
0
        memcpy(dst, &code, 2);
184
0
        /* It is invalid to pass nullptr to memcpy, even though length is 0 */
185
0
        if (message) {
186
0
            memcpy(dst + 2, message, length);
187
0
        }
188
0
        return length + 2;
189
0
    }
190
0
    return 0;
191
0
}
192
193
0
static inline size_t messageFrameSize(size_t messageSize) {
194
0
    if (messageSize < 126) {
195
0
        return 2 + messageSize;
196
0
    } else if (messageSize <= UINT16_MAX) {
197
0
        return 4 + messageSize;
198
0
    }
199
0
    return 10 + messageSize;
200
0
}
201
202
enum {
203
    SND_CONTINUATION = 1,
204
    SND_NO_FIN = 2,
205
    SND_COMPRESSED = 64
206
};
207
208
template <bool isServer>
209
static inline size_t formatMessage(char *dst, const char *src, size_t length, OpCode opCode, size_t reportedLength, bool compressed, bool fin) {
210
    size_t messageLength;
211
    size_t headerLength;
212
    if (reportedLength < 126) {
213
        headerLength = 2;
214
        dst[1] = (char) reportedLength;
215
    } else if (reportedLength <= UINT16_MAX) {
216
        headerLength = 4;
217
        dst[1] = 126;
218
        uint16_t tmp = cond_byte_swap<uint16_t>((uint16_t) reportedLength);
219
        memcpy(&dst[2], &tmp, sizeof(uint16_t));
220
    } else {
221
        headerLength = 10;
222
        dst[1] = 127;
223
        uint64_t tmp = cond_byte_swap<uint64_t>((uint64_t) reportedLength);
224
        memcpy(&dst[2], &tmp, sizeof(uint64_t));
225
    }
226
227
    dst[0] = (char) ((fin ? 128 : 0) | ((compressed && opCode) ? SND_COMPRESSED : 0) | (char) opCode);
228
229
    //printf("%d\n", (int)dst[0]);
230
231
    char mask[4];
232
    if (!isServer) {
233
        dst[1] |= 0x80;
234
        uint32_t random = (uint32_t) rand();
235
        memcpy(mask, &random, 4);
236
        memcpy(dst + headerLength, &random, 4);
237
        headerLength += 4;
238
    }
239
240
    messageLength = headerLength + length;
241
    memcpy(dst + headerLength, src, length);
242
243
    if (!isServer) {
244
245
        // overwrites up to 3 bytes outside of the given buffer!
246
        //WebSocketProtocol<isServer>::unmaskInplace(dst + headerLength, dst + headerLength + length, mask);
247
248
        // this is not optimal
249
        char *start = dst + headerLength;
250
        char *stop = start + length;
251
        int i = 0;
252
        while (start != stop) {
253
            (*start++) ^= mask[i++ % 4];
254
        }
255
    }
256
    return messageLength;
257
}
258
259
}
260
261
// essentially this is only a parser
262
template <const bool isServer, typename Impl>
263
struct WebSocketProtocol {
264
public:
265
    static const unsigned int SHORT_MESSAGE_HEADER = isServer ? 6 : 2;
266
    static const unsigned int MEDIUM_MESSAGE_HEADER = isServer ? 8 : 4;
267
    static const unsigned int LONG_MESSAGE_HEADER = isServer ? 14 : 10;
268
269
protected:
270
65.8k
    static inline bool isFin(char *frame) {return *((unsigned char *) frame) & 128;}
271
160k
    static inline unsigned char getOpCode(char *frame) {return *((unsigned char *) frame) & 15;}
272
75.9k
    static inline unsigned char payloadLength(char *frame) {return ((unsigned char *) frame)[1] & 127;}
273
76.7k
    static inline bool rsv23(char *frame) {return *((unsigned char *) frame) & 48;}
274
76.7k
    static inline bool rsv1(char *frame) {return *((unsigned char *) frame) & 64;}
275
276
    template <int N>
277
0
    static inline void UnrolledXor(char * __restrict data, char * __restrict mask) {
278
0
        if constexpr (N != 1) {
279
0
            UnrolledXor<N - 1>(data, mask);
280
0
        }
281
0
        data[N - 1] ^= mask[(N - 1) % 4];
282
0
    }
Unexecuted instantiation: void uWS::WebSocketProtocol<true, Impl>::UnrolledXor<16>(char*, char*)
Unexecuted instantiation: void uWS::WebSocketProtocol<true, Impl>::UnrolledXor<15>(char*, char*)
Unexecuted instantiation: void uWS::WebSocketProtocol<true, Impl>::UnrolledXor<14>(char*, char*)
Unexecuted instantiation: void uWS::WebSocketProtocol<true, Impl>::UnrolledXor<13>(char*, char*)
Unexecuted instantiation: void uWS::WebSocketProtocol<true, Impl>::UnrolledXor<12>(char*, char*)
Unexecuted instantiation: void uWS::WebSocketProtocol<true, Impl>::UnrolledXor<11>(char*, char*)
Unexecuted instantiation: void uWS::WebSocketProtocol<true, Impl>::UnrolledXor<10>(char*, char*)
Unexecuted instantiation: void uWS::WebSocketProtocol<true, Impl>::UnrolledXor<9>(char*, char*)
Unexecuted instantiation: void uWS::WebSocketProtocol<true, Impl>::UnrolledXor<8>(char*, char*)
Unexecuted instantiation: void uWS::WebSocketProtocol<true, Impl>::UnrolledXor<7>(char*, char*)
Unexecuted instantiation: void uWS::WebSocketProtocol<true, Impl>::UnrolledXor<6>(char*, char*)
Unexecuted instantiation: void uWS::WebSocketProtocol<true, Impl>::UnrolledXor<5>(char*, char*)
Unexecuted instantiation: void uWS::WebSocketProtocol<true, Impl>::UnrolledXor<4>(char*, char*)
Unexecuted instantiation: void uWS::WebSocketProtocol<true, Impl>::UnrolledXor<3>(char*, char*)
Unexecuted instantiation: void uWS::WebSocketProtocol<true, Impl>::UnrolledXor<2>(char*, char*)
Unexecuted instantiation: void uWS::WebSocketProtocol<true, Impl>::UnrolledXor<1>(char*, char*)
283
284
    template <int DESTINATION>
285
8.18k
    static inline void unmaskImprecise8(char *src, uint64_t mask, unsigned int length) {
286
123k
        for (unsigned int n = (length >> 3) + 1; n; n--) {
287
115k
            uint64_t loaded;
288
115k
            memcpy(&loaded, src, 8);
289
115k
            loaded ^= mask;
290
115k
            memcpy(src - DESTINATION, &loaded, 8);
291
115k
            src += 8;
292
115k
        }
293
8.18k
    }
void uWS::WebSocketProtocol<true, Impl>::unmaskImprecise8<0>(char*, unsigned long, unsigned int)
Line
Count
Source
285
4.26k
    static inline void unmaskImprecise8(char *src, uint64_t mask, unsigned int length) {
286
39.4k
        for (unsigned int n = (length >> 3) + 1; n; n--) {
287
35.1k
            uint64_t loaded;
288
35.1k
            memcpy(&loaded, src, 8);
289
35.1k
            loaded ^= mask;
290
35.1k
            memcpy(src - DESTINATION, &loaded, 8);
291
35.1k
            src += 8;
292
35.1k
        }
293
4.26k
    }
void uWS::WebSocketProtocol<true, Impl>::unmaskImprecise8<8>(char*, unsigned long, unsigned int)
Line
Count
Source
285
2.06k
    static inline void unmaskImprecise8(char *src, uint64_t mask, unsigned int length) {
286
47.5k
        for (unsigned int n = (length >> 3) + 1; n; n--) {
287
45.4k
            uint64_t loaded;
288
45.4k
            memcpy(&loaded, src, 8);
289
45.4k
            loaded ^= mask;
290
45.4k
            memcpy(src - DESTINATION, &loaded, 8);
291
45.4k
            src += 8;
292
45.4k
        }
293
2.06k
    }
void uWS::WebSocketProtocol<true, Impl>::unmaskImprecise8<14>(char*, unsigned long, unsigned int)
Line
Count
Source
285
1.85k
    static inline void unmaskImprecise8(char *src, uint64_t mask, unsigned int length) {
286
36.7k
        for (unsigned int n = (length >> 3) + 1; n; n--) {
287
34.9k
            uint64_t loaded;
288
34.9k
            memcpy(&loaded, src, 8);
289
34.9k
            loaded ^= mask;
290
34.9k
            memcpy(src - DESTINATION, &loaded, 8);
291
34.9k
            src += 8;
292
34.9k
        }
293
1.85k
    }
294
295
    /* DESTINATION = 6 makes this not SIMD, DESTINATION = 4 is with SIMD but we don't want that for short messages */
296
    template <int DESTINATION>
297
23.3k
    static inline void unmaskImprecise4(char *src, uint32_t mask, unsigned int length) {
298
60.9k
        for (unsigned int n = (length >> 2) + 1; n; n--) {
299
37.6k
            uint32_t loaded;
300
37.6k
            memcpy(&loaded, src, 4);
301
37.6k
            loaded ^= mask;
302
37.6k
            memcpy(src - DESTINATION, &loaded, 4);
303
37.6k
            src += 4;
304
37.6k
        }
305
23.3k
    }
306
307
    template <int HEADER_SIZE>
308
27.2k
    static inline void unmaskImpreciseCopyMask(char *src, unsigned int length) {
309
27.2k
        if constexpr (HEADER_SIZE != 6) {
310
23.3k
            char mask[8] = {src[-4], src[-3], src[-2], src[-1], src[-4], src[-3], src[-2], src[-1]};
311
23.3k
            uint64_t maskInt;
312
23.3k
            memcpy(&maskInt, mask, 8);
313
23.3k
            unmaskImprecise8<HEADER_SIZE>(src, maskInt, length);
314
23.3k
        } else {
315
23.3k
            char mask[4] = {src[-4], src[-3], src[-2], src[-1]};
316
23.3k
            uint32_t maskInt;
317
23.3k
            memcpy(&maskInt, mask, 4);
318
23.3k
            unmaskImprecise4<HEADER_SIZE>(src, maskInt, length);
319
23.3k
        }
320
27.2k
    }
void uWS::WebSocketProtocol<true, Impl>::unmaskImpreciseCopyMask<6>(char*, unsigned int)
Line
Count
Source
308
23.3k
    static inline void unmaskImpreciseCopyMask(char *src, unsigned int length) {
309
23.3k
        if constexpr (HEADER_SIZE != 6) {
310
23.3k
            char mask[8] = {src[-4], src[-3], src[-2], src[-1], src[-4], src[-3], src[-2], src[-1]};
311
23.3k
            uint64_t maskInt;
312
23.3k
            memcpy(&maskInt, mask, 8);
313
23.3k
            unmaskImprecise8<HEADER_SIZE>(src, maskInt, length);
314
23.3k
        } else {
315
23.3k
            char mask[4] = {src[-4], src[-3], src[-2], src[-1]};
316
23.3k
            uint32_t maskInt;
317
23.3k
            memcpy(&maskInt, mask, 4);
318
23.3k
            unmaskImprecise4<HEADER_SIZE>(src, maskInt, length);
319
23.3k
        }
320
23.3k
    }
void uWS::WebSocketProtocol<true, Impl>::unmaskImpreciseCopyMask<8>(char*, unsigned int)
Line
Count
Source
308
2.06k
    static inline void unmaskImpreciseCopyMask(char *src, unsigned int length) {
309
2.06k
        if constexpr (HEADER_SIZE != 6) {
310
2.06k
            char mask[8] = {src[-4], src[-3], src[-2], src[-1], src[-4], src[-3], src[-2], src[-1]};
311
2.06k
            uint64_t maskInt;
312
2.06k
            memcpy(&maskInt, mask, 8);
313
2.06k
            unmaskImprecise8<HEADER_SIZE>(src, maskInt, length);
314
2.06k
        } else {
315
2.06k
            char mask[4] = {src[-4], src[-3], src[-2], src[-1]};
316
2.06k
            uint32_t maskInt;
317
2.06k
            memcpy(&maskInt, mask, 4);
318
2.06k
            unmaskImprecise4<HEADER_SIZE>(src, maskInt, length);
319
2.06k
        }
320
2.06k
    }
void uWS::WebSocketProtocol<true, Impl>::unmaskImpreciseCopyMask<14>(char*, unsigned int)
Line
Count
Source
308
1.85k
    static inline void unmaskImpreciseCopyMask(char *src, unsigned int length) {
309
1.85k
        if constexpr (HEADER_SIZE != 6) {
310
1.85k
            char mask[8] = {src[-4], src[-3], src[-2], src[-1], src[-4], src[-3], src[-2], src[-1]};
311
1.85k
            uint64_t maskInt;
312
1.85k
            memcpy(&maskInt, mask, 8);
313
1.85k
            unmaskImprecise8<HEADER_SIZE>(src, maskInt, length);
314
1.85k
        } else {
315
1.85k
            char mask[4] = {src[-4], src[-3], src[-2], src[-1]};
316
1.85k
            uint32_t maskInt;
317
1.85k
            memcpy(&maskInt, mask, 4);
318
1.85k
            unmaskImprecise4<HEADER_SIZE>(src, maskInt, length);
319
1.85k
        }
320
1.85k
    }
321
322
5.67k
    static inline void rotateMask(unsigned int offset, char *mask) {
323
5.67k
        char originalMask[4] = {mask[0], mask[1], mask[2], mask[3]};
324
5.67k
        mask[(0 + offset) % 4] = originalMask[0];
325
5.67k
        mask[(1 + offset) % 4] = originalMask[1];
326
5.67k
        mask[(2 + offset) % 4] = originalMask[2];
327
5.67k
        mask[(3 + offset) % 4] = originalMask[3];
328
5.67k
    }
329
330
7.77k
    static inline void unmaskInplace(char *data, char *stop, char *mask) {
331
50.4k
        while (data < stop) {
332
42.6k
            *(data++) ^= mask[0];
333
42.6k
            *(data++) ^= mask[1];
334
42.6k
            *(data++) ^= mask[2];
335
42.6k
            *(data++) ^= mask[3];
336
42.6k
        }
337
7.77k
    }
338
339
    template <unsigned int MESSAGE_HEADER, typename T>
340
34.9k
    static inline bool consumeMessage(T payLength, char *&src, unsigned int &length, WebSocketState<isServer> *wState, void *user) {
341
34.9k
        if (getOpCode(src)) {
342
6.61k
            if (wState->state.opStack == 1 || (!wState->state.lastFin && getOpCode(src) < 2)) {
343
1.40k
                Impl::forceClose(wState, user);
344
1.40k
                return true;
345
1.40k
            }
346
5.21k
            wState->state.opCode[++wState->state.opStack] = (OpCode) getOpCode(src);
347
28.3k
        } else if (wState->state.opStack == -1) {
348
1.38k
            Impl::forceClose(wState, user);
349
1.38k
            return true;
350
1.38k
        }
351
32.1k
        wState->state.lastFin = isFin(src);
352
353
32.1k
        if (Impl::refusePayloadLength(payLength, wState, user)) {
354
633
            Impl::forceClose(wState, user, ERR_TOO_BIG_MESSAGE);
355
633
            return true;
356
633
        }
357
358
31.5k
        if (payLength + MESSAGE_HEADER <= length) {
359
27.2k
            bool fin = isFin(src);
360
27.2k
            if (isServer) {
361
                /* This guy can never be assumed to be perfectly aligned since we can get multiple messages in one read */
362
27.2k
                unmaskImpreciseCopyMask<MESSAGE_HEADER>(src + MESSAGE_HEADER, (unsigned int) payLength);
363
27.2k
                if (Impl::handleFragment(src, payLength, 0, wState->state.opCode[wState->state.opStack], fin, wState, user)) {
364
3.55k
                    return true;
365
3.55k
                }
366
27.2k
            } else {
367
0
                if (Impl::handleFragment(src + MESSAGE_HEADER, payLength, 0, wState->state.opCode[wState->state.opStack], isFin(src), wState, user)) {
368
0
                    return true;
369
0
                }
370
0
            }
371
372
23.7k
            if (fin) {
373
2.69k
                wState->state.opStack--;
374
2.69k
            }
375
376
23.7k
            src += payLength + MESSAGE_HEADER;
377
23.7k
            length -= (unsigned int) (payLength + MESSAGE_HEADER);
378
23.7k
            wState->state.spillLength = 0;
379
23.7k
            return false;
380
27.2k
        } else {
381
4.26k
            wState->state.spillLength = 0;
382
4.26k
            wState->state.wantsHead = false;
383
4.26k
            wState->remainingBytes = (unsigned int) (payLength - length + MESSAGE_HEADER);
384
4.26k
            bool fin = isFin(src);
385
4.26k
            if constexpr (isServer) {
386
4.26k
                memcpy(wState->mask, src + MESSAGE_HEADER - 4, 4);
387
4.26k
                uint64_t mask;
388
4.26k
                memcpy(&mask, src + MESSAGE_HEADER - 4, 4);
389
4.26k
                memcpy(((char *)&mask) + 4, src + MESSAGE_HEADER - 4, 4);
390
4.26k
                unmaskImprecise8<0>(src + MESSAGE_HEADER, mask, length);
391
4.26k
                rotateMask(4 - (length - MESSAGE_HEADER) % 4, wState->mask);
392
4.26k
            }
393
4.26k
            Impl::handleFragment(src + MESSAGE_HEADER, length - MESSAGE_HEADER, wState->remainingBytes, wState->state.opCode[wState->state.opStack], fin, wState, user);
394
4.26k
            return true;
395
4.26k
        }
396
31.5k
    }
bool uWS::WebSocketProtocol<true, Impl>::consumeMessage<6u, unsigned char>(unsigned char, char*&, unsigned int&, uWS::WebSocketState<true>*, void*)
Line
Count
Source
340
26.9k
    static inline bool consumeMessage(T payLength, char *&src, unsigned int &length, WebSocketState<isServer> *wState, void *user) {
341
26.9k
        if (getOpCode(src)) {
342
3.46k
            if (wState->state.opStack == 1 || (!wState->state.lastFin && getOpCode(src) < 2)) {
343
619
                Impl::forceClose(wState, user);
344
619
                return true;
345
619
            }
346
2.84k
            wState->state.opCode[++wState->state.opStack] = (OpCode) getOpCode(src);
347
23.4k
        } else if (wState->state.opStack == -1) {
348
847
            Impl::forceClose(wState, user);
349
847
            return true;
350
847
        }
351
25.4k
        wState->state.lastFin = isFin(src);
352
353
25.4k
        if (Impl::refusePayloadLength(payLength, wState, user)) {
354
0
            Impl::forceClose(wState, user, ERR_TOO_BIG_MESSAGE);
355
0
            return true;
356
0
        }
357
358
25.4k
        if (payLength + MESSAGE_HEADER <= length) {
359
23.3k
            bool fin = isFin(src);
360
23.3k
            if (isServer) {
361
                /* This guy can never be assumed to be perfectly aligned since we can get multiple messages in one read */
362
23.3k
                unmaskImpreciseCopyMask<MESSAGE_HEADER>(src + MESSAGE_HEADER, (unsigned int) payLength);
363
23.3k
                if (Impl::handleFragment(src, payLength, 0, wState->state.opCode[wState->state.opStack], fin, wState, user)) {
364
1.56k
                    return true;
365
1.56k
                }
366
23.3k
            } else {
367
0
                if (Impl::handleFragment(src + MESSAGE_HEADER, payLength, 0, wState->state.opCode[wState->state.opStack], isFin(src), wState, user)) {
368
0
                    return true;
369
0
                }
370
0
            }
371
372
21.8k
            if (fin) {
373
1.37k
                wState->state.opStack--;
374
1.37k
            }
375
376
21.8k
            src += payLength + MESSAGE_HEADER;
377
21.8k
            length -= (unsigned int) (payLength + MESSAGE_HEADER);
378
21.8k
            wState->state.spillLength = 0;
379
21.8k
            return false;
380
23.3k
        } else {
381
2.08k
            wState->state.spillLength = 0;
382
2.08k
            wState->state.wantsHead = false;
383
2.08k
            wState->remainingBytes = (unsigned int) (payLength - length + MESSAGE_HEADER);
384
2.08k
            bool fin = isFin(src);
385
2.08k
            if constexpr (isServer) {
386
2.08k
                memcpy(wState->mask, src + MESSAGE_HEADER - 4, 4);
387
2.08k
                uint64_t mask;
388
2.08k
                memcpy(&mask, src + MESSAGE_HEADER - 4, 4);
389
2.08k
                memcpy(((char *)&mask) + 4, src + MESSAGE_HEADER - 4, 4);
390
2.08k
                unmaskImprecise8<0>(src + MESSAGE_HEADER, mask, length);
391
2.08k
                rotateMask(4 - (length - MESSAGE_HEADER) % 4, wState->mask);
392
2.08k
            }
393
2.08k
            Impl::handleFragment(src + MESSAGE_HEADER, length - MESSAGE_HEADER, wState->remainingBytes, wState->state.opCode[wState->state.opStack], fin, wState, user);
394
2.08k
            return true;
395
2.08k
        }
396
25.4k
    }
bool uWS::WebSocketProtocol<true, Impl>::consumeMessage<8u, unsigned short>(unsigned short, char*&, unsigned int&, uWS::WebSocketState<true>*, void*)
Line
Count
Source
340
3.90k
    static inline bool consumeMessage(T payLength, char *&src, unsigned int &length, WebSocketState<isServer> *wState, void *user) {
341
3.90k
        if (getOpCode(src)) {
342
1.59k
            if (wState->state.opStack == 1 || (!wState->state.lastFin && getOpCode(src) < 2)) {
343
390
                Impl::forceClose(wState, user);
344
390
                return true;
345
390
            }
346
1.20k
            wState->state.opCode[++wState->state.opStack] = (OpCode) getOpCode(src);
347
2.31k
        } else if (wState->state.opStack == -1) {
348
194
            Impl::forceClose(wState, user);
349
194
            return true;
350
194
        }
351
3.32k
        wState->state.lastFin = isFin(src);
352
353
3.32k
        if (Impl::refusePayloadLength(payLength, wState, user)) {
354
244
            Impl::forceClose(wState, user, ERR_TOO_BIG_MESSAGE);
355
244
            return true;
356
244
        }
357
358
3.08k
        if (payLength + MESSAGE_HEADER <= length) {
359
2.06k
            bool fin = isFin(src);
360
2.06k
            if (isServer) {
361
                /* This guy can never be assumed to be perfectly aligned since we can get multiple messages in one read */
362
2.06k
                unmaskImpreciseCopyMask<MESSAGE_HEADER>(src + MESSAGE_HEADER, (unsigned int) payLength);
363
2.06k
                if (Impl::handleFragment(src, payLength, 0, wState->state.opCode[wState->state.opStack], fin, wState, user)) {
364
1.10k
                    return true;
365
1.10k
                }
366
2.06k
            } else {
367
0
                if (Impl::handleFragment(src + MESSAGE_HEADER, payLength, 0, wState->state.opCode[wState->state.opStack], isFin(src), wState, user)) {
368
0
                    return true;
369
0
                }
370
0
            }
371
372
962
            if (fin) {
373
723
                wState->state.opStack--;
374
723
            }
375
376
962
            src += payLength + MESSAGE_HEADER;
377
962
            length -= (unsigned int) (payLength + MESSAGE_HEADER);
378
962
            wState->state.spillLength = 0;
379
962
            return false;
380
2.06k
        } else {
381
1.01k
            wState->state.spillLength = 0;
382
1.01k
            wState->state.wantsHead = false;
383
1.01k
            wState->remainingBytes = (unsigned int) (payLength - length + MESSAGE_HEADER);
384
1.01k
            bool fin = isFin(src);
385
1.01k
            if constexpr (isServer) {
386
1.01k
                memcpy(wState->mask, src + MESSAGE_HEADER - 4, 4);
387
1.01k
                uint64_t mask;
388
1.01k
                memcpy(&mask, src + MESSAGE_HEADER - 4, 4);
389
1.01k
                memcpy(((char *)&mask) + 4, src + MESSAGE_HEADER - 4, 4);
390
1.01k
                unmaskImprecise8<0>(src + MESSAGE_HEADER, mask, length);
391
1.01k
                rotateMask(4 - (length - MESSAGE_HEADER) % 4, wState->mask);
392
1.01k
            }
393
1.01k
            Impl::handleFragment(src + MESSAGE_HEADER, length - MESSAGE_HEADER, wState->remainingBytes, wState->state.opCode[wState->state.opStack], fin, wState, user);
394
1.01k
            return true;
395
1.01k
        }
396
3.08k
    }
bool uWS::WebSocketProtocol<true, Impl>::consumeMessage<14u, unsigned long>(unsigned long, char*&, unsigned int&, uWS::WebSocketState<true>*, void*)
Line
Count
Source
340
4.14k
    static inline bool consumeMessage(T payLength, char *&src, unsigned int &length, WebSocketState<isServer> *wState, void *user) {
341
4.14k
        if (getOpCode(src)) {
342
1.55k
            if (wState->state.opStack == 1 || (!wState->state.lastFin && getOpCode(src) < 2)) {
343
393
                Impl::forceClose(wState, user);
344
393
                return true;
345
393
            }
346
1.16k
            wState->state.opCode[++wState->state.opStack] = (OpCode) getOpCode(src);
347
2.58k
        } else if (wState->state.opStack == -1) {
348
343
            Impl::forceClose(wState, user);
349
343
            return true;
350
343
        }
351
3.40k
        wState->state.lastFin = isFin(src);
352
353
3.40k
        if (Impl::refusePayloadLength(payLength, wState, user)) {
354
389
            Impl::forceClose(wState, user, ERR_TOO_BIG_MESSAGE);
355
389
            return true;
356
389
        }
357
358
3.01k
        if (payLength + MESSAGE_HEADER <= length) {
359
1.85k
            bool fin = isFin(src);
360
1.85k
            if (isServer) {
361
                /* This guy can never be assumed to be perfectly aligned since we can get multiple messages in one read */
362
1.85k
                unmaskImpreciseCopyMask<MESSAGE_HEADER>(src + MESSAGE_HEADER, (unsigned int) payLength);
363
1.85k
                if (Impl::handleFragment(src, payLength, 0, wState->state.opCode[wState->state.opStack], fin, wState, user)) {
364
889
                    return true;
365
889
                }
366
1.85k
            } else {
367
0
                if (Impl::handleFragment(src + MESSAGE_HEADER, payLength, 0, wState->state.opCode[wState->state.opStack], isFin(src), wState, user)) {
368
0
                    return true;
369
0
                }
370
0
            }
371
372
962
            if (fin) {
373
595
                wState->state.opStack--;
374
595
            }
375
376
962
            src += payLength + MESSAGE_HEADER;
377
962
            length -= (unsigned int) (payLength + MESSAGE_HEADER);
378
962
            wState->state.spillLength = 0;
379
962
            return false;
380
1.85k
        } else {
381
1.16k
            wState->state.spillLength = 0;
382
1.16k
            wState->state.wantsHead = false;
383
1.16k
            wState->remainingBytes = (unsigned int) (payLength - length + MESSAGE_HEADER);
384
1.16k
            bool fin = isFin(src);
385
1.16k
            if constexpr (isServer) {
386
1.16k
                memcpy(wState->mask, src + MESSAGE_HEADER - 4, 4);
387
1.16k
                uint64_t mask;
388
1.16k
                memcpy(&mask, src + MESSAGE_HEADER - 4, 4);
389
1.16k
                memcpy(((char *)&mask) + 4, src + MESSAGE_HEADER - 4, 4);
390
1.16k
                unmaskImprecise8<0>(src + MESSAGE_HEADER, mask, length);
391
1.16k
                rotateMask(4 - (length - MESSAGE_HEADER) % 4, wState->mask);
392
1.16k
            }
393
1.16k
            Impl::handleFragment(src + MESSAGE_HEADER, length - MESSAGE_HEADER, wState->remainingBytes, wState->state.opCode[wState->state.opStack], fin, wState, user);
394
1.16k
            return true;
395
1.16k
        }
396
3.01k
    }
397
398
    /* This one is nicely vectorized on both ARM64 and X64 - especially with -mavx */
399
0
    static inline void unmaskAll(char * __restrict data, char * __restrict mask) {
400
0
        for (int i = 0; i < LIBUS_RECV_BUFFER_LENGTH; i += 16) {
401
0
            UnrolledXor<16>(data + i, mask);
402
0
        }
403
0
    }
404
405
8.99k
    static inline bool consumeContinuation(char *&src, unsigned int &length, WebSocketState<isServer> *wState, void *user) {
406
8.99k
        if (wState->remainingBytes <= length) {
407
6.87k
            if (isServer) {
408
6.87k
                unsigned int n = wState->remainingBytes >> 2;
409
6.87k
                unmaskInplace(src, src + n * 4, wState->mask);
410
15.7k
                for (unsigned int i = 0, s = wState->remainingBytes % 4; i < s; i++) {
411
8.87k
                    src[n * 4 + i] ^= wState->mask[i];
412
8.87k
                }
413
6.87k
            }
414
415
6.87k
            if (Impl::handleFragment(src, wState->remainingBytes, 0, wState->state.opCode[wState->state.opStack], wState->state.lastFin, wState, user)) {
416
3.17k
                return false;
417
3.17k
            }
418
419
3.69k
            if (wState->state.lastFin) {
420
1.16k
                wState->state.opStack--;
421
1.16k
            }
422
423
3.69k
            src += wState->remainingBytes;
424
3.69k
            length -= wState->remainingBytes;
425
3.69k
            wState->state.wantsHead = true;
426
3.69k
            return true;
427
6.87k
        } else {
428
2.11k
            if (isServer) {
429
                /* No need to unmask if mask is 0 */
430
2.11k
                uint32_t nullmask = 0;
431
2.11k
                if (memcmp(wState->mask, &nullmask, sizeof(uint32_t))) {
432
895
                    if /*constexpr*/ (LIBUS_RECV_BUFFER_LENGTH == length) {
433
0
                        unmaskAll(src, wState->mask);
434
895
                    } else {
435
                        // Slow path
436
895
                        unmaskInplace(src, src + ((length >> 2) + 1) * 4, wState->mask);
437
895
                    }
438
895
                }
439
2.11k
            }
440
441
2.11k
            wState->remainingBytes -= length;
442
2.11k
            if (Impl::handleFragment(src, length, wState->remainingBytes, wState->state.opCode[wState->state.opStack], wState->state.lastFin, wState, user)) {
443
416
                return false;
444
416
            }
445
446
1.70k
            if (isServer && length % 4) {
447
1.41k
                rotateMask(4 - (length % 4), wState->mask);
448
1.41k
            }
449
1.70k
            return false;
450
2.11k
        }
451
8.99k
    }
452
453
public:
454
    WebSocketProtocol() {
455
456
    }
457
458
60.2k
    static inline void consume(char *src, unsigned int length, WebSocketState<isServer> *wState, void *user) {
459
60.2k
        if (wState->state.spillLength) {
460
24.2k
            src -= wState->state.spillLength;
461
24.2k
            length += wState->state.spillLength;
462
24.2k
            memcpy(src, wState->state.spill, wState->state.spillLength);
463
24.2k
        }
464
60.2k
        if (wState->state.wantsHead) {
465
54.9k
            parseNext:
466
78.7k
            while (length >= SHORT_MESSAGE_HEADER) {
467
468
                // invalid reserved bits / invalid opcodes / invalid control frames / set compressed frame
469
76.7k
                if ((rsv1(src) && !Impl::setCompressed(wState, user)) || rsv23(src) || (getOpCode(src) > 2 && getOpCode(src) < 8) ||
470
76.7k
                    getOpCode(src) > 10 || (getOpCode(src) > 2 && (!isFin(src) || payloadLength(src) > 125))) {
471
39.6k
                    Impl::forceClose(wState, user);
472
39.6k
                    return;
473
39.6k
                }
474
475
37.1k
                if (payloadLength(src) < 126) {
476
26.9k
                    if (consumeMessage<SHORT_MESSAGE_HEADER, uint8_t>(payloadLength(src), src, length, wState, user)) {
477
5.11k
                        return;
478
5.11k
                    }
479
26.9k
                } else if (payloadLength(src) == 126) {
480
4.94k
                    if (length < MEDIUM_MESSAGE_HEADER) {
481
1.03k
                        break;
482
3.90k
                    } else if(consumeMessage<MEDIUM_MESSAGE_HEADER, uint16_t>(protocol::cond_byte_swap<uint16_t>(protocol::bit_cast<uint16_t>(src + 2)), src, length, wState, user)) {
483
2.94k
                        return;
484
2.94k
                    }
485
5.23k
                } else if (length < LONG_MESSAGE_HEADER) {
486
1.09k
                    break;
487
4.14k
                } else if (consumeMessage<LONG_MESSAGE_HEADER, uint64_t>(protocol::cond_byte_swap<uint64_t>(protocol::bit_cast<uint64_t>(src + 2)), src, length, wState, user)) {
488
3.17k
                    return;
489
3.17k
                }
490
37.1k
            }
491
4.11k
            if (length) {
492
3.17k
                memcpy(wState->state.spill, src, length);
493
3.17k
                wState->state.spillLength = length & 0xf;
494
3.17k
            }
495
8.99k
        } else if (consumeContinuation(src, length, wState, user)) {
496
3.69k
            goto parseNext;
497
3.69k
        }
498
60.2k
    }
499
500
    static const int CONSUME_POST_PADDING = 4;
501
    static const int CONSUME_PRE_PADDING = LONG_MESSAGE_HEADER - 1;
502
};
503
504
}
505
506
#endif // UWS_WEBSOCKETPROTOCOL_H