Coverage Report

Created: 2025-06-13 06:09

/src/uWebSockets/src/WebSocketProtocol.h
Line
Count
Source (jump to first uncovered line)
1
/*
2
 * Authored by Alex Hultman, 2018-2020.
3
 * Intellectual property of third-party.
4
5
 * Licensed under the Apache License, Version 2.0 (the "License");
6
 * you may not use this file except in compliance with the License.
7
 * You may obtain a copy of the License at
8
9
 *     http://www.apache.org/licenses/LICENSE-2.0
10
11
 * Unless required by applicable law or agreed to in writing, software
12
 * distributed under the License is distributed on an "AS IS" BASIS,
13
 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14
 * See the License for the specific language governing permissions and
15
 * limitations under the License.
16
 */
17
18
#ifndef UWS_WEBSOCKETPROTOCOL_H
19
#define UWS_WEBSOCKETPROTOCOL_H
20
21
#include <libusockets.h>
22
23
#include <cstdint>
24
#include <cstring>
25
#include <cstdlib>
26
#include <string_view>
27
28
namespace uWS {
29
30
/* We should not overcomplicate these */
31
constexpr std::string_view ERR_TOO_BIG_MESSAGE("Received too big message");
32
constexpr std::string_view ERR_WEBSOCKET_TIMEOUT("WebSocket timed out from inactivity");
33
constexpr std::string_view ERR_INVALID_TEXT("Received invalid UTF-8");
34
constexpr std::string_view ERR_TOO_BIG_MESSAGE_INFLATION("Received too big message, or other inflation error");
35
constexpr std::string_view ERR_INVALID_CLOSE_PAYLOAD("Received invalid close payload");
36
constexpr std::string_view ERR_PROTOCOL("Received invalid WebSocket frame");
37
constexpr std::string_view ERR_TCP_FIN("Received TCP FIN before WebSocket close frame");
38
39
enum OpCode : unsigned char {
40
    CONTINUATION = 0,
41
    TEXT = 1,
42
    BINARY = 2,
43
    CLOSE = 8,
44
    PING = 9,
45
    PONG = 10
46
};
47
48
enum {
49
    CLIENT,
50
    SERVER
51
};
52
53
// 24 bytes perfectly
54
template <bool isServer>
55
struct WebSocketState {
56
public:
57
    static const unsigned int SHORT_MESSAGE_HEADER = isServer ? 6 : 2;
58
    static const unsigned int MEDIUM_MESSAGE_HEADER = isServer ? 8 : 4;
59
    static const unsigned int LONG_MESSAGE_HEADER = isServer ? 14 : 10;
60
61
    // 16 bytes
62
    struct State {
63
        unsigned int wantsHead : 1;
64
        unsigned int spillLength : 4;
65
        signed int opStack : 2; // -1, 0, 1
66
        unsigned int lastFin : 1;
67
68
        // 15 bytes
69
        unsigned char spill[LONG_MESSAGE_HEADER - 1];
70
        OpCode opCode[2];
71
72
40.4k
        State() {
73
40.4k
            wantsHead = true;
74
40.4k
            spillLength = 0;
75
40.4k
            opStack = -1;
76
40.4k
            lastFin = true;
77
40.4k
        }
78
79
    } state;
80
81
    // 8 bytes
82
    unsigned int remainingBytes = 0;
83
    char mask[isServer ? 4 : 1];
84
};
85
86
namespace protocol {
87
88
template <typename T>
89
8.06k
T bit_cast(char *c) {
90
8.06k
    T val;
91
8.06k
    memcpy(&val, c, sizeof(T));
92
8.06k
    return val;
93
8.06k
}
unsigned short uWS::protocol::bit_cast<unsigned short>(char*)
Line
Count
Source
89
4.31k
T bit_cast(char *c) {
90
4.31k
    T val;
91
4.31k
    memcpy(&val, c, sizeof(T));
92
4.31k
    return val;
93
4.31k
}
unsigned long uWS::protocol::bit_cast<unsigned long>(char*)
Line
Count
Source
89
3.74k
T bit_cast(char *c) {
90
3.74k
    T val;
91
3.74k
    memcpy(&val, c, sizeof(T));
92
3.74k
    return val;
93
3.74k
}
94
95
/* Byte swap for little-endian systems */
96
template <typename T>
97
16.0k
T cond_byte_swap(T value) {
98
16.0k
    static_assert(std::is_trivially_copyable<T>::value, "T must be trivially copyable");
99
16.0k
    uint32_t endian_test = 1;
100
16.0k
    if (*reinterpret_cast<char*>(&endian_test)) {
101
16.0k
        uint8_t src[sizeof(T)];
102
16.0k
        uint8_t dst[sizeof(T)];
103
104
16.0k
        std::memcpy(src, &value, sizeof(T));
105
70.6k
        for (size_t i = 0; i < sizeof(T); ++i) {
106
54.5k
            dst[i] = src[sizeof(T) - 1 - i];
107
54.5k
        }
108
109
16.0k
        T result;
110
16.0k
        std::memcpy(&result, dst, sizeof(T));
111
16.0k
        return result;
112
16.0k
    }
113
0
    return value;
114
16.0k
}
unsigned short uWS::protocol::cond_byte_swap<unsigned short>(unsigned short)
Line
Count
Source
97
12.2k
T cond_byte_swap(T value) {
98
12.2k
    static_assert(std::is_trivially_copyable<T>::value, "T must be trivially copyable");
99
12.2k
    uint32_t endian_test = 1;
100
12.2k
    if (*reinterpret_cast<char*>(&endian_test)) {
101
12.2k
        uint8_t src[sizeof(T)];
102
12.2k
        uint8_t dst[sizeof(T)];
103
104
12.2k
        std::memcpy(src, &value, sizeof(T));
105
36.8k
        for (size_t i = 0; i < sizeof(T); ++i) {
106
24.5k
            dst[i] = src[sizeof(T) - 1 - i];
107
24.5k
        }
108
109
12.2k
        T result;
110
12.2k
        std::memcpy(&result, dst, sizeof(T));
111
12.2k
        return result;
112
12.2k
    }
113
0
    return value;
114
12.2k
}
unsigned long uWS::protocol::cond_byte_swap<unsigned long>(unsigned long)
Line
Count
Source
97
3.74k
T cond_byte_swap(T value) {
98
3.74k
    static_assert(std::is_trivially_copyable<T>::value, "T must be trivially copyable");
99
3.74k
    uint32_t endian_test = 1;
100
3.74k
    if (*reinterpret_cast<char*>(&endian_test)) {
101
3.74k
        uint8_t src[sizeof(T)];
102
3.74k
        uint8_t dst[sizeof(T)];
103
104
3.74k
        std::memcpy(src, &value, sizeof(T));
105
33.7k
        for (size_t i = 0; i < sizeof(T); ++i) {
106
29.9k
            dst[i] = src[sizeof(T) - 1 - i];
107
29.9k
        }
108
109
3.74k
        T result;
110
3.74k
        std::memcpy(&result, dst, sizeof(T));
111
3.74k
        return result;
112
3.74k
    }
113
0
    return value;
114
3.74k
}
115
116
// Based on utf8_check.c by Markus Kuhn, 2005
117
// https://www.cl.cam.ac.uk/~mgk25/ucs/utf8_check.c
118
// Optimized for predominantly 7-bit content by Alex Hultman, 2016
119
// Licensed as Zlib, like the rest of this project
120
// This runs about 40% faster than simdutf with g++ -mavx
121
static bool isValidUtf8(unsigned char *s, size_t length)
122
14.0k
{
123
24.2k
    for (unsigned char *e = s + length; s != e; ) {
124
23.1k
        if (s + 16 <= e) {
125
10.0k
            uint64_t tmp[2];
126
10.0k
            memcpy(tmp, s, 16);
127
10.0k
            if (((tmp[0] & 0x8080808080808080) | (tmp[1] & 0x8080808080808080)) == 0) {
128
7.12k
                s += 16;
129
7.12k
                continue;
130
7.12k
            }
131
10.0k
        }
132
133
27.6k
        while (!(*s & 0x80)) {
134
20.4k
            if (++s == e) {
135
8.92k
                return true;
136
8.92k
            }
137
20.4k
        }
138
139
7.14k
        if ((s[0] & 0x60) == 0x40) {
140
1.75k
            if (s + 1 >= e || (s[1] & 0xc0) != 0x80 || (s[0] & 0xfe) == 0xc0) {
141
829
                return false;
142
829
            }
143
926
            s += 2;
144
5.38k
        } else if ((s[0] & 0xf0) == 0xe0) {
145
2.60k
            if (s + 2 >= e || (s[1] & 0xc0) != 0x80 || (s[2] & 0xc0) != 0x80 ||
146
2.60k
                    (s[0] == 0xe0 && (s[1] & 0xe0) == 0x80) || (s[0] == 0xed && (s[1] & 0xe0) == 0xa0)) {
147
1.11k
                return false;
148
1.11k
            }
149
1.48k
            s += 3;
150
2.78k
        } else if ((s[0] & 0xf8) == 0xf0) {
151
2.16k
            if (s + 3 >= e || (s[1] & 0xc0) != 0x80 || (s[2] & 0xc0) != 0x80 || (s[3] & 0xc0) != 0x80 ||
152
2.16k
                    (s[0] == 0xf0 && (s[1] & 0xf0) == 0x80) || (s[0] == 0xf4 && s[1] > 0x8f) || s[0] > 0xf4) {
153
1.52k
                return false;
154
1.52k
            }
155
637
            s += 4;
156
637
        } else {
157
624
            return false;
158
624
        }
159
7.14k
    }
160
1.05k
    return true;
161
14.0k
}
162
163
struct CloseFrame {
164
    uint16_t code;
165
    char *message;
166
    size_t length;
167
};
168
169
5.99k
static inline CloseFrame parseClosePayload(char *src, size_t length) {
170
    /* If we get no code or message, default to reporting 1005 no status code present */
171
5.99k
    CloseFrame cf = {1005, nullptr, 0};
172
5.99k
    if (length >= 2) {
173
5.69k
        memcpy(&cf.code, src, 2);
174
5.69k
        cf = {cond_byte_swap<uint16_t>(cf.code), src + 2, length - 2};
175
5.69k
        if (cf.code < 1000 || cf.code > 4999 || (cf.code > 1011 && cf.code < 4000) ||
176
5.69k
            (cf.code >= 1004 && cf.code <= 1006) || !isValidUtf8((unsigned char *) cf.message, cf.length)) {
177
            /* Even though we got a WebSocket close frame, it in itself is abnormal */
178
4.23k
            return {1006, (char *) ERR_INVALID_CLOSE_PAYLOAD.data(), ERR_INVALID_CLOSE_PAYLOAD.length()};
179
4.23k
        }
180
5.69k
    }
181
1.76k
    return cf;
182
5.99k
}
183
184
5.99k
static inline size_t formatClosePayload(char *dst, uint16_t code, const char *message, size_t length) {
185
    /* We could have more strict checks here, but never append code 0 or 1005 or 1006 */
186
5.99k
    if (code && code != 1005 && code != 1006) {
187
1.46k
        code = cond_byte_swap<uint16_t>(code);
188
1.46k
        memcpy(dst, &code, 2);
189
        /* It is invalid to pass nullptr to memcpy, even though length is 0 */
190
1.46k
        if (message) {
191
1.46k
            memcpy(dst + 2, message, length);
192
1.46k
        }
193
1.46k
        return length + 2;
194
1.46k
    }
195
4.53k
    return 0;
196
5.99k
}
197
198
23.5k
static inline size_t messageFrameSize(size_t messageSize) {
199
23.5k
    if (messageSize < 126) {
200
22.6k
        return 2 + messageSize;
201
22.6k
    } else if (messageSize <= UINT16_MAX) {
202
824
        return 4 + messageSize;
203
824
    }
204
0
    return 10 + messageSize;
205
23.5k
}
206
207
enum {
208
    SND_CONTINUATION = 1,
209
    SND_NO_FIN = 2,
210
    SND_COMPRESSED = 64
211
};
212
213
template <bool isServer>
214
23.5k
static inline size_t formatMessage(char *dst, const char *src, size_t length, OpCode opCode, size_t reportedLength, bool compressed, bool fin) {
215
23.5k
    size_t messageLength;
216
23.5k
    size_t headerLength;
217
23.5k
    if (reportedLength < 126) {
218
22.6k
        headerLength = 2;
219
22.6k
        dst[1] = (char) reportedLength;
220
22.6k
    } else if (reportedLength <= UINT16_MAX) {
221
824
        headerLength = 4;
222
824
        dst[1] = 126;
223
824
        uint16_t tmp = cond_byte_swap<uint16_t>((uint16_t) reportedLength);
224
824
        memcpy(&dst[2], &tmp, sizeof(uint16_t));
225
824
    } else {
226
0
        headerLength = 10;
227
0
        dst[1] = 127;
228
0
        uint64_t tmp = cond_byte_swap<uint64_t>((uint64_t) reportedLength);
229
0
        memcpy(&dst[2], &tmp, sizeof(uint64_t));
230
0
    }
231
232
23.5k
    dst[0] = (char) ((fin ? 128 : 0) | ((compressed && opCode) ? SND_COMPRESSED : 0) | (char) opCode);
233
234
    //printf("%d\n", (int)dst[0]);
235
236
23.5k
    char mask[4];
237
23.5k
    if (!isServer) {
238
0
        dst[1] |= 0x80;
239
0
        uint32_t random = (uint32_t) rand();
240
0
        memcpy(mask, &random, 4);
241
0
        memcpy(dst + headerLength, &random, 4);
242
0
        headerLength += 4;
243
0
    }
244
245
23.5k
    messageLength = headerLength + length;
246
23.5k
    memcpy(dst + headerLength, src, length);
247
248
23.5k
    if (!isServer) {
249
250
        // overwrites up to 3 bytes outside of the given buffer!
251
        //WebSocketProtocol<isServer>::unmaskInplace(dst + headerLength, dst + headerLength + length, mask);
252
253
        // this is not optimal
254
0
        char *start = dst + headerLength;
255
0
        char *stop = start + length;
256
0
        int i = 0;
257
0
        while (start != stop) {
258
0
            (*start++) ^= mask[i++ % 4];
259
0
        }
260
0
    }
261
23.5k
    return messageLength;
262
23.5k
}
263
264
}
265
266
// essentially this is only a parser
267
template <const bool isServer, typename Impl>
268
struct WebSocketProtocol {
269
public:
270
    static const unsigned int SHORT_MESSAGE_HEADER = isServer ? 6 : 2;
271
    static const unsigned int MEDIUM_MESSAGE_HEADER = isServer ? 8 : 4;
272
    static const unsigned int LONG_MESSAGE_HEADER = isServer ? 14 : 10;
273
274
protected:
275
104k
    static inline bool isFin(char *frame) {return *((unsigned char *) frame) & 128;}
276
247k
    static inline unsigned char getOpCode(char *frame) {return *((unsigned char *) frame) & 15;}
277
109k
    static inline unsigned char payloadLength(char *frame) {return ((unsigned char *) frame)[1] & 127;}
278
48.7k
    static inline bool rsv23(char *frame) {return *((unsigned char *) frame) & 48;}
279
49.8k
    static inline bool rsv1(char *frame) {return *((unsigned char *) frame) & 64;}
280
281
    template <int N>
282
0
    static inline void UnrolledXor(char * __restrict data, char * __restrict mask) {
283
0
        if constexpr (N != 1) {
284
0
            UnrolledXor<N - 1>(data, mask);
285
0
        }
286
0
        data[N - 1] ^= mask[(N - 1) % 4];
287
0
    }
Unexecuted instantiation: EpollHelloWorld.cpp:void uWS::WebSocketProtocol<true, uWS::WebSocketContext<false, true, test()::PerSocketData> >::UnrolledXor<16>(char*, char*)
Unexecuted instantiation: EpollHelloWorld.cpp:void uWS::WebSocketProtocol<true, uWS::WebSocketContext<false, true, test()::PerSocketData> >::UnrolledXor<15>(char*, char*)
Unexecuted instantiation: EpollHelloWorld.cpp:void uWS::WebSocketProtocol<true, uWS::WebSocketContext<false, true, test()::PerSocketData> >::UnrolledXor<14>(char*, char*)
Unexecuted instantiation: EpollHelloWorld.cpp:void uWS::WebSocketProtocol<true, uWS::WebSocketContext<false, true, test()::PerSocketData> >::UnrolledXor<13>(char*, char*)
Unexecuted instantiation: EpollHelloWorld.cpp:void uWS::WebSocketProtocol<true, uWS::WebSocketContext<false, true, test()::PerSocketData> >::UnrolledXor<12>(char*, char*)
Unexecuted instantiation: EpollHelloWorld.cpp:void uWS::WebSocketProtocol<true, uWS::WebSocketContext<false, true, test()::PerSocketData> >::UnrolledXor<11>(char*, char*)
Unexecuted instantiation: EpollHelloWorld.cpp:void uWS::WebSocketProtocol<true, uWS::WebSocketContext<false, true, test()::PerSocketData> >::UnrolledXor<10>(char*, char*)
Unexecuted instantiation: EpollHelloWorld.cpp:void uWS::WebSocketProtocol<true, uWS::WebSocketContext<false, true, test()::PerSocketData> >::UnrolledXor<9>(char*, char*)
Unexecuted instantiation: EpollHelloWorld.cpp:void uWS::WebSocketProtocol<true, uWS::WebSocketContext<false, true, test()::PerSocketData> >::UnrolledXor<8>(char*, char*)
Unexecuted instantiation: EpollHelloWorld.cpp:void uWS::WebSocketProtocol<true, uWS::WebSocketContext<false, true, test()::PerSocketData> >::UnrolledXor<7>(char*, char*)
Unexecuted instantiation: EpollHelloWorld.cpp:void uWS::WebSocketProtocol<true, uWS::WebSocketContext<false, true, test()::PerSocketData> >::UnrolledXor<6>(char*, char*)
Unexecuted instantiation: EpollHelloWorld.cpp:void uWS::WebSocketProtocol<true, uWS::WebSocketContext<false, true, test()::PerSocketData> >::UnrolledXor<5>(char*, char*)
Unexecuted instantiation: EpollHelloWorld.cpp:void uWS::WebSocketProtocol<true, uWS::WebSocketContext<false, true, test()::PerSocketData> >::UnrolledXor<4>(char*, char*)
Unexecuted instantiation: EpollHelloWorld.cpp:void uWS::WebSocketProtocol<true, uWS::WebSocketContext<false, true, test()::PerSocketData> >::UnrolledXor<3>(char*, char*)
Unexecuted instantiation: EpollHelloWorld.cpp:void uWS::WebSocketProtocol<true, uWS::WebSocketContext<false, true, test()::PerSocketData> >::UnrolledXor<2>(char*, char*)
Unexecuted instantiation: EpollHelloWorld.cpp:void uWS::WebSocketProtocol<true, uWS::WebSocketContext<false, true, test()::PerSocketData> >::UnrolledXor<1>(char*, char*)
288
289
    template <int DESTINATION>
290
13.1k
    static inline void unmaskImprecise8(char *src, uint64_t mask, unsigned int length) {
291
78.8k
        for (unsigned int n = (length >> 3) + 1; n; n--) {
292
65.6k
            uint64_t loaded;
293
65.6k
            memcpy(&loaded, src, 8);
294
65.6k
            loaded ^= mask;
295
65.6k
            memcpy(src - DESTINATION, &loaded, 8);
296
65.6k
            src += 8;
297
65.6k
        }
298
13.1k
    }
EpollHelloWorld.cpp:void uWS::WebSocketProtocol<true, uWS::WebSocketContext<false, true, test()::PerSocketData> >::unmaskImprecise8<0>(char*, unsigned long, unsigned int)
Line
Count
Source
290
9.48k
    static inline void unmaskImprecise8(char *src, uint64_t mask, unsigned int length) {
291
60.2k
        for (unsigned int n = (length >> 3) + 1; n; n--) {
292
50.7k
            uint64_t loaded;
293
50.7k
            memcpy(&loaded, src, 8);
294
50.7k
            loaded ^= mask;
295
50.7k
            memcpy(src - DESTINATION, &loaded, 8);
296
50.7k
            src += 8;
297
50.7k
        }
298
9.48k
    }
EpollHelloWorld.cpp:void uWS::WebSocketProtocol<true, uWS::WebSocketContext<false, true, test()::PerSocketData> >::unmaskImprecise8<8>(char*, unsigned long, unsigned int)
Line
Count
Source
290
2.19k
    static inline void unmaskImprecise8(char *src, uint64_t mask, unsigned int length) {
291
12.9k
        for (unsigned int n = (length >> 3) + 1; n; n--) {
292
10.7k
            uint64_t loaded;
293
10.7k
            memcpy(&loaded, src, 8);
294
10.7k
            loaded ^= mask;
295
10.7k
            memcpy(src - DESTINATION, &loaded, 8);
296
10.7k
            src += 8;
297
10.7k
        }
298
2.19k
    }
EpollHelloWorld.cpp:void uWS::WebSocketProtocol<true, uWS::WebSocketContext<false, true, test()::PerSocketData> >::unmaskImprecise8<14>(char*, unsigned long, unsigned int)
Line
Count
Source
290
1.44k
    static inline void unmaskImprecise8(char *src, uint64_t mask, unsigned int length) {
291
5.61k
        for (unsigned int n = (length >> 3) + 1; n; n--) {
292
4.16k
            uint64_t loaded;
293
4.16k
            memcpy(&loaded, src, 8);
294
4.16k
            loaded ^= mask;
295
4.16k
            memcpy(src - DESTINATION, &loaded, 8);
296
4.16k
            src += 8;
297
4.16k
        }
298
1.44k
    }
299
300
    /* DESTINATION = 6 makes this not SIMD, DESTINATION = 4 is with SIMD but we don't want that for short messages */
301
    template <int DESTINATION>
302
30.2k
    static inline void unmaskImprecise4(char *src, uint32_t mask, unsigned int length) {
303
122k
        for (unsigned int n = (length >> 2) + 1; n; n--) {
304
91.9k
            uint32_t loaded;
305
91.9k
            memcpy(&loaded, src, 4);
306
91.9k
            loaded ^= mask;
307
91.9k
            memcpy(src - DESTINATION, &loaded, 4);
308
91.9k
            src += 4;
309
91.9k
        }
310
30.2k
    }
311
312
    template <int HEADER_SIZE>
313
33.8k
    static inline void unmaskImpreciseCopyMask(char *src, unsigned int length) {
314
33.8k
        if constexpr (HEADER_SIZE != 6) {
315
3.64k
            char mask[8] = {src[-4], src[-3], src[-2], src[-1], src[-4], src[-3], src[-2], src[-1]};
316
3.64k
            uint64_t maskInt;
317
3.64k
            memcpy(&maskInt, mask, 8);
318
3.64k
            unmaskImprecise8<HEADER_SIZE>(src, maskInt, length);
319
30.2k
        } else {
320
30.2k
            char mask[4] = {src[-4], src[-3], src[-2], src[-1]};
321
30.2k
            uint32_t maskInt;
322
30.2k
            memcpy(&maskInt, mask, 4);
323
30.2k
            unmaskImprecise4<HEADER_SIZE>(src, maskInt, length);
324
30.2k
        }
325
33.8k
    }
EpollHelloWorld.cpp:void uWS::WebSocketProtocol<true, uWS::WebSocketContext<false, true, test()::PerSocketData> >::unmaskImpreciseCopyMask<6>(char*, unsigned int)
Line
Count
Source
313
30.2k
    static inline void unmaskImpreciseCopyMask(char *src, unsigned int length) {
314
        if constexpr (HEADER_SIZE != 6) {
315
            char mask[8] = {src[-4], src[-3], src[-2], src[-1], src[-4], src[-3], src[-2], src[-1]};
316
            uint64_t maskInt;
317
            memcpy(&maskInt, mask, 8);
318
            unmaskImprecise8<HEADER_SIZE>(src, maskInt, length);
319
30.2k
        } else {
320
30.2k
            char mask[4] = {src[-4], src[-3], src[-2], src[-1]};
321
30.2k
            uint32_t maskInt;
322
30.2k
            memcpy(&maskInt, mask, 4);
323
30.2k
            unmaskImprecise4<HEADER_SIZE>(src, maskInt, length);
324
30.2k
        }
325
30.2k
    }
EpollHelloWorld.cpp:void uWS::WebSocketProtocol<true, uWS::WebSocketContext<false, true, test()::PerSocketData> >::unmaskImpreciseCopyMask<8>(char*, unsigned int)
Line
Count
Source
313
2.19k
    static inline void unmaskImpreciseCopyMask(char *src, unsigned int length) {
314
2.19k
        if constexpr (HEADER_SIZE != 6) {
315
2.19k
            char mask[8] = {src[-4], src[-3], src[-2], src[-1], src[-4], src[-3], src[-2], src[-1]};
316
2.19k
            uint64_t maskInt;
317
2.19k
            memcpy(&maskInt, mask, 8);
318
2.19k
            unmaskImprecise8<HEADER_SIZE>(src, maskInt, length);
319
        } else {
320
            char mask[4] = {src[-4], src[-3], src[-2], src[-1]};
321
            uint32_t maskInt;
322
            memcpy(&maskInt, mask, 4);
323
            unmaskImprecise4<HEADER_SIZE>(src, maskInt, length);
324
        }
325
2.19k
    }
EpollHelloWorld.cpp:void uWS::WebSocketProtocol<true, uWS::WebSocketContext<false, true, test()::PerSocketData> >::unmaskImpreciseCopyMask<14>(char*, unsigned int)
Line
Count
Source
313
1.44k
    static inline void unmaskImpreciseCopyMask(char *src, unsigned int length) {
314
1.44k
        if constexpr (HEADER_SIZE != 6) {
315
1.44k
            char mask[8] = {src[-4], src[-3], src[-2], src[-1], src[-4], src[-3], src[-2], src[-1]};
316
1.44k
            uint64_t maskInt;
317
1.44k
            memcpy(&maskInt, mask, 8);
318
1.44k
            unmaskImprecise8<HEADER_SIZE>(src, maskInt, length);
319
        } else {
320
            char mask[4] = {src[-4], src[-3], src[-2], src[-1]};
321
            uint32_t maskInt;
322
            memcpy(&maskInt, mask, 4);
323
            unmaskImprecise4<HEADER_SIZE>(src, maskInt, length);
324
        }
325
1.44k
    }
326
327
40.7k
    static inline void rotateMask(unsigned int offset, char *mask) {
328
40.7k
        char originalMask[4] = {mask[0], mask[1], mask[2], mask[3]};
329
40.7k
        mask[(0 + offset) % 4] = originalMask[0];
330
40.7k
        mask[(1 + offset) % 4] = originalMask[1];
331
40.7k
        mask[(2 + offset) % 4] = originalMask[2];
332
40.7k
        mask[(3 + offset) % 4] = originalMask[3];
333
40.7k
    }
334
335
38.4k
    static inline void unmaskInplace(char *data, char *stop, char *mask) {
336
1.17M
        while (data < stop) {
337
1.14M
            *(data++) ^= mask[0];
338
1.14M
            *(data++) ^= mask[1];
339
1.14M
            *(data++) ^= mask[2];
340
1.14M
            *(data++) ^= mask[3];
341
1.14M
        }
342
38.4k
    }
343
344
    template <unsigned int MESSAGE_HEADER, typename T>
345
46.2k
    static inline bool consumeMessage(T payLength, char *&src, unsigned int &length, WebSocketState<isServer> *wState, void *user) {
346
46.2k
        if (getOpCode(src)) {
347
38.7k
            if (wState->state.opStack == 1 || (!wState->state.lastFin && getOpCode(src) < 2)) {
348
1.22k
                Impl::forceClose(wState, user, ERR_PROTOCOL);
349
1.22k
                return true;
350
1.22k
            }
351
37.5k
            wState->state.opCode[++wState->state.opStack] = (OpCode) getOpCode(src);
352
37.5k
        } else if (wState->state.opStack == -1) {
353
657
            Impl::forceClose(wState, user, ERR_PROTOCOL);
354
657
            return true;
355
657
        }
356
44.3k
        wState->state.lastFin = isFin(src);
357
358
44.3k
        if (Impl::refusePayloadLength(payLength, wState, user)) {
359
985
            Impl::forceClose(wState, user, ERR_TOO_BIG_MESSAGE);
360
985
            return true;
361
985
        }
362
363
43.3k
        if (payLength + MESSAGE_HEADER <= length) {
364
33.8k
            bool fin = isFin(src);
365
33.8k
            if (isServer) {
366
                /* This guy can never be assumed to be perfectly aligned since we can get multiple messages in one read */
367
33.8k
                unmaskImpreciseCopyMask<MESSAGE_HEADER>(src + MESSAGE_HEADER, (unsigned int) payLength);
368
33.8k
                if (Impl::handleFragment(src, payLength, 0, wState->state.opCode[wState->state.opStack], fin, wState, user)) {
369
6.53k
                    return true;
370
6.53k
                }
371
33.8k
            } else {
372
0
                if (Impl::handleFragment(src + MESSAGE_HEADER, payLength, 0, wState->state.opCode[wState->state.opStack], isFin(src), wState, user)) {
373
0
                    return true;
374
0
                }
375
0
            }
376
377
27.3k
            if (fin) {
378
18.4k
                wState->state.opStack--;
379
18.4k
            }
380
381
27.3k
            src += payLength + MESSAGE_HEADER;
382
27.3k
            length -= (unsigned int) (payLength + MESSAGE_HEADER);
383
27.3k
            wState->state.spillLength = 0;
384
27.3k
            return false;
385
33.8k
        } else {
386
9.48k
            wState->state.spillLength = 0;
387
9.48k
            wState->state.wantsHead = false;
388
9.48k
            wState->remainingBytes = (unsigned int) (payLength - length + MESSAGE_HEADER);
389
9.48k
            bool fin = isFin(src);
390
9.48k
            if constexpr (isServer) {
391
9.48k
                memcpy(wState->mask, src + MESSAGE_HEADER - 4, 4);
392
9.48k
                uint64_t mask;
393
9.48k
                memcpy(&mask, src + MESSAGE_HEADER - 4, 4);
394
9.48k
                memcpy(((char *)&mask) + 4, src + MESSAGE_HEADER - 4, 4);
395
9.48k
                unmaskImprecise8<0>(src + MESSAGE_HEADER, mask, length);
396
9.48k
                rotateMask(4 - (length - MESSAGE_HEADER) % 4, wState->mask);
397
9.48k
            }
398
9.48k
            Impl::handleFragment(src + MESSAGE_HEADER, length - MESSAGE_HEADER, wState->remainingBytes, wState->state.opCode[wState->state.opStack], fin, wState, user);
399
9.48k
            return true;
400
9.48k
        }
401
43.3k
    }
EpollHelloWorld.cpp:bool uWS::WebSocketProtocol<true, uWS::WebSocketContext<false, true, test()::PerSocketData> >::consumeMessage<6u, unsigned char>(unsigned char, char*&, unsigned int&, uWS::WebSocketState<true>*, void*)
Line
Count
Source
345
38.1k
    static inline bool consumeMessage(T payLength, char *&src, unsigned int &length, WebSocketState<isServer> *wState, void *user) {
346
38.1k
        if (getOpCode(src)) {
347
32.3k
            if (wState->state.opStack == 1 || (!wState->state.lastFin && getOpCode(src) < 2)) {
348
446
                Impl::forceClose(wState, user, ERR_PROTOCOL);
349
446
                return true;
350
446
            }
351
31.9k
            wState->state.opCode[++wState->state.opStack] = (OpCode) getOpCode(src);
352
31.9k
        } else if (wState->state.opStack == -1) {
353
269
            Impl::forceClose(wState, user, ERR_PROTOCOL);
354
269
            return true;
355
269
        }
356
37.4k
        wState->state.lastFin = isFin(src);
357
358
37.4k
        if (Impl::refusePayloadLength(payLength, wState, user)) {
359
0
            Impl::forceClose(wState, user, ERR_TOO_BIG_MESSAGE);
360
0
            return true;
361
0
        }
362
363
37.4k
        if (payLength + MESSAGE_HEADER <= length) {
364
30.2k
            bool fin = isFin(src);
365
30.2k
            if (isServer) {
366
                /* This guy can never be assumed to be perfectly aligned since we can get multiple messages in one read */
367
30.2k
                unmaskImpreciseCopyMask<MESSAGE_HEADER>(src + MESSAGE_HEADER, (unsigned int) payLength);
368
30.2k
                if (Impl::handleFragment(src, payLength, 0, wState->state.opCode[wState->state.opStack], fin, wState, user)) {
369
6.04k
                    return true;
370
6.04k
                }
371
30.2k
            } else {
372
0
                if (Impl::handleFragment(src + MESSAGE_HEADER, payLength, 0, wState->state.opCode[wState->state.opStack], isFin(src), wState, user)) {
373
0
                    return true;
374
0
                }
375
0
            }
376
377
24.1k
            if (fin) {
378
16.4k
                wState->state.opStack--;
379
16.4k
            }
380
381
24.1k
            src += payLength + MESSAGE_HEADER;
382
24.1k
            length -= (unsigned int) (payLength + MESSAGE_HEADER);
383
24.1k
            wState->state.spillLength = 0;
384
24.1k
            return false;
385
30.2k
        } else {
386
7.22k
            wState->state.spillLength = 0;
387
7.22k
            wState->state.wantsHead = false;
388
7.22k
            wState->remainingBytes = (unsigned int) (payLength - length + MESSAGE_HEADER);
389
7.22k
            bool fin = isFin(src);
390
7.22k
            if constexpr (isServer) {
391
7.22k
                memcpy(wState->mask, src + MESSAGE_HEADER - 4, 4);
392
7.22k
                uint64_t mask;
393
7.22k
                memcpy(&mask, src + MESSAGE_HEADER - 4, 4);
394
7.22k
                memcpy(((char *)&mask) + 4, src + MESSAGE_HEADER - 4, 4);
395
7.22k
                unmaskImprecise8<0>(src + MESSAGE_HEADER, mask, length);
396
7.22k
                rotateMask(4 - (length - MESSAGE_HEADER) % 4, wState->mask);
397
7.22k
            }
398
7.22k
            Impl::handleFragment(src + MESSAGE_HEADER, length - MESSAGE_HEADER, wState->remainingBytes, wState->state.opCode[wState->state.opStack], fin, wState, user);
399
7.22k
            return true;
400
7.22k
        }
401
37.4k
    }
EpollHelloWorld.cpp:bool uWS::WebSocketProtocol<true, uWS::WebSocketContext<false, true, test()::PerSocketData> >::consumeMessage<8u, unsigned short>(unsigned short, char*&, unsigned int&, uWS::WebSocketState<true>*, void*)
Line
Count
Source
345
4.31k
    static inline bool consumeMessage(T payLength, char *&src, unsigned int &length, WebSocketState<isServer> *wState, void *user) {
346
4.31k
        if (getOpCode(src)) {
347
3.56k
            if (wState->state.opStack == 1 || (!wState->state.lastFin && getOpCode(src) < 2)) {
348
389
                Impl::forceClose(wState, user, ERR_PROTOCOL);
349
389
                return true;
350
389
            }
351
3.17k
            wState->state.opCode[++wState->state.opStack] = (OpCode) getOpCode(src);
352
3.17k
        } else if (wState->state.opStack == -1) {
353
194
            Impl::forceClose(wState, user, ERR_PROTOCOL);
354
194
            return true;
355
194
        }
356
3.73k
        wState->state.lastFin = isFin(src);
357
358
3.73k
        if (Impl::refusePayloadLength(payLength, wState, user)) {
359
209
            Impl::forceClose(wState, user, ERR_TOO_BIG_MESSAGE);
360
209
            return true;
361
209
        }
362
363
3.52k
        if (payLength + MESSAGE_HEADER <= length) {
364
2.19k
            bool fin = isFin(src);
365
2.19k
            if (isServer) {
366
                /* This guy can never be assumed to be perfectly aligned since we can get multiple messages in one read */
367
2.19k
                unmaskImpreciseCopyMask<MESSAGE_HEADER>(src + MESSAGE_HEADER, (unsigned int) payLength);
368
2.19k
                if (Impl::handleFragment(src, payLength, 0, wState->state.opCode[wState->state.opStack], fin, wState, user)) {
369
247
                    return true;
370
247
                }
371
2.19k
            } else {
372
0
                if (Impl::handleFragment(src + MESSAGE_HEADER, payLength, 0, wState->state.opCode[wState->state.opStack], isFin(src), wState, user)) {
373
0
                    return true;
374
0
                }
375
0
            }
376
377
1.94k
            if (fin) {
378
1.38k
                wState->state.opStack--;
379
1.38k
            }
380
381
1.94k
            src += payLength + MESSAGE_HEADER;
382
1.94k
            length -= (unsigned int) (payLength + MESSAGE_HEADER);
383
1.94k
            wState->state.spillLength = 0;
384
1.94k
            return false;
385
2.19k
        } else {
386
1.32k
            wState->state.spillLength = 0;
387
1.32k
            wState->state.wantsHead = false;
388
1.32k
            wState->remainingBytes = (unsigned int) (payLength - length + MESSAGE_HEADER);
389
1.32k
            bool fin = isFin(src);
390
1.32k
            if constexpr (isServer) {
391
1.32k
                memcpy(wState->mask, src + MESSAGE_HEADER - 4, 4);
392
1.32k
                uint64_t mask;
393
1.32k
                memcpy(&mask, src + MESSAGE_HEADER - 4, 4);
394
1.32k
                memcpy(((char *)&mask) + 4, src + MESSAGE_HEADER - 4, 4);
395
1.32k
                unmaskImprecise8<0>(src + MESSAGE_HEADER, mask, length);
396
1.32k
                rotateMask(4 - (length - MESSAGE_HEADER) % 4, wState->mask);
397
1.32k
            }
398
1.32k
            Impl::handleFragment(src + MESSAGE_HEADER, length - MESSAGE_HEADER, wState->remainingBytes, wState->state.opCode[wState->state.opStack], fin, wState, user);
399
1.32k
            return true;
400
1.32k
        }
401
3.52k
    }
EpollHelloWorld.cpp:bool uWS::WebSocketProtocol<true, uWS::WebSocketContext<false, true, test()::PerSocketData> >::consumeMessage<14u, unsigned long>(unsigned long, char*&, unsigned int&, uWS::WebSocketState<true>*, void*)
Line
Count
Source
345
3.74k
    static inline bool consumeMessage(T payLength, char *&src, unsigned int &length, WebSocketState<isServer> *wState, void *user) {
346
3.74k
        if (getOpCode(src)) {
347
2.81k
            if (wState->state.opStack == 1 || (!wState->state.lastFin && getOpCode(src) < 2)) {
348
393
                Impl::forceClose(wState, user, ERR_PROTOCOL);
349
393
                return true;
350
393
            }
351
2.42k
            wState->state.opCode[++wState->state.opStack] = (OpCode) getOpCode(src);
352
2.42k
        } else if (wState->state.opStack == -1) {
353
194
            Impl::forceClose(wState, user, ERR_PROTOCOL);
354
194
            return true;
355
194
        }
356
3.16k
        wState->state.lastFin = isFin(src);
357
358
3.16k
        if (Impl::refusePayloadLength(payLength, wState, user)) {
359
776
            Impl::forceClose(wState, user, ERR_TOO_BIG_MESSAGE);
360
776
            return true;
361
776
        }
362
363
2.38k
        if (payLength + MESSAGE_HEADER <= length) {
364
1.44k
            bool fin = isFin(src);
365
1.44k
            if (isServer) {
366
                /* This guy can never be assumed to be perfectly aligned since we can get multiple messages in one read */
367
1.44k
                unmaskImpreciseCopyMask<MESSAGE_HEADER>(src + MESSAGE_HEADER, (unsigned int) payLength);
368
1.44k
                if (Impl::handleFragment(src, payLength, 0, wState->state.opCode[wState->state.opStack], fin, wState, user)) {
369
245
                    return true;
370
245
                }
371
1.44k
            } else {
372
0
                if (Impl::handleFragment(src + MESSAGE_HEADER, payLength, 0, wState->state.opCode[wState->state.opStack], isFin(src), wState, user)) {
373
0
                    return true;
374
0
                }
375
0
            }
376
377
1.20k
            if (fin) {
378
596
                wState->state.opStack--;
379
596
            }
380
381
1.20k
            src += payLength + MESSAGE_HEADER;
382
1.20k
            length -= (unsigned int) (payLength + MESSAGE_HEADER);
383
1.20k
            wState->state.spillLength = 0;
384
1.20k
            return false;
385
1.44k
        } else {
386
937
            wState->state.spillLength = 0;
387
937
            wState->state.wantsHead = false;
388
937
            wState->remainingBytes = (unsigned int) (payLength - length + MESSAGE_HEADER);
389
937
            bool fin = isFin(src);
390
937
            if constexpr (isServer) {
391
937
                memcpy(wState->mask, src + MESSAGE_HEADER - 4, 4);
392
937
                uint64_t mask;
393
937
                memcpy(&mask, src + MESSAGE_HEADER - 4, 4);
394
937
                memcpy(((char *)&mask) + 4, src + MESSAGE_HEADER - 4, 4);
395
937
                unmaskImprecise8<0>(src + MESSAGE_HEADER, mask, length);
396
937
                rotateMask(4 - (length - MESSAGE_HEADER) % 4, wState->mask);
397
937
            }
398
937
            Impl::handleFragment(src + MESSAGE_HEADER, length - MESSAGE_HEADER, wState->remainingBytes, wState->state.opCode[wState->state.opStack], fin, wState, user);
399
937
            return true;
400
937
        }
401
2.38k
    }
402
403
    /* This one is nicely vectorized on both ARM64 and X64 - especially with -mavx */
404
0
    static inline void unmaskAll(char * __restrict data, char * __restrict mask) {
405
0
        for (int i = 0; i < LIBUS_RECV_BUFFER_LENGTH; i += 16) {
406
0
            UnrolledXor<16>(data + i, mask);
407
0
        }
408
0
    }
409
410
38.6k
    static inline bool consumeContinuation(char *&src, unsigned int &length, WebSocketState<isServer> *wState, void *user) {
411
38.6k
        if (wState->remainingBytes <= length) {
412
6.85k
            if (isServer) {
413
6.85k
                unsigned int n = wState->remainingBytes >> 2;
414
6.85k
                unmaskInplace(src, src + n * 4, wState->mask);
415
16.8k
                for (unsigned int i = 0, s = wState->remainingBytes % 4; i < s; i++) {
416
10.0k
                    src[n * 4 + i] ^= wState->mask[i];
417
10.0k
                }
418
6.85k
            }
419
420
6.85k
            if (Impl::handleFragment(src, wState->remainingBytes, 0, wState->state.opCode[wState->state.opStack], wState->state.lastFin, wState, user)) {
421
2.05k
                return false;
422
2.05k
            }
423
424
4.80k
            if (wState->state.lastFin) {
425
3.51k
                wState->state.opStack--;
426
3.51k
            }
427
428
4.80k
            src += wState->remainingBytes;
429
4.80k
            length -= wState->remainingBytes;
430
4.80k
            wState->state.wantsHead = true;
431
4.80k
            return true;
432
31.8k
        } else {
433
31.8k
            if (isServer) {
434
                /* No need to unmask if mask is 0 */
435
31.8k
                uint32_t nullmask = 0;
436
31.8k
                if (memcmp(wState->mask, &nullmask, sizeof(uint32_t))) {
437
31.5k
                    if /*constexpr*/ (LIBUS_RECV_BUFFER_LENGTH == length) {
438
0
                        unmaskAll(src, wState->mask);
439
31.5k
                    } else {
440
                        // Slow path
441
31.5k
                        unmaskInplace(src, src + ((length >> 2) + 1) * 4, wState->mask);
442
31.5k
                    }
443
31.5k
                }
444
31.8k
            }
445
446
31.8k
            wState->remainingBytes -= length;
447
31.8k
            if (Impl::handleFragment(src, length, wState->remainingBytes, wState->state.opCode[wState->state.opStack], wState->state.lastFin, wState, user)) {
448
71
                return false;
449
71
            }
450
451
31.7k
            if (isServer && length % 4) {
452
31.2k
                rotateMask(4 - (length % 4), wState->mask);
453
31.2k
            }
454
31.7k
            return false;
455
31.8k
        }
456
38.6k
    }
457
458
public:
459
    WebSocketProtocol() {
460
461
    }
462
463
60.0k
    static inline void consume(char *src, unsigned int length, WebSocketState<isServer> *wState, void *user) {
464
60.0k
        if (wState->state.spillLength) {
465
2.61k
            src -= wState->state.spillLength;
466
2.61k
            length += wState->state.spillLength;
467
2.61k
            memcpy(src, wState->state.spill, wState->state.spillLength);
468
2.61k
        }
469
60.0k
        if (wState->state.wantsHead) {
470
26.2k
            parseNext:
471
53.5k
            while (length >= SHORT_MESSAGE_HEADER) {
472
473
                // invalid reserved bits / invalid opcodes / invalid control frames / set compressed frame
474
49.8k
                if ((rsv1(src) && !Impl::setCompressed(wState, user)) || rsv23(src) || (getOpCode(src) > 2 && getOpCode(src) < 8) ||
475
49.8k
                    getOpCode(src) > 10 || (getOpCode(src) > 2 && (!isFin(src) || payloadLength(src) > 125))) {
476
3.09k
                    Impl::forceClose(wState, user, ERR_PROTOCOL);
477
3.09k
                    return;
478
3.09k
                }
479
480
46.7k
                if (payloadLength(src) < 126) {
481
38.1k
                    if (consumeMessage<SHORT_MESSAGE_HEADER, uint8_t>(payloadLength(src), src, length, wState, user)) {
482
13.9k
                        return;
483
13.9k
                    }
484
38.1k
                } else if (payloadLength(src) == 126) {
485
4.53k
                    if (length < MEDIUM_MESSAGE_HEADER) {
486
220
                        break;
487
4.31k
                    } else if(consumeMessage<MEDIUM_MESSAGE_HEADER, uint16_t>(protocol::cond_byte_swap<uint16_t>(protocol::bit_cast<uint16_t>(src + 2)), src, length, wState, user)) {
488
2.36k
                        return;
489
2.36k
                    }
490
4.53k
                } else if (length < LONG_MESSAGE_HEADER) {
491
283
                    break;
492
3.74k
                } else if (consumeMessage<LONG_MESSAGE_HEADER, uint64_t>(protocol::cond_byte_swap<uint64_t>(protocol::bit_cast<uint64_t>(src + 2)), src, length, wState, user)) {
493
2.54k
                    return;
494
2.54k
                }
495
46.7k
            }
496
4.22k
            if (length) {
497
3.14k
                memcpy(wState->state.spill, src, length);
498
3.14k
                wState->state.spillLength = length & 0xf;
499
3.14k
            }
500
38.6k
        } else if (consumeContinuation(src, length, wState, user)) {
501
4.80k
            goto parseNext;
502
4.80k
        }
503
60.0k
    }
504
505
    static const int CONSUME_POST_PADDING = 4;
506
    static const int CONSUME_PRE_PADDING = LONG_MESSAGE_HEADER - 1;
507
};
508
509
}
510
511
#endif // UWS_WEBSOCKETPROTOCOL_H