/src/brpc/src/brpc/details/ssl_helper.cpp
Line | Count | Source (jump to first uncovered line) |
1 | | // Licensed to the Apache Software Foundation (ASF) under one |
2 | | // or more contributor license agreements. See the NOTICE file |
3 | | // distributed with this work for additional information |
4 | | // regarding copyright ownership. The ASF licenses this file |
5 | | // to you under the Apache License, Version 2.0 (the |
6 | | // "License"); you may not use this file except in compliance |
7 | | // with the License. You may obtain a copy of the License at |
8 | | // |
9 | | // http://www.apache.org/licenses/LICENSE-2.0 |
10 | | // |
11 | | // Unless required by applicable law or agreed to in writing, |
12 | | // software distributed under the License is distributed on an |
13 | | // "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY |
14 | | // KIND, either express or implied. See the License for the |
15 | | // specific language governing permissions and limitations |
16 | | // under the License. |
17 | | |
18 | | |
19 | | |
20 | | #include <openssl/bio.h> |
21 | | #ifndef USE_MESALINK |
22 | | |
23 | | #include <sys/socket.h> // recv |
24 | | #include <openssl/ssl.h> |
25 | | #include <openssl/err.h> |
26 | | #include <openssl/x509.h> |
27 | | #include <openssl/x509v3.h> |
28 | | #include "butil/unique_ptr.h" |
29 | | #include "butil/logging.h" |
30 | | #include "butil/ssl_compat.h" |
31 | | #include "butil/string_splitter.h" |
32 | | #include "brpc/socket.h" |
33 | | #include "brpc/details/ssl_helper.h" |
34 | | |
35 | | namespace brpc { |
36 | | |
37 | | #ifndef OPENSSL_NO_DH |
38 | | static DH* g_dh_1024 = NULL; |
39 | | static DH* g_dh_2048 = NULL; |
40 | | static DH* g_dh_4096 = NULL; |
41 | | static DH* g_dh_8192 = NULL; |
42 | | #endif // OPENSSL_NO_DH |
43 | | |
44 | | static const char* const PEM_START = "-----BEGIN"; |
45 | | |
46 | 0 | static bool IsPemString(const std::string& input) { |
47 | 0 | for (const char* s = input.c_str(); *s != '\0'; ++s) { |
48 | 0 | if (*s != '\n') { |
49 | 0 | return strncmp(s, PEM_START, strlen(PEM_START)) == 0; |
50 | 0 | } |
51 | 0 | } |
52 | 0 | return false; |
53 | 0 | } |
54 | | |
55 | 0 | const char* SSLStateToString(SSLState s) { |
56 | 0 | switch (s) { |
57 | 0 | case SSL_UNKNOWN: |
58 | 0 | return "SSL_UNKNOWN"; |
59 | 0 | case SSL_OFF: |
60 | 0 | return "SSL_OFF"; |
61 | 0 | case SSL_CONNECTING: |
62 | 0 | return "SSL_CONNECTING"; |
63 | 0 | case SSL_CONNECTED: |
64 | 0 | return "SSL_CONNECTED"; |
65 | 0 | } |
66 | 0 | return "Bad SSLState"; |
67 | 0 | } |
68 | | |
69 | 0 | static int ParseSSLProtocols(const std::string& str_protocol) { |
70 | 0 | int protocol_flag = 0; |
71 | 0 | butil::StringSplitter sp(str_protocol.data(), |
72 | 0 | str_protocol.data() + str_protocol.size(), ','); |
73 | 0 | for (; sp; ++sp) { |
74 | 0 | butil::StringPiece protocol(sp.field(), sp.length()); |
75 | 0 | protocol.trim_spaces(); |
76 | 0 | if (strncasecmp(protocol.data(), "SSLv3", protocol.size()) == 0) { |
77 | 0 | protocol_flag |= SSLv3; |
78 | 0 | } else if (strncasecmp(protocol.data(), "TLSv1", protocol.size()) == 0) { |
79 | 0 | protocol_flag |= TLSv1; |
80 | 0 | } else if (strncasecmp(protocol.data(), "TLSv1.1", protocol.size()) == 0) { |
81 | 0 | protocol_flag |= TLSv1_1; |
82 | 0 | } else if (strncasecmp(protocol.data(), "TLSv1.2", protocol.size()) == 0) { |
83 | 0 | protocol_flag |= TLSv1_2; |
84 | 0 | } else { |
85 | 0 | LOG(ERROR) << "Unknown SSL protocol=" << protocol; |
86 | 0 | return -1; |
87 | 0 | } |
88 | 0 | } |
89 | 0 | return protocol_flag; |
90 | 0 | } |
91 | | |
92 | 0 | std::ostream& operator<<(std::ostream& os, const SSLError& ssl) { |
93 | 0 | char buf[128]; // Should be enough |
94 | 0 | ERR_error_string_n(ssl.error, buf, sizeof(buf)); |
95 | 0 | return os << buf; |
96 | 0 | } |
97 | | |
98 | 0 | std::ostream& operator<<(std::ostream& os, const CertInfo& cert) { |
99 | 0 | os << "certificate["; |
100 | 0 | if (IsPemString(cert.certificate)) { |
101 | 0 | size_t pos = cert.certificate.find('\n'); |
102 | 0 | if (pos == std::string::npos) { |
103 | 0 | pos = 0; |
104 | 0 | } else { |
105 | 0 | pos++; |
106 | 0 | } |
107 | 0 | os << cert.certificate.substr(pos, 16) << "..."; |
108 | 0 | } else { |
109 | 0 | os << cert.certificate; |
110 | 0 | } |
111 | |
|
112 | 0 | os << "] private-key["; |
113 | 0 | if (IsPemString(cert.private_key)) { |
114 | 0 | size_t pos = cert.private_key.find('\n'); |
115 | 0 | if (pos == std::string::npos) { |
116 | 0 | pos = 0; |
117 | 0 | } else { |
118 | 0 | pos++; |
119 | 0 | } |
120 | 0 | os << cert.private_key.substr(pos, 16) << "..."; |
121 | 0 | } else { |
122 | 0 | os << cert.private_key; |
123 | 0 | } |
124 | 0 | os << "]"; |
125 | 0 | return os; |
126 | 0 | } |
127 | | |
128 | 0 | static void SSLInfoCallback(const SSL* ssl, int where, int ret) { |
129 | 0 | (void)ret; |
130 | 0 | SocketUniquePtr s; |
131 | 0 | SocketId id = (SocketId)SSL_get_app_data((SSL*)ssl); |
132 | 0 | if (Socket::Address(id, &s) != 0) { |
133 | | // Already failed |
134 | 0 | return; |
135 | 0 | } |
136 | | |
137 | 0 | if (where & SSL_CB_HANDSHAKE_START) { |
138 | 0 | if (s->ssl_state() == SSL_CONNECTED) { |
139 | | // Disable renegotiation (CVE-2009-3555) |
140 | 0 | LOG(ERROR) << "Close " << *s << " due to insecure " |
141 | 0 | << "renegotiation detected (CVE-2009-3555)"; |
142 | 0 | s->SetFailed(); |
143 | 0 | } |
144 | 0 | } |
145 | 0 | } |
146 | | |
147 | | static void SSLMessageCallback(int write_p, int version, int content_type, |
148 | 0 | const void* buf, size_t len, SSL* ssl, void* arg) { |
149 | 0 | (void)version; |
150 | 0 | (void)arg; |
151 | | #ifdef TLS1_RT_HEARTBEAT |
152 | | // Test heartbeat received (write_p is set to 0 for a received record) |
153 | | if ((content_type == TLS1_RT_HEARTBEAT) && (write_p == 0)) { |
154 | | const unsigned char* p = (const unsigned char*)buf; |
155 | | |
156 | | // Check if this is a CVE-2014-0160 exploitation attempt. |
157 | | if (*p != TLS1_HB_REQUEST) { |
158 | | return; |
159 | | } |
160 | | |
161 | | // 1 type + 2 size + 0 payload + 16 padding |
162 | | if (len >= 1 + 2 + 16) { |
163 | | unsigned int payload = (p[1] * 256) + p[2]; |
164 | | if (3 + payload + 16 <= len) { |
165 | | return; // OK no problem |
166 | | } |
167 | | } |
168 | | |
169 | | // We have a clear heartbleed attack (CVE-2014-0160), the |
170 | | // advertised payload is larger than the advertised packet |
171 | | // length, so we have garbage in the buffer between the |
172 | | // payload and the end of the buffer (p+len). We can't know |
173 | | // if the SSL stack is patched, and we don't know if we can |
174 | | // safely wipe out the area between p+3+len and payload. |
175 | | // So instead, we prevent the response from being sent by |
176 | | // setting the max_send_fragment to 0 and we report an SSL |
177 | | // error, which will kill this connection. It will be reported |
178 | | // above as SSL_ERROR_SSL while an other handshake failure with |
179 | | // a heartbeat message will be reported as SSL_ERROR_SYSCALL. |
180 | | ssl->max_send_fragment = 0; |
181 | | SSLerr(SSL_F_TLS1_HEARTBEAT, SSL_R_SSL_HANDSHAKE_FAILURE); |
182 | | return; |
183 | | } |
184 | | #endif // TLS1_RT_HEARTBEAT |
185 | 0 | } |
186 | | |
187 | | #ifndef OPENSSL_NO_DH |
188 | 0 | static DH* SSLGetDHCallback(SSL* ssl, int exp, int keylen) { |
189 | 0 | (void)exp; |
190 | 0 | EVP_PKEY* pkey = SSL_get_privatekey(ssl); |
191 | 0 | int type = pkey ? EVP_PKEY_base_id(pkey) : EVP_PKEY_NONE; |
192 | | |
193 | | // The keylen supplied by OpenSSL can only be 512 or 1024. |
194 | | // See ssl3_send_server_key_exchange() in ssl/s3_srvr.c |
195 | 0 | if (type == EVP_PKEY_RSA || type == EVP_PKEY_DSA) { |
196 | 0 | keylen = EVP_PKEY_bits(pkey); |
197 | 0 | } |
198 | |
|
199 | 0 | if (keylen >= 8192) { |
200 | 0 | return g_dh_8192; |
201 | 0 | } else if (keylen >= 4096) { |
202 | 0 | return g_dh_4096; |
203 | 0 | } else if (keylen >= 2048) { |
204 | 0 | return g_dh_2048; |
205 | 0 | } else { |
206 | 0 | return g_dh_1024; |
207 | 0 | } |
208 | 0 | } |
209 | | #endif // OPENSSL_NO_DH |
210 | | |
211 | 0 | void ExtractHostnames(X509* x, std::vector<std::string>* hostnames) { |
212 | 0 | #ifdef SSL_CTRL_SET_TLSEXT_HOSTNAME |
213 | 0 | STACK_OF(GENERAL_NAME)* names = (STACK_OF(GENERAL_NAME)*) |
214 | 0 | X509_get_ext_d2i(x, NID_subject_alt_name, NULL, NULL); |
215 | 0 | if (names) { |
216 | 0 | for (size_t i = 0; i < static_cast<size_t>(sk_GENERAL_NAME_num(names)); i++) { |
217 | 0 | char* str = NULL; |
218 | 0 | GENERAL_NAME* name = sk_GENERAL_NAME_value(names, i); |
219 | 0 | if (name->type == GEN_DNS) { |
220 | 0 | if (ASN1_STRING_to_UTF8((unsigned char**)&str, |
221 | 0 | name->d.dNSName) >= 0) { |
222 | 0 | std::string hostname(str); |
223 | 0 | hostnames->push_back(hostname); |
224 | 0 | OPENSSL_free(str); |
225 | 0 | } |
226 | 0 | } |
227 | 0 | } |
228 | 0 | sk_GENERAL_NAME_pop_free(names, GENERAL_NAME_free); |
229 | 0 | } |
230 | 0 | #endif // SSL_CTRL_SET_TLSEXT_HOSTNAME |
231 | |
|
232 | 0 | int i = -1; |
233 | 0 | X509_NAME* xname = X509_get_subject_name(x); |
234 | 0 | while ((i = X509_NAME_get_index_by_NID(xname, NID_commonName, i)) != -1) { |
235 | 0 | char* str = NULL; |
236 | 0 | X509_NAME_ENTRY* entry = X509_NAME_get_entry(xname, i); |
237 | 0 | const int len = ASN1_STRING_to_UTF8((unsigned char**)&str, |
238 | 0 | X509_NAME_ENTRY_get_data(entry)); |
239 | 0 | if (len >= 0) { |
240 | 0 | std::string hostname(str, len); |
241 | 0 | hostnames->push_back(hostname); |
242 | 0 | OPENSSL_free(str); |
243 | 0 | } |
244 | 0 | } |
245 | 0 | } |
246 | | |
247 | | struct FreeSSL { |
248 | 0 | inline void operator()(SSL* ssl) const { |
249 | 0 | if (ssl != NULL) { |
250 | 0 | SSL_free(ssl); |
251 | 0 | } |
252 | 0 | } |
253 | | }; |
254 | | |
255 | | struct FreeBIO { |
256 | 0 | inline void operator()(BIO* io) const { |
257 | 0 | if (io != NULL) { |
258 | 0 | BIO_free(io); |
259 | 0 | } |
260 | 0 | } |
261 | | }; |
262 | | |
263 | | struct FreeX509 { |
264 | 0 | inline void operator()(X509* x) const { |
265 | 0 | if (x != NULL) { |
266 | 0 | X509_free(x); |
267 | 0 | } |
268 | 0 | } |
269 | | }; |
270 | | |
271 | | struct FreeEVPKEY { |
272 | 0 | inline void operator()(EVP_PKEY* k) const { |
273 | 0 | if (k != NULL) { |
274 | 0 | EVP_PKEY_free(k); |
275 | 0 | } |
276 | 0 | } |
277 | | }; |
278 | | |
279 | | static int LoadCertificate(SSL_CTX* ctx, |
280 | | const std::string& certificate, |
281 | | const std::string& private_key, |
282 | 0 | std::vector<std::string>* hostnames) { |
283 | | // Load the private key |
284 | 0 | if (IsPemString(private_key)) { |
285 | 0 | std::unique_ptr<BIO, FreeBIO> kbio( |
286 | 0 | BIO_new_mem_buf((void*)private_key.c_str(), -1)); |
287 | 0 | std::unique_ptr<EVP_PKEY, FreeEVPKEY> key( |
288 | 0 | PEM_read_bio_PrivateKey(kbio.get(), NULL, 0, NULL)); |
289 | 0 | if (SSL_CTX_use_PrivateKey(ctx, key.get()) != 1) { |
290 | 0 | LOG(ERROR) << "Fail to load " << private_key << ": " |
291 | 0 | << SSLError(ERR_get_error()); |
292 | 0 | return -1; |
293 | 0 | } |
294 | |
|
295 | 0 | } else { |
296 | 0 | if (SSL_CTX_use_PrivateKey_file( |
297 | 0 | ctx, private_key.c_str(), SSL_FILETYPE_PEM) != 1) { |
298 | 0 | LOG(ERROR) << "Fail to load " << private_key << ": " |
299 | 0 | << SSLError(ERR_get_error()); |
300 | 0 | return -1; |
301 | 0 | } |
302 | 0 | } |
303 | | |
304 | | // Open & Read certificate |
305 | 0 | std::unique_ptr<BIO, FreeBIO> cbio; |
306 | 0 | if (IsPemString(certificate)) { |
307 | 0 | cbio.reset(BIO_new_mem_buf((void*)certificate.c_str(), -1)); |
308 | 0 | } else { |
309 | 0 | cbio.reset(BIO_new(BIO_s_file())); |
310 | 0 | if (BIO_read_filename(cbio.get(), certificate.c_str()) <= 0) { |
311 | 0 | LOG(ERROR) << "Fail to read " << certificate << ": " |
312 | 0 | << SSLError(ERR_get_error()); |
313 | 0 | return -1; |
314 | 0 | } |
315 | 0 | } |
316 | 0 | std::unique_ptr<X509, FreeX509> x( |
317 | 0 | PEM_read_bio_X509_AUX(cbio.get(), NULL, 0, NULL)); |
318 | 0 | if (!x) { |
319 | 0 | LOG(ERROR) << "Fail to parse " << certificate << ": " |
320 | 0 | << SSLError(ERR_get_error()); |
321 | 0 | return -1; |
322 | 0 | } |
323 | | |
324 | | // Load the main certificate |
325 | 0 | if (SSL_CTX_use_certificate(ctx, x.get()) != 1) { |
326 | 0 | LOG(ERROR) << "Fail to load " << certificate << ": " |
327 | 0 | << SSLError(ERR_get_error()); |
328 | 0 | return -1; |
329 | 0 | } |
330 | | |
331 | | // Load the certificate chain |
332 | 0 | #if (OPENSSL_VERSION_NUMBER >= 0x10002000L) |
333 | 0 | SSL_CTX_clear_chain_certs(ctx); |
334 | | #else |
335 | | if (ctx->extra_certs != NULL) { |
336 | | sk_X509_pop_free(ctx->extra_certs, X509_free); |
337 | | ctx->extra_certs = NULL; |
338 | | } |
339 | | #endif |
340 | 0 | X509* ca = NULL; |
341 | 0 | while ((ca = PEM_read_bio_X509(cbio.get(), NULL, 0, NULL))) { |
342 | 0 | if (SSL_CTX_add_extra_chain_cert(ctx, ca) != 1) { |
343 | 0 | LOG(ERROR) << "Fail to load chain certificate in " |
344 | 0 | << certificate << ": " << SSLError(ERR_get_error()); |
345 | 0 | X509_free(ca); |
346 | 0 | return -1; |
347 | 0 | } |
348 | 0 | } |
349 | | |
350 | 0 | int err = ERR_get_error(); |
351 | 0 | if (err != 0 && (ERR_GET_LIB(err) != ERR_LIB_PEM |
352 | 0 | || ERR_GET_REASON(err) != PEM_R_NO_START_LINE)) { |
353 | 0 | LOG(ERROR) << "Fail to read chain certificate in " |
354 | 0 | << certificate << ": " << SSLError(err); |
355 | 0 | return -1; |
356 | 0 | } |
357 | 0 | ERR_clear_error(); |
358 | | |
359 | | // Validate certificate and private key |
360 | 0 | if (SSL_CTX_check_private_key(ctx) != 1) { |
361 | 0 | LOG(ERROR) << "Fail to verify " << private_key << ": " |
362 | 0 | << SSLError(ERR_get_error()); |
363 | 0 | return -1; |
364 | 0 | } |
365 | | |
366 | 0 | if (hostnames != NULL) { |
367 | 0 | ExtractHostnames(x.get(), hostnames); |
368 | 0 | } |
369 | 0 | return 0; |
370 | 0 | } |
371 | | |
372 | | static int SetSSLOptions(SSL_CTX* ctx, const std::string& ciphers, |
373 | 0 | int protocols, const VerifyOptions& verify) { |
374 | | long ssloptions = SSL_OP_ALL // All known workarounds for bugs |
375 | 0 | | SSL_OP_NO_SSLv2 |
376 | 0 | #ifdef SSL_OP_NO_COMPRESSION |
377 | 0 | | SSL_OP_NO_COMPRESSION |
378 | 0 | #endif // SSL_OP_NO_COMPRESSION |
379 | 0 | | SSL_OP_CIPHER_SERVER_PREFERENCE |
380 | 0 | | SSL_OP_NO_SESSION_RESUMPTION_ON_RENEGOTIATION; |
381 | |
|
382 | 0 | if (!(protocols & SSLv3)) { |
383 | 0 | ssloptions |= SSL_OP_NO_SSLv3; |
384 | 0 | } |
385 | 0 | if (!(protocols & TLSv1)) { |
386 | 0 | ssloptions |= SSL_OP_NO_TLSv1; |
387 | 0 | } |
388 | |
|
389 | 0 | #ifdef SSL_OP_NO_TLSv1_1 |
390 | 0 | if (!(protocols & TLSv1_1)) { |
391 | 0 | ssloptions |= SSL_OP_NO_TLSv1_1; |
392 | 0 | } |
393 | 0 | #endif // SSL_OP_NO_TLSv1_1 |
394 | |
|
395 | 0 | #ifdef SSL_OP_NO_TLSv1_2 |
396 | 0 | if (!(protocols & TLSv1_2)) { |
397 | 0 | ssloptions |= SSL_OP_NO_TLSv1_2; |
398 | 0 | } |
399 | 0 | #endif // SSL_OP_NO_TLSv1_2 |
400 | 0 | SSL_CTX_set_options(ctx, ssloptions); |
401 | |
|
402 | 0 | long sslmode = SSL_MODE_ENABLE_PARTIAL_WRITE |
403 | 0 | | SSL_MODE_ACCEPT_MOVING_WRITE_BUFFER; |
404 | 0 | SSL_CTX_set_mode(ctx, sslmode); |
405 | |
|
406 | 0 | if (!ciphers.empty() && |
407 | 0 | SSL_CTX_set_cipher_list(ctx, ciphers.c_str()) != 1) { |
408 | 0 | LOG(ERROR) << "Fail to set cipher list to " << ciphers |
409 | 0 | << ": " << SSLError(ERR_get_error()); |
410 | 0 | return -1; |
411 | 0 | } |
412 | | |
413 | | // TODO: Verify the CNAME in certificate matches the requesting host |
414 | 0 | if (verify.verify_depth > 0) { |
415 | 0 | SSL_CTX_set_verify(ctx, (SSL_VERIFY_PEER |
416 | 0 | | SSL_VERIFY_FAIL_IF_NO_PEER_CERT), NULL); |
417 | 0 | SSL_CTX_set_verify_depth(ctx, verify.verify_depth); |
418 | 0 | std::string cafile = verify.ca_file_path; |
419 | 0 | if (cafile.empty()) { |
420 | 0 | cafile = X509_get_default_cert_area() + std::string("/cert.pem"); |
421 | 0 | } |
422 | 0 | if (SSL_CTX_load_verify_locations(ctx, cafile.c_str(), NULL) == 0) { |
423 | 0 | if (verify.ca_file_path.empty()) { |
424 | 0 | LOG(WARNING) << "Fail to load default CA file " << cafile |
425 | 0 | << ": " << SSLError(ERR_get_error()); |
426 | 0 | } else { |
427 | 0 | LOG(ERROR) << "Fail to load CA file " << cafile |
428 | 0 | << ": " << SSLError(ERR_get_error()); |
429 | 0 | return -1; |
430 | 0 | } |
431 | 0 | } |
432 | 0 | } else { |
433 | 0 | SSL_CTX_set_verify(ctx, SSL_VERIFY_NONE, NULL); |
434 | 0 | } |
435 | | |
436 | 0 | SSL_CTX_set_info_callback(ctx, SSLInfoCallback); |
437 | 0 | #if OPENSSL_VERSION_NUMBER >= 0x00907000L |
438 | | // To detect and protect from heartbleed attack |
439 | 0 | SSL_CTX_set_msg_callback(ctx, SSLMessageCallback); |
440 | 0 | #endif |
441 | |
|
442 | 0 | return 0; |
443 | 0 | } |
444 | | |
445 | | static int ServerALPNCallback( |
446 | | SSL* ssl, const unsigned char** out, unsigned char* outlen, |
447 | 0 | const unsigned char* in, unsigned int inlen, void* arg) { |
448 | 0 | const std::string* alpns = static_cast<const std::string*>(arg); |
449 | 0 | if (alpns == nullptr) { |
450 | 0 | return SSL_TLSEXT_ERR_NOACK; |
451 | 0 | } |
452 | | |
453 | | // Use OpenSSL standard select API. |
454 | 0 | int select_result = SSL_select_next_proto( |
455 | 0 | const_cast<unsigned char**>(out), outlen, |
456 | 0 | reinterpret_cast<const unsigned char*>(alpns->data()), alpns->size(), |
457 | 0 | in, inlen); |
458 | 0 | return (select_result == OPENSSL_NPN_NEGOTIATED) |
459 | 0 | ? SSL_TLSEXT_ERR_OK : SSL_TLSEXT_ERR_NOACK; |
460 | 0 | } |
461 | | |
462 | 0 | static int SetServerALPNCallback(SSL_CTX* ssl_ctx, const std::string* alpns) { |
463 | 0 | if (ssl_ctx == nullptr) { |
464 | 0 | LOG(ERROR) << "Fail to set server ALPN callback, ssl_ctx is nullptr."; |
465 | 0 | return -1; |
466 | 0 | } |
467 | | |
468 | | // Server set alpn callback when openssl version is more than 1.0.2 |
469 | 0 | #if (OPENSSL_VERSION_NUMBER >= SSL_VERSION_NUMBER(1, 0, 2)) |
470 | 0 | SSL_CTX_set_alpn_select_cb(ssl_ctx, ServerALPNCallback, |
471 | 0 | const_cast<std::string*>(alpns)); |
472 | | #else |
473 | | LOG(WARNING) << "OpenSSL version=" << OPENSSL_VERSION_STR |
474 | | << " is lower than 1.0.2, ignore server alpn."; |
475 | | #endif |
476 | 0 | return 0; |
477 | 0 | } |
478 | | |
479 | 0 | SSL_CTX* CreateClientSSLContext(const ChannelSSLOptions& options) { |
480 | 0 | std::unique_ptr<SSL_CTX, FreeSSLCTX> ssl_ctx( |
481 | 0 | SSL_CTX_new(SSLv23_client_method())); |
482 | 0 | if (!ssl_ctx) { |
483 | 0 | LOG(ERROR) << "Fail to new SSL_CTX: " << SSLError(ERR_get_error()); |
484 | 0 | return NULL; |
485 | 0 | } |
486 | | |
487 | 0 | if (!options.client_cert.certificate.empty() |
488 | 0 | && LoadCertificate(ssl_ctx.get(), |
489 | 0 | options.client_cert.certificate, |
490 | 0 | options.client_cert.private_key, NULL) != 0) { |
491 | 0 | return NULL; |
492 | 0 | } |
493 | | |
494 | 0 | int protocols = ParseSSLProtocols(options.protocols); |
495 | 0 | if (protocols < 0 |
496 | 0 | || SetSSLOptions(ssl_ctx.get(), options.ciphers, |
497 | 0 | protocols, options.verify) != 0) { |
498 | 0 | return NULL; |
499 | 0 | } |
500 | | |
501 | 0 | if (!options.alpn_protocols.empty()) { |
502 | 0 | std::vector<unsigned char> alpn_list; |
503 | 0 | if (!BuildALPNProtocolList(options.alpn_protocols, alpn_list)) { |
504 | 0 | return NULL; |
505 | 0 | } |
506 | 0 | SSL_CTX_set_alpn_protos(ssl_ctx.get(), alpn_list.data(), alpn_list.size()); |
507 | 0 | } |
508 | | |
509 | 0 | SSL_CTX_set_session_cache_mode(ssl_ctx.get(), SSL_SESS_CACHE_CLIENT); |
510 | 0 | return ssl_ctx.release(); |
511 | 0 | } |
512 | | |
513 | | SSL_CTX* CreateServerSSLContext(const std::string& certificate, |
514 | | const std::string& private_key, |
515 | | const ServerSSLOptions& options, |
516 | | const std::string* alpns, |
517 | 0 | std::vector<std::string>* hostnames) { |
518 | 0 | std::unique_ptr<SSL_CTX, FreeSSLCTX> ssl_ctx( |
519 | 0 | SSL_CTX_new(SSLv23_server_method())); |
520 | 0 | if (!ssl_ctx) { |
521 | 0 | LOG(ERROR) << "Fail to new SSL_CTX: " << SSLError(ERR_get_error()); |
522 | 0 | return NULL; |
523 | 0 | } |
524 | | |
525 | 0 | if (LoadCertificate(ssl_ctx.get(), certificate, |
526 | 0 | private_key, hostnames) != 0) { |
527 | 0 | return NULL; |
528 | 0 | } |
529 | | |
530 | 0 | int protocols = TLSv1 | TLSv1_1 | TLSv1_2; |
531 | 0 | if (!options.disable_ssl3) { |
532 | 0 | protocols |= SSLv3; |
533 | 0 | } |
534 | 0 | if (SetSSLOptions(ssl_ctx.get(), options.ciphers, |
535 | 0 | protocols, options.verify) != 0) { |
536 | 0 | return NULL; |
537 | 0 | } |
538 | | |
539 | 0 | #ifdef SSL_MODE_RELEASE_BUFFERS |
540 | 0 | if (options.release_buffer) { |
541 | 0 | long sslmode = SSL_CTX_get_mode(ssl_ctx.get()); |
542 | 0 | sslmode |= SSL_MODE_RELEASE_BUFFERS; |
543 | 0 | SSL_CTX_set_mode(ssl_ctx.get(), sslmode); |
544 | 0 | } |
545 | 0 | #endif // SSL_MODE_RELEASE_BUFFERS |
546 | |
|
547 | 0 | SSL_CTX_set_timeout(ssl_ctx.get(), options.session_lifetime_s); |
548 | 0 | SSL_CTX_sess_set_cache_size(ssl_ctx.get(), options.session_cache_size); |
549 | |
|
550 | 0 | #ifndef OPENSSL_NO_DH |
551 | 0 | SSL_CTX_set_tmp_dh_callback(ssl_ctx.get(), SSLGetDHCallback); |
552 | |
|
553 | 0 | #if !defined(OPENSSL_NO_ECDH) && defined(SSL_CTX_set_tmp_ecdh) |
554 | 0 | EC_KEY* ecdh = NULL; |
555 | 0 | int i = OBJ_sn2nid(options.ecdhe_curve_name.c_str()); |
556 | 0 | if (!i || ((ecdh = EC_KEY_new_by_curve_name(i)) == NULL)) { |
557 | 0 | LOG(ERROR) << "Fail to find ECDHE named curve=" |
558 | 0 | << options.ecdhe_curve_name |
559 | 0 | << ": " << SSLError(ERR_get_error()); |
560 | 0 | return NULL; |
561 | 0 | } |
562 | 0 | SSL_CTX_set_tmp_ecdh(ssl_ctx.get(), ecdh); |
563 | 0 | EC_KEY_free(ecdh); |
564 | 0 | #endif |
565 | |
|
566 | 0 | #endif // OPENSSL_NO_DH |
567 | | |
568 | | // Set ALPN callback to choose application protocol when alpns is not empty. |
569 | 0 | if (alpns != nullptr && !alpns->empty()) { |
570 | 0 | if (SetServerALPNCallback(ssl_ctx.get(), alpns) != 0) { |
571 | 0 | return NULL; |
572 | 0 | } |
573 | 0 | } |
574 | 0 | return ssl_ctx.release(); |
575 | 0 | } |
576 | | |
577 | 0 | SSL* CreateSSLSession(SSL_CTX* ctx, SocketId id, int fd, bool server_mode) { |
578 | 0 | if (ctx == NULL) { |
579 | 0 | LOG(WARNING) << "Lack SSL_ctx to create an SSL session"; |
580 | 0 | return NULL; |
581 | 0 | } |
582 | 0 | SSL* ssl = SSL_new(ctx); |
583 | 0 | if (ssl == NULL) { |
584 | 0 | LOG(ERROR) << "Fail to SSL_new: " << SSLError(ERR_get_error()); |
585 | 0 | return NULL; |
586 | 0 | } |
587 | 0 | if (SSL_set_fd(ssl, fd) != 1) { |
588 | 0 | LOG(ERROR) << "Fail to SSL_set_fd: " << SSLError(ERR_get_error()); |
589 | 0 | SSL_free(ssl); |
590 | 0 | return NULL; |
591 | 0 | } |
592 | | |
593 | 0 | if (server_mode) { |
594 | 0 | SSL_set_accept_state(ssl); |
595 | 0 | } else { |
596 | 0 | SSL_set_connect_state(ssl); |
597 | 0 | } |
598 | 0 | SSL_set_app_data(ssl, id); |
599 | 0 | return ssl; |
600 | 0 | } |
601 | | |
602 | 0 | void AddBIOBuffer(SSL* ssl, int fd, int bufsize) { |
603 | | #if defined(OPENSSL_IS_BORINGSSL) |
604 | | BIO *rbio = BIO_new(BIO_s_mem()); |
605 | | BIO *wbio = BIO_new(BIO_s_mem()); |
606 | | #else |
607 | 0 | BIO *rbio = BIO_new(BIO_f_buffer()); |
608 | 0 | BIO_set_buffer_size(rbio, bufsize); |
609 | 0 | BIO *wbio = BIO_new(BIO_f_buffer()); |
610 | 0 | BIO_set_buffer_size(wbio, bufsize); |
611 | 0 | #endif |
612 | 0 | BIO* rfd = BIO_new(BIO_s_fd()); |
613 | 0 | BIO_set_fd(rfd, fd, 0); |
614 | 0 | rbio = BIO_push(rbio, rfd); |
615 | 0 | BIO* wfd = BIO_new(BIO_s_fd()); |
616 | 0 | BIO_set_fd(wfd, fd, 0); |
617 | 0 | wbio = BIO_push(wbio, wfd); |
618 | 0 | SSL_set_bio(ssl, rbio, wbio); |
619 | 0 | } |
620 | | |
621 | 0 | SSLState DetectSSLState(int fd, int* error_code) { |
622 | | // Peek the first few bytes inside socket to detect whether |
623 | | // it's an SSL connection. If it is, create an SSL session |
624 | | // which will be used to read/write after |
625 | | |
626 | | // Header format of SSLv2 |
627 | | // +-----------+------+----- |
628 | | // | 2B header | 0x01 | etc. |
629 | | // +-----------+------+----- |
630 | | // The first bit of header is always 1, with the following |
631 | | // 15 bits are the length of data |
632 | | |
633 | | // Header format of SSLv3 or TLSv1.0, 1.1, 1.2 |
634 | | // +------+------------+-----------+------+----- |
635 | | // | 0x16 | 2B version | 2B length | 0x01 | etc. |
636 | | // +------+------------+-----------+------+----- |
637 | 0 | char header[6]; |
638 | 0 | const ssize_t nr = recv(fd, header, sizeof(header), MSG_PEEK); |
639 | 0 | if (nr < (ssize_t)sizeof(header)) { |
640 | 0 | if (nr < 0) { |
641 | 0 | if (errno == ENOTSOCK) { |
642 | 0 | return SSL_OFF; |
643 | 0 | } |
644 | 0 | *error_code = errno; // Including EAGAIN and EINTR |
645 | 0 | } else if (nr == 0) { // EOF |
646 | 0 | *error_code = 0; |
647 | 0 | } else { // Not enough data, need retry |
648 | 0 | *error_code = EAGAIN; |
649 | 0 | } |
650 | 0 | return SSL_UNKNOWN; |
651 | 0 | } |
652 | | |
653 | 0 | if ((header[0] == 0x16 && header[5] == 0x01) // SSLv3 or TLSv1.0, 1.1, 1.2 |
654 | 0 | || ((header[0] & 0x80) == 0x80 && header[2] == 0x01)) { // SSLv2 |
655 | 0 | return SSL_CONNECTING; |
656 | 0 | } else { |
657 | 0 | return SSL_OFF; |
658 | 0 | } |
659 | 0 | } |
660 | | |
661 | | #if OPENSSL_VERSION_NUMBER < 0x10100000L |
662 | | |
663 | | // NOTE: Can't find a macro for CRYPTO_THREADID |
664 | | // Fallback to use CRYPTO_LOCK_ECDH as flag |
665 | | #ifdef CRYPTO_LOCK_ECDH |
666 | | static void SSLGetThreadId(CRYPTO_THREADID* tid) { |
667 | | CRYPTO_THREADID_set_numeric(tid, (unsigned long)pthread_self()); |
668 | | } |
669 | | #else |
670 | | static unsigned long SSLGetThreadId() { |
671 | | return pthread_self(); |
672 | | } |
673 | | #endif // CRYPTO_LOCK_ECDH |
674 | | |
675 | | // Locks for SSL library |
676 | | // NOTE: If we replace this with bthread_mutex_t, SSL routines |
677 | | // may crash probably due to some TLS data used inside OpenSSL |
678 | | // Also according to performance test, there is little difference |
679 | | // between pthread mutex and bthread mutex |
680 | | static butil::Mutex* g_ssl_mutexs = NULL; |
681 | | |
682 | | static void SSLLockCallback(int mode, int n, const char* file, int line) { |
683 | | (void)file; |
684 | | (void)line; |
685 | | // Following log is too anonying even for verbose logs. |
686 | | // RPC_VLOG << "[" << file << ':' << line << "] SSL" |
687 | | // << (mode & CRYPTO_LOCK ? "locks" : "unlocks") |
688 | | // << " thread=" << CRYPTO_thread_id(); |
689 | | if (mode & CRYPTO_LOCK) { |
690 | | g_ssl_mutexs[n].lock(); |
691 | | } else { |
692 | | g_ssl_mutexs[n].unlock(); |
693 | | } |
694 | | } |
695 | | #endif // OPENSSL_VERSION_NUMBER < 0x10100000L |
696 | | |
697 | 0 | int SSLThreadInit() { |
698 | | #if OPENSSL_VERSION_NUMBER < 0x10100000L |
699 | | g_ssl_mutexs = new butil::Mutex[CRYPTO_num_locks()]; |
700 | | CRYPTO_set_locking_callback(SSLLockCallback); |
701 | | # ifdef CRYPTO_LOCK_ECDH |
702 | | CRYPTO_THREADID_set_callback(SSLGetThreadId); |
703 | | # else |
704 | | CRYPTO_set_id_callback(SSLGetThreadId); |
705 | | # endif // CRYPTO_LOCK_ECDH |
706 | | #endif // OPENSSL_VERSION_NUMBER < 0x10100000L |
707 | 0 | return 0; |
708 | 0 | } |
709 | | |
710 | | #ifndef OPENSSL_NO_DH |
711 | | |
712 | 0 | static DH* SSLGetDH1024() { |
713 | 0 | BIGNUM* p = get_rfc2409_prime_1024(NULL); |
714 | 0 | if (!p) { |
715 | 0 | return NULL; |
716 | 0 | } |
717 | | // See RFC 2409, Section 6 "Oakley Groups" |
718 | | // for the reason why 2 is used as generator. |
719 | 0 | BIGNUM* g = NULL; |
720 | 0 | BN_dec2bn(&g, "2"); |
721 | 0 | if (!g) { |
722 | 0 | BN_free(p); |
723 | 0 | return NULL; |
724 | 0 | } |
725 | 0 | DH *dh = DH_new(); |
726 | 0 | if (!dh) { |
727 | 0 | BN_free(p); |
728 | 0 | BN_free(g); |
729 | 0 | return NULL; |
730 | 0 | } |
731 | 0 | DH_set0_pqg(dh, p, NULL, g); |
732 | 0 | return dh; |
733 | 0 | } |
734 | | |
735 | 0 | static DH* SSLGetDH2048() { |
736 | 0 | BIGNUM* p = get_rfc3526_prime_2048(NULL); |
737 | 0 | if (!p) { |
738 | 0 | return NULL; |
739 | 0 | } |
740 | | // See RFC 3526, Section 3 "2048-bit MODP Group" |
741 | | // for the reason why 2 is used as generator. |
742 | 0 | BIGNUM* g = NULL; |
743 | 0 | BN_dec2bn(&g, "2"); |
744 | 0 | if (!g) { |
745 | 0 | BN_free(p); |
746 | 0 | return NULL; |
747 | 0 | } |
748 | 0 | DH* dh = DH_new(); |
749 | 0 | if (!dh) { |
750 | 0 | BN_free(p); |
751 | 0 | BN_free(g); |
752 | 0 | return NULL; |
753 | 0 | } |
754 | 0 | DH_set0_pqg(dh, p, NULL, g); |
755 | 0 | return dh; |
756 | 0 | } |
757 | | |
758 | 0 | static DH* SSLGetDH4096() { |
759 | 0 | BIGNUM* p = get_rfc3526_prime_4096(NULL); |
760 | 0 | if (!p) { |
761 | 0 | return NULL; |
762 | 0 | } |
763 | | // See RFC 3526, Section 5 "4096-bit MODP Group" |
764 | | // for the reason why 2 is used as generator. |
765 | 0 | BIGNUM* g = NULL; |
766 | 0 | BN_dec2bn(&g, "2"); |
767 | 0 | if (!g) { |
768 | 0 | BN_free(p); |
769 | 0 | return NULL; |
770 | 0 | } |
771 | 0 | DH *dh = DH_new(); |
772 | 0 | if (!dh) { |
773 | 0 | BN_free(p); |
774 | 0 | BN_free(g); |
775 | 0 | return NULL; |
776 | 0 | } |
777 | 0 | DH_set0_pqg(dh, p, NULL, g); |
778 | 0 | return dh; |
779 | 0 | } |
780 | | |
781 | 0 | static DH* SSLGetDH8192() { |
782 | 0 | BIGNUM* p = get_rfc3526_prime_8192(NULL); |
783 | 0 | if (!p) { |
784 | 0 | return NULL; |
785 | 0 | } |
786 | | // See RFC 3526, Section 7 "8192-bit MODP Group" |
787 | | // for the reason why 2 is used as generator. |
788 | 0 | BIGNUM* g = NULL; |
789 | 0 | BN_dec2bn(&g, "2"); |
790 | 0 | if (!g) { |
791 | 0 | BN_free(g); |
792 | 0 | return NULL; |
793 | 0 | } |
794 | 0 | DH *dh = DH_new(); |
795 | 0 | if (!dh) { |
796 | 0 | BN_free(p); |
797 | 0 | BN_free(g); |
798 | 0 | return NULL; |
799 | 0 | } |
800 | 0 | DH_set0_pqg(dh, p, NULL, g); |
801 | 0 | return dh; |
802 | 0 | } |
803 | | |
804 | | #endif // OPENSSL_NO_DH |
805 | | |
806 | 0 | int SSLDHInit() { |
807 | 0 | #ifndef OPENSSL_NO_DH |
808 | 0 | if ((g_dh_1024 = SSLGetDH1024()) == NULL) { |
809 | 0 | LOG(ERROR) << "Fail to initialize DH-1024"; |
810 | 0 | return -1; |
811 | 0 | } |
812 | 0 | if ((g_dh_2048 = SSLGetDH2048()) == NULL) { |
813 | 0 | LOG(ERROR) << "Fail to initialize DH-2048"; |
814 | 0 | return -1; |
815 | 0 | } |
816 | 0 | if ((g_dh_4096 = SSLGetDH4096()) == NULL) { |
817 | 0 | LOG(ERROR) << "Fail to initialize DH-4096"; |
818 | 0 | return -1; |
819 | 0 | } |
820 | 0 | if ((g_dh_8192 = SSLGetDH8192()) == NULL) { |
821 | 0 | LOG(ERROR) << "Fail to initialize DH-8192"; |
822 | 0 | return -1; |
823 | 0 | } |
824 | 0 | #endif // OPENSSL_NO_DH |
825 | 0 | return 0; |
826 | 0 | } |
827 | | |
828 | 0 | static std::string GetNextLevelSeparator(const char* sep) { |
829 | 0 | if (sep[0] != '\n') { |
830 | 0 | return sep; |
831 | 0 | } |
832 | 0 | const size_t left_len = strlen(sep + 1); |
833 | 0 | if (left_len == 0) { |
834 | 0 | return "\n "; |
835 | 0 | } |
836 | 0 | std::string new_sep; |
837 | 0 | new_sep.reserve(left_len * 2 + 1); |
838 | 0 | new_sep.append(sep, left_len + 1); |
839 | 0 | new_sep.append(sep + 1, left_len); |
840 | 0 | return new_sep; |
841 | 0 | } |
842 | | |
843 | 0 | void Print(std::ostream& os, SSL* ssl, const char* sep) { |
844 | 0 | os << "cipher=" << SSL_get_cipher(ssl) << sep |
845 | 0 | << "protocol=" << SSL_get_version(ssl) << sep |
846 | 0 | << "verify=" << (SSL_get_verify_mode(ssl) & SSL_VERIFY_PEER |
847 | 0 | ? "success" : "none"); |
848 | 0 | X509* cert = SSL_get_peer_certificate(ssl); |
849 | 0 | if (cert) { |
850 | 0 | os << sep << "peer_certificate={"; |
851 | 0 | const std::string new_sep = GetNextLevelSeparator(sep); |
852 | 0 | if (sep[0] == '\n') { |
853 | 0 | os << new_sep; |
854 | 0 | } |
855 | 0 | Print(os, cert, new_sep.c_str()); |
856 | 0 | if (sep[0] == '\n') { |
857 | 0 | os << sep; |
858 | 0 | } |
859 | 0 | os << '}'; |
860 | 0 | } |
861 | 0 | } |
862 | | |
863 | 0 | void Print(std::ostream& os, X509* cert, const char* sep) { |
864 | 0 | BIO* buf = BIO_new(BIO_s_mem()); |
865 | 0 | if (buf == NULL) { |
866 | 0 | return; |
867 | 0 | } |
868 | 0 | BIO_printf(buf, "subject="); |
869 | 0 | X509_NAME_print(buf, X509_get_subject_name(cert), 0); |
870 | 0 | BIO_printf(buf, "%sstart_date=", sep); |
871 | 0 | ASN1_TIME_print(buf, X509_get_notBefore(cert)); |
872 | 0 | BIO_printf(buf, "%sexpire_date=", sep); |
873 | 0 | ASN1_TIME_print(buf, X509_get_notAfter(cert)); |
874 | |
|
875 | 0 | BIO_printf(buf, "%scommon_name=", sep); |
876 | 0 | std::vector<std::string> hostnames; |
877 | 0 | brpc::ExtractHostnames(cert, &hostnames); |
878 | 0 | for (size_t i = 0; i < hostnames.size(); ++i) { |
879 | 0 | BIO_printf(buf, "%s;", hostnames[i].c_str()); |
880 | 0 | } |
881 | |
|
882 | 0 | BIO_printf(buf, "%sissuer=", sep); |
883 | 0 | X509_NAME_print(buf, X509_get_issuer_name(cert), 0); |
884 | |
|
885 | 0 | char* bufp = NULL; |
886 | 0 | int len = BIO_get_mem_data(buf, &bufp); |
887 | 0 | os << butil::StringPiece(bufp, len); |
888 | 0 | } |
889 | | |
890 | 0 | std::string ALPNProtocolToString(const AdaptiveProtocolType& protocol) { |
891 | 0 | butil::StringPiece name = protocol.name(); |
892 | | // Default use http 1.1 version |
893 | 0 | if (name.starts_with("http")) { |
894 | 0 | name.set("http/1.1"); |
895 | 0 | } |
896 | | |
897 | | // ALPN extension uses 1 byte to record the protocol length |
898 | | // and it's maximum length is 255. |
899 | 0 | if (name.size() > CHAR_MAX) { |
900 | 0 | name = name.substr(0, CHAR_MAX); |
901 | 0 | } |
902 | |
|
903 | 0 | char length = static_cast<char>(name.size()); |
904 | 0 | return std::string(&length, 1) + name.data(); |
905 | 0 | } |
906 | | |
907 | | bool BuildALPNProtocolList( |
908 | | const std::vector<std::string>& alpn_protocols, |
909 | | std::vector<unsigned char>& result |
910 | 0 | ) { |
911 | 0 | size_t alpn_list_length = 0; |
912 | 0 | for (const auto& alpn_protocol : alpn_protocols) { |
913 | 0 | if (alpn_protocol.size() > UCHAR_MAX) { |
914 | 0 | LOG(ERROR) << "Fail to build ALPN procotol list: " |
915 | 0 | << "protocol name length " << alpn_protocol.size() << " too long, " |
916 | 0 | << "max 255 supported."; |
917 | 0 | return false; |
918 | 0 | } |
919 | 0 | alpn_list_length += alpn_protocol.size() + 1; |
920 | 0 | } |
921 | | |
922 | 0 | result.resize(alpn_list_length); |
923 | 0 | for (size_t curr = 0, i = 0; i < alpn_protocols.size(); i++) { |
924 | 0 | result[curr++] = static_cast<unsigned char>( |
925 | 0 | alpn_protocols[i].size() |
926 | 0 | ); |
927 | 0 | std::copy( |
928 | 0 | alpn_protocols[i].begin(), |
929 | 0 | alpn_protocols[i].end(), |
930 | 0 | result.begin() + curr |
931 | 0 | ); |
932 | 0 | curr += alpn_protocols[i].size(); |
933 | 0 | } |
934 | 0 | return true; |
935 | 0 | } |
936 | | |
937 | | } // namespace brpc |
938 | | |
939 | | #endif // USE_MESALINK |