1
#include "source/common/tls/ssl_socket.h"
2

            
3
#include "envoy/stats/scope.h"
4

            
5
#include "source/common/common/assert.h"
6
#include "source/common/common/empty_string.h"
7
#include "source/common/common/hex.h"
8
#include "source/common/http/headers.h"
9
#include "source/common/tls/io_handle_bio.h"
10
#include "source/common/tls/ssl_handshaker.h"
11
#include "source/common/tls/utility.h"
12

            
13
#include "absl/strings/str_replace.h"
14
#include "openssl/err.h"
15
#include "openssl/x509v3.h"
16

            
17
using Envoy::Network::PostIoAction;
18

            
19
namespace Envoy {
20
namespace Extensions {
21
namespace TransportSockets {
22
namespace Tls {
23

            
24
namespace {
25

            
26
constexpr absl::string_view NotReadyReason{"TLS error: Secret is not supplied by SDS"};
27

            
28
} // namespace
29

            
30
8
absl::string_view NotReadySslSocket::failureReason() const { return NotReadyReason; }
31

            
32
absl::StatusOr<std::unique_ptr<SslSocket>>
33
SslSocket::create(Envoy::Ssl::ContextSharedPtr ctx, InitialState state,
34
                  const Network::TransportSocketOptionsConstSharedPtr& transport_socket_options,
35
                  Ssl::HandshakerFactoryCb handshaker_factory_cb,
36
2556
                  Upstream::HostDescriptionConstSharedPtr host) {
37
2556
  std::unique_ptr<SslSocket> socket(new SslSocket(ctx, transport_socket_options));
38
2556
  auto status = socket->initialize(state, handshaker_factory_cb, host);
39
2556
  if (status.ok()) {
40
2554
    return socket;
41
2554
  } else {
42
2
    return status;
43
2
  }
44
2556
}
45

            
46
SslSocket::SslSocket(Envoy::Ssl::ContextSharedPtr ctx,
47
                     const Network::TransportSocketOptionsConstSharedPtr& transport_socket_options)
48
2556
    : transport_socket_options_(transport_socket_options),
49
2556
      ctx_(std::dynamic_pointer_cast<ContextImpl>(ctx)) {}
50

            
51
absl::Status SslSocket::initialize(InitialState state,
52
                                   Ssl::HandshakerFactoryCb handshaker_factory_cb,
53
2556
                                   Upstream::HostDescriptionConstSharedPtr host) {
54
2556
  auto status_or_ssl = ctx_->newSsl(transport_socket_options_, host);
55
2556
  if (!status_or_ssl.ok()) {
56
2
    return status_or_ssl.status();
57
2
  }
58

            
59
2554
  info_ = std::dynamic_pointer_cast<SslHandshakerImpl>(handshaker_factory_cb(
60
2554
      std::move(status_or_ssl.value()), ctx_->sslExtendedSocketInfoIndex(), this));
61

            
62
2554
  if (state == InitialState::Client) {
63
1181
    SSL_set_connect_state(rawSsl());
64
1414
  } else {
65
1373
    ASSERT(state == InitialState::Server);
66
1373
    SSL_set_accept_state(rawSsl());
67
1373
  }
68

            
69
2554
  return absl::OkStatus();
70
2556
}
71

            
72
2472
void SslSocket::setTransportSocketCallbacks(Network::TransportSocketCallbacks& callbacks) {
73
2472
  ASSERT(!callbacks_);
74
2472
  callbacks_ = &callbacks;
75

            
76
  // Associate this SSL connection with all the certificates (with their potentially different
77
  // private key methods).
78
2472
  for (auto const& provider : ctx_->getPrivateKeyMethodProviders()) {
79
19
    provider->registerPrivateKeyMethod(rawSsl(), *this, callbacks_->connection().dispatcher());
80
19
  }
81

            
82
  // Use custom BIO that reads from/writes to IoHandle
83
2472
  BIO* bio = BIO_new_io_handle(&callbacks_->ioHandle());
84
2472
  SSL_set_bio(rawSsl(), bio, bio);
85
2472
  SSL_set_ex_data(rawSsl(), ContextImpl::sslSocketIndex(), static_cast<void*>(callbacks_));
86
2472
}
87

            
88
22418
SslSocket::ReadResult SslSocket::sslReadIntoSlice(Buffer::RawSlice& slice) {
89
22418
  ReadResult result;
90
22418
  uint8_t* mem = static_cast<uint8_t*>(slice.mem_);
91
22418
  size_t remaining = slice.len_;
92
44945
  while (remaining > 0) {
93
32530
    int rc = SSL_read(rawSsl(), mem, remaining);
94
32530
    ENVOY_CONN_LOG(trace, "ssl read returns: {}", callbacks_->connection(), rc);
95
32530
    if (rc > 0) {
96
22527
      ASSERT(static_cast<size_t>(rc) <= remaining);
97
22527
      mem += rc;
98
22527
      remaining -= rc;
99
22527
      result.bytes_read_ += rc;
100
23961
    } else {
101
10003
      result.error_ = absl::make_optional<int>(rc);
102
10003
      break;
103
10003
    }
104
32530
  }
105

            
106
22418
  return result;
107
22418
}
108

            
109
11588
Network::IoResult SslSocket::doRead(Buffer::Instance& read_buffer) {
110
11588
  if (info_->state() != Ssl::SocketState::HandshakeComplete &&
111
11588
      info_->state() != Ssl::SocketState::ShutdownSent) {
112
1453
    PostIoAction action = doHandshake();
113
1453
    if (action == PostIoAction::Close || info_->state() != Ssl::SocketState::HandshakeComplete) {
114
      // end_stream is false because either a hard error occurred (action == Close) or
115
      // the handshake isn't complete, so a half-close cannot occur yet.
116
1453
      return {action, 0, false};
117
1453
    }
118
1453
  }
119

            
120
10135
  bool keep_reading = true;
121
10135
  bool end_stream = false;
122
10135
  PostIoAction action = PostIoAction::KeepOpen;
123
10135
  uint64_t bytes_read = 0;
124
21631
  while (keep_reading) {
125
11496
    uint64_t bytes_read_this_iteration = 0;
126
11496
    Buffer::Reservation reservation = read_buffer.reserveForRead();
127
23910
    for (uint64_t i = 0; i < reservation.numSlices(); i++) {
128
22418
      auto result = sslReadIntoSlice(reservation.slices()[i]);
129
22418
      bytes_read_this_iteration += result.bytes_read_;
130
22418
      if (result.error_.has_value()) {
131
10004
        keep_reading = false;
132
10004
        int err = SSL_get_error(rawSsl(), result.error_.value());
133
10004
        ENVOY_CONN_LOG(trace, "ssl error occurred while read: {}", callbacks_->connection(),
134
10004
                       Utility::getErrorDescription(err));
135
10004
        switch (err) {
136
9195
        case SSL_ERROR_WANT_READ:
137
9195
          break;
138
797
        case SSL_ERROR_ZERO_RETURN:
139
          // Graceful shutdown using close_notify TLS alert.
140
797
          end_stream = true;
141
797
          break;
142
3
        case SSL_ERROR_SYSCALL:
143
3
          if (result.error_.value() == 0) {
144
            // Non-graceful shutdown by closing the underlying socket.
145
3
            end_stream = true;
146
3
            break;
147
3
          }
148
          FALLTHRU;
149
        case SSL_ERROR_WANT_WRITE:
150
          // Renegotiation has started. We don't handle renegotiation so just fall through.
151
9
        default:
152
9
          drainErrorQueue();
153
9
          action = PostIoAction::Close;
154
9
          break;
155
10004
        }
156

            
157
10004
        break;
158
10004
      }
159
22418
    }
160

            
161
11496
    reservation.commit(bytes_read_this_iteration);
162
11496
    if (bytes_read_this_iteration > 0 && callbacks_->shouldDrainReadBuffer()) {
163
191
      callbacks_->setTransportSocketIsReadable();
164
191
      keep_reading = false;
165
191
    }
166

            
167
11496
    bytes_read += bytes_read_this_iteration;
168
11496
  }
169

            
170
10135
  ENVOY_CONN_LOG(trace, "ssl read {} bytes", callbacks_->connection(), bytes_read);
171

            
172
10135
  return {action, bytes_read, end_stream};
173
10135
}
174

            
175
9
void SslSocket::onPrivateKeyMethodComplete() { resumeHandshake(); }
176

            
177
45
void SslSocket::resumeHandshake() {
178
45
  ASSERT(callbacks_ != nullptr && callbacks_->connection().dispatcher().isThreadSafe());
179
45
  ASSERT(info_->state() == Ssl::SocketState::HandshakeInProgress);
180

            
181
  // Resume handshake.
182
45
  PostIoAction action = doHandshake();
183
45
  if (action == PostIoAction::Close) {
184
9
    ENVOY_CONN_LOG(debug, "async handshake completion error", callbacks_->connection());
185
9
    callbacks_->connection().close(Network::ConnectionCloseType::FlushWrite,
186
9
                                   "failed_resuming_async_handshake");
187
9
  }
188
45
}
189

            
190
4652
Network::Connection& SslSocket::connection() const { return callbacks_->connection(); }
191

            
192
2099
void SslSocket::onSuccess(SSL* ssl) {
193
2099
  ctx_->logHandshake(ssl);
194
2099
  if (callbacks_->connection().streamInfo().upstreamInfo()) {
195
1144
    callbacks_->connection()
196
1144
        .streamInfo()
197
1144
        .upstreamInfo()
198
1144
        ->upstreamTiming()
199
1144
        .onUpstreamHandshakeComplete(callbacks_->connection().dispatcher().timeSource());
200
1288
  } else {
201
955
    callbacks_->connection().streamInfo().downstreamTiming().onDownstreamHandshakeComplete(
202
955
        callbacks_->connection().dispatcher().timeSource());
203
955
  }
204
2099
  callbacks_->raiseEvent(Network::ConnectionEvent::Connected);
205
2099
}
206

            
207
271
void SslSocket::onFailure() { drainErrorQueue(); }
208

            
209
7261
PostIoAction SslSocket::doHandshake() { return info_->doHandshake(); }
210

            
211
2775
void SslSocket::drainErrorQueue() {
212
2775
  bool saw_error = false;
213
2775
  bool saw_counted_error = false;
214
2775
  bool saw_cert_verify_failed = false;
215
2775
  bool saw_no_client_cert = false;
216
3075
  while (uint64_t err = ERR_get_error()) {
217
300
    if (ERR_GET_LIB(err) == ERR_LIB_SSL) {
218
248
      if (ERR_GET_REASON(err) == SSL_R_PEER_DID_NOT_RETURN_A_CERTIFICATE) {
219
14
        ctx_->stats().fail_verify_no_cert_.inc();
220
14
        saw_counted_error = true;
221
14
        saw_no_client_cert = true;
222
234
      } else if (ERR_GET_REASON(err) == SSL_R_CERTIFICATE_VERIFY_FAILED) {
223
56
        saw_counted_error = true;
224
56
        saw_cert_verify_failed = true;
225
56
      }
226
293
    } else if (ERR_GET_LIB(err) == ERR_LIB_SYS) {
227
      // Any syscall errors that result in connection closure are already tracked in other
228
      // connection related stats. We will still retain the specific syscall failure for
229
      // transport failure reasons.
230
50
      saw_counted_error = true;
231
50
    }
232
300
    saw_error = true;
233

            
234
300
    if (failure_reason_.empty()) {
235
279
      failure_reason_ = "TLS_error:";
236
279
    }
237

            
238
300
    absl::StrAppend(&failure_reason_, "|", err, ":",
239
300
                    absl::NullSafeStringView(ERR_lib_error_string(err)), ":",
240
300
                    absl::NullSafeStringView(ERR_func_error_string(err)), ":",
241
300
                    absl::NullSafeStringView(ERR_reason_error_string(err)));
242
300
  }
243

            
244
2775
  if (!saw_error) {
245
2496
    return;
246
2496
  }
247

            
248
  // Append detailed error info for certificate-related failures.
249
279
  bool added_detail = false;
250
279
  if (saw_cert_verify_failed) {
251
56
    auto* extended_socket_info = reinterpret_cast<Envoy::Ssl::SslExtendedSocketInfo*>(
252
56
        SSL_get_ex_data(rawSsl(), ContextImpl::sslExtendedSocketInfoIndex()));
253
56
    if (extended_socket_info != nullptr) {
254
56
      absl::string_view cert_validation_error = extended_socket_info->certificateValidationError();
255
56
      if (!cert_validation_error.empty()) {
256
56
        absl::StrAppend(&failure_reason_, ":", cert_validation_error);
257
56
        added_detail = true;
258
56
      }
259
56
    }
260
56
  }
261
279
  if (!added_detail && saw_no_client_cert) {
262
14
    absl::StrAppend(&failure_reason_, ":peer did not provide required client certificate");
263
14
  }
264

            
265
279
  if (!failure_reason_.empty()) {
266
279
    absl::StrAppend(&failure_reason_, ":TLS_error_end");
267
279
    ENVOY_CONN_LOG(debug, "remote address:{},{}", callbacks_->connection(),
268
279
                   callbacks_->connection().connectionInfoProvider().remoteAddress()->asString(),
269
279
                   failure_reason_);
270
279
  }
271

            
272
279
  if (!saw_counted_error) {
273
159
    ctx_->stats().connection_error_.inc();
274
159
  }
275
279
}
276

            
277
21845
Network::IoResult SslSocket::doWrite(Buffer::Instance& write_buffer, bool end_stream) {
278
21845
  ASSERT(info_->state() != Ssl::SocketState::ShutdownSent || write_buffer.length() == 0);
279
21845
  if (info_->state() != Ssl::SocketState::HandshakeComplete &&
280
21845
      info_->state() != Ssl::SocketState::ShutdownSent) {
281
5763
    PostIoAction action = doHandshake();
282
5763
    if (action == PostIoAction::Close || info_->state() != Ssl::SocketState::HandshakeComplete) {
283
3874
      return {action, 0, false};
284
3874
    }
285
5763
  }
286

            
287
17971
  uint64_t bytes_to_write;
288
17971
  if (bytes_to_retry_) {
289
29
    bytes_to_write = bytes_to_retry_;
290
29
    bytes_to_retry_ = 0;
291
17942
  } else {
292
17942
    bytes_to_write = std::min(write_buffer.length(), static_cast<uint64_t>(16384));
293
17942
  }
294

            
295
17971
  uint64_t total_bytes_written = 0;
296
37594
  while (bytes_to_write > 0) {
297
    // TODO(mattklein123): As it relates to our fairness efforts, we might want to limit the number
298
    // of iterations of this loop, either by pure iterations, bytes written, etc.
299

            
300
    // SSL_write() requires that if a previous call returns SSL_ERROR_WANT_WRITE, we need to call
301
    // it again with the same parameters. This is done by tracking last write size, but not write
302
    // data, since linearize() will return the same undrained data anyway.
303
19676
    ASSERT(bytes_to_write <= write_buffer.length());
304
19676
    int rc = SSL_write(rawSsl(), write_buffer.linearize(bytes_to_write), bytes_to_write);
305
19676
    ENVOY_CONN_LOG(trace, "ssl write returns: {}", callbacks_->connection(), rc);
306
19676
    if (rc > 0) {
307
19623
      ASSERT(rc == static_cast<int>(bytes_to_write));
308
19623
      total_bytes_written += rc;
309
19623
      write_buffer.drain(rc);
310
19623
      bytes_to_write = std::min(write_buffer.length(), static_cast<uint64_t>(16384));
311
19623
    } else {
312
53
      int err = SSL_get_error(rawSsl(), rc);
313
53
      ENVOY_CONN_LOG(trace, "ssl error occurred while write: {}", callbacks_->connection(),
314
53
                     Utility::getErrorDescription(err));
315
53
      switch (err) {
316
29
      case SSL_ERROR_WANT_WRITE:
317
29
        bytes_to_retry_ = bytes_to_write;
318
29
        break;
319
      case SSL_ERROR_WANT_READ:
320
      // Renegotiation has started. We don't handle renegotiation so just fall through.
321
24
      default:
322
24
        drainErrorQueue();
323
24
        return {PostIoAction::Close, total_bytes_written, false};
324
53
      }
325

            
326
29
      break;
327
53
    }
328
19676
  }
329

            
330
17948
  if (write_buffer.length() == 0 && end_stream) {
331
89
    shutdownSsl();
332
89
  }
333

            
334
17947
  return {PostIoAction::KeepOpen, total_bytes_written, false};
335
17971
}
336

            
337
1403
void SslSocket::onConnected() { ASSERT(info_->state() == Ssl::SocketState::PreHandshake); }
338

            
339
2619
Ssl::ConnectionInfoConstSharedPtr SslSocket::ssl() const { return info_; }
340

            
341
2156
void SslSocket::shutdownSsl() {
342
2156
  ASSERT(info_->state() != Ssl::SocketState::PreHandshake);
343
2156
  if (info_->state() != Ssl::SocketState::ShutdownSent &&
344
2156
      callbacks_->connection().state() != Network::Connection::State::Closed) {
345
2113
    int rc = SSL_shutdown(rawSsl());
346
    if constexpr (Event::PlatformDefaultTriggerType == Event::FileTriggerType::EmulatedEdge) {
347
      // Windows operate under `EmulatedEdge`. These are level events that are artificially
348
      // made to behave like edge events. And if the rc is 0 then in that case we want read
349
      // activation resumption. This code is protected with an `constexpr` if, to minimize the tax
350
      // on POSIX systems that operate in Edge events.
351
      if (rc == 0) {
352
        // See https://www.openssl.org/docs/manmaster/man3/SSL_shutdown.html
353
        // if return value is 0,  Call SSL_read() to do a bidirectional shutdown.
354
        callbacks_->setTransportSocketIsReadable();
355
      }
356
    }
357
2113
    ENVOY_CONN_LOG(debug, "SSL shutdown: rc={}", callbacks_->connection(), rc);
358
2113
    drainErrorQueue();
359
2113
    info_->setState(Ssl::SocketState::ShutdownSent);
360
2113
  }
361
2156
}
362

            
363
404
void SslSocket::shutdownBasic() {
364
404
  if (info_->state() != Ssl::SocketState::ShutdownSent &&
365
404
      callbacks_->connection().state() != Network::Connection::State::Closed) {
366
358
    callbacks_->ioHandle().shutdown(ENVOY_SHUT_WR);
367
358
    drainErrorQueue();
368
358
    info_->setState(Ssl::SocketState::ShutdownSent);
369
358
  }
370
404
}
371

            
372
2471
void SslSocket::closeSocket(Network::ConnectionEvent) {
373
  // Unregister the SSL connection object from private key method providers.
374
2471
  for (auto const& provider : ctx_->getPrivateKeyMethodProviders()) {
375
19
    provider->unregisterPrivateKeyMethod(rawSsl());
376
19
  }
377

            
378
  // Attempt to send a shutdown before closing the socket. It's possible this won't go out if
379
  // there is no room on the socket. We can extend the state machine to handle this at some point
380
  // if needed.
381
2471
  if (info_->state() == Ssl::SocketState::HandshakeInProgress ||
382
2471
      info_->state() == Ssl::SocketState::HandshakeComplete) {
383
2067
    shutdownSsl();
384
2111
  } else {
385
    // We're not in a state to do the full SSL shutdown so perform a basic shutdown to flush any
386
    // outstanding alerts
387
404
    shutdownBasic();
388
404
  }
389
2471
}
390

            
391
75
std::string SslSocket::protocol() const { return ssl()->alpn(); }
392

            
393
4058
absl::string_view SslSocket::failureReason() const { return failure_reason_; }
394

            
395
6
void SslSocket::onAsynchronousCertValidationComplete() {
396
6
  ENVOY_CONN_LOG(debug, "Async cert validation completed", callbacks_->connection());
397
6
  if (info_->state() == Ssl::SocketState::HandshakeInProgress) {
398
5
    resumeHandshake();
399
5
  }
400
6
}
401

            
402
31
void SslSocket::onAsynchronousCertificateSelectionComplete() {
403
31
  ENVOY_CONN_LOG(debug, "Async cert selection completed", callbacks_->connection());
404
31
  if (info_->state() != Ssl::SocketState::HandshakeInProgress) {
405
    IS_ENVOY_BUG(fmt::format("unexpected handshake state: {}", static_cast<int>(info_->state())));
406
    return;
407
  }
408
31
  resumeHandshake();
409
31
}
410

            
411
} // namespace Tls
412
} // namespace TransportSockets
413
} // namespace Extensions
414
} // namespace Envoy