1
#include "source/common/tls/ssl_socket.h"
2

            
3
#include "envoy/stats/scope.h"
4

            
5
#include "source/common/common/assert.h"
6
#include "source/common/common/empty_string.h"
7
#include "source/common/common/hex.h"
8
#include "source/common/http/headers.h"
9
#include "source/common/tls/io_handle_bio.h"
10
#include "source/common/tls/ssl_handshaker.h"
11
#include "source/common/tls/utility.h"
12

            
13
#include "absl/strings/str_replace.h"
14
#include "openssl/err.h"
15
#include "openssl/x509v3.h"
16

            
17
using Envoy::Network::PostIoAction;
18

            
19
namespace Envoy {
20
namespace Extensions {
21
namespace TransportSockets {
22
namespace Tls {
23

            
24
namespace {
25

            
26
constexpr absl::string_view NotReadyReason{"TLS error: Secret is not supplied by SDS"};
27

            
28
} // namespace
29

            
30
8
absl::string_view NotReadySslSocket::failureReason() const { return NotReadyReason; }
31

            
32
absl::StatusOr<std::unique_ptr<SslSocket>>
33
SslSocket::create(Envoy::Ssl::ContextSharedPtr ctx, InitialState state,
34
                  const Network::TransportSocketOptionsConstSharedPtr& transport_socket_options,
35
                  Ssl::HandshakerFactoryCb handshaker_factory_cb,
36
2587
                  Upstream::HostDescriptionConstSharedPtr host) {
37
2587
  std::unique_ptr<SslSocket> socket(new SslSocket(ctx, transport_socket_options));
38
2587
  auto status = socket->initialize(state, handshaker_factory_cb, host);
39
2587
  if (status.ok()) {
40
2585
    return socket;
41
2585
  } else {
42
2
    return status;
43
2
  }
44
2587
}
45

            
46
SslSocket::SslSocket(Envoy::Ssl::ContextSharedPtr ctx,
47
                     const Network::TransportSocketOptionsConstSharedPtr& transport_socket_options)
48
2587
    : transport_socket_options_(transport_socket_options),
49
2587
      ctx_(std::dynamic_pointer_cast<ContextImpl>(ctx)) {}
50

            
51
absl::Status SslSocket::initialize(InitialState state,
52
                                   Ssl::HandshakerFactoryCb handshaker_factory_cb,
53
2587
                                   Upstream::HostDescriptionConstSharedPtr host) {
54
2587
  auto status_or_ssl = ctx_->newSsl(transport_socket_options_, host);
55
2587
  if (!status_or_ssl.ok()) {
56
2
    return status_or_ssl.status();
57
2
  }
58

            
59
2585
  info_ = std::dynamic_pointer_cast<SslHandshakerImpl>(handshaker_factory_cb(
60
2585
      std::move(status_or_ssl.value()), ctx_->sslExtendedSocketInfoIndex(), this));
61

            
62
2585
  if (state == InitialState::Client) {
63
1196
    SSL_set_connect_state(rawSsl());
64
1428
  } else {
65
1389
    ASSERT(state == InitialState::Server);
66
1389
    SSL_set_accept_state(rawSsl());
67
1389
  }
68

            
69
2585
  return absl::OkStatus();
70
2587
}
71

            
72
2503
void SslSocket::setTransportSocketCallbacks(Network::TransportSocketCallbacks& callbacks) {
73
2503
  ASSERT(!callbacks_);
74
2503
  callbacks_ = &callbacks;
75

            
76
  // Associate this SSL connection with all the certificates (with their potentially different
77
  // private key methods).
78
2503
  for (auto const& provider : ctx_->getPrivateKeyMethodProviders()) {
79
19
    provider->registerPrivateKeyMethod(rawSsl(), *this, callbacks_->connection().dispatcher());
80
19
  }
81

            
82
  // Use custom BIO that reads from/writes to IoHandle
83
2503
  BIO* bio = BIO_new_io_handle(&callbacks_->ioHandle());
84
2503
  SSL_set_bio(rawSsl(), bio, bio);
85
2503
  SSL_set_ex_data(rawSsl(), ContextImpl::sslSocketIndex(), static_cast<void*>(callbacks_));
86
2503
}
87

            
88
22378
SslSocket::ReadResult SslSocket::sslReadIntoSlice(Buffer::RawSlice& slice) {
89
22378
  ReadResult result;
90
22378
  uint8_t* mem = static_cast<uint8_t*>(slice.mem_);
91
22378
  size_t remaining = slice.len_;
92
46586
  while (remaining > 0) {
93
34168
    int rc = SSL_read(rawSsl(), mem, remaining);
94
34168
    ENVOY_CONN_LOG(trace, "ssl read returns: {}", callbacks_->connection(), rc);
95
34168
    if (rc > 0) {
96
24208
      ASSERT(static_cast<size_t>(rc) <= remaining);
97
24208
      mem += rc;
98
24208
      remaining -= rc;
99
24208
      result.bytes_read_ += rc;
100
25690
    } else {
101
9960
      result.error_ = absl::make_optional<int>(rc);
102
9960
      break;
103
9960
    }
104
34168
  }
105

            
106
22378
  return result;
107
22378
}
108

            
109
11551
Network::IoResult SslSocket::doRead(Buffer::Instance& read_buffer) {
110
11551
  if (info_->state() != Ssl::SocketState::HandshakeComplete &&
111
11551
      info_->state() != Ssl::SocketState::ShutdownSent) {
112
1469
    PostIoAction action = doHandshake();
113
1469
    if (action == PostIoAction::Close || info_->state() != Ssl::SocketState::HandshakeComplete) {
114
      // end_stream is false because either a hard error occurred (action == Close) or
115
      // the handshake isn't complete, so a half-close cannot occur yet.
116
1468
      return {action, 0, false};
117
1468
    }
118
1469
  }
119

            
120
10083
  bool keep_reading = true;
121
10083
  bool end_stream = false;
122
10083
  PostIoAction action = PostIoAction::KeepOpen;
123
10083
  uint64_t bytes_read = 0;
124
21553
  while (keep_reading) {
125
11470
    uint64_t bytes_read_this_iteration = 0;
126
11470
    Buffer::Reservation reservation = read_buffer.reserveForRead();
127
23888
    for (uint64_t i = 0; i < reservation.numSlices(); i++) {
128
22378
      auto result = sslReadIntoSlice(reservation.slices()[i]);
129
22378
      bytes_read_this_iteration += result.bytes_read_;
130
22378
      if (result.error_.has_value()) {
131
9960
        keep_reading = false;
132
9960
        int err = SSL_get_error(rawSsl(), result.error_.value());
133
9960
        ENVOY_CONN_LOG(trace, "ssl error occurred while read: {}", callbacks_->connection(),
134
9960
                       Utility::getErrorDescription(err));
135
9960
        switch (err) {
136
9127
        case SSL_ERROR_WANT_READ:
137
9127
          break;
138
809
        case SSL_ERROR_ZERO_RETURN:
139
          // Graceful shutdown using close_notify TLS alert.
140
809
          end_stream = true;
141
809
          break;
142
15
        case SSL_ERROR_SYSCALL:
143
15
          if (result.error_.value() == 0) {
144
            // Non-graceful shutdown by closing the underlying socket.
145
15
            end_stream = true;
146
15
            break;
147
15
          }
148
          FALLTHRU;
149
        case SSL_ERROR_WANT_WRITE:
150
          // Renegotiation has started. We don't handle renegotiation so just fall through.
151
9
        default:
152
9
          drainErrorQueue();
153
9
          action = PostIoAction::Close;
154
9
          break;
155
9960
        }
156

            
157
9960
        break;
158
9960
      }
159
22378
    }
160

            
161
11470
    reservation.commit(bytes_read_this_iteration);
162
11470
    if (bytes_read_this_iteration > 0 && callbacks_->shouldDrainReadBuffer()) {
163
181
      callbacks_->setTransportSocketIsReadable();
164
181
      keep_reading = false;
165
181
    }
166

            
167
11470
    bytes_read += bytes_read_this_iteration;
168
11470
  }
169

            
170
10083
  ENVOY_CONN_LOG(trace, "ssl read {} bytes", callbacks_->connection(), bytes_read);
171

            
172
10083
  return {action, bytes_read, end_stream};
173
10083
}
174

            
175
9
void SslSocket::onPrivateKeyMethodComplete() { resumeHandshake(); }
176

            
177
45
void SslSocket::resumeHandshake() {
178
45
  ASSERT(callbacks_ != nullptr && callbacks_->connection().dispatcher().isThreadSafe());
179
45
  ASSERT(info_->state() == Ssl::SocketState::HandshakeInProgress);
180

            
181
  // Resume handshake.
182
45
  PostIoAction action = doHandshake();
183
45
  if (action == PostIoAction::Close) {
184
9
    ENVOY_CONN_LOG(debug, "async handshake completion error", callbacks_->connection());
185
9
    callbacks_->connection().close(Network::ConnectionCloseType::FlushWrite,
186
9
                                   "failed_resuming_async_handshake");
187
9
  }
188
45
}
189

            
190
4700
Network::Connection& SslSocket::connection() const { return callbacks_->connection(); }
191

            
192
2127
void SslSocket::onSuccess(SSL* ssl) {
193
2127
  ctx_->logHandshake(ssl);
194
2127
  if (callbacks_->connection().streamInfo().upstreamInfo()) {
195
1152
    callbacks_->connection()
196
1152
        .streamInfo()
197
1152
        .upstreamInfo()
198
1152
        ->upstreamTiming()
199
1152
        .onUpstreamHandshakeComplete(callbacks_->connection().dispatcher().timeSource());
200
1308
  } else {
201
975
    callbacks_->connection().streamInfo().downstreamTiming().onDownstreamHandshakeComplete(
202
975
        callbacks_->connection().dispatcher().timeSource());
203
975
  }
204
2127
  callbacks_->raiseEvent(Network::ConnectionEvent::Connected);
205
2127
}
206

            
207
273
void SslSocket::onFailure() { drainErrorQueue(); }
208

            
209
7323
PostIoAction SslSocket::doHandshake() { return info_->doHandshake(); }
210

            
211
2792
void SslSocket::drainErrorQueue() {
212
2792
  bool saw_error = false;
213
2792
  bool saw_counted_error = false;
214
2792
  bool saw_cert_verify_failed = false;
215
2792
  bool saw_no_client_cert = false;
216
3094
  while (uint64_t err = ERR_get_error()) {
217
302
    if (ERR_GET_LIB(err) == ERR_LIB_SSL) {
218
248
      if (ERR_GET_REASON(err) == SSL_R_PEER_DID_NOT_RETURN_A_CERTIFICATE) {
219
14
        ctx_->stats().fail_verify_no_cert_.inc();
220
14
        saw_counted_error = true;
221
14
        saw_no_client_cert = true;
222
234
      } else if (ERR_GET_REASON(err) == SSL_R_CERTIFICATE_VERIFY_FAILED) {
223
56
        saw_counted_error = true;
224
56
        saw_cert_verify_failed = true;
225
56
      }
226
295
    } else if (ERR_GET_LIB(err) == ERR_LIB_SYS) {
227
      // Any syscall errors that result in connection closure are already tracked in other
228
      // connection related stats. We will still retain the specific syscall failure for
229
      // transport failure reasons.
230
52
      saw_counted_error = true;
231
52
    }
232
302
    saw_error = true;
233

            
234
302
    if (failure_reason_.empty()) {
235
281
      failure_reason_ = "TLS_error:";
236
281
    }
237

            
238
302
    absl::StrAppend(&failure_reason_, "|", err, ":",
239
302
                    absl::NullSafeStringView(ERR_lib_error_string(err)), ":",
240
302
                    absl::NullSafeStringView(ERR_func_error_string(err)), ":",
241
302
                    absl::NullSafeStringView(ERR_reason_error_string(err)));
242
302
  }
243

            
244
2792
  if (!saw_error) {
245
2511
    return;
246
2511
  }
247

            
248
  // Append detailed error info for certificate-related failures.
249
281
  bool added_detail = false;
250
281
  if (saw_cert_verify_failed) {
251
56
    auto* extended_socket_info = reinterpret_cast<Envoy::Ssl::SslExtendedSocketInfo*>(
252
56
        SSL_get_ex_data(rawSsl(), ContextImpl::sslExtendedSocketInfoIndex()));
253
56
    if (extended_socket_info != nullptr) {
254
56
      absl::string_view cert_validation_error = extended_socket_info->certificateValidationError();
255
56
      if (!cert_validation_error.empty()) {
256
56
        absl::StrAppend(&failure_reason_, ":", cert_validation_error);
257
56
        added_detail = true;
258
56
      }
259
56
    }
260
56
  }
261
281
  if (!added_detail && saw_no_client_cert) {
262
14
    absl::StrAppend(&failure_reason_, ":peer did not provide required client certificate");
263
14
  }
264

            
265
281
  if (!failure_reason_.empty()) {
266
281
    absl::StrAppend(&failure_reason_, ":TLS_error_end");
267
281
    ENVOY_CONN_LOG(debug, "remote address:{},{}", callbacks_->connection(),
268
281
                   callbacks_->connection().connectionInfoProvider().remoteAddress()->asString(),
269
281
                   failure_reason_);
270
281
  }
271

            
272
281
  if (!saw_counted_error) {
273
159
    ctx_->stats().connection_error_.inc();
274
159
  }
275
281
}
276

            
277
21960
Network::IoResult SslSocket::doWrite(Buffer::Instance& write_buffer, bool end_stream) {
278
21960
  ASSERT(info_->state() != Ssl::SocketState::ShutdownSent || write_buffer.length() == 0);
279
21960
  if (info_->state() != Ssl::SocketState::HandshakeComplete &&
280
21960
      info_->state() != Ssl::SocketState::ShutdownSent) {
281
5809
    PostIoAction action = doHandshake();
282
5809
    if (action == PostIoAction::Close || info_->state() != Ssl::SocketState::HandshakeComplete) {
283
3877
      return {action, 0, false};
284
3877
    }
285
5809
  }
286

            
287
18083
  uint64_t bytes_to_write;
288
18083
  if (bytes_to_retry_) {
289
41
    bytes_to_write = bytes_to_retry_;
290
41
    bytes_to_retry_ = 0;
291
18042
  } else {
292
18042
    bytes_to_write = std::min(write_buffer.length(), static_cast<uint64_t>(16384));
293
18042
  }
294

            
295
18083
  uint64_t total_bytes_written = 0;
296
37770
  while (bytes_to_write > 0) {
297
    // TODO(mattklein123): As it relates to our fairness efforts, we might want to limit the number
298
    // of iterations of this loop, either by pure iterations, bytes written, etc.
299

            
300
    // SSL_write() requires that if a previous call returns SSL_ERROR_WANT_WRITE, we need to call
301
    // it again with the same parameters. This is done by tracking last write size, but not write
302
    // data, since linearize() will return the same undrained data anyway.
303
19736
    ASSERT(bytes_to_write <= write_buffer.length());
304
19736
    int rc = SSL_write(rawSsl(), write_buffer.linearize(bytes_to_write), bytes_to_write);
305
19736
    ENVOY_CONN_LOG(trace, "ssl write returns: {}", callbacks_->connection(), rc);
306
19736
    if (rc > 0) {
307
19687
      ASSERT(rc == static_cast<int>(bytes_to_write));
308
19687
      total_bytes_written += rc;
309
19687
      write_buffer.drain(rc);
310
19687
      bytes_to_write = std::min(write_buffer.length(), static_cast<uint64_t>(16384));
311
19687
    } else {
312
49
      int err = SSL_get_error(rawSsl(), rc);
313
49
      ENVOY_CONN_LOG(trace, "ssl error occurred while write: {}", callbacks_->connection(),
314
49
                     Utility::getErrorDescription(err));
315
49
      switch (err) {
316
41
      case SSL_ERROR_WANT_WRITE:
317
41
        bytes_to_retry_ = bytes_to_write;
318
41
        break;
319
      case SSL_ERROR_WANT_READ:
320
      // Renegotiation has started. We don't handle renegotiation so just fall through.
321
8
      default:
322
8
        drainErrorQueue();
323
8
        return {PostIoAction::Close, total_bytes_written, false};
324
49
      }
325

            
326
41
      break;
327
49
    }
328
19736
  }
329

            
330
18075
  if (write_buffer.length() == 0 && end_stream) {
331
90
    shutdownSsl();
332
90
  }
333

            
334
18075
  return {PostIoAction::KeepOpen, total_bytes_written, false};
335
18083
}
336

            
337
1411
void SslSocket::onConnected() { ASSERT(info_->state() == Ssl::SocketState::PreHandshake); }
338

            
339
2650
Ssl::ConnectionInfoConstSharedPtr SslSocket::ssl() const { return info_; }
340

            
341
2184
void SslSocket::shutdownSsl() {
342
2184
  ASSERT(info_->state() != Ssl::SocketState::PreHandshake);
343
2184
  if (info_->state() != Ssl::SocketState::ShutdownSent &&
344
2184
      callbacks_->connection().state() != Network::Connection::State::Closed) {
345
2141
    int rc = SSL_shutdown(rawSsl());
346
    if constexpr (Event::PlatformDefaultTriggerType == Event::FileTriggerType::EmulatedEdge) {
347
      // Windows operate under `EmulatedEdge`. These are level events that are artificially
348
      // made to behave like edge events. And if the rc is 0 then in that case we want read
349
      // activation resumption. This code is protected with an `constexpr` if, to minimize the tax
350
      // on POSIX systems that operate in Edge events.
351
      if (rc == 0) {
352
        // See https://www.openssl.org/docs/manmaster/man3/SSL_shutdown.html
353
        // if return value is 0,  Call SSL_read() to do a bidirectional shutdown.
354
        callbacks_->setTransportSocketIsReadable();
355
      }
356
    }
357
2141
    ENVOY_CONN_LOG(debug, "SSL shutdown: rc={}", callbacks_->connection(), rc);
358
2141
    drainErrorQueue();
359
2141
    info_->setState(Ssl::SocketState::ShutdownSent);
360
2141
  }
361
2184
}
362

            
363
408
void SslSocket::shutdownBasic() {
364
408
  if (info_->state() != Ssl::SocketState::ShutdownSent &&
365
408
      callbacks_->connection().state() != Network::Connection::State::Closed) {
366
361
    callbacks_->ioHandle().shutdown(ENVOY_SHUT_WR);
367
361
    drainErrorQueue();
368
361
    info_->setState(Ssl::SocketState::ShutdownSent);
369
361
  }
370
408
}
371

            
372
2502
void SslSocket::closeSocket(Network::ConnectionEvent) {
373
  // Unregister the SSL connection object from private key method providers.
374
2502
  for (auto const& provider : ctx_->getPrivateKeyMethodProviders()) {
375
19
    provider->unregisterPrivateKeyMethod(rawSsl());
376
19
  }
377

            
378
  // Attempt to send a shutdown before closing the socket. It's possible this won't go out if
379
  // there is no room on the socket. We can extend the state machine to handle this at some point
380
  // if needed.
381
2502
  if (info_->state() == Ssl::SocketState::HandshakeInProgress ||
382
2502
      info_->state() == Ssl::SocketState::HandshakeComplete) {
383
2094
    shutdownSsl();
384
2136
  } else {
385
    // We're not in a state to do the full SSL shutdown so perform a basic shutdown to flush any
386
    // outstanding alerts
387
408
    shutdownBasic();
388
408
  }
389
2502
}
390

            
391
75
std::string SslSocket::protocol() const { return ssl()->alpn(); }
392

            
393
4107
absl::string_view SslSocket::failureReason() const { return failure_reason_; }
394

            
395
6
void SslSocket::onAsynchronousCertValidationComplete() {
396
6
  ENVOY_CONN_LOG(debug, "Async cert validation completed", callbacks_->connection());
397
6
  if (info_->state() == Ssl::SocketState::HandshakeInProgress) {
398
5
    resumeHandshake();
399
5
  }
400
6
}
401

            
402
31
void SslSocket::onAsynchronousCertificateSelectionComplete() {
403
31
  ENVOY_CONN_LOG(debug, "Async cert selection completed", callbacks_->connection());
404
31
  if (info_->state() != Ssl::SocketState::HandshakeInProgress) {
405
    IS_ENVOY_BUG(fmt::format("unexpected handshake state: {}", static_cast<int>(info_->state())));
406
    return;
407
  }
408
31
  resumeHandshake();
409
31
}
410

            
411
} // namespace Tls
412
} // namespace TransportSockets
413
} // namespace Extensions
414
} // namespace Envoy