1
#include "source/extensions/transport_sockets/alts/tsi_socket.h"
2

            
3
#include <algorithm>
4
#include <memory>
5
#include <string>
6
#include <utility>
7

            
8
#include "source/common/common/assert.h"
9
#include "source/common/common/cleanup.h"
10
#include "source/common/common/empty_string.h"
11
#include "source/common/common/enum_to_int.h"
12
#include "source/common/network/raw_buffer_socket.h"
13

            
14
namespace Envoy {
15
namespace Extensions {
16
namespace TransportSockets {
17
namespace Alts {
18

            
19
TsiSocket::TsiSocket(HandshakerFactory handshaker_factory, HandshakeValidator handshake_validator,
20
                     Network::TransportSocketPtr&& raw_socket, bool downstream)
21
54
    : handshaker_factory_(handshaker_factory), handshake_validator_(handshake_validator),
22
54
      raw_buffer_socket_(std::move(raw_socket)), downstream_(downstream) {}
23

            
24
TsiSocket::TsiSocket(HandshakerFactory handshaker_factory, HandshakeValidator handshake_validator,
25
                     bool downstream)
26
16
    : TsiSocket(handshaker_factory, handshake_validator,
27
16
                std::make_unique<Network::RawBufferSocket>(), downstream) {
28
16
  raw_read_buffer_.setWatermarks(default_max_frame_size_);
29
16
}
30

            
31
54
TsiSocket::~TsiSocket() { ASSERT(!handshaker_); }
32

            
33
54
void TsiSocket::setTransportSocketCallbacks(Envoy::Network::TransportSocketCallbacks& callbacks) {
34
54
  ASSERT(!callbacks_);
35
54
  callbacks_ = &callbacks;
36

            
37
54
  tsi_callbacks_ = std::make_unique<TsiTransportSocketCallbacks>(callbacks, raw_read_buffer_);
38
54
  raw_buffer_socket_->setTransportSocketCallbacks(*tsi_callbacks_);
39
54
}
40

            
41
4
std::string TsiSocket::protocol() const {
42
  // TSI doesn't have a generic way to indicate application layer protocol.
43
  // TODO(lizan): support application layer protocol from TSI for known TSIs.
44
4
  return EMPTY_STRING;
45
4
}
46

            
47
17
absl::string_view TsiSocket::failureReason() const {
48
  // TODO(htuch): Implement error reason for TSI.
49
17
  return EMPTY_STRING;
50
17
}
51

            
52
98
Network::PostIoAction TsiSocket::doHandshake() {
53
98
  ASSERT(!handshake_complete_);
54
98
  ENVOY_CONN_LOG(debug, "TSI: doHandshake", callbacks_->connection());
55
98
  if (!handshaker_next_calling_ && raw_read_buffer_.length() > 0) {
56
48
    return doHandshakeNext();
57
48
  }
58
50
  return Network::PostIoAction::KeepOpen;
59
98
}
60

            
61
71
Network::PostIoAction TsiSocket::doHandshakeNext() {
62
71
  ENVOY_CONN_LOG(debug, "TSI: doHandshake next: received: {}", callbacks_->connection(),
63
71
                 raw_read_buffer_.length());
64

            
65
71
  if (!handshaker_) {
66
41
    handshaker_ =
67
41
        handshaker_factory_(callbacks_->connection().dispatcher(),
68
41
                            callbacks_->connection().connectionInfoProvider().localAddress(),
69
41
                            callbacks_->connection().connectionInfoProvider().remoteAddress());
70
41
    if (!handshaker_) {
71
2
      ENVOY_CONN_LOG(warn, "TSI: failed to create handshaker", callbacks_->connection());
72
2
      callbacks_->connection().close(Network::ConnectionCloseType::NoFlush,
73
2
                                     "failed_creating_handshaker");
74
2
      return Network::PostIoAction::Close;
75
2
    }
76

            
77
39
    handshaker_->setHandshakerCallbacks(*this);
78
39
  }
79

            
80
69
  handshaker_next_calling_ = true;
81
69
  Buffer::OwnedImpl handshaker_buffer;
82
69
  handshaker_buffer.move(raw_read_buffer_);
83
69
  absl::Status status = handshaker_->next(handshaker_buffer);
84
69
  if (!status.ok()) {
85
5
    ENVOY_CONN_LOG(debug, "TSI: Handshake failed: status: {}", callbacks_->connection(), status);
86
5
    return Network::PostIoAction::Close;
87
5
  }
88
64
  return Network::PostIoAction::KeepOpen;
89
69
}
90

            
91
67
Network::PostIoAction TsiSocket::doHandshakeNextDone(NextResultPtr&& next_result) {
92
67
  ASSERT(next_result);
93

            
94
67
  ENVOY_CONN_LOG(debug, "TSI: doHandshake next done: status: {} to_send: {}",
95
67
                 callbacks_->connection(), next_result->status_, next_result->to_send_->length());
96

            
97
67
  absl::Status status = next_result->status_;
98
67
  AltsHandshakeResult* handshake_result = next_result->result_.get();
99
67
  if (!status.ok()) {
100
3
    ENVOY_CONN_LOG(debug, "TSI: Handshake failed: status: {}", callbacks_->connection(), status);
101
3
    return Network::PostIoAction::Close;
102
3
  }
103

            
104
64
  if (next_result->to_send_->length() > 0) {
105
50
    raw_write_buffer_.move(*next_result->to_send_);
106
50
  }
107

            
108
64
  if (status.ok() && handshake_result != nullptr) {
109
30
    if (handshake_validator_) {
110
7
      std::string err;
111
7
      TsiInfo tsi_info;
112
7
      tsi_info.peer_identity_ = handshake_result->peer_identity;
113
7
      const bool peer_validated = handshake_validator_(tsi_info, err);
114
7
      if (peer_validated) {
115
4
        ENVOY_CONN_LOG(debug, "TSI: Handshake validation succeeded.", callbacks_->connection());
116
4
      } else {
117
3
        ENVOY_CONN_LOG(debug, "TSI: Handshake validation failed: {}", callbacks_->connection(),
118
3
                       err);
119
3
        return Network::PostIoAction::Close;
120
3
      }
121
4
      Protobuf::Struct dynamic_metadata;
122
4
      Protobuf::Value val;
123
4
      val.set_string_value(tsi_info.peer_identity_);
124
4
      dynamic_metadata.mutable_fields()->insert({std::string("peer_identity"), val});
125
4
      callbacks_->connection().streamInfo().setDynamicMetadata(
126
4
          "envoy.transport_sockets.peer_information", dynamic_metadata);
127
4
      ENVOY_CONN_LOG(debug, "TSI handshake with peer: {}", callbacks_->connection(),
128
4
                     tsi_info.peer_identity_);
129
23
    } else {
130
23
      ENVOY_CONN_LOG(debug, "TSI: Handshake validation skipped.", callbacks_->connection());
131
23
    }
132

            
133
27
    if (!handshake_result->unused_bytes.empty()) {
134
      // All handshake data is consumed.
135
2
      ASSERT(raw_read_buffer_.length() == 0);
136
2
      absl::string_view unused_bytes(
137
2
          reinterpret_cast<const char*>(handshake_result->unused_bytes.data()),
138
2
          handshake_result->unused_bytes.size());
139
2
      raw_read_buffer_.prepend(unused_bytes);
140
2
    }
141
27
    ENVOY_CONN_LOG(debug, "TSI: Handshake successful: unused_bytes: {}", callbacks_->connection(),
142
27
                   handshake_result->unused_bytes.size());
143
    // Reset the watermarks with actual negotiated max frame size.
144
27
    raw_read_buffer_.setWatermarks(
145
27
        std::max<size_t>(actual_frame_size_to_use_, callbacks_->connection().bufferLimit()));
146
27
    frame_protector_ = std::move(handshake_result->frame_protector);
147

            
148
27
    handshake_complete_ = true;
149
27
    if (raw_write_buffer_.length() == 0) {
150
13
      callbacks_->raiseEvent(Network::ConnectionEvent::Connected);
151
13
    }
152
27
  }
153

            
154
61
  if (read_error_ || (!handshake_complete_ && end_stream_read_)) {
155
    ENVOY_CONN_LOG(debug, "TSI: Handshake failed: end of stream without enough data",
156
                   callbacks_->connection());
157
    return Network::PostIoAction::Close;
158
  }
159

            
160
61
  if (raw_read_buffer_.length() > 0) {
161
2
    callbacks_->setTransportSocketIsReadable();
162
2
  }
163

            
164
  // Try to write raw buffer when next call is done, even this is not in do[Read|Write] stack.
165
61
  if (raw_write_buffer_.length() > 0) {
166
48
    Network::IoResult result = raw_buffer_socket_->doWrite(raw_write_buffer_, false);
167
48
    if (handshake_complete_ && result.action_ != Network::PostIoAction::Close) {
168
14
      callbacks_->raiseEvent(Network::ConnectionEvent::Connected);
169
14
    }
170
48
    return result.action_;
171
48
  }
172

            
173
13
  return Network::PostIoAction::KeepOpen;
174
61
}
175

            
176
Network::IoResult TsiSocket::repeatReadAndUnprotect(Buffer::Instance& buffer,
177
39
                                                    Network::IoResult prev_result) {
178
39
  Network::IoResult result = prev_result;
179
39
  uint64_t total_bytes_processed = 0;
180

            
181
54
  while (true) {
182
    // Do unprotect.
183
54
    if (raw_read_buffer_.length() > 0) {
184
17
      uint64_t prev_size = buffer.length();
185
17
      ENVOY_CONN_LOG(debug, "TSI: unprotecting buffer size: {}", callbacks_->connection(),
186
17
                     raw_read_buffer_.length());
187
17
      tsi_result status = frame_protector_->unprotect(raw_read_buffer_, buffer);
188
17
      if (status != TSI_OK) {
189
1
        ENVOY_CONN_LOG(debug, "TSI: unprotect failed: status: {}", callbacks_->connection(),
190
1
                       tsi_result_to_string(status));
191
1
        result.action_ = Network::PostIoAction::Close;
192
1
        break;
193
1
      }
194
16
      ASSERT(raw_read_buffer_.length() == 0);
195
16
      ENVOY_CONN_LOG(debug, "TSI: unprotected buffer left: {} result: {}", callbacks_->connection(),
196
16
                     raw_read_buffer_.length(), tsi_result_to_string(status));
197
16
      total_bytes_processed += buffer.length() - prev_size;
198

            
199
      // Check if buffer needs to be drained.
200
16
      if (callbacks_->shouldDrainReadBuffer()) {
201
2
        callbacks_->setTransportSocketIsReadable();
202
2
        break;
203
2
      }
204
16
    }
205

            
206
51
    if (result.action_ == Network::PostIoAction::Close) {
207
1
      break;
208
1
    }
209

            
210
    // End of stream is reached in the previous read.
211
50
    if (end_stream_read_) {
212
4
      result.end_stream_read_ = true;
213
4
      break;
214
4
    }
215
    // Do another read.
216
46
    result = readFromRawSocket();
217
    // No data is read.
218
46
    if (result.bytes_processed_ == 0) {
219
31
      break;
220
31
    }
221
46
  };
222
39
  result.bytes_processed_ = total_bytes_processed;
223
39
  ENVOY_CONN_LOG(debug, "TSI: do read result action {} bytes {} end_stream {}",
224
39
                 callbacks_->connection(), enumToInt(result.action_), result.bytes_processed_,
225
39
                 result.end_stream_read_);
226
39
  return result;
227
39
}
228

            
229
99
Network::IoResult TsiSocket::readFromRawSocket() {
230
99
  Network::IoResult result = raw_buffer_socket_->doRead(raw_read_buffer_);
231
99
  end_stream_read_ = result.end_stream_read_;
232
99
  read_error_ = result.action_ == Network::PostIoAction::Close;
233
99
  return result;
234
99
}
235

            
236
72
Network::IoResult TsiSocket::doRead(Buffer::Instance& buffer) {
237
72
  Network::IoResult result = {Network::PostIoAction::KeepOpen, 0, false};
238
72
  if (!handshake_complete_) {
239
53
    if (!end_stream_read_ && !read_error_) {
240
53
      result = readFromRawSocket();
241
53
      ENVOY_CONN_LOG(debug, "TSI: raw read result action {} bytes {} end_stream {}",
242
53
                     callbacks_->connection(), enumToInt(result.action_), result.bytes_processed_,
243
53
                     result.end_stream_read_);
244
53
      if (result.action_ == Network::PostIoAction::Close && result.bytes_processed_ == 0) {
245
1
        return result;
246
1
      }
247

            
248
52
      if (result.end_stream_read_ && result.bytes_processed_ == 0) {
249
4
        return {Network::PostIoAction::Close, result.bytes_processed_, result.end_stream_read_};
250
4
      }
251
52
    }
252
48
    Network::PostIoAction action = doHandshake();
253
48
    if (action == Network::PostIoAction::Close || !handshake_complete_) {
254
28
      return {action, 0, false};
255
28
    }
256
48
  }
257
  // Handshake finishes.
258
39
  ASSERT(handshake_complete_);
259
39
  ASSERT(frame_protector_);
260
39
  return repeatReadAndUnprotect(buffer, result);
261
72
}
262

            
263
31
Network::IoResult TsiSocket::repeatProtectAndWrite(Buffer::Instance& buffer, bool end_stream) {
264
31
  uint64_t total_bytes_written = 0;
265
31
  Network::IoResult result = {Network::PostIoAction::KeepOpen, 0, false};
266
  // There should be no handshake bytes in raw_write_buffer_.
267
31
  ASSERT(!(raw_write_buffer_.length() > 0 && prev_bytes_to_drain_ == 0));
268
48
  while (true) {
269
48
    uint64_t bytes_to_drain_this_iteration =
270
48
        prev_bytes_to_drain_ > 0
271
48
            ? prev_bytes_to_drain_
272
48
            : std::min<uint64_t>(buffer.length(), actual_frame_size_to_use_ - frame_overhead_size_);
273
    // Consumed all data. Exit.
274
48
    if (bytes_to_drain_this_iteration == 0) {
275
30
      break;
276
30
    }
277
    // Short write did not occur previously.
278
18
    if (raw_write_buffer_.length() == 0) {
279
18
      ASSERT(frame_protector_);
280
18
      ASSERT(prev_bytes_to_drain_ == 0);
281

            
282
      // Do protect.
283
18
      ENVOY_CONN_LOG(debug, "TSI: protecting buffer size: {}", callbacks_->connection(),
284
18
                     bytes_to_drain_this_iteration);
285
18
      tsi_result status = frame_protector_->protect(
286
18
          grpc_slice_from_static_buffer(buffer.linearize(bytes_to_drain_this_iteration),
287
18
                                        bytes_to_drain_this_iteration),
288
18
          raw_write_buffer_);
289
18
      ENVOY_CONN_LOG(debug, "TSI: protected buffer left: {} result: {}", callbacks_->connection(),
290
18
                     bytes_to_drain_this_iteration, tsi_result_to_string(status));
291
18
    }
292

            
293
    // Write raw_write_buffer_ to network.
294
18
    ENVOY_CONN_LOG(debug, "TSI: raw_write length {} end_stream {}", callbacks_->connection(),
295
18
                   raw_write_buffer_.length(), end_stream);
296
18
    result = raw_buffer_socket_->doWrite(raw_write_buffer_, end_stream && (buffer.length() == 0));
297

            
298
    // Short write. Exit.
299
18
    if (raw_write_buffer_.length() > 0) {
300
1
      prev_bytes_to_drain_ = bytes_to_drain_this_iteration;
301
1
      break;
302
17
    } else {
303
17
      buffer.drain(bytes_to_drain_this_iteration);
304
17
      prev_bytes_to_drain_ = 0;
305
17
      total_bytes_written += bytes_to_drain_this_iteration;
306
17
    }
307
18
  }
308

            
309
31
  return {result.action_, total_bytes_written, false};
310
31
}
311

            
312
81
Network::IoResult TsiSocket::doWrite(Buffer::Instance& buffer, bool end_stream) {
313
81
  if (!handshake_complete_) {
314
50
    Network::PostIoAction action = doHandshake();
315
50
    ASSERT(!handshake_complete_);
316
50
    return {action, 0, false};
317
50
  } else {
318
31
    ASSERT(frame_protector_);
319
    // Check if we need to flush outstanding handshake bytes.
320
31
    if (raw_write_buffer_.length() > 0 && prev_bytes_to_drain_ == 0) {
321
1
      ENVOY_CONN_LOG(debug, "TSI: raw_write length {} end_stream {}", callbacks_->connection(),
322
1
                     raw_write_buffer_.length(), end_stream);
323
1
      Network::IoResult result =
324
1
          raw_buffer_socket_->doWrite(raw_write_buffer_, end_stream && (buffer.length() == 0));
325
      // Check if short write occurred.
326
1
      if (raw_write_buffer_.length() > 0) {
327
        return {result.action_, 0, false};
328
      }
329
1
    }
330
31
    return repeatProtectAndWrite(buffer, end_stream);
331
31
  }
332
81
}
333

            
334
52
void TsiSocket::closeSocket(Network::ConnectionEvent) {
335
52
  ENVOY_CONN_LOG(debug, "TSI: closing socket", callbacks_->connection());
336
52
  if (handshaker_) {
337
39
    handshaker_.release()->deferredDelete();
338
39
  }
339
52
}
340

            
341
31
void TsiSocket::onConnected() {
342
31
  ASSERT(!handshake_complete_);
343
  // Client initiates the handshake, so ignore onConnect call on the downstream.
344
31
  if (!downstream_) {
345
23
    doHandshakeNext();
346
23
  }
347
31
}
348

            
349
67
void TsiSocket::onNextDone(NextResultPtr&& result) {
350
67
  handshaker_next_calling_ = false;
351

            
352
67
  Network::PostIoAction action = doHandshakeNextDone(std::move(result));
353
67
  if (action == Network::PostIoAction::Close) {
354
6
    callbacks_->connection().close(Network::ConnectionCloseType::NoFlush, "tsi_handshake_failed");
355
6
  }
356
67
}
357

            
358
TsiSocketFactory::TsiSocketFactory(HandshakerFactory handshaker_factory,
359
                                   HandshakeValidator handshake_validator)
360
18
    : handshaker_factory_(std::move(handshaker_factory)),
361
18
      handshake_validator_(std::move(handshake_validator)) {}
362

            
363
2
bool TsiSocketFactory::implementsSecureTransport() const { return true; }
364

            
365
Network::TransportSocketPtr
366
TsiSocketFactory::createTransportSocket(Network::TransportSocketOptionsConstSharedPtr,
367
8
                                        Upstream::HostDescriptionConstSharedPtr) const {
368
8
  return std::make_unique<TsiSocket>(handshaker_factory_, handshake_validator_, false);
369
8
}
370

            
371
8
Network::TransportSocketPtr TsiSocketFactory::createDownstreamTransportSocket() const {
372
8
  return std::make_unique<TsiSocket>(handshaker_factory_, handshake_validator_, true);
373
8
}
374

            
375
} // namespace Alts
376
} // namespace TransportSockets
377
} // namespace Extensions
378
} // namespace Envoy