Coverage Report

Created: 2024-09-19 09:45

/proc/self/cwd/source/common/tls/ssl_socket.cc
Line
Count
Source (jump to first uncovered line)
1
#include "source/common/tls/ssl_socket.h"
2
3
#include "envoy/stats/scope.h"
4
5
#include "source/common/common/assert.h"
6
#include "source/common/common/empty_string.h"
7
#include "source/common/common/hex.h"
8
#include "source/common/http/headers.h"
9
#include "source/common/tls/io_handle_bio.h"
10
#include "source/common/tls/ssl_handshaker.h"
11
#include "source/common/tls/utility.h"
12
13
#include "absl/strings/str_replace.h"
14
#include "openssl/err.h"
15
#include "openssl/x509v3.h"
16
17
using Envoy::Network::PostIoAction;
18
19
namespace Envoy {
20
namespace Extensions {
21
namespace TransportSockets {
22
namespace Tls {
23
24
namespace {
25
26
constexpr absl::string_view NotReadyReason{"TLS error: Secret is not supplied by SDS"};
27
28
} // namespace
29
30
0
absl::string_view NotReadySslSocket::failureReason() const { return NotReadyReason; }
31
32
absl::StatusOr<std::unique_ptr<SslSocket>>
33
SslSocket::create(Envoy::Ssl::ContextSharedPtr ctx, InitialState state,
34
                  const Network::TransportSocketOptionsConstSharedPtr& transport_socket_options,
35
0
                  Ssl::HandshakerFactoryCb handshaker_factory_cb) {
36
0
  std::unique_ptr<SslSocket> socket(new SslSocket(ctx, transport_socket_options));
37
0
  auto status = socket->initialize(state, handshaker_factory_cb);
38
0
  if (status.ok()) {
39
0
    return socket;
40
0
  } else {
41
0
    return status;
42
0
  }
43
0
}
44
45
SslSocket::SslSocket(Envoy::Ssl::ContextSharedPtr ctx,
46
                     const Network::TransportSocketOptionsConstSharedPtr& transport_socket_options)
47
    : transport_socket_options_(transport_socket_options),
48
0
      ctx_(std::dynamic_pointer_cast<ContextImpl>(ctx)) {}
49
50
absl::Status SslSocket::initialize(InitialState state,
51
0
                                   Ssl::HandshakerFactoryCb handshaker_factory_cb) {
52
0
  auto status_or_ssl = ctx_->newSsl(transport_socket_options_);
53
0
  if (!status_or_ssl.ok()) {
54
0
    return status_or_ssl.status();
55
0
  }
56
57
0
  info_ = std::dynamic_pointer_cast<SslHandshakerImpl>(handshaker_factory_cb(
58
0
      std::move(status_or_ssl.value()), ctx_->sslExtendedSocketInfoIndex(), this));
59
60
0
  if (state == InitialState::Client) {
61
0
    SSL_set_connect_state(rawSsl());
62
0
  } else {
63
0
    ASSERT(state == InitialState::Server);
64
0
    SSL_set_accept_state(rawSsl());
65
0
  }
66
67
0
  return absl::OkStatus();
68
0
}
69
70
0
void SslSocket::setTransportSocketCallbacks(Network::TransportSocketCallbacks& callbacks) {
71
0
  ASSERT(!callbacks_);
72
0
  callbacks_ = &callbacks;
73
74
  // Associate this SSL connection with all the certificates (with their potentially different
75
  // private key methods).
76
0
  for (auto const& provider : ctx_->getPrivateKeyMethodProviders()) {
77
0
    provider->registerPrivateKeyMethod(rawSsl(), *this, callbacks_->connection().dispatcher());
78
0
  }
79
80
  // Use custom BIO that reads from/writes to IoHandle
81
0
  BIO* bio = BIO_new_io_handle(&callbacks_->ioHandle());
82
0
  SSL_set_bio(rawSsl(), bio, bio);
83
0
  SSL_set_ex_data(rawSsl(), ContextImpl::sslSocketIndex(), static_cast<void*>(callbacks_));
84
0
}
85
86
0
SslSocket::ReadResult SslSocket::sslReadIntoSlice(Buffer::RawSlice& slice) {
87
0
  ReadResult result;
88
0
  uint8_t* mem = static_cast<uint8_t*>(slice.mem_);
89
0
  size_t remaining = slice.len_;
90
0
  while (remaining > 0) {
91
0
    int rc = SSL_read(rawSsl(), mem, remaining);
92
0
    ENVOY_CONN_LOG(trace, "ssl read returns: {}", callbacks_->connection(), rc);
93
0
    if (rc > 0) {
94
0
      ASSERT(static_cast<size_t>(rc) <= remaining);
95
0
      mem += rc;
96
0
      remaining -= rc;
97
0
      result.bytes_read_ += rc;
98
0
    } else {
99
0
      result.error_ = absl::make_optional<int>(rc);
100
0
      break;
101
0
    }
102
0
  }
103
104
0
  return result;
105
0
}
106
107
0
Network::IoResult SslSocket::doRead(Buffer::Instance& read_buffer) {
108
0
  if (info_->state() != Ssl::SocketState::HandshakeComplete &&
109
0
      info_->state() != Ssl::SocketState::ShutdownSent) {
110
0
    PostIoAction action = doHandshake();
111
0
    if (action == PostIoAction::Close || info_->state() != Ssl::SocketState::HandshakeComplete) {
112
      // end_stream is false because either a hard error occurred (action == Close) or
113
      // the handshake isn't complete, so a half-close cannot occur yet.
114
0
      return {action, 0, false};
115
0
    }
116
0
  }
117
118
0
  bool keep_reading = true;
119
0
  bool end_stream = false;
120
0
  PostIoAction action = PostIoAction::KeepOpen;
121
0
  uint64_t bytes_read = 0;
122
0
  while (keep_reading) {
123
0
    uint64_t bytes_read_this_iteration = 0;
124
0
    Buffer::Reservation reservation = read_buffer.reserveForRead();
125
0
    for (uint64_t i = 0; i < reservation.numSlices(); i++) {
126
0
      auto result = sslReadIntoSlice(reservation.slices()[i]);
127
0
      bytes_read_this_iteration += result.bytes_read_;
128
0
      if (result.error_.has_value()) {
129
0
        keep_reading = false;
130
0
        int err = SSL_get_error(rawSsl(), result.error_.value());
131
0
        ENVOY_CONN_LOG(trace, "ssl error occurred while read: {}", callbacks_->connection(),
132
0
                       Utility::getErrorDescription(err));
133
0
        switch (err) {
134
0
        case SSL_ERROR_WANT_READ:
135
0
          break;
136
0
        case SSL_ERROR_ZERO_RETURN:
137
          // Graceful shutdown using close_notify TLS alert.
138
0
          end_stream = true;
139
0
          break;
140
0
        case SSL_ERROR_SYSCALL:
141
0
          if (result.error_.value() == 0) {
142
            // Non-graceful shutdown by closing the underlying socket.
143
0
            end_stream = true;
144
0
            break;
145
0
          }
146
0
          FALLTHRU;
147
0
        case SSL_ERROR_WANT_WRITE:
148
          // Renegotiation has started. We don't handle renegotiation so just fall through.
149
0
        default:
150
0
          drainErrorQueue();
151
0
          action = PostIoAction::Close;
152
0
          break;
153
0
        }
154
155
0
        break;
156
0
      }
157
0
    }
158
159
0
    reservation.commit(bytes_read_this_iteration);
160
0
    if (bytes_read_this_iteration > 0 && callbacks_->shouldDrainReadBuffer()) {
161
0
      callbacks_->setTransportSocketIsReadable();
162
0
      keep_reading = false;
163
0
    }
164
165
0
    bytes_read += bytes_read_this_iteration;
166
0
  }
167
168
0
  ENVOY_CONN_LOG(trace, "ssl read {} bytes", callbacks_->connection(), bytes_read);
169
170
0
  return {action, bytes_read, end_stream};
171
0
}
172
173
0
void SslSocket::onPrivateKeyMethodComplete() { resumeHandshake(); }
174
175
0
void SslSocket::resumeHandshake() {
176
0
  ASSERT(callbacks_ != nullptr && callbacks_->connection().dispatcher().isThreadSafe());
177
0
  ASSERT(info_->state() == Ssl::SocketState::HandshakeInProgress);
178
179
  // Resume handshake.
180
0
  PostIoAction action = doHandshake();
181
0
  if (action == PostIoAction::Close) {
182
0
    ENVOY_CONN_LOG(debug, "async handshake completion error", callbacks_->connection());
183
0
    callbacks_->connection().close(Network::ConnectionCloseType::FlushWrite,
184
0
                                   "failed_resuming_async_handshake");
185
0
  }
186
0
}
187
188
0
Network::Connection& SslSocket::connection() const { return callbacks_->connection(); }
189
190
0
void SslSocket::onSuccess(SSL* ssl) {
191
0
  ctx_->logHandshake(ssl);
192
0
  if (callbacks_->connection().streamInfo().upstreamInfo()) {
193
0
    callbacks_->connection()
194
0
        .streamInfo()
195
0
        .upstreamInfo()
196
0
        ->upstreamTiming()
197
0
        .onUpstreamHandshakeComplete(callbacks_->connection().dispatcher().timeSource());
198
0
  } else {
199
0
    callbacks_->connection().streamInfo().downstreamTiming().onDownstreamHandshakeComplete(
200
0
        callbacks_->connection().dispatcher().timeSource());
201
0
  }
202
0
  callbacks_->raiseEvent(Network::ConnectionEvent::Connected);
203
0
}
204
205
0
void SslSocket::onFailure() { drainErrorQueue(); }
206
207
0
PostIoAction SslSocket::doHandshake() { return info_->doHandshake(); }
208
209
0
void SslSocket::drainErrorQueue() {
210
0
  bool saw_error = false;
211
0
  bool saw_counted_error = false;
212
0
  while (uint64_t err = ERR_get_error()) {
213
0
    if (ERR_GET_LIB(err) == ERR_LIB_SSL) {
214
0
      if (ERR_GET_REASON(err) == SSL_R_PEER_DID_NOT_RETURN_A_CERTIFICATE) {
215
0
        ctx_->stats().fail_verify_no_cert_.inc();
216
0
        saw_counted_error = true;
217
0
      } else if (ERR_GET_REASON(err) == SSL_R_CERTIFICATE_VERIFY_FAILED) {
218
0
        saw_counted_error = true;
219
0
      }
220
0
    } else if (ERR_GET_LIB(err) == ERR_LIB_SYS) {
221
      // Any syscall errors that result in connection closure are already tracked in other
222
      // connection related stats. We will still retain the specific syscall failure for
223
      // transport failure reasons.
224
0
      saw_counted_error = true;
225
0
    }
226
0
    saw_error = true;
227
228
0
    if (failure_reason_.empty()) {
229
0
      failure_reason_ = "TLS_error:";
230
0
    }
231
232
0
    absl::StrAppend(&failure_reason_, "|", err, ":",
233
0
                    absl::NullSafeStringView(ERR_lib_error_string(err)), ":",
234
0
                    absl::NullSafeStringView(ERR_func_error_string(err)), ":",
235
0
                    absl::NullSafeStringView(ERR_reason_error_string(err)));
236
0
  }
237
238
0
  if (!saw_error) {
239
0
    return;
240
0
  }
241
242
0
  if (!failure_reason_.empty()) {
243
0
    absl::StrAppend(&failure_reason_, ":TLS_error_end");
244
0
    ENVOY_CONN_LOG(debug, "remote address:{},{}", callbacks_->connection(),
245
0
                   callbacks_->connection().connectionInfoProvider().remoteAddress()->asString(),
246
0
                   failure_reason_);
247
0
  }
248
249
0
  if (!saw_counted_error) {
250
0
    ctx_->stats().connection_error_.inc();
251
0
  }
252
0
}
253
254
0
Network::IoResult SslSocket::doWrite(Buffer::Instance& write_buffer, bool end_stream) {
255
0
  ASSERT(info_->state() != Ssl::SocketState::ShutdownSent || write_buffer.length() == 0);
256
0
  if (info_->state() != Ssl::SocketState::HandshakeComplete &&
257
0
      info_->state() != Ssl::SocketState::ShutdownSent) {
258
0
    PostIoAction action = doHandshake();
259
0
    if (action == PostIoAction::Close || info_->state() != Ssl::SocketState::HandshakeComplete) {
260
0
      return {action, 0, false};
261
0
    }
262
0
  }
263
264
0
  uint64_t bytes_to_write;
265
0
  if (bytes_to_retry_) {
266
0
    bytes_to_write = bytes_to_retry_;
267
0
    bytes_to_retry_ = 0;
268
0
  } else {
269
0
    bytes_to_write = std::min(write_buffer.length(), static_cast<uint64_t>(16384));
270
0
  }
271
272
0
  uint64_t total_bytes_written = 0;
273
0
  while (bytes_to_write > 0) {
274
    // TODO(mattklein123): As it relates to our fairness efforts, we might want to limit the number
275
    // of iterations of this loop, either by pure iterations, bytes written, etc.
276
277
    // SSL_write() requires that if a previous call returns SSL_ERROR_WANT_WRITE, we need to call
278
    // it again with the same parameters. This is done by tracking last write size, but not write
279
    // data, since linearize() will return the same undrained data anyway.
280
0
    ASSERT(bytes_to_write <= write_buffer.length());
281
0
    int rc = SSL_write(rawSsl(), write_buffer.linearize(bytes_to_write), bytes_to_write);
282
0
    ENVOY_CONN_LOG(trace, "ssl write returns: {}", callbacks_->connection(), rc);
283
0
    if (rc > 0) {
284
0
      ASSERT(rc == static_cast<int>(bytes_to_write));
285
0
      total_bytes_written += rc;
286
0
      write_buffer.drain(rc);
287
0
      bytes_to_write = std::min(write_buffer.length(), static_cast<uint64_t>(16384));
288
0
    } else {
289
0
      int err = SSL_get_error(rawSsl(), rc);
290
0
      ENVOY_CONN_LOG(trace, "ssl error occurred while write: {}", callbacks_->connection(),
291
0
                     Utility::getErrorDescription(err));
292
0
      switch (err) {
293
0
      case SSL_ERROR_WANT_WRITE:
294
0
        bytes_to_retry_ = bytes_to_write;
295
0
        break;
296
0
      case SSL_ERROR_WANT_READ:
297
      // Renegotiation has started. We don't handle renegotiation so just fall through.
298
0
      default:
299
0
        drainErrorQueue();
300
0
        return {PostIoAction::Close, total_bytes_written, false};
301
0
      }
302
303
0
      break;
304
0
    }
305
0
  }
306
307
0
  if (write_buffer.length() == 0 && end_stream) {
308
0
    shutdownSsl();
309
0
  }
310
311
0
  return {PostIoAction::KeepOpen, total_bytes_written, false};
312
0
}
313
314
0
void SslSocket::onConnected() { ASSERT(info_->state() == Ssl::SocketState::PreHandshake); }
315
316
0
Ssl::ConnectionInfoConstSharedPtr SslSocket::ssl() const { return info_; }
317
318
0
void SslSocket::shutdownSsl() {
319
0
  ASSERT(info_->state() != Ssl::SocketState::PreHandshake);
320
0
  if (info_->state() != Ssl::SocketState::ShutdownSent &&
321
0
      callbacks_->connection().state() != Network::Connection::State::Closed) {
322
0
    int rc = SSL_shutdown(rawSsl());
323
0
    if constexpr (Event::PlatformDefaultTriggerType == Event::FileTriggerType::EmulatedEdge) {
324
      // Windows operate under `EmulatedEdge`. These are level events that are artificially
325
      // made to behave like edge events. And if the rc is 0 then in that case we want read
326
      // activation resumption. This code is protected with an `constexpr` if, to minimize the tax
327
      // on POSIX systems that operate in Edge events.
328
0
      if (rc == 0) {
329
        // See https://www.openssl.org/docs/manmaster/man3/SSL_shutdown.html
330
        // if return value is 0,  Call SSL_read() to do a bidirectional shutdown.
331
0
        callbacks_->setTransportSocketIsReadable();
332
0
      }
333
0
    }
334
0
    ENVOY_CONN_LOG(debug, "SSL shutdown: rc={}", callbacks_->connection(), rc);
335
0
    drainErrorQueue();
336
0
    info_->setState(Ssl::SocketState::ShutdownSent);
337
0
  }
338
0
}
339
340
0
void SslSocket::shutdownBasic() {
341
0
  if (info_->state() != Ssl::SocketState::ShutdownSent &&
342
0
      callbacks_->connection().state() != Network::Connection::State::Closed) {
343
0
    callbacks_->ioHandle().shutdown(ENVOY_SHUT_WR);
344
0
    drainErrorQueue();
345
0
    info_->setState(Ssl::SocketState::ShutdownSent);
346
0
  }
347
0
}
348
349
0
void SslSocket::closeSocket(Network::ConnectionEvent) {
350
  // Unregister the SSL connection object from private key method providers.
351
0
  for (auto const& provider : ctx_->getPrivateKeyMethodProviders()) {
352
0
    provider->unregisterPrivateKeyMethod(rawSsl());
353
0
  }
354
355
  // Attempt to send a shutdown before closing the socket. It's possible this won't go out if
356
  // there is no room on the socket. We can extend the state machine to handle this at some point
357
  // if needed.
358
0
  if (info_->state() == Ssl::SocketState::HandshakeInProgress ||
359
0
      info_->state() == Ssl::SocketState::HandshakeComplete) {
360
0
    shutdownSsl();
361
0
  } else {
362
    // We're not in a state to do the full SSL shutdown so perform a basic shutdown to flush any
363
    // outstanding alerts
364
0
    shutdownBasic();
365
0
  }
366
0
}
367
368
0
std::string SslSocket::protocol() const { return ssl()->alpn(); }
369
370
0
absl::string_view SslSocket::failureReason() const { return failure_reason_; }
371
372
0
void SslSocket::onAsynchronousCertValidationComplete() {
373
0
  ENVOY_CONN_LOG(debug, "Async cert validation completed", callbacks_->connection());
374
0
  if (info_->state() == Ssl::SocketState::HandshakeInProgress) {
375
0
    resumeHandshake();
376
0
  }
377
0
}
378
379
0
void SslSocket::onAsynchronousCertificateSelectionComplete() {
380
0
  ENVOY_CONN_LOG(debug, "Async cert selection completed", callbacks_->connection());
381
0
  if (info_->state() != Ssl::SocketState::HandshakeInProgress) {
382
0
    IS_ENVOY_BUG(fmt::format("unexpected handshake state: {}", static_cast<int>(info_->state())));
383
0
    return;
384
0
  }
385
0
  resumeHandshake();
386
0
}
387
388
} // namespace Tls
389
} // namespace TransportSockets
390
} // namespace Extensions
391
} // namespace Envoy