Coverage Report

Created: 2026-05-30 06:16

next uncovered line (L), next uncovered region (R), next uncovered branch (B)
/src/libzmq/src/ws_engine.cpp
Line
Count
Source
1
/* SPDX-License-Identifier: MPL-2.0 */
2
3
#include "precompiled.hpp"
4
5
#ifdef ZMQ_USE_NSS
6
#include <secoid.h>
7
#include <sechash.h>
8
#define SHA_DIGEST_LENGTH 20
9
#elif defined ZMQ_USE_BUILTIN_SHA1
10
#include "../external/sha1/sha1.h"
11
#elif defined ZMQ_USE_GNUTLS
12
#define SHA_DIGEST_LENGTH 20
13
#include <gnutls/gnutls.h>
14
#include <gnutls/crypto.h>
15
#endif
16
17
#if !defined ZMQ_HAVE_WINDOWS
18
#include <sys/types.h>
19
#include <unistd.h>
20
#include <sys/socket.h>
21
#include <netinet/in.h>
22
#include <arpa/inet.h>
23
#ifdef ZMQ_HAVE_VXWORKS
24
#include <sockLib.h>
25
#endif
26
#endif
27
28
#include <cstring>
29
30
#include "compat.hpp"
31
#include "tcp.hpp"
32
#include "ws_engine.hpp"
33
#include "session_base.hpp"
34
#include "err.hpp"
35
#include "ip.hpp"
36
#include "random.hpp"
37
#include "ws_decoder.hpp"
38
#include "ws_encoder.hpp"
39
#include "null_mechanism.hpp"
40
#include "plain_server.hpp"
41
#include "plain_client.hpp"
42
43
#ifdef ZMQ_HAVE_CURVE
44
#include "curve_client.hpp"
45
#include "curve_server.hpp"
46
#endif
47
48
//  OSX uses a different name for this socket option
49
#ifndef IPV6_ADD_MEMBERSHIP
50
#define IPV6_ADD_MEMBERSHIP IPV6_JOIN_GROUP
51
#endif
52
53
#ifdef __APPLE__
54
#include <TargetConditionals.h>
55
#endif
56
57
static int
58
encode_base64 (const unsigned char *in_, int in_len_, char *out_, int out_len_);
59
60
static void compute_accept_key (char *key_,
61
                                unsigned char hash_[SHA_DIGEST_LENGTH]);
62
63
zmq::ws_engine_t::ws_engine_t (fd_t fd_,
64
                               const options_t &options_,
65
                               const endpoint_uri_pair_t &endpoint_uri_pair_,
66
                               const ws_address_t &address_,
67
                               bool client_) :
68
0
    stream_engine_base_t (fd_, options_, endpoint_uri_pair_, true),
69
0
    _client (client_),
70
0
    _address (address_),
71
0
    _client_handshake_state (client_handshake_initial),
72
0
    _server_handshake_state (handshake_initial),
73
0
    _header_name_position (0),
74
0
    _header_value_position (0),
75
0
    _header_upgrade_websocket (false),
76
0
    _header_connection_upgrade (false),
77
0
    _heartbeat_timeout (0)
