1
#pragma once
2

            
3
#include "envoy/extensions/router/cluster_specifiers/lua/v3/lua.pb.h"
4
#include "envoy/router/cluster_specifier_plugin.h"
5

            
6
#include "source/common/config/datasource.h"
7
#include "source/common/runtime/runtime_features.h"
8
#include "source/extensions/filters/common/lua/wrappers.h"
9

            
10
namespace Envoy {
11
namespace Extensions {
12
namespace Router {
13
namespace Lua {
14

            
15
using LuaClusterSpecifierConfigProto =
16
    envoy::extensions::router::cluster_specifiers::lua::v3::LuaConfig;
17

            
18
class PerLuaCodeSetup : Logger::Loggable<Logger::Id::lua> {
19
public:
20
  PerLuaCodeSetup(const std::string& lua_code, ThreadLocal::SlotAllocator& tls);
21

            
22
11
  Extensions::Filters::Common::Lua::CoroutinePtr createCoroutine() {
23
11
    return lua_state_.createCoroutine();
24
11
  }
25

            
26
10
  int clusterFunctionRef() { return lua_state_.getGlobalRef(cluster_function_slot_); }
27

            
28
8
  void runtimeGC() { lua_state_.runtimeGC(); }
29

            
30
private:
31
  uint64_t cluster_function_slot_{};
32

            
33
  Filters::Common::Lua::ThreadLocalState lua_state_;
34
};
35

            
36
using PerLuaCodeSetupPtr = std::unique_ptr<PerLuaCodeSetup>;
37

            
38
class HeaderMapWrapper : public Filters::Common::Lua::BaseLuaObject<HeaderMapWrapper> {
39
public:
40
6
  HeaderMapWrapper(const Http::HeaderMap& headers) : headers_(headers) {}
41

            
42
24
  static ExportedFunctions exportedFunctions() { return {{"get", static_luaGet}}; }
43

            
44
private:
45
  /**
46
   * Get a header value from the map.
47
   * @param 1 (string): header name.
48
   * @return string value if found or nil.
49
   */
50
  DECLARE_LUA_FUNCTION(HeaderMapWrapper, luaGet);
51

            
52
  const Http::HeaderMap& headers_;
53
};
54

            
55
using HeaderMapRef = Filters::Common::Lua::LuaDeathRef<HeaderMapWrapper>;
56

            
57
class ClusterWrapper : public Filters::Common::Lua::BaseLuaObject<ClusterWrapper> {
58
public:
59
5
  ClusterWrapper(Upstream::ClusterInfoConstSharedPtr cluster) : cluster_(cluster) {}
60

            
61
24
  static ExportedFunctions exportedFunctions() {
62
24
    return {
63
24
        {"numConnections", static_luaNumConnections},
64
24
        {"numRequests", static_luaNumRequests},
65
24
        {"numPendingRequests", static_luaNumPendingRequests},
66
24
    };
67
24
  }
68

            
69
5
  void onMarkDead() override { cluster_.reset(); }
70

            
71
private:
72
  DECLARE_LUA_FUNCTION(ClusterWrapper, luaNumConnections);
73
  DECLARE_LUA_FUNCTION(ClusterWrapper, luaNumRequests);
74
  DECLARE_LUA_FUNCTION(ClusterWrapper, luaNumPendingRequests);
75

            
76
  Upstream::ClusterInfoConstSharedPtr cluster_;
77
};
78

            
79
using ClusterRef = Filters::Common::Lua::LuaDeathRef<ClusterWrapper>;
80

            
81
class RouteHandleWrapper : public Filters::Common::Lua::BaseLuaObject<RouteHandleWrapper> {
82
public:
83
  RouteHandleWrapper(const Http::HeaderMap& headers, Upstream::ClusterManager& cm)
84
11
      : headers_(headers), cm_(cm) {}
85

            
86
24
  static ExportedFunctions exportedFunctions() {
87
24
    return {
88
24
        {"headers", static_luaHeaders},
89
24
        {"getCluster", static_luaGetCluster},
90
24
    };
91
24
  }
92

            
93
  // All embedded references should be reset when the object is marked dead. This is to ensure that
94
  // we won't do the resetting in the destructor, which may be called after the referenced
95
  // coroutine's lua_State is closed. And if that happens, the resetting will cause a crash.
96
11
  void onMarkDead() override {
97
11
    headers_wrapper_.reset();
98
11
    clusters_.clear();
99
11
  }
100

            
101
private:
102
  /**
103
   * @return a handle to the headers.
104
   */
105
  DECLARE_LUA_FUNCTION(RouteHandleWrapper, luaHeaders);
106
  DECLARE_LUA_FUNCTION(RouteHandleWrapper, luaGetCluster);
107

            
108
  const Http::HeaderMap& headers_;
109
  Upstream::ClusterManager& cm_;
110
  HeaderMapRef headers_wrapper_;
111
  std::vector<ClusterRef> clusters_;
112
};
113

            
114
using RouteHandleRef = Filters::Common::Lua::LuaDeathRef<RouteHandleWrapper>;
115

            
116
class LuaClusterSpecifierConfig : Logger::Loggable<Logger::Id::lua> {
117
public:
118
  LuaClusterSpecifierConfig(const LuaClusterSpecifierConfigProto& config,
119
                            Server::Configuration::CommonFactoryContext& context);
120

            
121
39
  PerLuaCodeSetup* perLuaCodeSetup() const { return per_lua_code_setup_ptr_.get(); }
122
5
  const std::string& defaultCluster() const { return default_cluster_; }
123
11
  Upstream::ClusterManager& clusterManager() { return cm_; }
124

            
125
private:
126
  Upstream::ClusterManager& cm_;
127
  PerLuaCodeSetupPtr per_lua_code_setup_ptr_;
128
  const std::string default_cluster_;
129
};
130

            
131
using LuaClusterSpecifierConfigSharedPtr = std::shared_ptr<LuaClusterSpecifierConfig>;
132

            
133
class LuaClusterSpecifierPlugin : public Envoy::Router::ClusterSpecifierPlugin,
134
                                  Logger::Loggable<Logger::Id::lua> {
135
public:
136
  LuaClusterSpecifierPlugin(LuaClusterSpecifierConfigSharedPtr config);
137
  Envoy::Router::RouteConstSharedPtr route(Envoy::Router::RouteEntryAndRouteConstSharedPtr parent,
138
                                           const Http::RequestHeaderMap& header,
139
                                           const StreamInfo::StreamInfo&, uint64_t) const override;
140

            
141
private:
142
  std::string startLua(const Http::HeaderMap& headers) const;
143

            
144
  LuaClusterSpecifierConfigSharedPtr config_;
145
  const int function_ref_;
146
};
147

            
148
} // namespace Lua
149
} // namespace Router
150
} // namespace Extensions
151
} // namespace Envoy