Coverage Report

Created: 2025-05-08 06:06

/src/brpc/src/brpc/details/ssl_helper.cpp
Line
Count
Source (jump to first uncovered line)
1
// Licensed to the Apache Software Foundation (ASF) under one
2
// or more contributor license agreements.  See the NOTICE file
3
// distributed with this work for additional information
4
// regarding copyright ownership.  The ASF licenses this file
5
// to you under the Apache License, Version 2.0 (the
6
// "License"); you may not use this file except in compliance
7
// with the License.  You may obtain a copy of the License at
8
//
9
//   http://www.apache.org/licenses/LICENSE-2.0
10
//
11
// Unless required by applicable law or agreed to in writing,
12
// software distributed under the License is distributed on an
13
// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
14
// KIND, either express or implied.  See the License for the
15
// specific language governing permissions and limitations
16
// under the License.
17
18
19
20
#include <openssl/bio.h>
21
#ifndef USE_MESALINK
22
23
#include <sys/socket.h>                // recv
24
#include <openssl/ssl.h>
25
#include <openssl/err.h>
26
#include <openssl/x509.h>
27
#include <openssl/x509v3.h>
28
#include "butil/unique_ptr.h"
29
#include "butil/logging.h"
30
#include "butil/ssl_compat.h"
31
#include "butil/string_splitter.h"
32
#include "brpc/socket.h"
33
#include "brpc/details/ssl_helper.h"
34
35
namespace brpc {
36
37
#ifndef OPENSSL_NO_DH
38
static DH* g_dh_1024 = NULL;
39
static DH* g_dh_2048 = NULL;
40
static DH* g_dh_4096 = NULL;
41
static DH* g_dh_8192 = NULL;
42
#endif  // OPENSSL_NO_DH
43
44
static const char* const PEM_START = "-----BEGIN";
45
46
0
static bool IsPemString(const std::string& input) {
47
0
    for (const char* s = input.c_str(); *s != '\0'; ++s) {
48
0
        if (*s != '\n') {
49
0
            return strncmp(s, PEM_START, strlen(PEM_START)) == 0;
50
0
        } 
51
0
    }
52
0
    return false;
53
0
}
54
55
0
const char* SSLStateToString(SSLState s) {
56
0
    switch (s) {
57
0
    case SSL_UNKNOWN:
58
0
        return "SSL_UNKNOWN";
59
0
    case SSL_OFF:
60
0
        return "SSL_OFF";
61
0
    case SSL_CONNECTING:
62
0
        return "SSL_CONNECTING";
63
0
    case SSL_CONNECTED:
64
0
        return "SSL_CONNECTED";
65
0
    }
66
0
    return "Bad SSLState";
67
0
}
68
69
0
static int ParseSSLProtocols(const std::string& str_protocol) {
70
0
    int protocol_flag = 0;
71
0
    butil::StringSplitter sp(str_protocol.data(),
72
0
                             str_protocol.data() + str_protocol.size(), ',');
73
0
    for (; sp; ++sp) {
74
0
        butil::StringPiece protocol(sp.field(), sp.length());
75
0
        protocol.trim_spaces();
76
0
        if (strncasecmp(protocol.data(), "SSLv3", protocol.size()) == 0) {
77
0
            protocol_flag |= SSLv3;
78
0
        } else if (strncasecmp(protocol.data(), "TLSv1", protocol.size()) == 0) {
79
0
            protocol_flag |= TLSv1;
80
0
        } else if (strncasecmp(protocol.data(), "TLSv1.1", protocol.size()) == 0) {
81
0
            protocol_flag |= TLSv1_1;
82
0
        } else if (strncasecmp(protocol.data(), "TLSv1.2", protocol.size()) == 0) {
83
0
            protocol_flag |= TLSv1_2;
84
0
        } else {
85
0
            LOG(ERROR) << "Unknown SSL protocol=" << protocol;
86
0
            return -1;
87
0
        }
88
0
    }
89
0
    return protocol_flag;
90
0
}
91
92
0
std::ostream& operator<<(std::ostream& os, const SSLError& ssl) {
93
0
    char buf[128];  // Should be enough
94
0
    ERR_error_string_n(ssl.error, buf, sizeof(buf));
95
0
    return os << buf;
96
0
}
97
98
0
std::ostream& operator<<(std::ostream& os, const CertInfo& cert) {
99
0
    os << "certificate[";
100
0
    if (IsPemString(cert.certificate)) {
101
0
        size_t pos = cert.certificate.find('\n');
102
0
        if (pos == std::string::npos) {
103
0
            pos = 0;
104
0
        } else {
105
0
            pos++;
106
0
        }
107
0
        os << cert.certificate.substr(pos, 16) << "...";
108
0
    } else {
109
0
        os << cert.certificate;
110
0
    } 
111
112
0
    os << "] private-key[";
113
0
    if (IsPemString(cert.private_key)) {
114
0
        size_t pos = cert.private_key.find('\n');
115
0
        if (pos == std::string::npos) {
116
0
            pos = 0;
117
0
        } else {
118
0
            pos++;
119
0
        }
120
0
        os << cert.private_key.substr(pos, 16) << "...";
121
0
    } else {
122
0
        os << cert.private_key;
123
0
    }
124
0
    os << "]";
125
0
    return os;
126
0
}
127
128
0
static void SSLInfoCallback(const SSL* ssl, int where, int ret) {
129
0
    (void)ret;
130
0
    SocketUniquePtr s;
131
0
    SocketId id = (SocketId)SSL_get_app_data((SSL*)ssl);
132
0
    if (Socket::Address(id, &s) != 0) {
133
        // Already failed
134
0
        return;
135
0
    }
136
137
0
    if (where & SSL_CB_HANDSHAKE_START) {
138
0
        if (s->ssl_state() == SSL_CONNECTED) {
139
            // Disable renegotiation (CVE-2009-3555)
140
0
            LOG(ERROR) << "Close " << *s << " due to insecure "
141
0
                       << "renegotiation detected (CVE-2009-3555)";
142
0
            s->SetFailed();
143
0
        }
144
0
    }
145
0
}
146
147
static void SSLMessageCallback(int write_p, int version, int content_type,
148
0
                               const void* buf, size_t len, SSL* ssl, void* arg) {
149
0
    (void)version;
150
0
    (void)arg;
151
#ifdef TLS1_RT_HEARTBEAT
152
    // Test heartbeat received (write_p is set to 0 for a received record)
153
    if ((content_type == TLS1_RT_HEARTBEAT) && (write_p == 0)) {
154
        const unsigned char* p = (const unsigned char*)buf;
155
156
        // Check if this is a CVE-2014-0160 exploitation attempt. 
157
        if (*p != TLS1_HB_REQUEST) {
158
            return;
159
        }
160
161
        // 1 type + 2 size + 0 payload + 16 padding
162
        if (len >= 1 + 2 + 16) {
163
            unsigned int payload = (p[1] * 256) + p[2];
164
            if (3 + payload + 16 <= len) {
165
                return;               // OK no problem
166
            }
167
        }
168
        
169
        // We have a clear heartbleed attack (CVE-2014-0160), the
170
        // advertised payload is larger than the advertised packet
171
        // length, so we have garbage in the buffer between the
172
        // payload and the end of the buffer (p+len). We can't know
173
        // if the SSL stack is patched, and we don't know if we can
174
        // safely wipe out the area between p+3+len and payload.
175
        // So instead, we prevent the response from being sent by
176
        // setting the max_send_fragment to 0 and we report an SSL
177
        // error, which will kill this connection. It will be reported
178
        // above as SSL_ERROR_SSL while an other handshake failure with
179
        // a heartbeat message will be reported as SSL_ERROR_SYSCALL.
180
        ssl->max_send_fragment = 0;
181
        SSLerr(SSL_F_TLS1_HEARTBEAT, SSL_R_SSL_HANDSHAKE_FAILURE);
182
        return;
183
    }
184
#endif // TLS1_RT_HEARTBEAT
185
0
}
186
187
#ifndef OPENSSL_NO_DH
188
0
static DH* SSLGetDHCallback(SSL* ssl, int exp, int keylen) {
189
0
    (void)exp;
190
0
    EVP_PKEY* pkey = SSL_get_privatekey(ssl);
191
0
    int type = pkey ? EVP_PKEY_base_id(pkey) : EVP_PKEY_NONE;
192
193
    // The keylen supplied by OpenSSL can only be 512 or 1024.
194
    // See ssl3_send_server_key_exchange() in ssl/s3_srvr.c
195
0
    if (type == EVP_PKEY_RSA || type == EVP_PKEY_DSA) {
196
0
        keylen = EVP_PKEY_bits(pkey);
197
0
    }
198
199
0
    if (keylen >= 8192) {
200
0
        return g_dh_8192;
201
0
    } else if (keylen >= 4096) {
202
0
        return g_dh_4096;
203
0
    } else if (keylen >= 2048) {
204
0
        return g_dh_2048;
205
0
    } else {
206
0
        return g_dh_1024;
207
0
    }
208
0
}
209
#endif  // OPENSSL_NO_DH
210
211
0
void ExtractHostnames(X509* x, std::vector<std::string>* hostnames) {
212
0
#ifdef SSL_CTRL_SET_TLSEXT_HOSTNAME
213
0
    STACK_OF(GENERAL_NAME)* names = (STACK_OF(GENERAL_NAME)*)
214
0
            X509_get_ext_d2i(x, NID_subject_alt_name, NULL, NULL);
215
0
    if (names) {
216
0
        for (size_t i = 0; i < static_cast<size_t>(sk_GENERAL_NAME_num(names)); i++) {
217
0
            char* str = NULL;
218
0
            GENERAL_NAME* name = sk_GENERAL_NAME_value(names, i);
219
0
            if (name->type == GEN_DNS) {
220
0
                if (ASN1_STRING_to_UTF8((unsigned char**)&str,
221
0
                                        name->d.dNSName) >= 0) {
222
0
                    std::string hostname(str);
223
0
                    hostnames->push_back(hostname);
224
0
                    OPENSSL_free(str);
225
0
                }
226
0
            }
227
0
        }
228
0
        sk_GENERAL_NAME_pop_free(names, GENERAL_NAME_free);
229
0
    }
230
0
#endif // SSL_CTRL_SET_TLSEXT_HOSTNAME 
231
232
0
    int i = -1;
233
0
    X509_NAME* xname = X509_get_subject_name(x);
234
0
    while ((i = X509_NAME_get_index_by_NID(xname, NID_commonName, i)) != -1) {
235
0
        char* str = NULL;
236
0
        X509_NAME_ENTRY* entry = X509_NAME_get_entry(xname, i);
237
0
        const int len = ASN1_STRING_to_UTF8((unsigned char**)&str, 
238
0
                                            X509_NAME_ENTRY_get_data(entry));
239
0
        if (len >= 0) {
240
0
            std::string hostname(str, len);
241
0
            hostnames->push_back(hostname);
242
0
            OPENSSL_free(str);
243
0
        }
244
0
    }
245
0
}
246
247
struct FreeSSL {
248
0
    inline void operator()(SSL* ssl) const {
249
0
        if (ssl != NULL) {
250
0
            SSL_free(ssl);
251
0
        }
252
0
    }
253
};
254
255
struct FreeBIO {
256
0
    inline void operator()(BIO* io) const {
257
0
        if (io != NULL) {
258
0
            BIO_free(io);
259
0
        }
260
0
    }
261
};
262
263
struct FreeX509 {
264
0
    inline void operator()(X509* x) const {
265
0
        if (x != NULL) {
266
0
            X509_free(x);
267
0
        }
268
0
    }
269
};
270
271
struct FreeEVPKEY {
272
0
    inline void operator()(EVP_PKEY* k) const {
273
0
        if (k != NULL) {
274
0
            EVP_PKEY_free(k);
275
0
        }
276
0
    }
277
};
278
279
static int LoadCertificate(SSL_CTX* ctx,
280
                           const std::string& certificate,
281
                           const std::string& private_key,
282
0
                           std::vector<std::string>* hostnames) {
283
    // Load the private key
284
0
    if (IsPemString(private_key)) {
285
0
        std::unique_ptr<BIO, FreeBIO> kbio(
286
0
            BIO_new_mem_buf((void*)private_key.c_str(), -1));
287
0
        std::unique_ptr<EVP_PKEY, FreeEVPKEY> key(
288
0
            PEM_read_bio_PrivateKey(kbio.get(), NULL, 0, NULL));
289
0
        if (SSL_CTX_use_PrivateKey(ctx, key.get()) != 1) {
290
0
            LOG(ERROR) << "Fail to load " << private_key << ": "
291
0
                       << SSLError(ERR_get_error());
292
0
            return -1;
293
0
        }
294
295
0
    } else {
296
0
        if (SSL_CTX_use_PrivateKey_file(
297
0
                ctx, private_key.c_str(), SSL_FILETYPE_PEM) != 1) {
298
0
            LOG(ERROR) << "Fail to load " << private_key << ": "
299
0
                       << SSLError(ERR_get_error());
300
0
            return -1;
301
0
        }
302
0
    }
303
304
    // Open & Read certificate
305
0
    std::unique_ptr<BIO, FreeBIO> cbio;
306
0
    if (IsPemString(certificate)) {
307
0
        cbio.reset(BIO_new_mem_buf((void*)certificate.c_str(), -1));
308
0
    } else {
309
0
        cbio.reset(BIO_new(BIO_s_file()));
310
0
        if (BIO_read_filename(cbio.get(), certificate.c_str()) <= 0) {
311
0
            LOG(ERROR) << "Fail to read " << certificate << ": "
312
0
                       << SSLError(ERR_get_error());
313
0
            return -1;
314
0
        }
315
0
    }
316
0
    std::unique_ptr<X509, FreeX509> x(
317
0
        PEM_read_bio_X509_AUX(cbio.get(), NULL, 0, NULL));
318
0
    if (!x) {
319
0
        LOG(ERROR) << "Fail to parse " << certificate << ": "
320
0
                   << SSLError(ERR_get_error());
321
0
        return -1;
322
0
    }
323
    
324
    // Load the main certificate
325
0
    if (SSL_CTX_use_certificate(ctx, x.get()) != 1) {
326
0
        LOG(ERROR) << "Fail to load " << certificate << ": "
327
0
                   << SSLError(ERR_get_error());
328
0
        return -1;
329
0
    }
330
331
    // Load the certificate chain
332
0
#if (OPENSSL_VERSION_NUMBER >= 0x10002000L)
333
0
    SSL_CTX_clear_chain_certs(ctx);
334
#else
335
    if (ctx->extra_certs != NULL) {
336
        sk_X509_pop_free(ctx->extra_certs, X509_free);
337
        ctx->extra_certs = NULL;
338
    }
339
#endif
340
0
    X509* ca = NULL;
341
0
    while ((ca = PEM_read_bio_X509(cbio.get(), NULL, 0, NULL))) {
342
0
        if (SSL_CTX_add_extra_chain_cert(ctx, ca) != 1) {
343
0
            LOG(ERROR) << "Fail to load chain certificate in "
344
0
                       << certificate << ": " << SSLError(ERR_get_error());
345
0
            X509_free(ca);
346
0
            return -1;
347
0
        }
348
0
    }
349
350
0
    int err = ERR_get_error();
351
0
    if (err != 0 && (ERR_GET_LIB(err) != ERR_LIB_PEM
352
0
                     || ERR_GET_REASON(err) != PEM_R_NO_START_LINE)) {
353
0
        LOG(ERROR) << "Fail to read chain certificate in "
354
0
                   << certificate << ": " << SSLError(err);
355
0
        return -1;
356
0
    }
357
0
    ERR_clear_error();
358
359
    // Validate certificate and private key 
360
0
    if (SSL_CTX_check_private_key(ctx) != 1) {
361
0
        LOG(ERROR) << "Fail to verify " << private_key << ": "
362
0
                   << SSLError(ERR_get_error());
363
0
        return -1;
364
0
    }
365
366
0
    if (hostnames != NULL) {
367
0
        ExtractHostnames(x.get(), hostnames);
368
0
    }
369
0
    return 0;
370
0
}
371
372
static int SetSSLOptions(SSL_CTX* ctx, const std::string& ciphers,
373
0
                         int protocols, const VerifyOptions& verify) {
374
    long ssloptions = SSL_OP_ALL    // All known workarounds for bugs
375
0
            | SSL_OP_NO_SSLv2
376
0
#ifdef SSL_OP_NO_COMPRESSION
377
0
            | SSL_OP_NO_COMPRESSION
378
0
#endif  // SSL_OP_NO_COMPRESSION
379
0
            | SSL_OP_CIPHER_SERVER_PREFERENCE
380
0
            | SSL_OP_NO_SESSION_RESUMPTION_ON_RENEGOTIATION;
381
382
0
    if (!(protocols & SSLv3)) {
383
0
        ssloptions |= SSL_OP_NO_SSLv3;
384
0
    }
385
0
    if (!(protocols & TLSv1)) {
386
0
        ssloptions |= SSL_OP_NO_TLSv1;
387
0
    }
388
389
0
#ifdef SSL_OP_NO_TLSv1_1
390
0
    if (!(protocols & TLSv1_1)) {
391
0
        ssloptions |= SSL_OP_NO_TLSv1_1;
392
0
    }
393
0
#endif  // SSL_OP_NO_TLSv1_1
394
395
0
#ifdef SSL_OP_NO_TLSv1_2
396
0
    if (!(protocols & TLSv1_2)) {
397
0
        ssloptions |= SSL_OP_NO_TLSv1_2;
398
0
    }
399
0
#endif  // SSL_OP_NO_TLSv1_2
400
0
    SSL_CTX_set_options(ctx, ssloptions);
401
402
0
    long sslmode = SSL_MODE_ENABLE_PARTIAL_WRITE
403
0
            | SSL_MODE_ACCEPT_MOVING_WRITE_BUFFER;
404
0
    SSL_CTX_set_mode(ctx, sslmode);
405
406
0
    if (!ciphers.empty() &&
407
0
        SSL_CTX_set_cipher_list(ctx, ciphers.c_str()) != 1) {
408
0
        LOG(ERROR) << "Fail to set cipher list to " << ciphers
409
0
                   << ": " << SSLError(ERR_get_error());
410
0
        return -1;
411
0
    }
412
413
    // TODO: Verify the CNAME in certificate matches the requesting host
414
0
    if (verify.verify_depth > 0) {
415
0
        SSL_CTX_set_verify(ctx, (SSL_VERIFY_PEER
416
0
                                 | SSL_VERIFY_FAIL_IF_NO_PEER_CERT), NULL);
417
0
        SSL_CTX_set_verify_depth(ctx, verify.verify_depth);
418
0
        std::string cafile = verify.ca_file_path;
419
0
        if (cafile.empty()) {
420
0
            cafile = X509_get_default_cert_area() + std::string("/cert.pem");
421
0
        }
422
0
        if (SSL_CTX_load_verify_locations(ctx, cafile.c_str(), NULL) == 0) {
423
0
            if (verify.ca_file_path.empty()) {
424
0
                LOG(WARNING) << "Fail to load default CA file " << cafile
425
0
                             << ": " << SSLError(ERR_get_error());
426
0
            } else {
427
0
                LOG(ERROR) << "Fail to load CA file " << cafile
428
0
                           << ": " << SSLError(ERR_get_error());
429
0
                return -1;
430
0
            }
431
0
        }
432
0
    } else {
433
0
        SSL_CTX_set_verify(ctx, SSL_VERIFY_NONE, NULL);
434
0
    }
435
436
0
    SSL_CTX_set_info_callback(ctx, SSLInfoCallback);
437
0
#if OPENSSL_VERSION_NUMBER >= 0x00907000L
438
    // To detect and protect from heartbleed attack
439
0
    SSL_CTX_set_msg_callback(ctx, SSLMessageCallback);
440
0
#endif
441
442
0
    return 0;
443
0
}
444
445
static int ServerALPNCallback(
446
        SSL* ssl, const unsigned char** out, unsigned char* outlen,
447
0
        const unsigned char* in, unsigned int inlen, void* arg) {
448
0
    const std::string* alpns = static_cast<const std::string*>(arg);
449
0
    if (alpns == nullptr) {
450
0
        return SSL_TLSEXT_ERR_NOACK;
451
0
    }
452
453
    // Use OpenSSL standard select API.
454
0
    int select_result = SSL_select_next_proto(
455
0
            const_cast<unsigned char**>(out), outlen, 
456
0
            reinterpret_cast<const unsigned char*>(alpns->data()), alpns->size(),
457
0
            in, inlen);
458
0
    return (select_result == OPENSSL_NPN_NEGOTIATED) 
459
0
                ? SSL_TLSEXT_ERR_OK : SSL_TLSEXT_ERR_NOACK;
460
0
}
461
462
0
static int SetServerALPNCallback(SSL_CTX* ssl_ctx, const std::string* alpns) {
463
0
    if (ssl_ctx == nullptr) {
464
0
        LOG(ERROR) << "Fail to set server ALPN callback, ssl_ctx is nullptr.";
465
0
        return -1;
466
0
    }
467
468
    // Server set alpn callback when openssl version is more than 1.0.2
469
0
#if (OPENSSL_VERSION_NUMBER >= SSL_VERSION_NUMBER(1, 0, 2))
470
0
    SSL_CTX_set_alpn_select_cb(ssl_ctx, ServerALPNCallback,
471
0
            const_cast<std::string*>(alpns));
472
#else
473
    LOG(WARNING) << "OpenSSL version=" << OPENSSL_VERSION_STR 
474
            << " is lower than 1.0.2, ignore server alpn.";
475
#endif
476
0
    return 0;
477
0
}
478
479
0
SSL_CTX* CreateClientSSLContext(const ChannelSSLOptions& options) {
480
0
    std::unique_ptr<SSL_CTX, FreeSSLCTX> ssl_ctx(
481
0
        SSL_CTX_new(SSLv23_client_method()));
482
0
    if (!ssl_ctx) {
483
0
        LOG(ERROR) << "Fail to new SSL_CTX: " << SSLError(ERR_get_error());
484
0
        return NULL;
485
0
    }
486
487
0
    if (!options.client_cert.certificate.empty()
488
0
        && LoadCertificate(ssl_ctx.get(),
489
0
                           options.client_cert.certificate,
490
0
                           options.client_cert.private_key, NULL) != 0) {
491
0
        return NULL;
492
0
    }
493
494
0
    int protocols = ParseSSLProtocols(options.protocols);
495
0
    if (protocols < 0
496
0
        || SetSSLOptions(ssl_ctx.get(), options.ciphers,
497
0
                         protocols, options.verify) != 0) {
498
0
        return NULL;
499
0
    }
500
501
0
    if (!options.alpn_protocols.empty()) {
502
0
        std::vector<unsigned char> alpn_list;
503
0
        if (!BuildALPNProtocolList(options.alpn_protocols, alpn_list)) {
504
0
            return NULL;
505
0
        }
506
0
        SSL_CTX_set_alpn_protos(ssl_ctx.get(), alpn_list.data(), alpn_list.size());
507
0
    }
508
509
0
    SSL_CTX_set_session_cache_mode(ssl_ctx.get(), SSL_SESS_CACHE_CLIENT);
510
0
    return ssl_ctx.release();
511
0
}
512
513
SSL_CTX* CreateServerSSLContext(const std::string& certificate,
514
                                const std::string& private_key,
515
                                const ServerSSLOptions& options,
516
                                const std::string* alpns,
517
0
                                std::vector<std::string>* hostnames) {
518
0
    std::unique_ptr<SSL_CTX, FreeSSLCTX> ssl_ctx(
519
0
        SSL_CTX_new(SSLv23_server_method()));
520
0
    if (!ssl_ctx) {
521
0
        LOG(ERROR) << "Fail to new SSL_CTX: " << SSLError(ERR_get_error());
522
0
        return NULL;
523
0
    }
524
525
0
    if (LoadCertificate(ssl_ctx.get(), certificate,
526
0
                        private_key, hostnames) != 0) {
527
0
        return NULL;
528
0
    }
529
530
0
    int protocols = TLSv1 | TLSv1_1 | TLSv1_2;
531
0
    if (!options.disable_ssl3) {
532
0
        protocols |= SSLv3;
533
0
    }
534
0
    if (SetSSLOptions(ssl_ctx.get(), options.ciphers,
535
0
                      protocols, options.verify) != 0) {
536
0
        return NULL;
537
0
    }
538
539
0
#ifdef SSL_MODE_RELEASE_BUFFERS
540
0
    if (options.release_buffer) {
541
0
        long sslmode = SSL_CTX_get_mode(ssl_ctx.get());
542
0
        sslmode |= SSL_MODE_RELEASE_BUFFERS;
543
0
        SSL_CTX_set_mode(ssl_ctx.get(), sslmode);
544
0
    }
545
0
#endif  // SSL_MODE_RELEASE_BUFFERS
546
547
0
    SSL_CTX_set_timeout(ssl_ctx.get(), options.session_lifetime_s);
548
0
    SSL_CTX_sess_set_cache_size(ssl_ctx.get(), options.session_cache_size);
549
550
0
#ifndef OPENSSL_NO_DH
551
0
    SSL_CTX_set_tmp_dh_callback(ssl_ctx.get(), SSLGetDHCallback);
552
553
0
#if !defined(OPENSSL_NO_ECDH) && defined(SSL_CTX_set_tmp_ecdh)
554
0
    EC_KEY* ecdh = NULL;
555
0
    int i = OBJ_sn2nid(options.ecdhe_curve_name.c_str());
556
0
    if (!i || ((ecdh = EC_KEY_new_by_curve_name(i)) == NULL)) {
557
0
        LOG(ERROR) << "Fail to find ECDHE named curve="
558
0
                   << options.ecdhe_curve_name
559
0
                   << ": " << SSLError(ERR_get_error());
560
0
        return NULL;
561
0
    }
562
0
    SSL_CTX_set_tmp_ecdh(ssl_ctx.get(), ecdh);
563
0
    EC_KEY_free(ecdh);
564
0
#endif
565
566
0
#endif  // OPENSSL_NO_DH
567
568
    // Set ALPN callback to choose application protocol when alpns is not empty.
569
0
    if (alpns != nullptr && !alpns->empty()) {
570
0
        if (SetServerALPNCallback(ssl_ctx.get(), alpns) != 0) {
571
0
            return NULL; 
572
0
        }
573
0
    }
574
0
    return ssl_ctx.release();
575
0
}
576
577
0
SSL* CreateSSLSession(SSL_CTX* ctx, SocketId id, int fd, bool server_mode) {
578
0
    if (ctx == NULL) {
579
0
        LOG(WARNING) << "Lack SSL_ctx to create an SSL session";
580
0
        return NULL;
581
0
    }
582
0
    SSL* ssl = SSL_new(ctx);
583
0
    if (ssl == NULL) {
584
0
        LOG(ERROR) << "Fail to SSL_new: " << SSLError(ERR_get_error());
585
0
        return NULL;
586
0
    }
587
0
    if (SSL_set_fd(ssl, fd) != 1) {
588
0
        LOG(ERROR) << "Fail to SSL_set_fd: " << SSLError(ERR_get_error());
589
0
        SSL_free(ssl);
590
0
        return NULL;
591
0
    }
592
593
0
    if (server_mode) {
594
0
        SSL_set_accept_state(ssl);
595
0
    } else {
596
0
        SSL_set_connect_state(ssl);
597
0
    }
598
0
    SSL_set_app_data(ssl, id);
599
0
    return ssl;
600
0
}
601
602
0
void AddBIOBuffer(SSL* ssl, int fd, int bufsize) {
603
#if defined(OPENSSL_IS_BORINGSSL)
604
    BIO *rbio = BIO_new(BIO_s_mem());
605
    BIO *wbio = BIO_new(BIO_s_mem());
606
#else
607
0
    BIO *rbio = BIO_new(BIO_f_buffer());
608
0
    BIO_set_buffer_size(rbio, bufsize);
609
0
    BIO *wbio = BIO_new(BIO_f_buffer());
610
0
    BIO_set_buffer_size(wbio, bufsize);
611
0
#endif
612
0
    BIO* rfd = BIO_new(BIO_s_fd());
613
0
    BIO_set_fd(rfd, fd, 0);
614
0
    rbio  = BIO_push(rbio, rfd);
615
0
    BIO* wfd = BIO_new(BIO_s_fd());
616
0
    BIO_set_fd(wfd, fd, 0);
617
0
    wbio = BIO_push(wbio, wfd);
618
0
    SSL_set_bio(ssl, rbio, wbio);
619
0
}
620
621
0
SSLState DetectSSLState(int fd, int* error_code) {
622
    // Peek the first few bytes inside socket to detect whether
623
    // it's an SSL connection. If it is, create an SSL session
624
    // which will be used to read/write after
625
626
    // Header format of SSLv2
627
    // +-----------+------+-----
628
    // | 2B header | 0x01 | etc.
629
    // +-----------+------+-----
630
    // The first bit of header is always 1, with the following
631
    // 15 bits are the length of data
632
633
    // Header format of SSLv3 or TLSv1.0, 1.1, 1.2
634
    // +------+------------+-----------+------+-----
635
    // | 0x16 | 2B version | 2B length | 0x01 | etc.
636
    // +------+------------+-----------+------+-----
637
0
    char header[6];
638
0
    const ssize_t nr = recv(fd, header, sizeof(header), MSG_PEEK);
639
0
    if (nr < (ssize_t)sizeof(header)) {
640
0
        if (nr < 0) {
641
0
            if (errno == ENOTSOCK) {
642
0
                return SSL_OFF;
643
0
            }
644
0
            *error_code = errno;   // Including EAGAIN and EINTR
645
0
        } else if (nr == 0) {      // EOF
646
0
            *error_code = 0;
647
0
        } else {                   // Not enough data, need retry
648
0
            *error_code = EAGAIN;
649
0
        }
650
0
        return SSL_UNKNOWN;
651
0
    }
652
    
653
0
    if ((header[0] == 0x16 && header[5] == 0x01) // SSLv3 or TLSv1.0, 1.1, 1.2
654
0
        || ((header[0] & 0x80) == 0x80 && header[2] == 0x01)) {  // SSLv2
655
0
        return SSL_CONNECTING;
656
0
    } else {
657
0
        return SSL_OFF;
658
0
    }
659
0
}
660
661
#if OPENSSL_VERSION_NUMBER < 0x10100000L
662
663
// NOTE: Can't find a macro for CRYPTO_THREADID
664
//       Fallback to use CRYPTO_LOCK_ECDH as flag
665
#ifdef CRYPTO_LOCK_ECDH
666
static void SSLGetThreadId(CRYPTO_THREADID* tid) {
667
    CRYPTO_THREADID_set_numeric(tid, (unsigned long)pthread_self());
668
}
669
#else
670
static unsigned long SSLGetThreadId() {
671
    return pthread_self();
672
}
673
#endif  // CRYPTO_LOCK_ECDH
674
675
// Locks for SSL library
676
// NOTE: If we replace this with bthread_mutex_t, SSL routines
677
// may crash probably due to some TLS data used inside OpenSSL
678
// Also according to performance test, there is little difference
679
// between pthread mutex and bthread mutex
680
static butil::Mutex* g_ssl_mutexs = NULL;
681
682
static void SSLLockCallback(int mode, int n, const char* file, int line) {
683
    (void)file;
684
    (void)line;
685
    // Following log is too anonying even for verbose logs.
686
    // RPC_VLOG << "[" << file << ':' << line << "] SSL"
687
    //          << (mode & CRYPTO_LOCK ? "locks" : "unlocks")
688
    //          << " thread=" << CRYPTO_thread_id();
689
    if (mode & CRYPTO_LOCK) {
690
        g_ssl_mutexs[n].lock();
691
    } else {
692
        g_ssl_mutexs[n].unlock();
693
    }
694
}
695
#endif // OPENSSL_VERSION_NUMBER < 0x10100000L
696
697
0
int SSLThreadInit() {
698
#if OPENSSL_VERSION_NUMBER < 0x10100000L
699
    g_ssl_mutexs = new butil::Mutex[CRYPTO_num_locks()];
700
    CRYPTO_set_locking_callback(SSLLockCallback);
701
# ifdef CRYPTO_LOCK_ECDH
702
    CRYPTO_THREADID_set_callback(SSLGetThreadId);
703
# else
704
    CRYPTO_set_id_callback(SSLGetThreadId);
705
# endif  // CRYPTO_LOCK_ECDH
706
#endif // OPENSSL_VERSION_NUMBER < 0x10100000L 
707
0
    return 0;
708
0
}
709
710
#ifndef OPENSSL_NO_DH
711
712
0
static DH* SSLGetDH1024() {
713
0
    BIGNUM* p = get_rfc2409_prime_1024(NULL);
714
0
    if (!p) {
715
0
        return NULL;
716
0
    }
717
    // See RFC 2409, Section 6 "Oakley Groups"
718
    // for the reason why 2 is used as generator.
719
0
    BIGNUM* g = NULL;
720
0
    BN_dec2bn(&g, "2");
721
0
    if (!g) {
722
0
        BN_free(p);
723
0
        return NULL;
724
0
    }
725
0
    DH *dh = DH_new();
726
0
    if (!dh) {
727
0
        BN_free(p);
728
0
        BN_free(g);
729
0
        return NULL;
730
0
    }
731
0
    DH_set0_pqg(dh, p, NULL, g);
732
0
    return dh;
733
0
}
734
735
0
static DH* SSLGetDH2048() {
736
0
    BIGNUM* p = get_rfc3526_prime_2048(NULL);
737
0
    if (!p) {
738
0
        return NULL;
739
0
    }
740
    // See RFC 3526, Section 3 "2048-bit MODP Group"
741
    // for the reason why 2 is used as generator.
742
0
    BIGNUM* g = NULL;
743
0
    BN_dec2bn(&g, "2");
744
0
    if (!g) {
745
0
        BN_free(p);
746
0
        return NULL;
747
0
    }
748
0
    DH* dh = DH_new();
749
0
    if (!dh) {
750
0
        BN_free(p);
751
0
        BN_free(g);
752
0
        return NULL;
753
0
    }
754
0
    DH_set0_pqg(dh, p, NULL, g);
755
0
    return dh;
756
0
}
757
758
0
static DH* SSLGetDH4096() {
759
0
    BIGNUM* p = get_rfc3526_prime_4096(NULL);
760
0
    if (!p) {
761
0
        return NULL;
762
0
    }
763
    // See RFC 3526, Section 5 "4096-bit MODP Group"
764
    // for the reason why 2 is used as generator.
765
0
    BIGNUM* g = NULL;
766
0
    BN_dec2bn(&g, "2");
767
0
    if (!g) {
768
0
        BN_free(p);
769
0
        return NULL;
770
0
    }
771
0
    DH *dh = DH_new();
772
0
    if (!dh) {
773
0
        BN_free(p);
774
0
        BN_free(g);
775
0
        return NULL;
776
0
    }
777
0
    DH_set0_pqg(dh, p, NULL, g);
778
0
    return dh;
779
0
}
780
781
0
static DH* SSLGetDH8192() {
782
0
    BIGNUM* p = get_rfc3526_prime_8192(NULL);
783
0
    if (!p) {
784
0
        return NULL;
785
0
    }
786
    // See RFC 3526, Section 7 "8192-bit MODP Group"
787
    // for the reason why 2 is used as generator.
788
0
    BIGNUM* g = NULL;
789
0
    BN_dec2bn(&g, "2");
790
0
    if (!g) {
791
0
        BN_free(g);
792
0
        return NULL;
793
0
    }
794
0
    DH *dh = DH_new();
795
0
    if (!dh) {
796
0
        BN_free(p);
797
0
        BN_free(g);
798
0
        return NULL;
799
0
    }
800
0
    DH_set0_pqg(dh, p, NULL, g);
801
0
    return dh;
802
0
}
803
804
#endif  // OPENSSL_NO_DH
805
806
0
int SSLDHInit() {
807
0
#ifndef OPENSSL_NO_DH
808
0
    if ((g_dh_1024 = SSLGetDH1024()) == NULL) {
809
0
        LOG(ERROR) << "Fail to initialize DH-1024";
810
0
        return -1;
811
0
    }
812
0
    if ((g_dh_2048 = SSLGetDH2048()) == NULL) {
813
0
        LOG(ERROR) << "Fail to initialize DH-2048";
814
0
        return -1;
815
0
    }
816
0
    if ((g_dh_4096 = SSLGetDH4096()) == NULL) {
817
0
        LOG(ERROR) << "Fail to initialize DH-4096";
818
0
        return -1;
819
0
    }
820
0
    if ((g_dh_8192 = SSLGetDH8192()) == NULL) {
821
0
        LOG(ERROR) << "Fail to initialize DH-8192";
822
0
        return -1;
823
0
    }
824
0
#endif  // OPENSSL_NO_DH
825
0
    return 0;
826
0
}
827
828
0
static std::string GetNextLevelSeparator(const char* sep) {
829
0
    if (sep[0] != '\n') {
830
0
        return sep;
831
0
    }
832
0
    const size_t left_len = strlen(sep + 1);
833
0
    if (left_len == 0) {
834
0
        return "\n ";
835
0
    }
836
0
    std::string new_sep;
837
0
    new_sep.reserve(left_len * 2 + 1);
838
0
    new_sep.append(sep, left_len + 1);
839
0
    new_sep.append(sep + 1, left_len);
840
0
    return new_sep;
841
0
}
842
843
0
void Print(std::ostream& os, SSL* ssl, const char* sep) {
844
0
    os << "cipher=" << SSL_get_cipher(ssl) << sep
845
0
       << "protocol=" << SSL_get_version(ssl) << sep
846
0
       << "verify=" << (SSL_get_verify_mode(ssl) & SSL_VERIFY_PEER
847
0
                        ? "success" : "none");
848
0
    X509* cert = SSL_get_peer_certificate(ssl);
849
0
    if (cert) {
850
0
        os << sep << "peer_certificate={";
851
0
        const std::string new_sep = GetNextLevelSeparator(sep);
852
0
        if (sep[0] == '\n') {
853
0
            os << new_sep;
854
0
        }
855
0
        Print(os, cert, new_sep.c_str());
856
0
        if (sep[0] == '\n') {
857
0
            os << sep;
858
0
        }
859
0
        os << '}';
860
0
    }
861
0
}
862
863
0
void Print(std::ostream& os, X509* cert, const char* sep) {
864
0
    BIO* buf = BIO_new(BIO_s_mem());
865
0
    if (buf == NULL) {
866
0
        return;
867
0
    }
868
0
    BIO_printf(buf, "subject=");
869
0
    X509_NAME_print(buf, X509_get_subject_name(cert), 0);
870
0
    BIO_printf(buf, "%sstart_date=", sep);
871
0
    ASN1_TIME_print(buf, X509_get_notBefore(cert));
872
0
    BIO_printf(buf, "%sexpire_date=", sep);
873
0
    ASN1_TIME_print(buf, X509_get_notAfter(cert));
874
875
0
    BIO_printf(buf, "%scommon_name=", sep);
876
0
    std::vector<std::string> hostnames;
877
0
    brpc::ExtractHostnames(cert, &hostnames);
878
0
    for (size_t i = 0; i < hostnames.size(); ++i) {
879
0
        BIO_printf(buf, "%s;", hostnames[i].c_str());
880
0
    }
881
882
0
    BIO_printf(buf, "%sissuer=", sep);
883
0
    X509_NAME_print(buf, X509_get_issuer_name(cert), 0);
884
885
0
    char* bufp = NULL;
886
0
    int len = BIO_get_mem_data(buf, &bufp);
887
0
    os << butil::StringPiece(bufp, len);
888
0
}
889
890
0
std::string ALPNProtocolToString(const AdaptiveProtocolType& protocol) {
891
0
    butil::StringPiece name = protocol.name();
892
    // Default use http 1.1 version
893
0
    if (name.starts_with("http")) {
894
0
        name.set("http/1.1");
895
0
    }
896
897
    // ALPN extension uses 1 byte to record the protocol length
898
    // and it's maximum length is 255.
899
0
    if (name.size() > CHAR_MAX) {
900
0
        name = name.substr(0, CHAR_MAX); 
901
0
    }
902
903
0
    char length = static_cast<char>(name.size());
904
0
    return std::string(&length, 1) + name.data(); 
905
0
}
906
907
bool BuildALPNProtocolList(
908
    const std::vector<std::string>& alpn_protocols,
909
    std::vector<unsigned char>& result
910
0
) {
911
0
    size_t alpn_list_length = 0;
912
0
    for (const auto& alpn_protocol : alpn_protocols) {
913
0
        if (alpn_protocol.size() > UCHAR_MAX) {
914
0
            LOG(ERROR) << "Fail to build ALPN procotol list: "
915
0
                       << "protocol name length " << alpn_protocol.size() << " too long, "
916
0
                       << "max 255 supported.";
917
0
            return false;
918
0
        }
919
0
        alpn_list_length += alpn_protocol.size() + 1;
920
0
    }
921
922
0
    result.resize(alpn_list_length);
923
0
    for (size_t curr = 0, i = 0; i < alpn_protocols.size(); i++) {
924
0
        result[curr++] = static_cast<unsigned char>(
925
0
            alpn_protocols[i].size()
926
0
        );
927
0
        std::copy(
928
0
            alpn_protocols[i].begin(),
929
0
            alpn_protocols[i].end(),
930
0
            result.begin() + curr
931
0
        );
932
0
        curr += alpn_protocols[i].size();
933
0
    }
934
0
    return true;
935
0
}
936
937
} // namespace brpc
938
939
#endif // USE_MESALINK