Line data Source code
1 : #include "source/extensions/transport_sockets/alts/tsi_socket.h"
2 :
3 : #include <algorithm>
4 : #include <memory>
5 : #include <string>
6 : #include <utility>
7 :
8 : #include "source/common/common/assert.h"
9 : #include "source/common/common/cleanup.h"
10 : #include "source/common/common/empty_string.h"
11 : #include "source/common/common/enum_to_int.h"
12 : #include "source/common/network/raw_buffer_socket.h"
13 :
14 : namespace Envoy {
15 : namespace Extensions {
16 : namespace TransportSockets {
17 : namespace Alts {
18 :
19 : TsiSocket::TsiSocket(HandshakerFactory handshaker_factory, HandshakeValidator handshake_validator,
20 : Network::TransportSocketPtr&& raw_socket, bool downstream)
21 : : handshaker_factory_(handshaker_factory), handshake_validator_(handshake_validator),
22 0 : raw_buffer_socket_(std::move(raw_socket)), downstream_(downstream) {}
23 :
24 : TsiSocket::TsiSocket(HandshakerFactory handshaker_factory, HandshakeValidator handshake_validator,
25 : bool downstream)
26 : : TsiSocket(handshaker_factory, handshake_validator,
27 0 : std::make_unique<Network::RawBufferSocket>(), downstream) {
28 0 : raw_read_buffer_.setWatermarks(default_max_frame_size_);
29 0 : }
30 :
31 0 : TsiSocket::~TsiSocket() { ASSERT(!handshaker_); }
32 :
33 0 : void TsiSocket::setTransportSocketCallbacks(Envoy::Network::TransportSocketCallbacks& callbacks) {
34 0 : ASSERT(!callbacks_);
35 0 : callbacks_ = &callbacks;
36 :
37 0 : tsi_callbacks_ = std::make_unique<TsiTransportSocketCallbacks>(callbacks, raw_read_buffer_);
38 0 : raw_buffer_socket_->setTransportSocketCallbacks(*tsi_callbacks_);
39 0 : }
40 :
41 0 : std::string TsiSocket::protocol() const {
42 : // TSI doesn't have a generic way to indicate application layer protocol.
43 : // TODO(lizan): support application layer protocol from TSI for known TSIs.
44 0 : return EMPTY_STRING;
45 0 : }
46 :
47 0 : absl::string_view TsiSocket::failureReason() const {
48 : // TODO(htuch): Implement error reason for TSI.
49 0 : return EMPTY_STRING;
50 0 : }
51 :
52 0 : Network::PostIoAction TsiSocket::doHandshake() {
53 0 : ASSERT(!handshake_complete_);
54 0 : ENVOY_CONN_LOG(debug, "TSI: doHandshake", callbacks_->connection());
55 0 : if (!handshaker_next_calling_ && raw_read_buffer_.length() > 0) {
56 0 : return doHandshakeNext();
57 0 : }
58 0 : return Network::PostIoAction::KeepOpen;
59 0 : }
60 :
61 0 : Network::PostIoAction TsiSocket::doHandshakeNext() {
62 0 : ENVOY_CONN_LOG(debug, "TSI: doHandshake next: received: {}", callbacks_->connection(),
63 0 : raw_read_buffer_.length());
64 :
65 0 : if (!handshaker_) {
66 0 : handshaker_ =
67 0 : handshaker_factory_(callbacks_->connection().dispatcher(),
68 0 : callbacks_->connection().connectionInfoProvider().localAddress(),
69 0 : callbacks_->connection().connectionInfoProvider().remoteAddress());
70 0 : if (!handshaker_) {
71 0 : ENVOY_CONN_LOG(warn, "TSI: failed to create handshaker", callbacks_->connection());
72 0 : callbacks_->connection().close(Network::ConnectionCloseType::NoFlush,
73 0 : "failed_creating_handshaker");
74 0 : return Network::PostIoAction::Close;
75 0 : }
76 :
77 0 : handshaker_->setHandshakerCallbacks(*this);
78 0 : }
79 :
80 0 : handshaker_next_calling_ = true;
81 0 : Buffer::OwnedImpl handshaker_buffer;
82 0 : handshaker_buffer.move(raw_read_buffer_);
83 0 : absl::Status status = handshaker_->next(handshaker_buffer);
84 0 : if (!status.ok()) {
85 0 : ENVOY_CONN_LOG(debug, "TSI: Handshake failed: status: {}", callbacks_->connection(), status);
86 0 : return Network::PostIoAction::Close;
87 0 : }
88 0 : return Network::PostIoAction::KeepOpen;
89 0 : }
90 :
91 0 : Network::PostIoAction TsiSocket::doHandshakeNextDone(NextResultPtr&& next_result) {
92 0 : ASSERT(next_result);
93 :
94 0 : ENVOY_CONN_LOG(debug, "TSI: doHandshake next done: status: {} to_send: {}",
95 0 : callbacks_->connection(), next_result->status_, next_result->to_send_->length());
96 :
97 0 : absl::Status status = next_result->status_;
98 0 : AltsHandshakeResult* handshake_result = next_result->result_.get();
99 0 : if (!status.ok()) {
100 0 : ENVOY_CONN_LOG(debug, "TSI: Handshake failed: status: {}", callbacks_->connection(), status);
101 0 : return Network::PostIoAction::Close;
102 0 : }
103 :
104 0 : if (next_result->to_send_->length() > 0) {
105 0 : raw_write_buffer_.move(*next_result->to_send_);
106 0 : }
107 :
108 0 : if (status.ok() && handshake_result != nullptr) {
109 0 : if (handshake_validator_) {
110 0 : std::string err;
111 0 : TsiInfo tsi_info;
112 0 : tsi_info.peer_identity_ = handshake_result->peer_identity;
113 0 : const bool peer_validated = handshake_validator_(tsi_info, err);
114 0 : if (peer_validated) {
115 0 : ENVOY_CONN_LOG(debug, "TSI: Handshake validation succeeded.", callbacks_->connection());
116 0 : } else {
117 0 : ENVOY_CONN_LOG(debug, "TSI: Handshake validation failed: {}", callbacks_->connection(),
118 0 : err);
119 0 : return Network::PostIoAction::Close;
120 0 : }
121 0 : ProtobufWkt::Struct dynamic_metadata;
122 0 : ProtobufWkt::Value val;
123 0 : val.set_string_value(tsi_info.peer_identity_);
124 0 : dynamic_metadata.mutable_fields()->insert({std::string("peer_identity"), val});
125 0 : callbacks_->connection().streamInfo().setDynamicMetadata(
126 0 : "envoy.transport_sockets.peer_information", dynamic_metadata);
127 0 : ENVOY_CONN_LOG(debug, "TSI handshake with peer: {}", callbacks_->connection(),
128 0 : tsi_info.peer_identity_);
129 0 : } else {
130 0 : ENVOY_CONN_LOG(debug, "TSI: Handshake validation skipped.", callbacks_->connection());
131 0 : }
132 :
133 0 : if (!handshake_result->unused_bytes.empty()) {
134 : // All handshake data is consumed.
135 0 : ASSERT(raw_read_buffer_.length() == 0);
136 0 : absl::string_view unused_bytes(
137 0 : reinterpret_cast<const char*>(handshake_result->unused_bytes.data()),
138 0 : handshake_result->unused_bytes.size());
139 0 : raw_read_buffer_.prepend(unused_bytes);
140 0 : }
141 0 : ENVOY_CONN_LOG(debug, "TSI: Handshake successful: unused_bytes: {}", callbacks_->connection(),
142 0 : handshake_result->unused_bytes.size());
143 : // Reset the watermarks with actual negotiated max frame size.
144 0 : raw_read_buffer_.setWatermarks(
145 0 : std::max<size_t>(actual_frame_size_to_use_, callbacks_->connection().bufferLimit()));
146 0 : frame_protector_ = std::move(handshake_result->frame_protector);
147 :
148 0 : handshake_complete_ = true;
149 0 : if (raw_write_buffer_.length() == 0) {
150 0 : callbacks_->raiseEvent(Network::ConnectionEvent::Connected);
151 0 : }
152 0 : }
153 :
154 0 : if (read_error_ || (!handshake_complete_ && end_stream_read_)) {
155 0 : ENVOY_CONN_LOG(debug, "TSI: Handshake failed: end of stream without enough data",
156 0 : callbacks_->connection());
157 0 : return Network::PostIoAction::Close;
158 0 : }
159 :
160 0 : if (raw_read_buffer_.length() > 0) {
161 0 : callbacks_->setTransportSocketIsReadable();
162 0 : }
163 :
164 : // Try to write raw buffer when next call is done, even this is not in do[Read|Write] stack.
165 0 : if (raw_write_buffer_.length() > 0) {
166 0 : Network::IoResult result = raw_buffer_socket_->doWrite(raw_write_buffer_, false);
167 0 : if (handshake_complete_ && result.action_ != Network::PostIoAction::Close) {
168 0 : callbacks_->raiseEvent(Network::ConnectionEvent::Connected);
169 0 : }
170 0 : return result.action_;
171 0 : }
172 :
173 0 : return Network::PostIoAction::KeepOpen;
174 0 : }
175 :
176 : Network::IoResult TsiSocket::repeatReadAndUnprotect(Buffer::Instance& buffer,
177 0 : Network::IoResult prev_result) {
178 0 : Network::IoResult result = prev_result;
179 0 : uint64_t total_bytes_processed = 0;
180 :
181 0 : while (true) {
182 : // Do unprotect.
183 0 : if (raw_read_buffer_.length() > 0) {
184 0 : uint64_t prev_size = buffer.length();
185 0 : ENVOY_CONN_LOG(debug, "TSI: unprotecting buffer size: {}", callbacks_->connection(),
186 0 : raw_read_buffer_.length());
187 0 : tsi_result status = frame_protector_->unprotect(raw_read_buffer_, buffer);
188 0 : if (status != TSI_OK) {
189 0 : ENVOY_CONN_LOG(debug, "TSI: unprotect failed: status: {}", callbacks_->connection(),
190 0 : status);
191 0 : result.action_ = Network::PostIoAction::Close;
192 0 : break;
193 0 : }
194 0 : ASSERT(raw_read_buffer_.length() == 0);
195 0 : ENVOY_CONN_LOG(debug, "TSI: unprotected buffer left: {} result: {}", callbacks_->connection(),
196 0 : raw_read_buffer_.length(), tsi_result_to_string(status));
197 0 : total_bytes_processed += buffer.length() - prev_size;
198 :
199 : // Check if buffer needs to be drained.
200 0 : if (callbacks_->shouldDrainReadBuffer()) {
201 0 : callbacks_->setTransportSocketIsReadable();
202 0 : break;
203 0 : }
204 0 : }
205 :
206 0 : if (result.action_ == Network::PostIoAction::Close) {
207 0 : break;
208 0 : }
209 :
210 : // End of stream is reached in the previous read.
211 0 : if (end_stream_read_) {
212 0 : result.end_stream_read_ = true;
213 0 : break;
214 0 : }
215 : // Do another read.
216 0 : result = readFromRawSocket();
217 : // No data is read.
218 0 : if (result.bytes_processed_ == 0) {
219 0 : break;
220 0 : }
221 0 : };
222 0 : result.bytes_processed_ = total_bytes_processed;
223 0 : ENVOY_CONN_LOG(debug, "TSI: do read result action {} bytes {} end_stream {}",
224 0 : callbacks_->connection(), enumToInt(result.action_), result.bytes_processed_,
225 0 : result.end_stream_read_);
226 0 : return result;
227 0 : }
228 :
229 0 : Network::IoResult TsiSocket::readFromRawSocket() {
230 0 : Network::IoResult result = raw_buffer_socket_->doRead(raw_read_buffer_);
231 0 : end_stream_read_ = result.end_stream_read_;
232 0 : read_error_ = result.action_ == Network::PostIoAction::Close;
233 0 : return result;
234 0 : }
235 :
236 0 : Network::IoResult TsiSocket::doRead(Buffer::Instance& buffer) {
237 0 : Network::IoResult result = {Network::PostIoAction::KeepOpen, 0, false};
238 0 : if (!handshake_complete_) {
239 0 : if (!end_stream_read_ && !read_error_) {
240 0 : result = readFromRawSocket();
241 0 : ENVOY_CONN_LOG(debug, "TSI: raw read result action {} bytes {} end_stream {}",
242 0 : callbacks_->connection(), enumToInt(result.action_), result.bytes_processed_,
243 0 : result.end_stream_read_);
244 0 : if (result.action_ == Network::PostIoAction::Close && result.bytes_processed_ == 0) {
245 0 : return result;
246 0 : }
247 :
248 0 : if (result.end_stream_read_ && result.bytes_processed_ == 0) {
249 0 : return {Network::PostIoAction::Close, result.bytes_processed_, result.end_stream_read_};
250 0 : }
251 0 : }
252 0 : Network::PostIoAction action = doHandshake();
253 0 : if (action == Network::PostIoAction::Close || !handshake_complete_) {
254 0 : return {action, 0, false};
255 0 : }
256 0 : }
257 : // Handshake finishes.
258 0 : ASSERT(handshake_complete_);
259 0 : ASSERT(frame_protector_);
260 0 : return repeatReadAndUnprotect(buffer, result);
261 0 : }
262 :
263 0 : Network::IoResult TsiSocket::repeatProtectAndWrite(Buffer::Instance& buffer, bool end_stream) {
264 0 : uint64_t total_bytes_written = 0;
265 0 : Network::IoResult result = {Network::PostIoAction::KeepOpen, 0, false};
266 : // There should be no handshake bytes in raw_write_buffer_.
267 0 : ASSERT(!(raw_write_buffer_.length() > 0 && prev_bytes_to_drain_ == 0));
268 0 : while (true) {
269 0 : uint64_t bytes_to_drain_this_iteration =
270 0 : prev_bytes_to_drain_ > 0
271 0 : ? prev_bytes_to_drain_
272 0 : : std::min<uint64_t>(buffer.length(), actual_frame_size_to_use_ - frame_overhead_size_);
273 : // Consumed all data. Exit.
274 0 : if (bytes_to_drain_this_iteration == 0) {
275 0 : break;
276 0 : }
277 : // Short write did not occur previously.
278 0 : if (raw_write_buffer_.length() == 0) {
279 0 : ASSERT(frame_protector_);
280 0 : ASSERT(prev_bytes_to_drain_ == 0);
281 :
282 : // Do protect.
283 0 : ENVOY_CONN_LOG(debug, "TSI: protecting buffer size: {}", callbacks_->connection(),
284 0 : bytes_to_drain_this_iteration);
285 0 : tsi_result status = frame_protector_->protect(
286 0 : grpc_slice_from_static_buffer(buffer.linearize(bytes_to_drain_this_iteration),
287 0 : bytes_to_drain_this_iteration),
288 0 : raw_write_buffer_);
289 0 : ENVOY_CONN_LOG(debug, "TSI: protected buffer left: {} result: {}", callbacks_->connection(),
290 0 : bytes_to_drain_this_iteration, tsi_result_to_string(status));
291 0 : }
292 :
293 : // Write raw_write_buffer_ to network.
294 0 : ENVOY_CONN_LOG(debug, "TSI: raw_write length {} end_stream {}", callbacks_->connection(),
295 0 : raw_write_buffer_.length(), end_stream);
296 0 : result = raw_buffer_socket_->doWrite(raw_write_buffer_, end_stream && (buffer.length() == 0));
297 :
298 : // Short write. Exit.
299 0 : if (raw_write_buffer_.length() > 0) {
300 0 : prev_bytes_to_drain_ = bytes_to_drain_this_iteration;
301 0 : break;
302 0 : } else {
303 0 : buffer.drain(bytes_to_drain_this_iteration);
304 0 : prev_bytes_to_drain_ = 0;
305 0 : total_bytes_written += bytes_to_drain_this_iteration;
306 0 : }
307 0 : }
308 :
309 0 : return {result.action_, total_bytes_written, false};
310 0 : }
311 :
312 0 : Network::IoResult TsiSocket::doWrite(Buffer::Instance& buffer, bool end_stream) {
313 0 : if (!handshake_complete_) {
314 0 : Network::PostIoAction action = doHandshake();
315 0 : ASSERT(!handshake_complete_);
316 0 : return {action, 0, false};
317 0 : } else {
318 0 : ASSERT(frame_protector_);
319 : // Check if we need to flush outstanding handshake bytes.
320 0 : if (raw_write_buffer_.length() > 0 && prev_bytes_to_drain_ == 0) {
321 0 : ENVOY_CONN_LOG(debug, "TSI: raw_write length {} end_stream {}", callbacks_->connection(),
322 0 : raw_write_buffer_.length(), end_stream);
323 0 : Network::IoResult result =
324 0 : raw_buffer_socket_->doWrite(raw_write_buffer_, end_stream && (buffer.length() == 0));
325 : // Check if short write occurred.
326 0 : if (raw_write_buffer_.length() > 0) {
327 0 : return {result.action_, 0, false};
328 0 : }
329 0 : }
330 0 : return repeatProtectAndWrite(buffer, end_stream);
331 0 : }
332 0 : }
333 :
334 0 : void TsiSocket::closeSocket(Network::ConnectionEvent) {
335 0 : ENVOY_CONN_LOG(debug, "TSI: closing socket", callbacks_->connection());
336 0 : if (handshaker_) {
337 0 : handshaker_.release()->deferredDelete();
338 0 : }
339 0 : }
340 :
341 0 : void TsiSocket::onConnected() {
342 0 : ASSERT(!handshake_complete_);
343 : // Client initiates the handshake, so ignore onConnect call on the downstream.
344 0 : if (!downstream_) {
345 0 : doHandshakeNext();
346 0 : }
347 0 : }
348 :
349 0 : void TsiSocket::onNextDone(NextResultPtr&& result) {
350 0 : handshaker_next_calling_ = false;
351 :
352 0 : Network::PostIoAction action = doHandshakeNextDone(std::move(result));
353 0 : if (action == Network::PostIoAction::Close) {
354 0 : callbacks_->connection().close(Network::ConnectionCloseType::NoFlush, "tsi_handshake_failed");
355 0 : }
356 0 : }
357 :
358 : TsiSocketFactory::TsiSocketFactory(HandshakerFactory handshaker_factory,
359 : HandshakeValidator handshake_validator)
360 : : handshaker_factory_(std::move(handshaker_factory)),
361 0 : handshake_validator_(std::move(handshake_validator)) {}
362 :
363 0 : bool TsiSocketFactory::implementsSecureTransport() const { return true; }
364 :
365 : Network::TransportSocketPtr
366 : TsiSocketFactory::createTransportSocket(Network::TransportSocketOptionsConstSharedPtr,
367 0 : Upstream::HostDescriptionConstSharedPtr) const {
368 0 : return std::make_unique<TsiSocket>(handshaker_factory_, handshake_validator_, false);
369 0 : }
370 :
371 0 : Network::TransportSocketPtr TsiSocketFactory::createDownstreamTransportSocket() const {
372 0 : return std::make_unique<TsiSocket>(handshaker_factory_, handshake_validator_, true);
373 0 : }
374 :
375 : } // namespace Alts
376 : } // namespace TransportSockets
377 : } // namespace Extensions
378 : } // namespace Envoy
|