LCOV - code coverage report
Current view: top level - source/extensions/common/wasm - foreign.cc (source / functions) Hit Total Coverage
Test: coverage.dat Lines: 26 195 13.3 %
Date: 2024-01-05 06:35:25 Functions: 8 23 34.8 %

          Line data    Source code
       1             : #include "source/common/common/logger.h"
       2             : #include "source/extensions/common/wasm/ext/declare_property.pb.h"
       3             : #include "source/extensions/common/wasm/ext/set_envoy_filter_state.pb.h"
       4             : #include "source/extensions/common/wasm/wasm.h"
       5             : 
       6             : #if defined(WASM_USE_CEL_PARSER)
       7             : #include "eval/public/builtin_func_registrar.h"
       8             : #include "eval/public/cel_expr_builder_factory.h"
       9             : #include "parser/parser.h"
      10             : #endif
      11             : #include "zlib.h"
      12             : 
      13             : using proxy_wasm::RegisterForeignFunction;
      14             : using proxy_wasm::WasmForeignFunction;
      15             : 
      16             : namespace Envoy {
      17             : namespace Extensions {
      18             : namespace Common {
      19             : namespace Wasm {
      20             : 
      21             : using CelStateType = Filters::Common::Expr::CelStateType;
      22             : 
      23          12 : template <typename T> WasmForeignFunction createFromClass() {
      24          12 :   auto c = std::make_shared<T>();
      25          12 :   return c->create(c);
      26          12 : }
      27             : 
      28             : inline StreamInfo::FilterState::LifeSpan
      29           0 : toFilterStateLifeSpan(envoy::source::extensions::common::wasm::LifeSpan span) {
      30           0 :   switch (span) {
      31           0 :   case envoy::source::extensions::common::wasm::LifeSpan::FilterChain:
      32           0 :     return StreamInfo::FilterState::LifeSpan::FilterChain;
      33           0 :   case envoy::source::extensions::common::wasm::LifeSpan::DownstreamRequest:
      34           0 :     return StreamInfo::FilterState::LifeSpan::Request;
      35           0 :   case envoy::source::extensions::common::wasm::LifeSpan::DownstreamConnection:
      36           0 :     return StreamInfo::FilterState::LifeSpan::Connection;
      37           0 :   default:
      38           0 :     return StreamInfo::FilterState::LifeSpan::FilterChain;
      39           0 :   }
      40           0 : }
      41             : 
      42             : RegisterForeignFunction registerCompressForeignFunction(
      43             :     "compress",
      44             :     [](WasmBase&, std::string_view arguments,
      45           0 :        const std::function<void*(size_t size)>& alloc_result) -> WasmResult {
      46           0 :       unsigned long dest_len = compressBound(arguments.size());
      47           0 :       std::unique_ptr<unsigned char[]> b(new unsigned char[dest_len]);
      48           0 :       if (compress(b.get(), &dest_len, reinterpret_cast<const unsigned char*>(arguments.data()),
      49           0 :                    arguments.size()) != Z_OK) {
      50           0 :         return WasmResult::SerializationFailure;
      51           0 :       }
      52           0 :       auto result = alloc_result(dest_len);
      53           0 :       memcpy(result, b.get(), dest_len); // NOLINT(safe-memcpy)
      54           0 :       return WasmResult::Ok;
      55           0 :     });
      56             : 
      57             : RegisterForeignFunction registerUncompressForeignFunction(
      58             :     "uncompress",
      59             :     [](WasmBase&, std::string_view arguments,
      60           0 :        const std::function<void*(size_t size)>& alloc_result) -> WasmResult {
      61           0 :       unsigned long dest_len = arguments.size() * 2 + 2; // output estimate.
      62           0 :       while (true) {
      63           0 :         std::unique_ptr<unsigned char[]> b(new unsigned char[dest_len]);
      64           0 :         auto r =
      65           0 :             uncompress(b.get(), &dest_len, reinterpret_cast<const unsigned char*>(arguments.data()),
      66           0 :                        arguments.size());
      67           0 :         if (r == Z_OK) {
      68           0 :           auto result = alloc_result(dest_len);
      69           0 :           memcpy(result, b.get(), dest_len); // NOLINT(safe-memcpy)
      70           0 :           return WasmResult::Ok;
      71           0 :         }
      72           0 :         if (r != Z_BUF_ERROR) {
      73           0 :           return WasmResult::SerializationFailure;
      74           0 :         }
      75           0 :         dest_len = dest_len * 2;
      76           0 :       }
      77           0 :     });
      78             : 
      79             : RegisterForeignFunction registerSetEnvoyFilterStateForeignFunction(
      80             :     "set_envoy_filter_state",
      81             :     [](WasmBase&, std::string_view arguments,
      82           0 :        const std::function<void*(size_t size)>&) -> WasmResult {
      83           0 :       envoy::source::extensions::common::wasm::SetEnvoyFilterStateArguments args;
      84           0 :       if (args.ParseFromArray(arguments.data(), arguments.size())) {
      85           0 :         auto context = static_cast<Context*>(proxy_wasm::current_context_);
      86           0 :         return context->setEnvoyFilterState(args.path(), args.value(),
      87           0 :                                             toFilterStateLifeSpan(args.span()));
      88           0 :       }
      89           0 :       return WasmResult::BadArgument;
      90           0 :     });
      91             : 
      92             : #if defined(WASM_USE_CEL_PARSER)
      93             : class ExpressionFactory : public Logger::Loggable<Logger::Id::wasm> {
      94             : protected:
      95             :   struct ExpressionData {
      96             :     google::api::expr::v1alpha1::ParsedExpr parsed_expr_;
      97             :     Filters::Common::Expr::ExpressionPtr compiled_expr_;
      98             :   };
      99             : 
     100             :   class ExpressionContext : public StorageObject {
     101             :   public:
     102             :     friend class ExpressionFactory;
     103           0 :     ExpressionContext(Filters::Common::Expr::BuilderPtr builder) : builder_(std::move(builder)) {}
     104           0 :     uint32_t createToken() {
     105           0 :       uint32_t token = next_expr_token_++;
     106           0 :       for (;;) {
     107           0 :         if (!expr_.count(token)) {
     108           0 :           break;
     109           0 :         }
     110           0 :         token = next_expr_token_++;
     111           0 :       }
     112           0 :       return token;
     113           0 :     }
     114           0 :     bool hasExpression(uint32_t token) { return expr_.contains(token); }
     115           0 :     ExpressionData& getExpression(uint32_t token) { return expr_[token]; }
     116           0 :     void deleteExpression(uint32_t token) { expr_.erase(token); }
     117           0 :     Filters::Common::Expr::Builder* builder() { return builder_.get(); }
     118             : 
     119             :   private:
     120             :     Filters::Common::Expr::BuilderPtr builder_{};
     121             :     uint32_t next_expr_token_ = 0;
     122             :     absl::flat_hash_map<uint32_t, ExpressionData> expr_;
     123             :   };
     124             : 
     125           0 :   static ExpressionContext& getOrCreateContext(ContextBase* context_base) {
     126           0 :     auto context = static_cast<Context*>(context_base);
     127           0 :     std::string data_name = "cel";
     128           0 :     auto expr_context = context->getForeignData<ExpressionContext>(data_name);
     129           0 :     if (!expr_context) {
     130           0 :       google::api::expr::runtime::InterpreterOptions options;
     131           0 :       auto builder = google::api::expr::runtime::CreateCelExpressionBuilder(options);
     132           0 :       auto status =
     133           0 :           google::api::expr::runtime::RegisterBuiltinFunctions(builder->GetRegistry(), options);
     134           0 :       if (!status.ok()) {
     135           0 :         ENVOY_LOG(warn, "failed to register built-in functions: {}", status.message());
     136           0 :       }
     137           0 :       auto new_context = std::make_unique<ExpressionContext>(std::move(builder));
     138           0 :       expr_context = new_context.get();
     139           0 :       context->setForeignData(data_name, std::move(new_context));
     140           0 :     }
     141           0 :     return *expr_context;
     142           0 :   }
     143             : };
     144             : 
     145             : class CreateExpressionFactory : public ExpressionFactory {
     146             : public:
     147           3 :   WasmForeignFunction create(std::shared_ptr<CreateExpressionFactory> self) const {
     148           3 :     WasmForeignFunction f =
     149           3 :         [self](WasmBase&, std::string_view expr,
     150           3 :                const std::function<void*(size_t size)>& alloc_result) -> WasmResult {
     151           0 :       auto parse_status = google::api::expr::parser::Parse(std::string(expr));
     152           0 :       if (!parse_status.ok()) {
     153           0 :         ENVOY_LOG(info, "expr_create parse error: {}", parse_status.status().message());
     154           0 :         return WasmResult::BadArgument;
     155           0 :       }
     156             : 
     157           0 :       auto& expr_context = getOrCreateContext(proxy_wasm::current_context_->root_context());
     158           0 :       auto token = expr_context.createToken();
     159           0 :       auto& handler = expr_context.getExpression(token);
     160             : 
     161           0 :       handler.parsed_expr_ = parse_status.value();
     162           0 :       auto cel_expression_status = expr_context.builder()->CreateExpression(
     163           0 :           &handler.parsed_expr_.expr(), &handler.parsed_expr_.source_info());
     164           0 :       if (!cel_expression_status.ok()) {
     165           0 :         ENVOY_LOG(info, "expr_create compile error: {}", cel_expression_status.status().message());
     166           0 :         expr_context.deleteExpression(token);
     167           0 :         return WasmResult::BadArgument;
     168           0 :       }
     169             : 
     170           0 :       handler.compiled_expr_ = std::move(cel_expression_status.value());
     171           0 :       auto result = reinterpret_cast<uint32_t*>(alloc_result(sizeof(uint32_t)));
     172           0 :       *result = token;
     173           0 :       return WasmResult::Ok;
     174           0 :     };
     175           3 :     return f;
     176           3 :   }
     177             : };
     178             : RegisterForeignFunction
     179             :     registerCreateExpressionForeignFunction("expr_create",
     180             :                                             createFromClass<CreateExpressionFactory>());
     181             : 
     182             : class EvaluateExpressionFactory : public ExpressionFactory {
     183             : public:
     184           3 :   WasmForeignFunction create(std::shared_ptr<EvaluateExpressionFactory> self) const {
     185           3 :     WasmForeignFunction f =
     186           3 :         [self](WasmBase&, std::string_view argument,
     187           3 :                const std::function<void*(size_t size)>& alloc_result) -> WasmResult {
     188           0 :       auto& expr_context = getOrCreateContext(proxy_wasm::current_context_->root_context());
     189           0 :       if (argument.size() != sizeof(uint32_t)) {
     190           0 :         return WasmResult::BadArgument;
     191           0 :       }
     192           0 :       uint32_t token = *reinterpret_cast<const uint32_t*>(argument.data());
     193           0 :       if (!expr_context.hasExpression(token)) {
     194           0 :         return WasmResult::NotFound;
     195           0 :       }
     196           0 :       Protobuf::Arena arena;
     197           0 :       auto& handler = expr_context.getExpression(token);
     198           0 :       auto context = static_cast<Context*>(proxy_wasm::current_context_);
     199           0 :       auto eval_status = handler.compiled_expr_->Evaluate(*context, &arena);
     200           0 :       if (!eval_status.ok()) {
     201           0 :         ENVOY_LOG(debug, "expr_evaluate error: {}", eval_status.status().message());
     202           0 :         return WasmResult::InternalFailure;
     203           0 :       }
     204           0 :       auto value = eval_status.value();
     205           0 :       if (value.IsError()) {
     206           0 :         ENVOY_LOG(debug, "expr_evaluate value error: {}", value.ErrorOrDie()->message());
     207           0 :         return WasmResult::InternalFailure;
     208           0 :       }
     209           0 :       std::string result;
     210           0 :       auto serialize_status = serializeValue(value, &result);
     211           0 :       if (serialize_status != WasmResult::Ok) {
     212           0 :         return serialize_status;
     213           0 :       }
     214           0 :       auto output = alloc_result(result.size());
     215           0 :       memcpy(output, result.data(), result.size()); // NOLINT(safe-memcpy)
     216           0 :       return WasmResult::Ok;
     217           0 :     };
     218           3 :     return f;
     219           3 :   }
     220             : };
     221             : RegisterForeignFunction
     222             :     registerEvaluateExpressionForeignFunction("expr_evaluate",
     223             :                                               createFromClass<EvaluateExpressionFactory>());
     224             : 
     225             : class DeleteExpressionFactory : public ExpressionFactory {
     226             : public:
     227           3 :   WasmForeignFunction create(std::shared_ptr<DeleteExpressionFactory> self) const {
     228           3 :     WasmForeignFunction f = [self](WasmBase&, std::string_view argument,
     229           3 :                                    const std::function<void*(size_t size)>&) -> WasmResult {
     230           0 :       auto& expr_context = getOrCreateContext(proxy_wasm::current_context_->root_context());
     231           0 :       if (argument.size() != sizeof(uint32_t)) {
     232           0 :         return WasmResult::BadArgument;
     233           0 :       }
     234           0 :       uint32_t token = *reinterpret_cast<const uint32_t*>(argument.data());
     235           0 :       expr_context.deleteExpression(token);
     236           0 :       return WasmResult::Ok;
     237           0 :     };
     238           3 :     return f;
     239           3 :   }
     240             : };
     241             : RegisterForeignFunction
     242             :     registerDeleteExpressionForeignFunction("expr_delete",
     243             :                                             createFromClass<DeleteExpressionFactory>());
     244             : #endif
     245             : 
     246             : // TODO(kyessenov) The factories should be separated into individual compilation units.
     247             : // TODO(kyessenov) Leverage the host argument marshaller instead of the protobuf argument list.
     248             : class DeclarePropertyFactory {
     249             : public:
     250           3 :   WasmForeignFunction create(std::shared_ptr<DeclarePropertyFactory> self) const {
     251           3 :     WasmForeignFunction f = [self](WasmBase&, std::string_view arguments,
     252           3 :                                    const std::function<void*(size_t size)>&) -> WasmResult {
     253           0 :       envoy::source::extensions::common::wasm::DeclarePropertyArguments args;
     254           0 :       if (args.ParseFromArray(arguments.data(), arguments.size())) {
     255           0 :         CelStateType type = CelStateType::Bytes;
     256           0 :         switch (args.type()) {
     257           0 :         case envoy::source::extensions::common::wasm::WasmType::Bytes:
     258           0 :           type = CelStateType::Bytes;
     259           0 :           break;
     260           0 :         case envoy::source::extensions::common::wasm::WasmType::Protobuf:
     261           0 :           type = CelStateType::Protobuf;
     262           0 :           break;
     263           0 :         case envoy::source::extensions::common::wasm::WasmType::String:
     264           0 :           type = CelStateType::String;
     265           0 :           break;
     266           0 :         case envoy::source::extensions::common::wasm::WasmType::FlatBuffers:
     267           0 :           type = CelStateType::FlatBuffers;
     268           0 :           break;
     269           0 :         default:
     270             :           // do nothing
     271           0 :           break;
     272           0 :         }
     273           0 :         StreamInfo::FilterState::LifeSpan span = toFilterStateLifeSpan(args.span());
     274           0 :         auto context = static_cast<Context*>(proxy_wasm::current_context_);
     275           0 :         return context->declareProperty(
     276           0 :             args.name(), std::make_unique<const Filters::Common::Expr::CelStatePrototype>(
     277           0 :                              args.readonly(), type, args.schema(), span));
     278           0 :       }
     279           0 :       return WasmResult::BadArgument;
     280           0 :     };
     281           3 :     return f;
     282           3 :   }
     283             : };
     284             : RegisterForeignFunction
     285             :     registerDeclarePropertyForeignFunction("declare_property",
     286             :                                            createFromClass<DeclarePropertyFactory>());
     287             : 
     288             : } // namespace Wasm
     289             : } // namespace Common
     290             : } // namespace Extensions
     291             : } // namespace Envoy

Generated by: LCOV version 1.15