1
#pragma once
2

            
3
#include "envoy/network/transport_socket.h"
4

            
5
#include "source/common/buffer/buffer_impl.h"
6
#include "source/common/buffer/watermark_buffer.h"
7
#include "source/common/network/raw_buffer_socket.h"
8
#include "source/common/network/transport_socket_options_impl.h"
9
#include "source/extensions/transport_sockets/alts/noop_transport_socket_callbacks.h"
10
#include "source/extensions/transport_sockets/alts/tsi_frame_protector.h"
11
#include "source/extensions/transport_sockets/alts/tsi_handshaker.h"
12

            
13
namespace Envoy {
14
namespace Extensions {
15
namespace TransportSockets {
16
namespace Alts {
17

            
18
struct TsiInfo {
19
  std::string peer_identity_;
20
};
21

            
22
/**
23
 * A factory function to create TsiHandshaker
24
 * @param dispatcher the dispatcher for the thread where the socket is running on.
25
 * @param local_address the local address of the connection.
26
 * @param remote_address the remote address of the connection.
27
 */
28
using HandshakerFactory = std::function<TsiHandshakerPtr(
29
    Event::Dispatcher& dispatcher, const Network::Address::InstanceConstSharedPtr& local_address,
30
    const Network::Address::InstanceConstSharedPtr& remote_address)>;
31

            
32
/**
33
 * A function to validate the peer of the connection.
34
 * @param err an error message to indicate why the peer is invalid. This is an
35
 * output param that should be populated by the function implementation.
36
 * @return true if the peer is valid or false if the peer is invalid.
37
 */
38
using HandshakeValidator = std::function<bool(TsiInfo& tsi_info, std::string& err)>;
39

            
40
/* Forward declaration */
41
class TsiTransportSocketCallbacks;
42

            
43
/**
44
 * A implementation of Network::TransportSocket based on gRPC TSI
45
 */
46
class TsiSocket : public Network::TransportSocket,
47
                  public TsiHandshakerCallbacks,
48
                  public Logger::Loggable<Logger::Id::connection> {
49
public:
50
  // For Test
51
  TsiSocket(HandshakerFactory handshaker_factory, HandshakeValidator handshake_validator,
52
            Network::TransportSocketPtr&& raw_socket_ptr, bool downstream);
53

            
54
  /**
55
   * @param handshaker_factory a function to initiate a TsiHandshaker
56
   * @param handshake_validator a function to validate the peer. Called right
57
   * after the handshake completed with peer data to do the peer validation.
58
   * The connection will be closed immediately if it returns false.
59
   * @param downstream is true for downstream transport socket.
60
   */
61
  TsiSocket(HandshakerFactory handshaker_factory, HandshakeValidator handshake_validator,
62
            bool downstream);
63
  ~TsiSocket() override;
64

            
65
  // Network::TransportSocket
66
  void setTransportSocketCallbacks(Envoy::Network::TransportSocketCallbacks& callbacks) override;
67
  std::string protocol() const override;
68
  absl::string_view failureReason() const override;
69
1
  bool canFlushClose() override { return handshake_complete_; }
70
18
  Envoy::Ssl::ConnectionInfoConstSharedPtr ssl() const override { return nullptr; }
71
1
  bool startSecureTransport() override { return false; }
72
1
  void configureInitialCongestionWindow(uint64_t, std::chrono::microseconds) override {}
73
  Network::IoResult doWrite(Buffer::Instance& buffer, bool end_stream) override;
74
  void closeSocket(Network::ConnectionEvent event) override;
75
  Network::IoResult doRead(Buffer::Instance& buffer) override;
76
  void onConnected() override;
77

            
78
  // TsiHandshakerCallbacks
79
  void onNextDone(NextResultPtr&& result) override;
80

            
81
  // This API should be called only after ALTS handshake finishes successfully.
82
  size_t actualFrameSizeToUse() { return actual_frame_size_to_use_; }
83
  // Set actual_frame_size_to_use_. Exposed for testing purpose.
84
  void setActualFrameSizeToUse(size_t frame_size) { actual_frame_size_to_use_ = frame_size; }
85
  // Set frame_overhead_size_. Exposed for testing purpose.
86
  void setFrameOverheadSize(size_t overhead_size) { frame_overhead_size_ = overhead_size; }
87

            
88
private:
89
  Network::PostIoAction doHandshake();
90
  Network::PostIoAction doHandshakeNext();
91
  Network::PostIoAction doHandshakeNextDone(NextResultPtr&& next_result);
92

            
93
  // Helper function to perform repeated read and unprotect operations.
94
  Network::IoResult repeatReadAndUnprotect(Buffer::Instance& buffer, Network::IoResult prev_result);
95
  // Helper function to perform repeated protect and write operations.
96
  Network::IoResult repeatProtectAndWrite(Buffer::Instance& buffer, bool end_stream);
97
  // Helper function to read from a raw socket and update status.
98
  Network::IoResult readFromRawSocket();
99

            
100
  HandshakerFactory handshaker_factory_;
101
  HandshakeValidator handshake_validator_;
102
  TsiHandshakerPtr handshaker_{};
103
  bool handshaker_next_calling_{};
104

            
105
  TsiFrameProtectorPtr frame_protector_;
106
  // default_max_frame_size_ is the maximum frame size supported by
107
  // TsiSocket.
108
  size_t default_max_frame_size_{16384};
109
  // actual_frame_size_to_use_ is the actual frame size used by
110
  // frame protector, which is the result of frame size negotiation.
111
  size_t actual_frame_size_to_use_{0};
112
  // frame_overhead_size_ includes 4 bytes frame message type and 16 bytes tag length.
113
  // It is consistent with gRPC ALTS zero copy frame protector implementation.
114
  // The maximum size of data that can be protected for each frame is equal to
115
  // actual_frame_size_to_use_ - frame_overhead_size_.
116
  size_t frame_overhead_size_{20};
117

            
118
  Envoy::Network::TransportSocketCallbacks* callbacks_{};
119
  std::unique_ptr<TsiTransportSocketCallbacks> tsi_callbacks_;
120
  Network::TransportSocketPtr raw_buffer_socket_;
121
  const bool downstream_;
122

            
123
1
  Buffer::WatermarkBuffer raw_read_buffer_{[]() {}, []() {}, []() {}};
124
  Envoy::Buffer::OwnedImpl raw_write_buffer_;
125
  bool handshake_complete_{};
126
  bool end_stream_read_{};
127
  bool read_error_{};
128
  uint64_t prev_bytes_to_drain_{};
129
};
130

            
131
/**
132
 * An implementation of Network::UpstreamTransportSocketFactory for TsiSocket
133
 */
134
class TsiSocketFactory : public Network::DownstreamTransportSocketFactory,
135
                         public Network::CommonUpstreamTransportSocketFactory {
136
public:
137
  TsiSocketFactory(HandshakerFactory handshaker_factory, HandshakeValidator handshake_validator);
138

            
139
  bool implementsSecureTransport() const override;
140
  absl::string_view defaultServerNameIndication() const override { return ""; }
141

            
142
  Network::TransportSocketPtr
143
  createTransportSocket(Network::TransportSocketOptionsConstSharedPtr options,
144
                        Upstream::HostDescriptionConstSharedPtr) const override;
145

            
146
  Network::TransportSocketPtr createDownstreamTransportSocket() const override;
147

            
148
private:
149
  HandshakerFactory handshaker_factory_;
150
  HandshakeValidator handshake_validator_;
151
};
152

            
153
/**
154
 * An implementation of Network::TransportSocketCallbacks for TsiSocket
155
 */
156
class TsiTransportSocketCallbacks : public NoOpTransportSocketCallbacks {
157
public:
158
  TsiTransportSocketCallbacks(Network::TransportSocketCallbacks& parent,
159
                              const Buffer::WatermarkBuffer& read_buffer)
160
54
      : NoOpTransportSocketCallbacks(parent), raw_read_buffer_(read_buffer) {}
161
24
  bool shouldDrainReadBuffer() override {
162
24
    return raw_read_buffer_.length() >= raw_read_buffer_.highWatermark();
163
24
  }
164

            
165
private:
166
  const Buffer::WatermarkBuffer& raw_read_buffer_;
167
};
168

            
169
} // namespace Alts
170
} // namespace TransportSockets
171
} // namespace Extensions
172
} // namespace Envoy