78
0
{
79
0
    memset (_websocket_key, 0, MAX_HEADER_VALUE_LENGTH + 1);
80
0
    memset (_websocket_accept, 0, MAX_HEADER_VALUE_LENGTH + 1);
81
0
    memset (_websocket_protocol, 0, 256);
82
83
0
    _next_msg = &ws_engine_t::next_handshake_command;
84
0
    _process_msg = &ws_engine_t::process_handshake_command;
85
0
    _close_msg.init ();
86
87
0
    if (_options.heartbeat_interval > 0) {
88
0
        _heartbeat_timeout = _options.heartbeat_timeout;
89
0
        if (_heartbeat_timeout == -1)
90
0
            _heartbeat_timeout = _options.heartbeat_interval;
91
0
    }
92
0
}
93
94
zmq::ws_engine_t::~ws_engine_t ()
95
0
{
96
0
    _close_msg.close ();
97
0
}
98
99
void zmq::ws_engine_t::start_ws_handshake ()
100
0
{
101
0
    if (_client) {
102
0
        const char *protocol;
103
0
        if (_options.mechanism == ZMQ_NULL)
104
0
            protocol = "ZWS2.0/NULL,ZWS2.0";
105
0
        else if (_options.mechanism == ZMQ_PLAIN)
106
0
            protocol = "ZWS2.0/PLAIN";
107
0
#ifdef ZMQ_HAVE_CURVE
108
0
        else if (_options.mechanism == ZMQ_CURVE)
109
0
            protocol = "ZWS2.0/CURVE";
110
0
#endif
111
0
        else {
112
            // Avoid uninitialized variable error breaking UWP build
113
0
            protocol = "";
114
0
            assert (false);
115
0
        }
116
117
0
        unsigned char nonce[16];
118
0
        int *p = reinterpret_cast<int *> (nonce);
119
120
        // The nonce doesn't have to be secure one, it is just use to avoid proxy cache
121
0
        *p = zmq::generate_random ();
122
0
        *(p + 1) = zmq::generate_random ();
123
0
        *(p + 2) = zmq::generate_random ();
124
0
        *(p + 3) = zmq::generate_random ();
125
126
0
        int size =
127
0
          encode_base64 (nonce, 16, _websocket_key, MAX_HEADER_VALUE_LENGTH);
128
0
        assert (size > 0);
129
130
0
        size = snprintf (
131
0
          reinterpret_cast<char *> (_write_buffer), WS_BUFFER_SIZE,
132
0
          "GET %s HTTP/1.1\r\n"
133
0
          "Host: %s\r\n"
134
0
          "Upgrade: websocket\r\n"
135
0
          "Connection: Upgrade\r\n"
136
0
          "Sec-WebSocket-Key: %s\r\n"
137
0
          "Sec-WebSocket-Protocol: %s\r\n"
138
0
          "Sec-WebSocket-Version: 13\r\n\r\n",
139
0
          _address.path (), _address.host (), _websocket_key, protocol);
140
0
        assert (size > 0 && size < WS_BUFFER_SIZE);
141
0
        _outpos = _write_buffer;
142
0
        _outsize = size;
143
0
        set_pollout ();
144
0
    }
145
0
}
146
147
void zmq::ws_engine_t::plug_internal ()
148
0
{
149
0
    start_ws_handshake ();
150
0
    set_pollin ();
151
0
    in_event ();
152
0
}
153
154
int zmq::ws_engine_t::routing_id_msg (msg_t *msg_)
155
0
{
156
0
    const int rc = msg_->init_size (_options.routing_id_size);
157
0
    errno_assert (rc == 0);
158
0
    if (_options.routing_id_size > 0)
159
0
        memcpy (msg_->data (), _options.routing_id, _options.routing_id_size);
160
0
    _next_msg = &ws_engine_t::pull_msg_from_session;
161
162
0
    return 0;
163
0
}
164
165
int zmq::ws_engine_t::process_routing_id_msg (msg_t *msg_)
166
0
{
167
0
    if (_options.recv_routing_id) {
168
0
        msg_->set_flags (msg_t::routing_id);
169
0
        const int rc = session ()->push_msg (msg_);
170
0
        errno_assert (rc == 0);
171
0
    } else {
172
0
        int rc = msg_->close ();
173
0
        errno_assert (rc == 0);
174
0
        rc = msg_->init ();
175
0
        errno_assert (rc == 0);
176
0
    }
177
178
0
    _process_msg = &ws_engine_t::push_msg_to_session;
179
180
0
    return 0;
181
0
}
182
183
bool zmq::ws_engine_t::select_protocol (const char *protocol_)
184
0
{
185
0
    if (_options.mechanism == ZMQ_NULL && (strcmp ("ZWS2.0", protocol_) == 0)) {
186
0
        _next_msg = static_cast<int (stream_engine_base_t::*) (msg_t *)> (
187
0
          &ws_engine_t::routing_id_msg);
188
0
        _process_msg = static_cast<int (stream_engine_base_t::*) (msg_t *)> (
189
0
          &ws_engine_t::process_routing_id_msg);
190
191
        // No mechanism in place, enabling heartbeat
192
0
        if (_options.heartbeat_interval > 0 && !_has_heartbeat_timer) {
193
0
            add_timer (_options.heartbeat_interval, heartbeat_ivl_timer_id);
194
0
            _has_heartbeat_timer = true;
195
0
        }
196
197
0
        return true;
198
0
    }
199
0
    if (_options.mechanism == ZMQ_NULL
200
0
        && strcmp ("ZWS2.0/NULL", protocol_) == 0) {
201
0
        _mechanism = new (std::nothrow)
202
0
          null_mechanism_t (session (), _peer_address, _options);
203
0
        alloc_assert (_mechanism);
204
0
        return true;
205
0
    } else if (_options.mechanism == ZMQ_PLAIN
206
0
               && strcmp ("ZWS2.0/PLAIN", protocol_) == 0) {
207
0
        if (_options.as_server)
208
0
            _mechanism = new (std::nothrow)
209
0
              plain_server_t (session (), _peer_address, _options);
210
0
        else
211
0
            _mechanism =
212
0
              new (std::nothrow) plain_client_t (session (), _options);
213
0
        alloc_assert (_mechanism);
214
0
        return true;
215
0
    }
216
0
#ifdef ZMQ_HAVE_CURVE
217
0
    else if (_options.mechanism == ZMQ_CURVE
218
0
             && strcmp ("ZWS2.0/CURVE", protocol_) == 0) {
219
0
        if (_options.as_server)
220
0
            _mechanism = new (std::nothrow)
221
0
              curve_server_t (session (), _peer_address, _options, false);
222
0
        else
223
0
            _mechanism =
224
0
              new (std::nothrow) curve_client_t (session (), _options, false);
225
0
        alloc_assert (_mechanism);
226
0
        return true;
227
0
    }
228
0
#endif
229
230
0
    return false;
231
0
}
232
233
bool zmq::ws_engine_t::handshake ()
234
0
{
235
0
    bool complete;
236
237
0
    if (_client)
238
0
        complete = client_handshake ();
239
0
    else
240
0
        complete = server_handshake ();
241
242
0
    if (complete) {
243
0
        _encoder =
244
0
          new (std::nothrow) ws_encoder_t (_options.out_batch_size, _client);
245
0
        alloc_assert (_encoder);
246
247
0
        _decoder = new (std::nothrow)
248
0
          ws_decoder_t (_options.in_batch_size, _options.maxmsgsize,
249
0
                        _options.zero_copy, !_client);
250
0
        alloc_assert (_decoder);
251
252
0
        socket ()->event_handshake_succeeded (_endpoint_uri_pair, 0);
253
254
0
        set_pollout ();
255
0
    }
256
257
0
    return complete;
258
0
}
259
260
bool zmq::ws_engine_t::server_handshake ()
261
0
{
262
0
    const int nbytes = read (_read_buffer, WS_BUFFER_SIZE);
263
0
    if (nbytes == -1) {
264
0
        if (errno != EAGAIN)
265
0
            error (zmq::i_engine::connection_error);
266
0
        return false;
267
0
    }
268
269
0
    _inpos = _read_buffer;
270
0
    _insize = nbytes;
271
272
0
    while (_insize > 0) {
273
0
        const char c = static_cast<char> (*_inpos);
274
275
0
        switch (_server_handshake_state) {
276
0
            case handshake_initial:
277
0
                if (c == 'G')
278
0
                    _server_handshake_state = request_line_G;
279
0
                else
280
0
                    _server_handshake_state = handshake_error;
281
0
                break;
282
0
            case request_line_G:
283
0
                if (c == 'E')
284
0
                    _server_handshake_state = request_line_GE;
285
0
                else
286
0
                    _server_handshake_state = handshake_error;
287
0
                break;
288
0
            case request_line_GE:
289
0
                if (c == 'T')
290
0
                    _server_handshake_state = request_line_GET;
291
0
                else
292
0
                    _server_handshake_state = handshake_error;
293
0
                break;
294
0
            case request_line_GET:
295
0
                if (c == ' ')
296
0
                    _server_handshake_state = request_line_GET_space;
297
0
                else
298
0
                    _server_handshake_state = handshake_error;
299
0
                break;
300
0
            case request_line_GET_space:
301
0
                if (c == '\r' || c == '\n')
302
0
                    _server_handshake_state = handshake_error;
303
                // TODO: instead of check what is not allowed check what is allowed
304
0
                if (c != ' ')
305
0
                    _server_handshake_state = request_line_resource;
306
0
                else
307
0
                    _server_handshake_state = request_line_GET_space;
308
0
                break;
309
0
            case request_line_resource:
310
0
                if (c == '\r' || c == '\n')
311
0
                    _server_handshake_state = handshake_error;
312
0
                else if (c == ' ')
313
0
                    _server_handshake_state = request_line_resource_space;
314
0
                else
315
0
                    _server_handshake_state = request_line_resource;
316
0
                break;
317
0
            case request_line_resource_space:
318
0
                if (c == 'H')
319
0
                    _server_handshake_state = request_line_H;
320
0
                else
321
0
                    _server_handshake_state = handshake_error;
322
0
                break;
323
0
            case request_line_H:
324
0
                if (c == 'T')
325
0
                    _server_handshake_state = request_line_HT;
326
0
                else
327
0
                    _server_handshake_state = handshake_error;
328
0
                break;
329
0
            case request_line_HT:
330
0
                if (c == 'T')
331
0
                    _server_handshake_state = request_line_HTT;
332
0
                else
333
0
                    _server_handshake_state = handshake_error;
334
0
                break;
335
0
            case request_line_HTT:
336
0
                if (c == 'P')
337
0
                    _server_handshake_state = request_line_HTTP;
338
0
                else
339
0
                    _server_handshake_state = handshake_error;
340
0
                break;
341
0
            case request_line_HTTP:
342
0
                if (c == '/')
343
0
                    _server_handshake_state = request_line_HTTP_slash;
344
0
                else
345
0
                    _server_handshake_state = handshake_error;
346
0
                break;
347
0
            case request_line_HTTP_slash:
348
0
                if (c == '1')
349
0
                    _server_handshake_state = request_line_HTTP_slash_1;
350
0
                else
351
0
                    _server_handshake_state = handshake_error;
352
0
                break;
353
0
            case request_line_HTTP_slash_1:
354
0
                if (c == '.')
355
0
                    _server_handshake_state = request_line_HTTP_slash_1_dot;
356
0
                else
357
0
                    _server_handshake_state = handshake_error;
358
0
                break;
359
0
            case request_line_HTTP_slash_1_dot:
360
0
                if (c == '1')
361
0
                    _server_handshake_state = request_line_HTTP_slash_1_dot_1;
362
0
                else
363
0
                    _server_handshake_state = handshake_error;
364
0
                break;
365
0
            case request_line_HTTP_slash_1_dot_1:
366
0
                if (c == '\r')
367
0
                    _server_handshake_state = request_line_cr;
368
0
                else
369
0
                    _server_handshake_state = handshake_error;
370
0
                break;
371
0
            case request_line_cr:
372
0
                if (c == '\n')
373
0
                    _server_handshake_state = header_field_begin_name;
374
0
                else
375
0
                    _server_handshake_state = handshake_error;
376
0
                break;
377
0
            case header_field_begin_name:
378
0
                switch (c) {
379
0
                    case '\r':
380
0
                        _server_handshake_state = handshake_end_line_cr;
381
0
                        break;
382
0
                    case '\n':
383
0
                        _server_handshake_state = handshake_error;
384
0
                        break;
385
0
                    default:
386
0
                        _header_name[0] = c;
387
0
                        _header_name_position = 1;
388
0
                        _server_handshake_state = header_field_name;
389
0
                        break;
390
0
                }
391
0
                break;
392
0
            case header_field_name:
393
0
                if (c == '\r' || c == '\n')
394
0
                    _server_handshake_state = handshake_error;
395
0
                else if (c == ':') {
396
0
                    _header_name[_header_name_position] = '\0';
397
0
                    _server_handshake_state = header_field_colon;
398
0
                } else if (_header_name_position + 1 > MAX_HEADER_NAME_LENGTH)
399
0
                    _server_handshake_state = handshake_error;
400
0
                else {
401
0
                    _header_name[_header_name_position] = c;
402
0
                    _header_name_position++;
403
0
                    _server_handshake_state = header_field_name;
404
0
                }
405
0
                break;
406
0
            case header_field_colon:
407
0
            case header_field_value_trailing_space:
408
0
                if (c == '\n')
409
0
                    _server_handshake_state = handshake_error;
410
0
                else if (c == '\r')
411
0
                    _server_handshake_state = header_field_cr;
412
0
                else if (c == ' ')
413
0
                    _server_handshake_state = header_field_value_trailing_space;
414
0
                else {
415
0
                    _header_value[0] = c;
416
0
                    _header_value_position = 1;
417
0
                    _server_handshake_state = header_field_value;
418
0
                }
419
0
                break;
420
0
            case header_field_value:
421
0
                if (c == '\n')
422
0
                    _server_handshake_state = handshake_error;
423
0
                else if (c == '\r') {
424
0
                    _header_value[_header_value_position] = '\0';
425
426
0
                    if (strcasecmp ("upgrade", _header_name) == 0)
427
0
                        _header_upgrade_websocket =
428
0
                          strcasecmp ("websocket", _header_value) == 0;
429
0
                    else if (strcasecmp ("connection", _header_name) == 0) {
430
0
                        char *rest = NULL;
431
0
                        char *element = strtok_r (_header_value, ",", &rest);
432
0
                        while (element != NULL) {
433
0
                            while (*element == ' ')
434
0
                                element++;
435
0
                            if (strcasecmp ("upgrade", element) == 0) {
436
0
                                _header_connection_upgrade = true;
437
0
                                break;
438
0
                            }
439
0
                            element = strtok_r (NULL, ",", &rest);
440
0
                        }
441
0
                    } else if (strcasecmp ("Sec-WebSocket-Key", _header_name)
442
0
                               == 0)
443
0
                        strcpy_s (_websocket_key, _header_value);
444
0
                    else if (strcasecmp ("Sec-WebSocket-Protocol", _header_name)
445
0
                             == 0) {
446
                        // Currently only the ZWS2.0 is supported
447
                        // Sec-WebSocket-Protocol can appear multiple times or be a comma separated list
448
                        // if _websocket_protocol is already set we skip the check
449
0
                        if (_websocket_protocol[0] == '\0') {
450
0
                            char *rest = NULL;
451
0
                            char *p = strtok_r (_header_value, ",", &rest);
452
0
                            while (p != NULL) {
453
0
                                if (*p == ' ')
454
0
                                    p++;
455
456
0
                                if (select_protocol (p)) {
457
0
                                    strcpy_s (_websocket_protocol, p);
458
0
                                    break;
459
0
                                }
460
461
0
                                p = strtok_r (NULL, ",", &rest);
462
0
                            }
463
0
                        }
464
0
                    }
465
466
0
                    _server_handshake_state = header_field_cr;
467
0
                } else if (_header_value_position + 1 > MAX_HEADER_VALUE_LENGTH)
468
0
                    _server_handshake_state = handshake_error;
469
0
                else {
470
0
                    _header_value[_header_value_position] = c;
471
0
                    _header_value_position++;
472
0
                    _server_handshake_state = header_field_value;
473
0
                }
474
0
                break;
475
0
            case header_field_cr:
476
0
                if (c == '\n')
477
0
                    _server_handshake_state = header_field_begin_name;
478
0
                else
479
0
                    _server_handshake_state = handshake_error;
480
0
                break;
481
0
            case handshake_end_line_cr:
482
0
                if (c == '\n') {
483
0
                    if (_header_connection_upgrade && _header_upgrade_websocket
484
0
                        && _websocket_protocol[0] != '\0'
485
0
                        && _websocket_key[0] != '\0') {
486
0
                        _server_handshake_state = handshake_complete;
487
488
0
                        unsigned char hash[SHA_DIGEST_LENGTH];
489
0
                        compute_accept_key (_websocket_key, hash);
490
491
0
                        const int accept_key_len = encode_base64 (
492
0
                          hash, SHA_DIGEST_LENGTH, _websocket_accept,
493
0
                          MAX_HEADER_VALUE_LENGTH);
494
0
                        assert (accept_key_len > 0);
495
0
                        _websocket_accept[accept_key_len] = '\0';
496
497
0
                        const int written =
498
0
                          snprintf (reinterpret_cast<char *> (_write_buffer),
499
0
                                    WS_BUFFER_SIZE,
500
0
                                    "HTTP/1.1 101 Switching Protocols\r\n"
501
0
                                    "Upgrade: websocket\r\n"
502
0
                                    "Connection: Upgrade\r\n"
503
0
                                    "Sec-WebSocket-Accept: %s\r\n"
504
0
                                    "Sec-WebSocket-Protocol: %s\r\n"
505
0
                                    "\r\n",
506
0
                                    _websocket_accept, _websocket_protocol);
507
0
                        assert (written >= 0 && written < WS_BUFFER_SIZE);
508
0
                        _outpos = _write_buffer;
509
0
                        _outsize = written;
510
511
0
                        _inpos++;
512
0
                        _insize--;
513
514
0
                        return true;
515
0
                    }
516
0
                    _server_handshake_state = handshake_error;
517
0
                } else
518
0
                    _server_handshake_state = handshake_error;
519
0
                break;
520
0
            default:
521
0
                assert (false);
522
0
        }
523
524
0
        _inpos++;
525
0
        _insize--;
526
527
0
        if (_server_handshake_state == handshake_error) {
528
            // TODO: send bad request
529
530
0
            socket ()->event_handshake_failed_protocol (
531
0
              _endpoint_uri_pair, ZMQ_PROTOCOL_ERROR_WS_UNSPECIFIED);
532
533
0
            error (zmq::i_engine::protocol_error);
534
0
            return false;
535
0
        }
536
0
    }
537
0
    return false;
538
0
}
539
540
bool zmq::ws_engine_t::client_handshake ()
541
0
{
542
0
    const int nbytes = read (_read_buffer, WS_BUFFER_SIZE);
543
0
    if (nbytes == -1) {
544
0
        if (errno != EAGAIN)
545
0
            error (zmq::i_engine::connection_error);
546
0
        return false;
547
0
    }
548
549
0
    _inpos = _read_buffer;
550
0
    _insize = nbytes;
551
552
0
    while (_insize > 0) {
553
0
        const char c = static_cast<char> (*_inpos);
554
555
0
        switch (_client_handshake_state) {
556
0
            case client_handshake_initial:
557
0
                if (c == 'H')
558
0
                    _client_handshake_state = response_line_H;
559
0
                else
560
0
                    _client_handshake_state = client_handshake_error;
561
0
                break;
562
0
            case response_line_H:
563
0
                if (c == 'T')
564
0
                    _client_handshake_state = response_line_HT;
565
0
                else
566
0
                    _client_handshake_state = client_handshake_error;
567
0
                break;
568
0
            case response_line_HT:
569
0
                if (c == 'T')
570
0
                    _client_handshake_state = response_line_HTT;
571
0
                else
572
0
                    _client_handshake_state = client_handshake_error;
573
0
                break;
574
0
            case response_line_HTT:
575
0
                if (c == 'P')
576
0
                    _client_handshake_state = response_line_HTTP;
577
0
                else
578
0
                    _client_handshake_state = client_handshake_error;
579
0
                break;
580
0
            case response_line_HTTP:
581
0
                if (c == '/')
582
0
                    _client_handshake_state = response_line_HTTP_slash;
583
0
                else
584
0
                    _client_handshake_state = client_handshake_error;
585
0
                break;
586
0
            case response_line_HTTP_slash:
587
0
                if (c == '1')
588
0
                    _client_handshake_state = response_line_HTTP_slash_1;
589
0
                else
590
0
                    _client_handshake_state = client_handshake_error;
591
0
                break;
592
0
            case response_line_HTTP_slash_1:
593
0
                if (c == '.')
594
0
                    _client_handshake_state = response_line_HTTP_slash_1_dot;
595
0
                else
596
0
                    _client_handshake_state = client_handshake_error;
597
0
                break;
598
0
            case response_line_HTTP_slash_1_dot:
599
0
                if (c == '1')
600
0
                    _client_handshake_state = response_line_HTTP_slash_1_dot_1;
601
0
                else
602
0
                    _client_handshake_state = client_handshake_error;
603
0
                break;
604
0
            case response_line_HTTP_slash_1_dot_1:
605
0
                if (c == ' ')
606
0
                    _client_handshake_state =
607
0
                      response_line_HTTP_slash_1_dot_1_space;
608
0
                else
609
0
                    _client_handshake_state = client_handshake_error;
610
0
                break;
611
0
            case response_line_HTTP_slash_1_dot_1_space:
612
0
                if (c == ' ')
613
0
                    _client_handshake_state =
614
0
                      response_line_HTTP_slash_1_dot_1_space;
615
0
                else if (c == '1')
616
0
                    _client_handshake_state = response_line_status_1;
617
0
                else
618
0
                    _client_handshake_state = client_handshake_error;
619
0
                break;
620
0
            case response_line_status_1:
621
0
                if (c == '0')
622
0
                    _client_handshake_state = response_line_status_10;
623
0
                else
624
0
                    _client_handshake_state = client_handshake_error;
625
0
                break;
626
0
            case response_line_status_10:
627
0
                if (c == '1')
628
0
                    _client_handshake_state = response_line_status_101;
629
0
                else
630
0
                    _client_handshake_state = client_handshake_error;
631
0
                break;
632
0
            case response_line_status_101:
633
0
                if (c == ' ')
634
0
                    _client_handshake_state = response_line_status_101_space;
635
0
                else
636
0
                    _client_handshake_state = client_handshake_error;
637
0
                break;
638
0
            case response_line_status_101_space:
639
0
                if (c == ' ')
640
0
                    _client_handshake_state = response_line_status_101_space;
641
0
                else if (c == 'S')
642
0
                    _client_handshake_state = response_line_s;
643
0
                else
644
0
                    _client_handshake_state = client_handshake_error;
645
0
                break;
646
0
            case response_line_s:
647
0
                if (c == 'w')
648
0
                    _client_handshake_state = response_line_sw;
649
0
                else
650
0
                    _client_handshake_state = client_handshake_error;
651
0
                break;
652
0
            case response_line_sw:
653
0
                if (c == 'i')
654
0
                    _client_handshake_state = response_line_swi;
655
0
                else
656
0
                    _client_handshake_state = client_handshake_error;
657
0
                break;
658
0
            case response_line_swi:
659
0
                if (c == 't')
660
0
                    _client_handshake_state = response_line_swit;
661
0
                else
662
0
                    _client_handshake_state = client_handshake_error;
663
0
                break;
664
0
            case response_line_swit:
665
0
                if (c == 'c')
666
0
                    _client_handshake_state = response_line_switc;
667
0
                else
668
0
                    _client_handshake_state = client_handshake_error;
669
0
                break;
670
0
            case response_line_switc:
671
0
                if (c == 'h')
672
0
                    _client_handshake_state = response_line_switch;
673
0
                else
674
0
                    _client_handshake_state = client_handshake_error;
675
0
                break;
676
0
            case response_line_switch:
677
0
                if (c == 'i')
678
0
                    _client_handshake_state = response_line_switchi;
679
0
                else
680
0
                    _client_handshake_state = client_handshake_error;
681
0
                break;
682
0
            case response_line_switchi:
683
0
                if (c == 'n')
684
0
                    _client_handshake_state = response_line_switchin;
685
0
                else
686
0
                    _client_handshake_state = client_handshake_error;
687
0
                break;
688
0
            case response_line_switchin:
689
0
                if (c == 'g')
690
0
                    _client_handshake_state = response_line_switching;
691
0
                else
692
0
                    _client_handshake_state = client_handshake_error;
693
0
                break;
694
0
            case response_line_switching:
695
0
                if (c == ' ')
696
0
                    _client_handshake_state = response_line_switching_space;
697
0
                else
698
0
                    _client_handshake_state = client_handshake_error;
699
0
                break;
700
0
            case response_line_switching_space:
701
0
                if (c == 'P')
702
0
                    _client_handshake_state = response_line_p;
703
0
                else
704
0
                    _client_handshake_state = client_handshake_error;
705
0
                break;
706
0
            case response_line_p:
707
0
                if (c == 'r')
708
0
                    _client_handshake_state = response_line_pr;
709
0
                else
710
0
                    _client_handshake_state = client_handshake_error;
711
0
                break;
712
0
            case response_line_pr:
713
0
                if (c == 'o')
714
0
                    _client_handshake_state = response_line_pro;
715
0
                else
716
0
                    _client_handshake_state = client_handshake_error;
717
0
                break;
718
0
            case response_line_pro:
719
0
                if (c == 't')
720
0
                    _client_handshake_state = response_line_prot;
721
0
                else
722
0
                    _client_handshake_state = client_handshake_error;
723
0
                break;
724
0
            case response_line_prot:
725
0
                if (c == 'o')
726
0
                    _client_handshake_state = response_line_proto;
727
0
                else
728
0
                    _client_handshake_state = client_handshake_error;
729
0
                break;
730
0
            case response_line_proto:
731
0
                if (c == 'c')
732
0
                    _client_handshake_state = response_line_protoc;
733
0
                else
734
0
                    _client_handshake_state = client_handshake_error;
735
0
                break;
736
0
            case response_line_protoc:
737
0
                if (c == 'o')
738
0
                    _client_handshake_state = response_line_protoco;
739
0
                else
740
0
                    _client_handshake_state = client_handshake_error;
741
0
                break;
742
0
            case response_line_protoco:
743
0
                if (c == 'l')
744
0
                    _client_handshake_state = response_line_protocol;
745
0
                else
746
0
                    _client_handshake_state = client_handshake_error;
747
0
                break;
748
0
            case response_line_protocol:
749
0
                if (c == 's')
750
0
                    _client_handshake_state = response_line_protocols;
751
0
                else
752
0
                    _client_handshake_state = client_handshake_error;
753
0
                break;
754
0
            case response_line_protocols:
755
0
                if (c == '\r')
756
0
                    _client_handshake_state = response_line_cr;
757
0
                else
758
0
                    _client_handshake_state = client_handshake_error;
759
0
                break;
760
0
            case response_line_cr:
761
0
                if (c == '\n')
762
0
                    _client_handshake_state = client_header_field_begin_name;
763
0
                else
764
0
                    _client_handshake_state = client_handshake_error;
765
0
                break;
766
0
            case client_header_field_begin_name:
767
0
                switch (c) {
768
0
                    case '\r':
769
0
                        _client_handshake_state = client_handshake_end_line_cr;
770
0
                        break;
771
0
                    case '\n':
772
0
                        _client_handshake_state = client_handshake_error;
773
0
                        break;
774
0
                    default:
775
0
                        _header_name[0] = c;
776
0
                        _header_name_position = 1;
777
0
                        _client_handshake_state = client_header_field_name;
778
0
                        break;
779
0
                }
780
0
                break;
781
0
            case client_header_field_name:
782
0
                if (c == '\r' || c == '\n')
783
0
                    _client_handshake_state = client_handshake_error;
784
0
                else if (c == ':') {
785
0
                    _header_name[_header_name_position] = '\0';
786
0
                    _client_handshake_state = client_header_field_colon;
787
0
                } else if (_header_name_position + 1 > MAX_HEADER_NAME_LENGTH)
788
0
                    _client_handshake_state = client_handshake_error;
789
0
                else {
790
0
                    _header_name[_header_name_position] = c;
791
0
                    _header_name_position++;
792
0
                    _client_handshake_state = client_header_field_name;
793
0
                }
794
0
                break;
795
0
            case client_header_field_colon:
796
0
            case client_header_field_value_trailing_space:
797
0
                if (c == '\n')
798
0
                    _client_handshake_state = client_handshake_error;
799
0
                else if (c == '\r')
800
0
                    _client_handshake_state = client_header_field_cr;
801
0
                else if (c == ' ')
802
0
                    _client_handshake_state =
803
0
                      client_header_field_value_trailing_space;
804
0
                else {
805
0
                    _header_value[0] = c;
806
0
                    _header_value_position = 1;
807
0
                    _client_handshake_state = client_header_field_value;
808
0
                }
809
0
                break;
810
0
            case client_header_field_value:
811
0
                if (c == '\n')
812
0
                    _client_handshake_state = client_handshake_error;
813
0
                else if (c == '\r') {
814
0
                    _header_value[_header_value_position] = '\0';
815
816
0
                    if (strcasecmp ("upgrade", _header_name) == 0)
817
0
                        _header_upgrade_websocket =
818
0
                          strcasecmp ("websocket", _header_value) == 0;
819
0
                    else if (strcasecmp ("connection", _header_name) == 0)
820
0
                        _header_connection_upgrade =
821
0
                          strcasecmp ("upgrade", _header_value) == 0;
822
0
                    else if (strcasecmp ("Sec-WebSocket-Accept", _header_name)
823
0
                             == 0)
824
0
                        strcpy_s (_websocket_accept, _header_value);
825
0
                    else if (strcasecmp ("Sec-WebSocket-Protocol", _header_name)
826
0
                             == 0) {
827
0
                        if (_mechanism) {
828
0
                            _client_handshake_state = client_handshake_error;
829
0
                            break;
830
0
                        }
831
0
                        if (select_protocol (_header_value))
832
0
                            strcpy_s (_websocket_protocol, _header_value);
833
0
                    }
834
0
                    _client_handshake_state = client_header_field_cr;
835
0
                } else if (_header_value_position + 1 > MAX_HEADER_VALUE_LENGTH)
836
0
                    _client_handshake_state = client_handshake_error;
837
0
                else {
838
0
                    _header_value[_header_value_position] = c;
839
0
                    _header_value_position++;
840
0
                    _client_handshake_state = client_header_field_value;
841
0
                }
842
0
                break;
843
0
            case client_header_field_cr:
844
0
                if (c == '\n')
845
0
                    _client_handshake_state = client_header_field_begin_name;
846
0
                else
847
0
                    _client_handshake_state = client_handshake_error;
848
0
                break;
849
0
            case client_handshake_end_line_cr:
850
0
                if (c == '\n') {
851
0
                    if (_header_connection_upgrade && _header_upgrade_websocket
852
0
                        && _websocket_protocol[0] != '\0'
853
0
                        && _websocket_accept[0] != '\0') {
854
0
                        _client_handshake_state = client_handshake_complete;
855
856
                        // TODO: validate accept key
857
858
0
                        _inpos++;
859
0
                        _insize--;
860
861
0
                        return true;
862
0
                    }
863
0
                    _client_handshake_state = client_handshake_error;
864
0
                } else
865
0
                    _client_handshake_state = client_handshake_error;
866
0
                break;
867
0
            default:
868
0
                assert (false);
869
0
        }
870
871
0
        _inpos++;
872
0
        _insize--;
873
874
0
        if (_client_handshake_state == client_handshake_error) {
875
0
            socket ()->event_handshake_failed_protocol (
876
0
              _endpoint_uri_pair, ZMQ_PROTOCOL_ERROR_WS_UNSPECIFIED);
877
878
0
            error (zmq::i_engine::protocol_error);
879
0
            return false;
880
0
        }
881
0
    }
882
883
0
    return false;
884
0
}
885
886
int zmq::ws_engine_t::decode_and_push (msg_t *msg_)
887
0
{
888
0
    zmq_assert (_mechanism != NULL);
889
890
    //  with WS engine, ping and pong commands are control messages and should not go through any mechanism
891
0
    if (msg_->is_ping () || msg_->is_pong () || msg_->is_close_cmd ()) {
892
0
        if (process_command_message (msg_) == -1)
893
0
            return -1;
894
0
    } else if (_mechanism->decode (msg_) == -1)
895
0
        return -1;
896
897
0
    if (_has_timeout_timer) {
898
0
        _has_timeout_timer = false;
899
0
        cancel_timer (heartbeat_timeout_timer_id);
900
0
    }
901
902
0
    if (msg_->flags () & msg_t::command && !msg_->is_ping ()
903
0
        && !msg_->is_pong () && !msg_->is_close_cmd ())
904
0
        process_command_message (msg_);
905
906
0
    if (_metadata)
907
0
        msg_->set_metadata (_metadata);
908
0
    if (session ()->push_msg (msg_) == -1) {
909
0
        if (errno == EAGAIN)
910
0
            _process_msg = &ws_engine_t::push_one_then_decode_and_push;
911
0
        return -1;
912
0
    }
913
0
    return 0;
914
0
}
915
916
int zmq::ws_engine_t::produce_close_message (msg_t *msg_)
917
0
{
918
0
    int rc = msg_->move (_close_msg);
919
0
    errno_assert (rc == 0);
920
921
0
    _next_msg = static_cast<int (stream_engine_base_t::*) (msg_t *)> (
922
0
      &ws_engine_t::produce_no_msg_after_close);
923
924
0
    return rc;
925
0
}
926
927
int zmq::ws_engine_t::produce_no_msg_after_close (msg_t *msg_)
928
0
{
929
0
    LIBZMQ_UNUSED (msg_);
930
0
    _next_msg = static_cast<int (stream_engine_base_t::*) (msg_t *)> (
931
0
      &ws_engine_t::close_connection_after_close);
932
933
0
    errno = EAGAIN;
934
0
    return -1;
935
0
}
936
937
int zmq::ws_engine_t::close_connection_after_close (msg_t *msg_)
938
0
{
939
0
    LIBZMQ_UNUSED (msg_);
940
0
    error (connection_error);
941
0
    errno = ECONNRESET;
942
0
    return -1;
943
0
}
944
945
int zmq::ws_engine_t::produce_ping_message (msg_t *msg_)
946
0
{
947
0
    int rc = msg_->init ();
948
0
    errno_assert (rc == 0);
949
0
    msg_->set_flags (msg_t::command | msg_t::ping);
950
951
0
    _next_msg = &ws_engine_t::pull_and_encode;
952
0
    if (!_has_timeout_timer && _heartbeat_timeout > 0) {
953
0
        add_timer (_heartbeat_timeout, heartbeat_timeout_timer_id);
954
0
        _has_timeout_timer = true;
955
0
    }
956
957
0
    return rc;
958
0
}
959
960
961
int zmq::ws_engine_t::produce_pong_message (msg_t *msg_)
962
0
{
963
0
    int rc = msg_->init ();
964
0
    errno_assert (rc == 0);
965
0
    msg_->set_flags (msg_t::command | msg_t::pong);
966
967
0
    _next_msg = &ws_engine_t::pull_and_encode;
968
0
    return rc;
969
0
}
970
971
972
int zmq::ws_engine_t::process_command_message (msg_t *msg_)
973
0
{
974
0
    if (msg_->is_ping ()) {
975
0
        _next_msg = static_cast<int (stream_engine_base_t::*) (msg_t *)> (
976
0
          &ws_engine_t::produce_pong_message);
977
0
        out_event ();
978
0
    } else if (msg_->is_close_cmd ()) {
979
0
        int rc = _close_msg.copy (*msg_);
980
0
        errno_assert (rc == 0);
981
0
        _next_msg = static_cast<int (stream_engine_base_t::*) (msg_t *)> (
982
0
          &ws_engine_t::produce_close_message);
983
0
        out_event ();
984
0
    }
985
986
0
    return 0;
987
0
}
988
989
static int
990
encode_base64 (const unsigned char *in_, int in_len_, char *out_, int out_len_)
991
0
{
992
0
    static const unsigned char base64enc_tab[65] =
993
0
      "ABCDEFGHIJKLMNOPQRSTUVWXYZabcdefghijklmnopqrstuvwxyz0123456789+/";
994
995
0
    int io = 0;
996
0
    uint32_t v = 0;
997
0
    int rem = 0;
998
999
0
    for (int ii = 0; ii < in_len_; ii++) {
1000
0
        unsigned char ch;
1001
0
        ch = in_[ii];
1002
0
        v = (v << 8) | ch;
1003
0
        rem += 8;
1004
0
        while (rem >= 6) {
1005
0
            rem -= 6;
1006
0
            if (io >= out_len_)
1007
0
                return -1; /* truncation is failure */
1008
0
            out_[io++] = base64enc_tab[(v >> rem) & 63];
1009
0
        }
1010
0
    }
1011
0
    if (rem) {
1012
0
        v <<= (6 - rem);
1013
0
        if (io >= out_len_)
1014
0
            return -1; /* truncation is failure */
1015
0
        out_[io++] = base64enc_tab[v & 63];
1016
0
    }
1017
0
    while (io & 3) {
1018
0
        if (io >= out_len_)
1019
0
            return -1; /* truncation is failure */
1020
0
        out_[io++] = '=';
1021
0
    }
1022
0
    if (io >= out_len_)
1023
0
        return -1; /* no room for null terminator */
1024
0
    out_[io] = 0;
1025
0
    return io;
1026
0
}
1027
1028
static void compute_accept_key (char *key_, unsigned char *hash_)
1029
0
{
1030
0
    const char *magic_string = "258EAFA5-E914-47DA-95CA-C5AB0DC85B11";
1031
#ifdef ZMQ_USE_NSS
1032
    unsigned int len;
1033
    HASH_HashType type = HASH_GetHashTypeByOidTag (SEC_OID_SHA1);
1034
    HASHContext *ctx = HASH_Create (type);
1035
    assert (ctx);
1036
1037
    HASH_Begin (ctx);
1038
    HASH_Update (ctx, (unsigned char *) key_, (unsigned int) strlen (key_));
1039
    HASH_Update (ctx, (unsigned char *) magic_string,
1040
                 (unsigned int) strlen (magic_string));
1041
    HASH_End (ctx, hash_, &len, SHA_DIGEST_LENGTH);
1042
    HASH_Destroy (ctx);
1043
#elif defined ZMQ_USE_BUILTIN_SHA1
1044
    sha1_ctxt ctx;
1045
0
    SHA1_Init (&ctx);
1046
0
    SHA1_Update (&ctx, (unsigned char *) key_, strlen (key_));
1047
0
    SHA1_Update (&ctx, (unsigned char *) magic_string, strlen (magic_string));
1048
1049
0
    SHA1_Final (hash_, &ctx);
1050
#elif defined ZMQ_USE_GNUTLS
1051
    gnutls_hash_hd_t hd;
1052
    gnutls_hash_init (&hd, GNUTLS_DIG_SHA1);
1053
    gnutls_hash (hd, key_, strlen (key_));
1054
    gnutls_hash (hd, magic_string, strlen (magic_string));
1055
    gnutls_hash_deinit (hd, hash_);
1056
#else
1057
#error "No sha1 implementation set"
1058
#endif
1059
0
}