1
#pragma once
2

            
3
#include <unistd.h>
4

            
5
#include <atomic>
6
#include <cstdint>
7
#include <list>
8

            
9
#include "envoy/event/dispatcher.h"
10
#include "envoy/event/timer.h"
11
#include "envoy/network/io_handle.h"
12
#include "envoy/network/socket.h"
13
#include "envoy/thread_local/thread_local.h"
14

            
15
#include "source/common/common/logger.h"
16
#include "source/common/common/random_generator.h"
17

            
18
#include "absl/synchronization/mutex.h"
19

            
20
namespace Envoy {
21
namespace Extensions {
22
namespace Bootstrap {
23
namespace ReverseConnection {
24

            
25
// Forward declarations
26
class ReverseTunnelAcceptorExtension;
27

            
28
/**
29
 * Thread-local socket manager for upstream reverse connections.
30
 */
31
class UpstreamSocketManager : public ThreadLocal::ThreadLocalObject,
32
                              public Logger::Loggable<Logger::Id::filter> {
33
  // Friend class for testing
34
  friend class TestUpstreamSocketManager;
35
  friend class TestUpstreamSocketManagerRebalancing;
36

            
37
public:
38
  UpstreamSocketManager(Event::Dispatcher& dispatcher,
39
                        ReverseTunnelAcceptorExtension* extension = nullptr);
40

            
41
  ~UpstreamSocketManager();
42

            
43
  /**
44
   * Add accepted connection to socket manager.
45
   * @param node_id node_id of initiating node.
46
   * @param cluster_id cluster_id of receiving cluster.
47
   * @param socket the socket to be added.
48
   * @param ping_interval the interval at which ping keepalives are sent.
49
   * @param rebalanced true if adding socket after rebalancing.
50
   */
51
  void addConnectionSocket(const std::string& node_id, const std::string& cluster_id,
52
                           Network::ConnectionSocketPtr socket,
53
                           const std::chrono::seconds& ping_interval, bool rebalanced = true);
54

            
55
  /**
56
   * Hand off a socket to this socket manager's dispatcher.
57
   * Used for cross-thread rebalancing of reverse connection sockets.
58
   * @param node_id node_id of initiating node.
59
   * @param cluster_id cluster_id of receiving cluster.
60
   * @param socket the socket to be added.
61
   * @param ping_interval the interval at which ping keepalives are sent.
62
   */
63
  void handoffSocketToWorker(const std::string& node_id, const std::string& cluster_id,
64
                             Network::ConnectionSocketPtr socket,
65
                             const std::chrono::seconds& ping_interval);
66

            
67
  /**
68
   * Get an available reverse connection socket.
69
   * @param node_id the node ID to get a socket for.
70
   * @return the connection socket, or nullptr if none available.
71
   */
72
  Network::ConnectionSocketPtr getConnectionSocket(const std::string& node_id);
73

            
74
  /**
75
   * Mark connection socket dead and remove from internal maps.
76
   * @param fd the FD for the socket to be marked dead.
77
   */
78
  void markSocketDead(const int fd);
79

            
80
  /**
81
   * Send a ping keepalive for a single reverse connection.
82
   * @param fd the file descriptor of the connection to ping.
83
   */
84
  void sendPingForConnection(int fd);
85

            
86
  /**
87
   * Clean up stale node entries when no active sockets remain.
88
   * @param node_id the node ID to clean up.
89
   */
90
  void cleanStaleNodeEntry(const std::string& node_id);
91

            
92
  /**
93
   * Handle ping response from a reverse connection.
94
   * @param io_handle the IO handle for the socket that sent the ping response.
95
   */
96
  void onPingResponse(Network::IoHandle& io_handle);
97

            
98
  /**
99
   * Handle ping response timeout for a specific socket.
100
   * Increments miss count and marks socket dead if threshold reached.
101
   * @param fd the file descriptor whose ping timed out.
102
   */
103
  void onPingTimeout(int fd);
104

            
105
  /**
106
   * Set the miss threshold (consecutive misses before marking a socket dead).
107
   * @param threshold minimum value 1.
108
   */
109
120
  void setMissThreshold(uint32_t threshold) { miss_threshold_ = std::max<uint32_t>(1, threshold); }
110
130
  void setTenantIsolationEnabled(bool enabled) { tenant_isolation_enabled_ = enabled; }
111
257
  bool tenantIsolationEnabled() const { return tenant_isolation_enabled_; }
112

            
113
  /**
114
   * Get the upstream extension for stats integration.
115
   * @return pointer to the upstream extension or nullptr if not available.
116
   */
117
360
  ReverseTunnelAcceptorExtension* getUpstreamExtension() const { return extension_; }
118

            
119
  /**
120
   * Get a node that has a socket (idle or used) for the given key.
121
   * If the key is found in the cluster_to_node_info_map_, assume it is the cluster ID and return a
122
   * node in that cluster in a round-robin manner. If the key is not found in the
123
   * cluster_to_node_info_map_, assume it is the node ID and return it as-is.
124
   * @param key the cluster ID or node ID to lookup.
125
   * @return the node ID, or the key itself if it cannot be resolved.
126
   */
127
  std::string getNodeWithSocket(const std::string& key);
128

            
129
  /**
130
   * Pick the least loaded socket manager across all worker threads for a given node.
131
   * @param node_id the node ID to find the least loaded manager for.
132
   * @param cluster_id the cluster ID for logging purposes.
133
   * @return reference to the least loaded socket manager.
134
   */
135
  UpstreamSocketManager& pickLeastLoadedSocketManager(const std::string& node_id,
136
                                                      const std::string& cluster_id);
137

            
138
private:
139
  /**
140
   * Helper method to check if a node has any reverse connection sockets (idle or used).
141
   * @param node_id the node ID to check.
142
   * @return true if the node has any sockets, false otherwise.
143
   */
144
  bool hasAnySocketsForNode(const std::string& node_id);
145

            
146
  /**
147
   * Compute the ping interval in milliseconds with 15% jitter applied.
148
   * @return jittered interval in milliseconds.
149
   */
150
  uint64_t pingIntervalWithJitterMs();
151

            
152
  /**
153
   * Re-arm the per-connection ping send timer for the given fd with jitter.
154
   * No-op if the fd has no entry in fd_to_ping_send_timer_map_.
155
   * @param fd the file descriptor whose send timer to re-arm.
156
   */
157
  void rearmPingSendTimer(int fd);
158

            
159
  // Thread local dispatcher instance.
160
  Event::Dispatcher& dispatcher_;
161
  Random::RandomGeneratorPtr random_generator_;
162

            
163
  // Map of node IDs to connection sockets.
164
  absl::flat_hash_map<std::string, std::list<Network::ConnectionSocketPtr>>
165
      accepted_reverse_connections_;
166

            
167
  // Map from file descriptor to node ID. An entry is added when a reverse tunnel is accepted from a
168
  // node and is removed when the socket dies.
169
  absl::flat_hash_map<int, std::string> fd_to_node_map_;
170

            
171
  // Map from FD to its iterator in accepted_reverse_connections_, used to avoid linear scans.
172
  absl::flat_hash_map<int, std::list<Network::ConnectionSocketPtr>::iterator> fd_to_socket_it_map_;
173

            
174
  // Map from file descriptor to cluster ID. An entry is added when a reverse tunnel is accepted
175
  // from a node and is removed when the socket dies.
176
  absl::flat_hash_map<int, std::string> fd_to_cluster_map_;
177

            
178
  // Map of node ID to cluster, for all nodes that have a reverse tunnel socket.
179
  absl::flat_hash_map<std::string, std::string> node_to_cluster_map_;
180

            
181
  // Cluster information for tracking member nodes.
182
  struct ClusterInfo {
183
    // List of node IDs that belong to this cluster and have any sockets (idle or used).
184
    std::vector<std::string> nodes;
185
    // Round-robin index for load distribution when selecting member nodes.
186
    size_t round_robin_index = 0;
187
  };
188

            
189
  // Map of cluster IDs to cluster node information.
190
  // A cluster entry is added when a reverse tunnel is accepted from a node in that cluster
191
  // and is removed only when all nodes in the cluster have no remaining sockets.
192
  absl::flat_hash_map<std::string, ClusterInfo> cluster_to_node_info_map_;
193

            
194
  // File events and timers for ping functionality.
195
  absl::flat_hash_map<int, Event::FileEventPtr> fd_to_event_map_;
196
  absl::flat_hash_map<int, Event::TimerPtr> fd_to_timer_map_;
197

            
198
  // Per-connection send timers that schedule individual ping sends with jitter.
199
  absl::flat_hash_map<int, Event::TimerPtr> fd_to_ping_send_timer_map_;
200

            
201
  // Track consecutive ping misses per file descriptor.
202
  absl::flat_hash_map<int, uint32_t> fd_to_miss_count_;
203
  // Miss threshold before declaring a socket dead.
204
  static constexpr uint32_t kDefaultMissThreshold = 3;
205
  uint32_t miss_threshold_{kDefaultMissThreshold};
206

            
207
  std::chrono::seconds ping_interval_{0};
208

            
209
  // Per node counter for total active FDs.
210
  absl::flat_hash_map<std::string, uint32_t> node_to_active_fd_count_;
211

            
212
  // Upstream extension for stats integration.
213
  ReverseTunnelAcceptorExtension* extension_;
214

            
215
  // Map of node IDs to the number of total accepted reverse connections
216
  // for the node. This is used to rebalance a request to accept reverse
217
  // connections to a different worker thread.
218
  absl::flat_hash_map<std::string, int> node_to_conn_count_map_;
219

            
220
  bool tenant_isolation_enabled_{false};
221

            
222
  // Global list of all socket managers across threads for rebalancing.
223
  static std::vector<UpstreamSocketManager*> socket_managers_;
224
  static absl::Mutex socket_manager_lock;
225
};
226

            
227
} // namespace ReverseConnection
228
} // namespace Bootstrap
229
} // namespace Extensions
230
} // namespace Envoy