1
#include "source/common/common/logger.h"
2
#include "source/extensions/common/wasm/ext/declare_property.pb.h"
3
#include "source/extensions/common/wasm/ext/set_envoy_filter_state.pb.h"
4
#include "source/extensions/common/wasm/ext/sign.pb.h"
5
#include "source/extensions/common/wasm/ext/verify_signature.pb.h"
6
#include "source/extensions/common/wasm/wasm.h"
7

            
8
#if defined(WASM_USE_CEL_PARSER)
9
#include "eval/public/builtin_func_registrar.h"
10
#include "eval/public/cel_expr_builder_factory.h"
11
#include "parser/parser.h"
12
#endif
13
#include "source/common/crypto/crypto_impl.h"
14
#include "source/common/crypto/utility.h"
15

            
16
#include "absl/types/span.h"
17
#include "zlib.h"
18

            
19
using proxy_wasm::RegisterForeignFunction;
20
using proxy_wasm::WasmForeignFunction;
21

            
22
namespace Envoy {
23

            
24
namespace {
25
// Helper function to import public key from either PEM or DER format
26
Envoy::Common::Crypto::PKeyObjectPtr
27
importPublicKey(Envoy::Common::Crypto::Utility& crypto_util,
28
12
                const envoy::source::extensions::common::wasm::VerifySignatureArguments& args) {
29
12
  bool has_pem = !args.public_key_pem().empty();
30
12
  bool has_der = !args.public_key().empty();
31

            
32
12
  if (has_pem && has_der) {
33
2
    return nullptr; // Both PEM and DER keys provided
34
2
  }
35

            
36
10
  if (has_pem) {
37
    return crypto_util.importPublicKeyPEM(args.public_key_pem());
38
10
  } else if (has_der) {
39
10
    auto key_str = args.public_key();
40
10
    return crypto_util.importPublicKeyDER(
41
10
        absl::MakeSpan(reinterpret_cast<const uint8_t*>(key_str.data()), key_str.size()));
42
10
  } else {
43
    return nullptr; // No key provided
44
  }
45
10
}
46

            
47
// Helper function to import private key from either PEM or DER format
48
Envoy::Common::Crypto::PKeyObjectPtr
49
importPrivateKey(Envoy::Common::Crypto::Utility& crypto_util,
50
8
                 const envoy::source::extensions::common::wasm::SignArguments& args) {
51
8
  bool has_pem = !args.private_key_pem().empty();
52
8
  bool has_der = !args.private_key().empty();
53

            
54
8
  if (has_pem && has_der) {
55
    return nullptr; // Both PEM and DER keys provided
56
  }
57

            
58
8
  if (has_pem) {
59
    return crypto_util.importPrivateKeyPEM(args.private_key_pem());
60
8
  } else if (has_der) {
61
8
    auto key_str = args.private_key();
62
8
    return crypto_util.importPrivateKeyDER(
63
8
        absl::MakeSpan(reinterpret_cast<const uint8_t*>(key_str.data()), key_str.size()));
64
8
  } else {
65
    return nullptr; // No key provided
66
  }
67
8
}
68
} // namespace
69
namespace Extensions {
70
namespace Common {
71
namespace Wasm {
72

            
73
using CelStateType = Filters::Common::Expr::CelStateType;
74

            
75
392
template <typename T> WasmForeignFunction createFromClass() {
76
392
  auto c = std::make_shared<T>();
77
392
  return c->create(c);
78
392
}
79

            
80
inline StreamInfo::FilterState::LifeSpan
81
13
toFilterStateLifeSpan(envoy::source::extensions::common::wasm::LifeSpan span) {
82
13
  switch (span) {
83
5
  case envoy::source::extensions::common::wasm::LifeSpan::FilterChain:
84
5
    return StreamInfo::FilterState::LifeSpan::FilterChain;
85
6
  case envoy::source::extensions::common::wasm::LifeSpan::DownstreamRequest:
86
6
    return StreamInfo::FilterState::LifeSpan::Request;
87
2
  case envoy::source::extensions::common::wasm::LifeSpan::DownstreamConnection:
88
2
    return StreamInfo::FilterState::LifeSpan::Connection;
89
  default:
90
    return StreamInfo::FilterState::LifeSpan::FilterChain;
91
13
  }
92
13
}
93

            
94
RegisterForeignFunction registerVerifySignatureForeignFunction(
95
    "verify_signature",
96
    [](WasmBase&, std::string_view arguments,
97
12
       const std::function<void*(size_t size)>& alloc_result) -> WasmResult {
98
12
      envoy::source::extensions::common::wasm::VerifySignatureArguments args;
99
12
      if (args.ParseFromString(arguments)) {
100
12
        const auto& hash = args.hash_function();
101
12
        auto signature_str = args.signature();
102
12
        auto text_str = args.text();
103

            
104
12
        auto& crypto_util = Envoy::Common::Crypto::UtilitySingleton::get();
105
12
        auto crypto_ptr = importPublicKey(crypto_util, args);
106
12
        if (!crypto_ptr) {
107
2
          return WasmResult::BadArgument;
108
2
        }
109

            
110
10
        auto output = crypto_util.verifySignature(
111
10
            hash, *crypto_ptr,
112
10
            absl::MakeSpan(reinterpret_cast<const uint8_t*>(signature_str.data()),
113
10
                           signature_str.size()),
114
10
            absl::MakeSpan(reinterpret_cast<const uint8_t*>(text_str.data()), text_str.size()));
115

            
116
10
        envoy::source::extensions::common::wasm::VerifySignatureResult verification_result;
117
10
        if (output.ok()) {
118
4
          verification_result.set_result(true);
119
4
          verification_result.set_error("");
120
8
        } else {
121
6
          verification_result.set_result(false);
122
6
          verification_result.set_error(output.message());
123
6
        }
124

            
125
10
        auto size = verification_result.ByteSizeLong();
126
10
        auto result = alloc_result(size);
127
10
        verification_result.SerializeToArray(result, static_cast<int>(size));
128
10
        return WasmResult::Ok;
129
12
      }
130
      return WasmResult::BadArgument;
131
12
    });
132

            
133
RegisterForeignFunction registerSignForeignFunction(
134
    "sign",
135
    [](WasmBase&, std::string_view arguments,
136
8
       const std::function<void*(size_t size)>& alloc_result) -> WasmResult {
137
8
      envoy::source::extensions::common::wasm::SignArguments args;
138
8
      if (args.ParseFromString(arguments)) {
139
8
        const auto& hash = args.hash_function();
140
8
        auto text_str = args.text();
141

            
142
8
        auto& crypto_util = Envoy::Common::Crypto::UtilitySingleton::get();
143
8
        auto crypto_ptr = importPrivateKey(crypto_util, args);
144
8
        if (!crypto_ptr) {
145
          return WasmResult::BadArgument;
146
        }
147

            
148
8
        auto output = crypto_util.sign(
149
8
            hash, *crypto_ptr,
150
8
            absl::MakeSpan(reinterpret_cast<const uint8_t*>(text_str.data()), text_str.size()));
151

            
152
8
        envoy::source::extensions::common::wasm::SignResult signing_result;
153
8
        if (output.ok()) {
154
4
          signing_result.set_result(true);
155
4
          signing_result.set_signature(output->data(), output->size());
156
4
          signing_result.set_error("");
157
6
        } else {
158
4
          signing_result.set_result(false);
159
4
          signing_result.set_error(output.status().message());
160
4
        }
161

            
162
8
        auto size = signing_result.ByteSizeLong();
163
8
        auto result = alloc_result(size);
164
8
        signing_result.SerializeToArray(result, static_cast<int>(size));
165
8
        return WasmResult::Ok;
166
8
      }
167
      return WasmResult::BadArgument;
168
8
    });
169

            
170
RegisterForeignFunction registerCompressForeignFunction(
171
    "compress",
172
    [](WasmBase&, std::string_view arguments,
173
2
       const std::function<void*(size_t size)>& alloc_result) -> WasmResult {
174
2
      unsigned long dest_len = compressBound(arguments.size());
175
2
      std::unique_ptr<unsigned char[]> b(new unsigned char[dest_len]);
176
2
      if (compress(b.get(), &dest_len, reinterpret_cast<const unsigned char*>(arguments.data()),
177
2
                   arguments.size()) != Z_OK) {
178
        return WasmResult::SerializationFailure;
179
      }
180
2
      auto result = alloc_result(dest_len);
181
2
      memcpy(result, b.get(), dest_len); // NOLINT(safe-memcpy)
182
2
      return WasmResult::Ok;
183
2
    });
184

            
185
RegisterForeignFunction registerUncompressForeignFunction(
186
    "uncompress",
187
    [](WasmBase&, std::string_view arguments,
188
4
       const std::function<void*(size_t size)>& alloc_result) -> WasmResult {
189
4
      unsigned long dest_len = arguments.size() * 2 + 2; // output estimate.
190
16
      while (true) {
191
16
        std::unique_ptr<unsigned char[]> b(new unsigned char[dest_len]);
192
16
        auto r =
193
16
            uncompress(b.get(), &dest_len, reinterpret_cast<const unsigned char*>(arguments.data()),
194
16
                       arguments.size());
195
16
        if (r == Z_OK) {
196
2
          auto result = alloc_result(dest_len);
197
2
          memcpy(result, b.get(), dest_len); // NOLINT(safe-memcpy)
198
2
          return WasmResult::Ok;
199
2
        }
200
14
        if (r != Z_BUF_ERROR) {
201
2
          return WasmResult::SerializationFailure;
202
2
        }
203
12
        dest_len = dest_len * 2;
204
12
      }
205
4
    });
206

            
207
RegisterForeignFunction registerSetEnvoyFilterStateForeignFunction(
208
    "set_envoy_filter_state",
209
    [](WasmBase&, std::string_view arguments,
210
4
       const std::function<void*(size_t size)>&) -> WasmResult {
211
4
      envoy::source::extensions::common::wasm::SetEnvoyFilterStateArguments args;
212
4
      if (args.ParseFromString(arguments)) {
213
3
        auto context = static_cast<Context*>(
214
3
            Runtime::runtimeFeatureEnabled(
215
3
                "envoy.reloadable_features.wasm_use_effective_ctx_for_foreign_functions")
216
3
                ? proxy_wasm::contextOrEffectiveContext()
217
3
                : proxy_wasm::current_context_);
218
3
        return context->setEnvoyFilterState(args.path(), args.value(),
219
3
                                            toFilterStateLifeSpan(args.span()));
220
3
      }
221
1
      return WasmResult::BadArgument;
222
4
    });
223

            
224
RegisterForeignFunction registerClearRouteCacheForeignFunction(
225
    "clear_route_cache",
226
2
    [](WasmBase&, std::string_view, const std::function<void*(size_t size)>&) -> WasmResult {
227
2
      auto context = static_cast<Context*>(
228
2
          Runtime::runtimeFeatureEnabled(
229
2
              "envoy.reloadable_features.wasm_use_effective_ctx_for_foreign_functions")
230
2
              ? proxy_wasm::contextOrEffectiveContext()
231
2
              : proxy_wasm::current_context_);
232
2
      context->clearRouteCache();
233
2
      return WasmResult::Ok;
234
2
    });
235

            
236
#if defined(WASM_USE_CEL_PARSER)
237
class ExpressionFactory : public Logger::Loggable<Logger::Id::wasm> {
238
protected:
239
  struct ExpressionData {
240
    cel::expr::ParsedExpr parsed_expr_;
241
    Filters::Common::Expr::ExpressionPtr compiled_expr_;
242
  };
243

            
244
  class ExpressionContext : public StorageObject {
245
  public:
246
    friend class ExpressionFactory;
247
    ExpressionContext(Filters::Common::Expr::BuilderConstPtr builder)
248
3
        : builder_(std::move(builder)) {}
249
10
    uint32_t createToken() {
250
10
      uint32_t token = next_expr_token_++;
251
10
      for (;;) {
252
10
        if (!expr_.count(token)) {
253
10
          break;
254
10
        }
255
        token = next_expr_token_++;
256
      }
257
10
      return token;
258
10
    }
259
13
    bool hasExpression(uint32_t token) { return expr_.contains(token); }
260
18
    ExpressionData& getExpression(uint32_t token) { return expr_[token]; }
261
12
    void deleteExpression(uint32_t token) { expr_.erase(token); }
262
10
    const Filters::Common::Expr::Builder* builder() const { return builder_.get(); }
263

            
264
  private:
265
    const Filters::Common::Expr::BuilderConstPtr builder_{};
266
    uint32_t next_expr_token_ = 0;
267
    absl::flat_hash_map<uint32_t, ExpressionData> expr_;
268
  };
269

            
270
35
  static ExpressionContext& getOrCreateContext(ContextBase* context_base) {
271
35
    auto context = static_cast<Context*>(context_base);
272
35
    std::string data_name = "cel";
273
35
    auto expr_context = context->getForeignData<ExpressionContext>(data_name);
274
35
    if (!expr_context) {
275
3
      google::api::expr::runtime::InterpreterOptions options;
276
3
      auto builder = google::api::expr::runtime::CreateCelExpressionBuilder(options);
277
3
      auto status =
278
3
          google::api::expr::runtime::RegisterBuiltinFunctions(builder->GetRegistry(), options);
279
3
      if (!status.ok()) {
280
        ENVOY_LOG(warn, "failed to register built-in functions: {}", status.message());
281
      }
282
3
      auto new_context = std::make_unique<ExpressionContext>(std::move(builder));
283
3
      expr_context = new_context.get();
284
3
      context->setForeignData(data_name, std::move(new_context));
285
3
    }
286
35
    return *expr_context;
287
35
  }
288
};
289

            
290
class CreateExpressionFactory : public ExpressionFactory {
291
public:
292
98
  WasmForeignFunction create(std::shared_ptr<CreateExpressionFactory> self) const {
293
98
    WasmForeignFunction f =
294
98
        [self](WasmBase&, std::string_view expr,
295
108
               const std::function<void*(size_t size)>& alloc_result) -> WasmResult {
296
12
      auto parse_status = google::api::expr::parser::Parse(std::string(expr));
297
12
      if (!parse_status.ok()) {
298
2
        ENVOY_LOG(info, "expr_create parse error: {}", parse_status.status().message());
299
2
        return WasmResult::BadArgument;
300
2
      }
301

            
302
10
      auto& expr_context = getOrCreateContext(proxy_wasm::current_context_->root_context());
303
10
      auto token = expr_context.createToken();
304
10
      auto& handler = expr_context.getExpression(token);
305

            
306
10
      const auto& parsed_expr = parse_status.value();
307
10
      handler.parsed_expr_ = parsed_expr;
308

            
309
10
      std::vector<absl::Status> warnings;
310
10
      auto cel_expression_status = expr_context.builder()->CreateExpression(
311
10
          &handler.parsed_expr_.expr(), &handler.parsed_expr_.source_info(), &warnings);
312

            
313
10
      if (!cel_expression_status.ok()) {
314
2
        ENVOY_LOG(info, "expr_create compile error: {}", cel_expression_status.status().message());
315
2
        expr_context.deleteExpression(token);
316
2
        return WasmResult::BadArgument;
317
2
      }
318

            
319
8
      handler.compiled_expr_ = std::move(cel_expression_status.value());
320

            
321
8
      auto result = reinterpret_cast<uint32_t*>(alloc_result(sizeof(uint32_t)));
322
8
      *result = token;
323
8
      return WasmResult::Ok;
324
10
    };
325
98
    return f;
326
98
  }
327
};
328
RegisterForeignFunction
329
    registerCreateExpressionForeignFunction("expr_create",
330
                                            createFromClass<CreateExpressionFactory>());
331

            
332
class EvaluateExpressionFactory : public ExpressionFactory {
333
public:
334
98
  WasmForeignFunction create(std::shared_ptr<EvaluateExpressionFactory> self) const {
335
98
    WasmForeignFunction f =
336
98
        [self](WasmBase&, std::string_view argument,
337
109
               const std::function<void*(size_t size)>& alloc_result) -> WasmResult {
338
14
      auto& expr_context = getOrCreateContext(proxy_wasm::current_context_->root_context());
339
14
      if (argument.size() != sizeof(uint32_t)) {
340
1
        return WasmResult::BadArgument;
341
1
      }
342
13
      uint32_t token = *reinterpret_cast<const uint32_t*>(argument.data());
343
13
      if (!expr_context.hasExpression(token)) {
344
5
        return WasmResult::NotFound;
345
5
      }
346
8
      Protobuf::Arena arena;
347
8
      auto& handler = expr_context.getExpression(token);
348
8
      auto context = static_cast<Context*>(proxy_wasm::current_context_);
349
8
      auto eval_status = handler.compiled_expr_->Evaluate(*context, &arena);
350
8
      if (!eval_status.ok()) {
351
        ENVOY_LOG(debug, "expr_evaluate error: {}", eval_status.status().message());
352
        return WasmResult::InternalFailure;
353
      }
354
8
      auto value = eval_status.value();
355
8
      if (value.IsError()) {
356
4
        ENVOY_LOG(debug, "expr_evaluate value error: {}", value.ErrorOrDie()->message());
357
4
        return WasmResult::InternalFailure;
358
4
      }
359
4
      std::string result;
360
4
      auto serialize_status = serializeValue(value, &result);
361
4
      if (serialize_status != WasmResult::Ok) {
362
        return serialize_status;
363
      }
364
4
      auto output = alloc_result(result.size());
365
4
      memcpy(output, result.data(), result.size()); // NOLINT(safe-memcpy)
366
4
      return WasmResult::Ok;
367
4
    };
368
98
    return f;
369
98
  }
370
};
371
RegisterForeignFunction
372
    registerEvaluateExpressionForeignFunction("expr_evaluate",
373
                                              createFromClass<EvaluateExpressionFactory>());
374

            
375
class DeleteExpressionFactory : public ExpressionFactory {
376
public:
377
98
  WasmForeignFunction create(std::shared_ptr<DeleteExpressionFactory> self) const {
378
98
    WasmForeignFunction f = [self](WasmBase&, std::string_view argument,
379
106
                                   const std::function<void*(size_t size)>&) -> WasmResult {
380
11
      auto& expr_context = getOrCreateContext(proxy_wasm::current_context_->root_context());
381
11
      if (argument.size() != sizeof(uint32_t)) {
382
1
        return WasmResult::BadArgument;
383
1
      }
384
10
      uint32_t token = *reinterpret_cast<const uint32_t*>(argument.data());
385
10
      expr_context.deleteExpression(token);
386
10
      return WasmResult::Ok;
387
11
    };
388
98
    return f;
389
98
  }
390
};
391
RegisterForeignFunction
392
    registerDeleteExpressionForeignFunction("expr_delete",
393
                                            createFromClass<DeleteExpressionFactory>());
394
#endif
395

            
396
// TODO(kyessenov) The factories should be separated into individual compilation units.
397
// TODO(kyessenov) Leverage the host argument marshaller instead of the protobuf argument list.
398
class DeclarePropertyFactory {
399
public:
400
98
  WasmForeignFunction create(std::shared_ptr<DeclarePropertyFactory> self) const {
401
98
    WasmForeignFunction f = [self](WasmBase&, std::string_view arguments,
402
108
                                   const std::function<void*(size_t size)>&) -> WasmResult {
403
12
      envoy::source::extensions::common::wasm::DeclarePropertyArguments args;
404
12
      if (args.ParseFromString(arguments)) {
405
10
        CelStateType type = CelStateType::Bytes;
406
10
        switch (args.type()) {
407
4
        case envoy::source::extensions::common::wasm::WasmType::Bytes:
408
4
          type = CelStateType::Bytes;
409
4
          break;
410
2
        case envoy::source::extensions::common::wasm::WasmType::Protobuf:
411
2
          type = CelStateType::Protobuf;
412
2
          break;
413
2
        case envoy::source::extensions::common::wasm::WasmType::String:
414
2
          type = CelStateType::String;
415
2
          break;
416
2
        case envoy::source::extensions::common::wasm::WasmType::FlatBuffers:
417
2
          type = CelStateType::FlatBuffers;
418
2
          break;
419
        default:
420
          // do nothing
421
          break;
422
10
        }
423
10
        StreamInfo::FilterState::LifeSpan span = toFilterStateLifeSpan(args.span());
424
10
        auto context = static_cast<Context*>(proxy_wasm::current_context_);
425
10
        return context->declareProperty(
426
10
            args.name(), std::make_unique<const Filters::Common::Expr::CelStatePrototype>(
427
10
                             args.readonly(), type, args.schema(), span));
428
10
      }
429
2
      return WasmResult::BadArgument;
430
12
    };
431
98
    return f;
432
98
  }
433
};
434
RegisterForeignFunction
435
    registerDeclarePropertyForeignFunction("declare_property",
436
                                           createFromClass<DeclarePropertyFactory>());
437

            
438
} // namespace Wasm
439
} // namespace Common
440
} // namespace Extensions
441
} // namespace Envoy