1
#include "source/extensions/string_matcher/lua/match.h"
2

            
3
#include "envoy/extensions/string_matcher/lua/v3/lua.pb.h"
4
#include "envoy/extensions/string_matcher/lua/v3/lua.pb.validate.h"
5

            
6
#include "source/common/config/datasource.h"
7
#include "source/common/config/utility.h"
8
#include "source/common/protobuf/message_validator_impl.h"
9

            
10
namespace Envoy {
11
namespace Extensions {
12
namespace StringMatcher {
13
namespace Lua {
14

            
15
11
LuaStringMatcher::LuaStringMatcher(const std::string& code) : state_(luaL_newstate()) {
16
11
  RELEASE_ASSERT(state_.get() != nullptr, "unable to create new Lua state object");
17
11
  luaL_openlibs(state_.get());
18
11
  int rc = luaL_dostring(state_.get(), code.c_str());
19
11
  if (rc != 0) {
20
1
    absl::string_view error("unknown");
21
1
    if (lua_isstring(state_.get(), -1)) {
22
1
      size_t len = 0;
23
1
      const char* err = lua_tolstring(state_.get(), -1, &len);
24
1
      error = absl::string_view(err, len);
25
1
    }
26
1
    throw EnvoyException(absl::StrCat("Failed to load lua code in Lua StringMatcher:", error));
27
1
  }
28

            
29
10
  lua_getglobal(state_.get(), "envoy_match");
30
10
  bool is_function = lua_isfunction(state_.get(), -1);
31
10
  if (!is_function) {
32
1
    throw EnvoyException("Lua code did not contain a global function named 'envoy_match'");
33
1
  }
34
9
  matcher_func_ref_ = luaL_ref(state_.get(), LUA_REGISTRYINDEX);
35
9
}
36

            
37
8
bool LuaStringMatcher::match(absl::string_view value) const {
38
8
  const int initial_depth = lua_gettop(state_.get());
39

            
40
8
  bool ret = [&]() {
41
8
    lua_rawgeti(state_.get(), LUA_REGISTRYINDEX, matcher_func_ref_);
42
8
    ASSERT(lua_isfunction(state_.get(), -1)); // Validated in constructor
43

            
44
8
    lua_pushlstring(state_.get(), value.data(), value.size());
45
8
    int rc = lua_pcall(state_.get(), 1, 1, 0);
46
8
    if (rc != 0) {
47
      // Runtime error
48
1
      absl::string_view error("unknown");
49
1
      if (lua_isstring(state_.get(), -1)) {
50
1
        size_t len = 0;
51
1
        const char* err = lua_tolstring(state_.get(), -1, &len);
52
1
        error = absl::string_view(err, len);
53
1
      }
54
1
      ENVOY_LOG_PERIODIC_MISC(error, std::chrono::seconds(5),
55
1
                              "Lua StringMatcher error running script: {}", error);
56
1
      lua_pop(state_.get(), 1);
57

            
58
1
      return false;
59
1
    }
60

            
61
7
    bool ret = false;
62
7
    if (lua_isboolean(state_.get(), -1)) {
63
6
      ret = lua_toboolean(state_.get(), -1) != 0;
64
6
    } else {
65
1
      ENVOY_LOG_PERIODIC_MISC(error, std::chrono::seconds(5),
66
1
                              "Lua StringMatcher match function did not return a boolean");
67
1
    }
68

            
69
7
    lua_pop(state_.get(), 1);
70
7
    return ret;
71
8
  }();
72

            
73
  // Validate that the stack is restored to it's original state; nothing added or removed.
74
8
  ASSERT(lua_gettop(state_.get()) == initial_depth);
75
8
  return ret;
76
8
}
77

            
78
// Lua state is not thread safe, so a state needs to be stored in thread local storage.
79
class LuaStringMatcherThreadWrapper : public Matchers::StringMatcher {
80
public:
81
1
  LuaStringMatcherThreadWrapper(const std::string& code, ThreadLocal::SlotAllocator& tls) {
82
    // Validate that there are no errors while creating on the main thread.
83
1
    LuaStringMatcher validator(code);
84

            
85
1
    tls_slot_ = ThreadLocal::TypedSlot<LuaStringMatcher>::makeUnique(tls);
86
2
    tls_slot_->set([code](Event::Dispatcher&) -> std::shared_ptr<LuaStringMatcher> {
87
2
      return std::make_shared<LuaStringMatcher>(code);
88
2
    });
89
1
  }
90

            
91
  // To avoid hiding other implementations of match.
92
  using Matchers::StringMatcher::match;
93

            
94
2
  bool match(absl::string_view value) const override { return (*tls_slot_)->match(value); }
95

            
96
private:
97
  ThreadLocal::TypedSlotPtr<LuaStringMatcher> tls_slot_;
98
};
99

            
100
Matchers::StringMatcherPtr
101
LuaStringMatcherFactory::createStringMatcher(const Protobuf::Message& untyped_config,
102
3
                                             Server::Configuration::CommonFactoryContext& context) {
103
3
  const auto& config =
104
3
      MessageUtil::downcastAndValidate<const ::envoy::extensions::string_matcher::lua::v3::Lua&>(
105
3
          untyped_config, context.messageValidationContext().staticValidationVisitor());
106

            
107
3
  absl::StatusOr<std::string> result = Config::DataSource::read(
108
3
      config.source_code(), false /* allow_empty */, context.api(), 0 /* max_size */);
109
3
  if (!result.ok()) {
110
1
    throw EnvoyException(
111
1
        fmt::format("Failed to get lua string matcher code from source: {}", result.status()));
112
1
  }
113
2
  return std::make_unique<LuaStringMatcherThreadWrapper>(*result, context.threadLocal());
114
3
}
115

            
116
3
ProtobufTypes::MessagePtr LuaStringMatcherFactory::createEmptyConfigProto() {
117
3
  return std::make_unique<::envoy::extensions::string_matcher::lua::v3::Lua>();
118
3
}
119

            
120
REGISTER_FACTORY(LuaStringMatcherFactory, Matchers::StringMatcherExtensionFactory);
121

            
122
} // namespace Lua
123
} // namespace StringMatcher
124
} // namespace Extensions
125
} // namespace Envoy