1
#pragma once
2

            
3
#include <memory>
4
#include <string>
5

            
6
#include "envoy/buffer/buffer.h"
7
#include "envoy/event/deferred_deletable.h"
8
#include "envoy/http/codec.h"
9
#include "envoy/network/connection.h"
10
#include "envoy/network/filter.h"
11
#include "envoy/upstream/upstream.h"
12

            
13
#include "source/common/common/logger.h"
14
#include "source/common/http/http1/codec_impl.h"
15
#include "source/common/http/response_decoder_impl_base.h"
16
#include "source/common/network/filter_impl.h"
17

            
18
#include "absl/strings/str_cat.h"
19
#include "absl/strings/string_view.h"
20
#include "absl/types/optional.h"
21

            
22
namespace Envoy {
23
namespace Extensions {
24
namespace Bootstrap {
25
namespace ReverseConnection {
26

            
27
// Forward declarations.
28
class ReverseConnectionIOHandle;
29
class ReverseTunnelInitiatorExtension;
30

            
31
/**
32
 * Class representing handshake failure with type and context.
33
 * Provides methods to generate detailed error messages and stat names.
34
 */
35
class HandshakeFailureReason {
36
public:
37
  enum class Type {
38
    HttpStatusError, // HTTP response with non-200 status code
39
    EncodeError,     // HTTP request encoding failed
40
  };
41

            
42
  /**
43
   * Create a handshake failure reason for HTTP status errors.
44
   * @param status_code the HTTP status code received
45
   */
46
2
  static HandshakeFailureReason httpStatusError(absl::string_view status_code) {
47
2
    return {Type::HttpStatusError, status_code};
48
2
  }
49

            
50
  /**
51
   * Create a handshake failure reason for encoding errors.
52
   */
53
1
  static HandshakeFailureReason encodeError() { return {Type::EncodeError, ""}; }
54

            
55
  /**
56
   * Get a detailed human-readable error message.
57
   * @return detailed error message string
58
   */
59
3
  std::string getDetailedName() const {
60
3
    switch (type_) {
61
2
    case Type::HttpStatusError:
62
2
      return absl::StrCat("HTTP handshake failed with status ", context_);
63
1
    case Type::EncodeError:
64
1
      return "HTTP handshake encode failed";
65
3
    }
66
    return "Unknown handshake failure";
67
3
  }
68

            
69
  /**
70
   * Get the stat name suffix for this failure.
71
   * @return stat name suffix (e.g., "http.401", "encode_error")
72
   */
73
3
  std::string getNameForStats() const {
74
3
    switch (type_) {
75
2
    case Type::HttpStatusError:
76
2
      return absl::StrCat("http.", context_);
77
1
    case Type::EncodeError:
78
1
      return "encode_error";
79
3
    }
80
    return "unknown";
81
3
  }
82

            
83
private:
84
3
  HandshakeFailureReason(Type type, absl::string_view context) : type_(type), context_(context) {}
85

            
86
  Type type_;
87
  std::string context_;
88
};
89

            
90
/**
91
 * Simple read filter for handling reverse connection handshake responses.
92
 * This filter processes the HTTP response from the upstream server during handshake.
93
 */
94
class SimpleConnReadFilter : public Network::ReadFilterBaseImpl,
95
                             public Logger::Loggable<Logger::Id::main> {
96
public:
97
  /**
98
   * Constructor that stores pointer to parent wrapper.
99
   */
100
48
  explicit SimpleConnReadFilter(void* parent) : parent_(parent) {}
101

            
102
  // Network::ReadFilter overrides
103
  Network::FilterStatus onData(Buffer::Instance& buffer, bool end_stream) override;
104

            
105
private:
106
  void* parent_; // Pointer to RCConnectionWrapper to avoid circular dependency.
107
};
108

            
109
/**
110
 * Wrapper for reverse connections that manages the connection lifecycle and handshake.
111
 * It handles the handshake process (both gRPC and HTTP fallback) and manages connection
112
 * callbacks and cleanup.
113
 */
114
class RCConnectionWrapper : public Network::ConnectionCallbacks,
115
                            public Event::DeferredDeletable,
116
                            public Logger::Loggable<Logger::Id::main>,
117
                            public Http::ResponseDecoderImplBase,
118
                            public Http::ConnectionCallbacks {
119
  friend class SimpleConnReadFilterTest;
120

            
121
public:
122
  /**
123
   * Constructor for RCConnectionWrapper.
124
   * @param parent reference to the parent ReverseConnectionIOHandle
125
   * @param connection the client connection to wrap
126
   * @param host the upstream host description
127
   * @param cluster_name the name of the cluster
128
   */
129
  RCConnectionWrapper(ReverseConnectionIOHandle& parent, Network::ClientConnectionPtr connection,
130
                      Upstream::HostDescriptionConstSharedPtr host,
131
                      const std::string& cluster_name);
132

            
133
  /**
134
   * Destructor for RCConnectionWrapper.
135
   * Performs defensive cleanup to prevent crashes during shutdown.
136
   */
137
  ~RCConnectionWrapper() override;
138

            
139
  // Network::ConnectionCallbacks overrides
140
  void onEvent(Network::ConnectionEvent event) override;
141
2
  void onAboveWriteBufferHighWatermark() override {}
142
2
  void onBelowWriteBufferLowWatermark() override {}
143

            
144
  // Http::ResponseDecoder overrides
145
1
  void decode1xxHeaders(Http::ResponseHeaderMapPtr&&) override {}
146
  void decodeHeaders(Http::ResponseHeaderMapPtr&& headers, bool end_stream) override;
147
2
  void decodeData(Buffer::Instance&, bool) override {}
148
1
  void decodeTrailers(Http::ResponseTrailerMapPtr&&) override {}
149
1
  void decodeMetadata(Http::MetadataMapPtr&&) override {}
150
2
  void dumpState(std::ostream&, int) const override {}
151

            
152
  // Http::ConnectionCallbacks overrides
153
2
  void onGoAway(Http::GoAwayErrorCode) override {}
154
1
  void onSettings(Http::ReceivedSettings&) override {}
155
2
  void onMaxStreamsChanged(uint32_t) override {}
156

            
157
  /**
158
   * Initiate the reverse connection handshake (HTTP only).
159
   * @param src_tenant_id the tenant identifier
160
   * @param src_cluster_id the cluster identifier
161
   * @param src_node_id the node identifier
162
   * @return the local address as string
163
   */
164
  std::string connect(const std::string& src_tenant_id, const std::string& src_cluster_id,
165
                      const std::string& src_node_id);
166

            
167
  /**
168
   * Release ownership of the connection.
169
   * @return the connection pointer (ownership transferred to caller)
170
   */
171
24
  Network::ClientConnectionPtr releaseConnection() { return std::move(connection_); }
172

            
173
  /**
174
   * Process HTTP response from upstream.
175
   * @param buffer the response data
176
   * @param end_stream whether this is the end of the stream
177
   */
178
  void processHttpResponse(Buffer::Instance& buffer, bool end_stream);
179

            
180
  /**
181
   * Handle successful handshake completion.
182
   */
183
  void onHandshakeSuccess();
184

            
185
  /**
186
   * Handle handshake failure.
187
   * @param reason the failure reason with type and context
188
   */
189
  void onHandshakeFailure(const HandshakeFailureReason& reason);
190

            
191
  /**
192
   * Perform graceful shutdown of the connection.
193
   */
194
  void shutdown();
195

            
196
  /**
197
   * Get the underlying connection.
198
   * @return pointer to the client connection
199
   */
200
68
  Network::ClientConnection* getConnection() { return connection_.get(); }
201

            
202
  /**
203
   * Get the host description.
204
   * @return shared pointer to the host description
205
   */
206
1
  Upstream::HostDescriptionConstSharedPtr getHost() { return host_; }
207

            
208
private:
209
  ReverseConnectionIOHandle& parent_;
210
  Network::ClientConnectionPtr connection_;
211
  Upstream::HostDescriptionConstSharedPtr host_;
212
  std::string cluster_name_;
213
  std::string connection_key_;
214
  bool http_handshake_sent_{false};
215
  bool handshake_completed_{false};
216
  bool shutdown_called_{false};
217

            
218
  /**
219
   * Get the downstream extension for accessing stats.
220
   * @return pointer to ReverseTunnelInitiatorExtension
221
   */
222
  ReverseTunnelInitiatorExtension* getDownstreamExtension() const;
223

            
224
public:
225
  // Dispatch incoming bytes to HTTP/1 codec.
226
  void dispatchHttp1(Buffer::Instance& buffer);
227

            
228
private:
229
  // HTTP/1 codec used to send request and parse response.
230
  std::unique_ptr<Http::Http1::ClientConnectionImpl> http1_client_codec_;
231
  // Base interface pointer used to call dispatch via public API.
232
  Http::Connection* http1_parse_connection_{nullptr};
233
};
234

            
235
} // namespace ReverseConnection
236
} // namespace Bootstrap
237
} // namespace Extensions
238
} // namespace Envoy