1
#include "source/extensions/common/wasm/context.h"
2

            
3
#include <algorithm>
4
#include <cctype>
5
#include <cstring>
6
#include <ctime>
7
#include <limits>
8
#include <memory>
9
#include <string>
10

            
11
#include "envoy/common/exception.h"
12
#include "envoy/extensions/wasm/v3/wasm.pb.validate.h"
13
#include "envoy/grpc/status.h"
14
#include "envoy/http/codes.h"
15
#include "envoy/local_info/local_info.h"
16
#include "envoy/network/filter.h"
17
#include "envoy/stats/sink.h"
18
#include "envoy/thread_local/thread_local.h"
19

            
20
#include "source/common/buffer/buffer_impl.h"
21
#include "source/common/common/assert.h"
22
#include "source/common/common/empty_string.h"
23
#include "source/common/common/enum_to_int.h"
24
#include "source/common/common/logger.h"
25
#include "source/common/common/safe_memcpy.h"
26
#include "source/common/http/header_map_impl.h"
27
#include "source/common/http/message_impl.h"
28
#include "source/common/http/utility.h"
29
#include "source/common/protobuf/utility.h"
30
#include "source/common/tracing/http_tracer_impl.h"
31
#include "source/extensions/common/wasm/plugin.h"
32
#include "source/extensions/common/wasm/wasm.h"
33
#include "source/extensions/filters/common/expr/context.h"
34

            
35
#include "absl/base/casts.h"
36
#include "absl/container/flat_hash_map.h"
37
#include "absl/container/node_hash_map.h"
38
#include "absl/strings/str_cat.h"
39
#include "absl/synchronization/mutex.h"
40

            
41
#if defined(__GNUC__)
42
#pragma GCC diagnostic push
43
#pragma GCC diagnostic ignored "-Woverloaded-virtual"
44
#endif
45

            
46
#include "eval/public/cel_value.h"
47
#include "eval/public/containers/field_access.h"
48
#include "eval/public/containers/field_backed_list_impl.h"
49
#include "eval/public/containers/field_backed_map_impl.h"
50
#include "eval/public/structs/cel_proto_wrapper.h"
51

            
52
#if defined(__GNUC__)
53
#pragma GCC diagnostic pop
54
#endif
55

            
56
#include "include/proxy-wasm/pairs_util.h"
57
#include "openssl/bytestring.h"
58
#include "openssl/hmac.h"
59
#include "openssl/sha.h"
60

            
61
using proxy_wasm::MetricType;
62
using proxy_wasm::Word;
63

            
64
namespace Envoy {
65
namespace Extensions {
66
namespace Common {
67
namespace Wasm {
68

            
69
namespace {
70

            
71
// FilterState prefix for CelState values.
72
constexpr absl::string_view CelStateKeyPrefix = "wasm.";
73

            
74
// Default behavior for Proxy-Wasm 0.2.* ABI is to not support StopIteration as
75
// a return value from onRequestHeaders() or onResponseHeaders() plugin
76
// callbacks.
77
constexpr bool DefaultAllowOnHeadersStopIteration = false;
78

            
79
using HashPolicy = envoy::config::route::v3::RouteAction::HashPolicy;
80
using CelState = Filters::Common::Expr::CelState;
81
using CelStatePrototype = Filters::Common::Expr::CelStatePrototype;
82

            
83
15
Http::RequestTrailerMapPtr buildRequestTrailerMapFromPairs(const Pairs& pairs) {
84
15
  auto map = Http::RequestTrailerMapImpl::create();
85
15
  for (auto& p : pairs) {
86
    // Note: because of the lack of a string_view interface for addCopy and
87
    // the lack of an interface to add an entry with an empty value and return
88
    // the entry, there is no efficient way to prevent either a double copy
89
    // of the value or a double lookup of the entry.
90
15
    map->addCopy(Http::LowerCaseString(std::string(p.first)), std::string(p.second));
91
15
  }
92
15
  return map;
93
15
}
94

            
95
116
Http::RequestHeaderMapPtr buildRequestHeaderMapFromPairs(const Pairs& pairs) {
96
116
  auto map = Http::RequestHeaderMapImpl::create();
97
178
  for (auto& p : pairs) {
98
    // Note: because of the lack of a string_view interface for addCopy and
99
    // the lack of an interface to add an entry with an empty value and return
100
    // the entry, there is no efficient way to prevent either a double copy
101
    // of the value or a double lookup of the entry.
102
178
    map->addCopy(Http::LowerCaseString(std::string(p.first)), std::string(p.second));
103
178
  }
104
116
  return map;
105
116
}
106

            
107
286
template <typename P> static uint32_t headerSize(const P& p) { return p ? p->size() : 0; }
108

            
109
} // namespace
110

            
111
// Test support.
112

            
113
1179
size_t Buffer::size() const {
114
1179
  if (const_buffer_instance_) {
115
746
    return const_buffer_instance_->length();
116
746
  }
117
433
  return proxy_wasm::BufferBase::size();
118
1179
}
119

            
120
WasmResult Buffer::copyTo(WasmBase* wasm, size_t start, size_t length, uint64_t ptr_ptr,
121
376
                          uint64_t size_ptr) const {
122
376
  if (const_buffer_instance_) {
123
    // Validate that the requested range is within bounds before allocating.
124
190
    if (start + length > const_buffer_instance_->length()) {
125
1
      return WasmResult::InvalidMemoryAccess;
126
1
    }
127
189
    uint64_t pointer;
128
189
    auto p = wasm->allocMemory(length, &pointer);
129
189
    if (!p) {
130
      return WasmResult::InvalidMemoryAccess;
131
    }
132
189
    const_buffer_instance_->copyOut(start, length, p);
133
189
    if (!wasm->wasm_vm()->setWord(ptr_ptr, Word(pointer))) {
134
1
      return WasmResult::InvalidMemoryAccess;
135
1
    }
136
188
    if (!wasm->wasm_vm()->setWord(size_ptr, Word(length))) {
137
1
      return WasmResult::InvalidMemoryAccess;
138
1
    }
139
187
    return WasmResult::Ok;
140
188
  }
141
186
  return proxy_wasm::BufferBase::copyTo(wasm, start, length, ptr_ptr, size_ptr);
142
376
}
143

            
144
82
WasmResult Buffer::copyFrom(size_t start, size_t length, std::string_view data) {
145
82
  if (buffer_instance_) {
146
80
    if (start == 0) {
147
54
      if (length != 0) {
148
45
        buffer_instance_->drain(length);
149
45
      }
150
54
      buffer_instance_->prepend(toAbslStringView(data));
151
54
      return WasmResult::Ok;
152
71
    } else if (start >= buffer_instance_->length()) {
153
25
      buffer_instance_->add(toAbslStringView(data));
154
25
      return WasmResult::Ok;
155
26
    } else {
156
1
      return WasmResult::BadArgument;
157
1
    }
158
80
  }
159
2
  if (const_buffer_instance_) { // This buffer is immutable.
160
1
    return WasmResult::BadArgument;
161
1
  }
162
1
  return proxy_wasm::BufferBase::copyFrom(start, length, data);
163
2
}
164

            
165
16
Context::Context() = default;
166
1051
Context::Context(Wasm* wasm) : ContextBase(wasm) {
167
1051
  if (wasm != nullptr) {
168
1051
    abi_version_ = wasm->abi_version_;
169
1051
  }
170
1051
}
171
741
Context::Context(Wasm* wasm, const PluginSharedPtr& plugin) : ContextBase(wasm, plugin) {
172
741
  if (wasm != nullptr) {
173
741
    abi_version_ = wasm->abi_version_;
174
741
  }
175
741
  root_local_info_ = &this->plugin()->localInfo();
176
741
  allow_on_headers_stop_iteration_ = PROTOBUF_GET_WRAPPED_OR_DEFAULT(
177
741
      this->plugin()->wasmConfig().config(), allow_on_headers_stop_iteration,
178
741
      DefaultAllowOnHeadersStopIteration);
179
741
}
180
Context::Context(Wasm* wasm, uint32_t root_context_id, PluginHandleSharedPtr plugin_handle)
181
331
    : ContextBase(wasm, root_context_id, plugin_handle), plugin_handle_(plugin_handle) {
182
331
  if (wasm != nullptr) {
183
316
    abi_version_ = wasm->abi_version_;
184
316
  }
185
331
  allow_on_headers_stop_iteration_ = PROTOBUF_GET_WRAPPED_OR_DEFAULT(
186
331
      plugin()->wasmConfig().config(), allow_on_headers_stop_iteration,
187
331
      DefaultAllowOnHeadersStopIteration);
188
331
}
189

            
190
7118
WasmBase* Context::wasm() const { return wasm_; }
191
1509
Wasm* Context::envoyWasm() const { return static_cast<Wasm*>(wasm_); }
192
1912
Plugin* Context::plugin() const { return static_cast<Plugin*>(plugin_.get()); }
193
130
Context* Context::rootContext() const { return static_cast<Context*>(root_context()); }
194
230
Upstream::ClusterManager& Context::clusterManager() const { return envoyWasm()->clusterManager(); }
195

            
196
2
void Context::error(std::string_view message) { ENVOY_LOG(trace, message); }
197

            
198
2
uint64_t Context::getCurrentTimeNanoseconds() {
199
2
  return std::chrono::duration_cast<std::chrono::nanoseconds>(
200
2
             envoyWasm()->time_source_.systemTime().time_since_epoch())
201
2
      .count();
202
2
}
203

            
204
2
uint64_t Context::getMonotonicTimeNanoseconds() {
205
2
  return std::chrono::duration_cast<std::chrono::nanoseconds>(
206
2
             envoyWasm()->time_source_.monotonicTime().time_since_epoch())
207
2
      .count();
208
2
}
209

            
210
36
void Context::onCloseTCP() {
211
36
  if (tcp_connection_closed_ || !in_vm_context_created_) {
212
9
    return;
213
9
  }
214
27
  tcp_connection_closed_ = true;
215
27
  onDone();
216
27
  onLog();
217
27
  onDelete();
218
27
}
219

            
220
void Context::onResolveDns(uint32_t token, Envoy::Network::DnsResolver::ResolutionStatus status,
221
7
                           std::list<Envoy::Network::DnsResponse>&& response) {
222
7
  proxy_wasm::DeferAfterCallActions actions(this);
223
7
  if (envoyWasm()->isFailed() || !envoyWasm()->on_resolve_dns_) {
224
1
    return;
225
1
  }
226
6
  if (status != Network::DnsResolver::ResolutionStatus::Completed) {
227
2
    buffer_.set("");
228
2
    envoyWasm()->on_resolve_dns_(this, id_, token, 0);
229
2
    return;
230
2
  }
231
  // buffer format:
232
  //    4 bytes number of entries = N
233
  //    N * 4 bytes TTL for each entry
234
  //    N * null-terminated addresses
235
4
  uint32_t s = 4; // length
236
4
  for (auto& e : response) {
237
4
    s += 4;                                                // for TTL
238
4
    s += e.addrInfo().address_->asStringView().size() + 1; // null terminated.
239
4
  }
240
4
  auto buffer = std::unique_ptr<char[]>(new char[s]);
241
4
  char* b = buffer.get();
242
4
  uint32_t n = response.size();
243
4
  safeMemcpyUnsafeDst(b, &n);
244
4
  b += sizeof(uint32_t);
245
4
  for (auto& e : response) {
246
4
    uint32_t ttl = e.addrInfo().ttl_.count();
247
4
    safeMemcpyUnsafeDst(b, &ttl);
248
4
    b += sizeof(uint32_t);
249
4
  };
250
4
  for (auto& e : response) {
251
4
    memcpy(b, e.addrInfo().address_->asStringView().data(), // NOLINT(safe-memcpy)
252
4
           e.addrInfo().address_->asStringView().size());
253
4
    b += e.addrInfo().address_->asStringView().size();
254
4
    *b++ = 0;
255
4
  };
256
4
  buffer_.set(std::move(buffer), s);
257
4
  envoyWasm()->on_resolve_dns_(this, id_, token, s);
258
4
}
259

            
260
8
template <typename I> inline uint32_t align(uint32_t i) {
261
8
  return (i + sizeof(I) - 1) & ~(sizeof(I) - 1);
262
8
}
263

            
264
8
template <typename I> inline char* align(char* p) {
265
8
  return reinterpret_cast<char*>((reinterpret_cast<uintptr_t>(p) + sizeof(I) - 1) &
266
8
                                 ~(sizeof(I) - 1));
267
8
}
268

            
269
7
void Context::onStatsUpdate(Envoy::Stats::MetricSnapshot& snapshot) {
270
7
  proxy_wasm::DeferAfterCallActions actions(this);
271
7
  if (envoyWasm()->isFailed() || !envoyWasm()->on_stats_update_) {
272
1
    return;
273
1
  }
274
  // buffer format:
275
  //  uint32 size of block of this type
276
  //  uint32 type
277
  //  uint32 count
278
  //    uint32 length of name
279
  //    name
280
  //    8 byte alignment padding
281
  //    8 bytes of absolute value
282
  //    8 bytes of delta  (if appropriate, e.g. for counters)
283
  //  uint32 size of block of this type
284

            
285
6
  uint32_t counter_block_size = 3 * sizeof(uint32_t); // type of stat
286
6
  uint32_t num_counters = snapshot.counters().size();
287
6
  uint32_t counter_type = 1;
288

            
289
6
  uint32_t gauge_block_size = 3 * sizeof(uint32_t); // type of stat
290
6
  uint32_t num_gauges = snapshot.gauges().size();
291
6
  uint32_t gauge_type = 2;
292

            
293
6
  uint32_t n = 0;
294
6
  uint64_t v = 0;
295

            
296
8
  for (const auto& counter : snapshot.counters()) {
297
4
    if (counter.counter_.get().used()) {
298
4
      counter_block_size += sizeof(uint32_t) + counter.counter_.get().name().size();
299
4
      counter_block_size = align<uint64_t>(counter_block_size + 2 * sizeof(uint64_t));
300
4
    }
301
4
  }
302

            
303
8
  for (const auto& gauge : snapshot.gauges()) {
304
4
    if (gauge.get().used()) {
305
4
      gauge_block_size += sizeof(uint32_t) + gauge.get().name().size();
306
4
      gauge_block_size += align<uint64_t>(gauge_block_size + sizeof(uint64_t));
307
4
    }
308
4
  }
309

            
310
6
  auto buffer = std::unique_ptr<char[]>(new char[counter_block_size + gauge_block_size]);
311
6
  char* b = buffer.get();
312

            
313
6
  safeMemcpyUnsafeDst(b, &counter_block_size);
314
6
  b += sizeof(uint32_t);
315
6
  safeMemcpyUnsafeDst(b, &counter_type);
316
6
  b += sizeof(uint32_t);
317
6
  safeMemcpyUnsafeDst(b, &num_counters);
318
6
  b += sizeof(uint32_t);
319

            
320
8
  for (const auto& counter : snapshot.counters()) {
321
4
    if (counter.counter_.get().used()) {
322
4
      n = counter.counter_.get().name().size();
323
4
      safeMemcpyUnsafeDst(b, &n);
324
4
      b += sizeof(uint32_t);
325
4
      memcpy(b, counter.counter_.get().name().data(), // NOLINT(safe-memcpy)
326
4
             counter.counter_.get().name().size());
327
4
      b = align<uint64_t>(b + counter.counter_.get().name().size());
328
4
      v = counter.counter_.get().value();
329
4
      safeMemcpyUnsafeDst(b, &v);
330
4
      b += sizeof(uint64_t);
331
4
      v = counter.delta_;
332
4
      safeMemcpyUnsafeDst(b, &v);
333
4
      b += sizeof(uint64_t);
334
4
    }
335
4
  }
336

            
337
6
  safeMemcpyUnsafeDst(b, &gauge_block_size);
338
6
  b += sizeof(uint32_t);
339
6
  safeMemcpyUnsafeDst(b, &gauge_type);
340
6
  b += sizeof(uint32_t);
341
6
  safeMemcpyUnsafeDst(b, &num_gauges);
342
6
  b += sizeof(uint32_t);
343

            
344
8
  for (const auto& gauge : snapshot.gauges()) {
345
4
    if (gauge.get().used()) {
346
4
      n = gauge.get().name().size();
347
4
      safeMemcpyUnsafeDst(b, &n);
348
4
      b += sizeof(uint32_t);
349
4
      memcpy(b, gauge.get().name().data(), gauge.get().name().size()); // NOLINT(safe-memcpy)
350
4
      b = align<uint64_t>(b + gauge.get().name().size());
351
4
      v = gauge.get().value();
352
4
      safeMemcpyUnsafeDst(b, &v);
353
4
      b += sizeof(uint64_t);
354
4
    }
355
4
  }
356
6
  buffer_.set(std::move(buffer), counter_block_size + gauge_block_size);
357
6
  envoyWasm()->on_stats_update_(this, id_, counter_block_size + gauge_block_size);
358
6
}
359

            
360
// Native serializer carrying over bit representation from CEL value to the extension.
361
// This implementation assumes that the value type is static and known to the consumer.
362
676
WasmResult serializeValue(Filters::Common::Expr::CelValue value, std::string* result) {
363
676
  using Filters::Common::Expr::CelValue;
364
676
  int64_t out_int64;
365
676
  uint64_t out_uint64;
366
676
  double out_double;
367
676
  bool out_bool;
368
676
  const Protobuf::Message* out_message;
369
676
  switch (value.type()) {
370
632
  case CelValue::Type::kString:
371
632
    result->assign(value.StringOrDie().value().data(), value.StringOrDie().value().size());
372
632
    return WasmResult::Ok;
373
8
  case CelValue::Type::kBytes:
374
8
    result->assign(value.BytesOrDie().value().data(), value.BytesOrDie().value().size());
375
8
    return WasmResult::Ok;
376
4
  case CelValue::Type::kInt64:
377
4
    out_int64 = value.Int64OrDie();
378
4
    result->assign(reinterpret_cast<const char*>(&out_int64), sizeof(int64_t));
379
4
    return WasmResult::Ok;
380
4
  case CelValue::Type::kUint64:
381
4
    out_uint64 = value.Uint64OrDie();
382
4
    result->assign(reinterpret_cast<const char*>(&out_uint64), sizeof(uint64_t));
383
4
    return WasmResult::Ok;
384
2
  case CelValue::Type::kDouble:
385
2
    out_double = value.DoubleOrDie();
386
2
    result->assign(reinterpret_cast<const char*>(&out_double), sizeof(double));
387
2
    return WasmResult::Ok;
388
2
  case CelValue::Type::kBool:
389
2
    out_bool = value.BoolOrDie();
390
2
    result->assign(reinterpret_cast<const char*>(&out_bool), sizeof(bool));
391
2
    return WasmResult::Ok;
392
3
  case CelValue::Type::kDuration:
393
    // Warning: loss of precision to nanoseconds
394
3
    out_int64 = absl::ToInt64Nanoseconds(value.DurationOrDie());
395
3
    result->assign(reinterpret_cast<const char*>(&out_int64), sizeof(int64_t));
396
3
    return WasmResult::Ok;
397
2
  case CelValue::Type::kTimestamp:
398
    // Warning: loss of precision to nanoseconds
399
2
    out_int64 = absl::ToUnixNanos(value.TimestampOrDie());
400
2
    result->assign(reinterpret_cast<const char*>(&out_int64), sizeof(int64_t));
401
2
    return WasmResult::Ok;
402
4
  case CelValue::Type::kMessage:
403
4
    out_message = value.MessageOrDie();
404
4
    result->clear();
405
4
    if (!out_message || out_message->SerializeToString(result)) {
406
4
      return WasmResult::Ok;
407
4
    }
408
    return WasmResult::SerializationFailure;
409
6
  case CelValue::Type::kMap: {
410
6
    const auto& map = *value.MapOrDie();
411
6
    auto keys_list = map.ListKeys();
412
6
    if (!keys_list.ok()) {
413
1
      return WasmResult::SerializationFailure;
414
1
    }
415
5
    const auto& keys = *keys_list.value();
416
5
    std::vector<std::pair<std::string, std::string>> pairs(map.size(), std::make_pair("", ""));
417
8
    for (auto i = 0; i < map.size(); i++) {
418
5
      if (serializeValue(keys[i], &pairs[i].first) != WasmResult::Ok) {
419
1
        return WasmResult::SerializationFailure;
420
1
      }
421
4
      if (serializeValue(map[keys[i]].value(), &pairs[i].second) != WasmResult::Ok) {
422
1
        return WasmResult::SerializationFailure;
423
1
      }
424
4
    }
425
3
    auto size = proxy_wasm::PairsUtil::pairsSize(pairs);
426
    // prevent string inlining which violates byte alignment
427
3
    result->resize(std::max(size, static_cast<size_t>(30)));
428
3
    if (!proxy_wasm::PairsUtil::marshalPairs(pairs, result->data(), size)) {
429
      return WasmResult::SerializationFailure;
430
    }
431
3
    result->resize(size);
432
3
    return WasmResult::Ok;
433
3
  }
434
4
  case CelValue::Type::kList: {
435
4
    const auto& list = *value.ListOrDie();
436
4
    std::vector<std::pair<std::string, std::string>> pairs(list.size(), std::make_pair("", ""));
437
9
    for (auto i = 0; i < list.size(); i++) {
438
6
      if (serializeValue(list[i], &pairs[i].first) != WasmResult::Ok) {
439
1
        return WasmResult::SerializationFailure;
440
1
      }
441
6
    }
442
3
    auto size = proxy_wasm::PairsUtil::pairsSize(pairs);
443
    // prevent string inlining which violates byte alignment
444
3
    if (size < 30) {
445
1
      result->reserve(30);
446
1
    }
447
3
    result->resize(size);
448
3
    if (!proxy_wasm::PairsUtil::marshalPairs(pairs, result->data(), size)) {
449
      return WasmResult::SerializationFailure;
450
    }
451
3
    return WasmResult::Ok;
452
3
  }
453
5
  default:
454
5
    break;
455
676
  }
456
5
  return WasmResult::SerializationFailure;
457
676
}
458

            
459
#define PROPERTY_TOKENS(_f) _f(PLUGIN_NAME) _f(PLUGIN_ROOT_ID) _f(PLUGIN_VM_ID) _f(CONNECTION_ID)
460

            
461
392
static inline std::string downCase(std::string s) {
462
4900
  std::transform(s.begin(), s.end(), s.begin(), [](unsigned char c) { return std::tolower(c); });
463
392
  return s;
464
392
}
465

            
466
#define _DECLARE(_t) _t,
467
enum class PropertyToken { PROPERTY_TOKENS(_DECLARE) };
468
#undef _DECLARE
469

            
470
#define _PAIR(_t) {downCase(#_t), PropertyToken::_t},
471
static absl::flat_hash_map<std::string, PropertyToken> property_tokens = {PROPERTY_TOKENS(_PAIR)};
472
#undef _PAIR
473

            
474
absl::optional<google::api::expr::runtime::CelValue>
475
8
Context::FindValue(absl::string_view name, Protobuf::Arena* arena) const {
476
8
  return findValue(name, arena, false);
477
8
}
478

            
479
absl::optional<google::api::expr::runtime::CelValue>
480
739
Context::findValue(absl::string_view name, Protobuf::Arena* arena, bool last) const {
481
739
  using google::api::expr::runtime::CelProtoWrapper;
482
739
  using google::api::expr::runtime::CelValue;
483

            
484
739
  const StreamInfo::StreamInfo* info = getConstRequestStreamInfo();
485
  // In order to delegate to the StreamActivation method, we have to set the
486
  // context properties to match the Wasm context properties in all callbacks
487
  // (e.g. onLog or onEncodeHeaders) for the duration of the call.
488
739
  if (root_local_info_) {
489
648
    local_info_ = root_local_info_;
490
652
  } else if (plugin_) {
491
87
    local_info_ = &plugin()->localInfo();
492
87
  }
493
739
  activation_info_ = info;
494
739
  activation_request_headers_ = request_headers_ ? request_headers_ : access_log_request_headers_;
495
739
  activation_response_headers_ =
496
739
      response_headers_ ? response_headers_ : access_log_response_headers_;
497
739
  activation_response_trailers_ =
498
739
      response_trailers_ ? response_trailers_ : access_log_response_trailers_;
499
739
  auto value = StreamActivation::FindValue(name, arena);
500
739
  resetActivation();
501
739
  if (value) {
502
56
    return value;
503
56
  }
504

            
505
  // Convert into a dense token to enable a jump table implementation.
506
683
  auto part_token = property_tokens.find(name);
507
683
  if (part_token == property_tokens.end()) {
508
105
    if (info) {
509
38
      std::string key = absl::StrCat(CelStateKeyPrefix, name);
510
38
      const CelState* state = info->filterState().getDataReadOnly<CelState>(key);
511
38
      if (state == nullptr) {
512
6
        if (info->upstreamInfo().has_value() &&
513
6
            info->upstreamInfo().value().get().upstreamFilterState() != nullptr) {
514
          state =
515
              info->upstreamInfo().value().get().upstreamFilterState()->getDataReadOnly<CelState>(
516
                  key);
517
        }
518
6
      }
519
38
      if (state != nullptr) {
520
32
        return state->exprValue(arena, last);
521
32
      }
522
38
    }
523
73
    return {};
524
105
  }
525

            
526
578
  switch (part_token->second) {
527
6
  case PropertyToken::CONNECTION_ID: {
528
6
    auto conn = getConnection();
529
6
    if (conn) {
530
2
      return CelValue::CreateUint64(conn->id());
531
2
    }
532
4
    break;
533
6
  }
534
7
  case PropertyToken::PLUGIN_NAME:
535
7
    if (plugin_) {
536
6
      return CelValue::CreateStringView(plugin()->name_);
537
6
    }
538
1
    break;
539
560
  case PropertyToken::PLUGIN_ROOT_ID:
540
559
    return CelValue::CreateStringView(toAbslStringView(root_id()));
541
6
  case PropertyToken::PLUGIN_VM_ID:
542
6
    return CelValue::CreateStringView(toAbslStringView(envoyWasm()->vm_id()));
543
578
  }
544
5
  return {};
545
578
}
546

            
547
731
WasmResult Context::getProperty(std::string_view path, std::string* result) {
548
731
  using google::api::expr::runtime::CelValue;
549

            
550
731
  bool first = true;
551
731
  CelValue value;
552
731
  Protobuf::Arena arena;
553

            
554
731
  size_t start = 0;
555
1530
  while (true) {
556
1530
    if (start >= path.size()) {
557
651
      break;
558
651
    }
559

            
560
879
    size_t end = path.find('\0', start);
561
879
    if (end == absl::string_view::npos) {
562
382
      end = start + path.size();
563
382
    }
564
879
    auto part = path.substr(start, end - start);
565
879
    start = end + 1;
566

            
567
879
    if (first) {
568
      // top-level identifier
569
731
      first = false;
570
731
      auto top_value = findValue(toAbslStringView(part), &arena, start >= path.size());
571
731
      if (!top_value.has_value()) {
572
74
        return WasmResult::NotFound;
573
74
      }
574
657
      value = top_value.value();
575
689
    } else if (value.IsMap()) {
576
91
      auto& map = *value.MapOrDie();
577
91
      auto field = map[CelValue::CreateStringView(toAbslStringView(part))];
578
91
      if (!field.has_value()) {
579
2
        return WasmResult::NotFound;
580
2
      }
581
89
      value = field.value();
582
89
    } else if (value.IsMessage()) {
583
51
      auto msg = value.MessageOrDie();
584
51
      if (msg == nullptr) {
585
        return WasmResult::NotFound;
586
      }
587
51
      const Protobuf::Descriptor* desc = msg->GetDescriptor();
588
51
      const Protobuf::FieldDescriptor* field_desc = desc->FindFieldByName(std::string(part));
589
51
      if (field_desc == nullptr) {
590
        return WasmResult::NotFound;
591
      }
592
51
      if (field_desc->is_map()) {
593
14
        value = CelValue::CreateMap(
594
14
            Protobuf::Arena::Create<google::api::expr::runtime::FieldBackedMapImpl>(
595
14
                &arena, msg, field_desc, &arena));
596
39
      } else if (field_desc->is_repeated()) {
597
2
        value = CelValue::CreateList(
598
2
            Protobuf::Arena::Create<google::api::expr::runtime::FieldBackedListImpl>(
599
2
                &arena, msg, field_desc, &arena));
600
35
      } else {
601
35
        auto status =
602
35
            google::api::expr::runtime::CreateValueFromSingleField(msg, field_desc, &arena, &value);
603
35
        if (!status.ok()) {
604
          return WasmResult::InternalFailure;
605
        }
606
35
      }
607
51
    } else if (value.IsList()) {
608
6
      auto& list = *value.ListOrDie();
609
6
      int idx = 0;
610
6
      if (!absl::SimpleAtoi(toAbslStringView(part), &idx)) {
611
2
        return WasmResult::NotFound;
612
2
      }
613
4
      if (idx < 0 || idx >= list.size()) {
614
2
        return WasmResult::NotFound;
615
2
      }
616
2
      value = list[idx];
617
2
    } else {
618
      return WasmResult::NotFound;
619
    }
620
879
  }
621

            
622
651
  return serializeValue(value, result);
623
731
}
624

            
625
// Header/Trailer/Metadata Maps.
626
297
Http::HeaderMap* Context::getMap(WasmHeaderMapType type) {
627
297
  switch (type) {
628
120
  case WasmHeaderMapType::RequestHeaders:
629
120
    return request_headers_;
630
10
  case WasmHeaderMapType::RequestTrailers:
631
10
    if (request_trailers_ == nullptr && request_body_buffer_ && end_of_stream_ &&
632
10
        decoder_callbacks_) {
633
6
      request_trailers_ = &decoder_callbacks_->addDecodedTrailers();
634
6
    }
635
10
    return request_trailers_;
636
82
  case WasmHeaderMapType::ResponseHeaders:
637
82
    return response_headers_;
638
59
  case WasmHeaderMapType::ResponseTrailers:
639
59
    if (response_trailers_ == nullptr && response_body_buffer_ && end_of_stream_ &&
640
59
        encoder_callbacks_) {
641
3
      response_trailers_ = &encoder_callbacks_->addEncodedTrailers();
642
3
    }
643
59
    return response_trailers_;
644
26
  default:
645
26
    return nullptr;
646
297
  }
647
297
}
648

            
649
294
const Http::HeaderMap* Context::getConstMap(WasmHeaderMapType type) {
650
294
  switch (type) {
651
156
  case WasmHeaderMapType::RequestHeaders:
652
156
    if (access_log_phase_) {
653
18
      return access_log_request_headers_;
654
18
    }
655
138
    return request_headers_;
656
10
  case WasmHeaderMapType::RequestTrailers:
657
10
    if (access_log_phase_) {
658
8
      return nullptr;
659
8
    }
660
2
    return request_trailers_;
661
42
  case WasmHeaderMapType::ResponseHeaders:
662
42
    if (access_log_phase_) {
663
19
      return access_log_response_headers_;
664
19
    }
665
23
    return response_headers_;
666
12
  case WasmHeaderMapType::ResponseTrailers:
667
12
    if (access_log_phase_) {
668
8
      return access_log_response_trailers_;
669
8
    }
670
4
    return response_trailers_;
671
25
  case WasmHeaderMapType::GrpcReceiveInitialMetadata:
672
25
    return rootContext()->grpc_receive_initial_metadata_.get();
673
10
  case WasmHeaderMapType::GrpcReceiveTrailingMetadata:
674
10
    return rootContext()->grpc_receive_trailing_metadata_.get();
675
20
  case WasmHeaderMapType::HttpCallResponseHeaders: {
676
20
    Envoy::Http::ResponseMessagePtr* response = rootContext()->http_call_response_;
677
20
    if (response) {
678
4
      return &(*response)->headers();
679
4
    }
680
16
    return nullptr;
681
20
  }
682
19
  case WasmHeaderMapType::HttpCallResponseTrailers: {
683
19
    Envoy::Http::ResponseMessagePtr* response = rootContext()->http_call_response_;
684
19
    if (response) {
685
3
      return (*response)->trailers();
686
3
    }
687
16
    return nullptr;
688
19
  }
689
294
  }
690
  IS_ENVOY_BUG("unexpected");
691
  return nullptr;
692
294
}
693

            
694
WasmResult Context::addHeaderMapValue(WasmHeaderMapType type, std::string_view key,
695
124
                                      std::string_view value) {
696
124
  auto map = getMap(type);
697
124
  if (!map) {
698
76
    return WasmResult::BadArgument;
699
76
  }
700
48
  const Http::LowerCaseString lower_key{std::string(key)};
701
48
  map->addCopy(lower_key, std::string(value));
702
48
  onHeadersModified(type);
703
48
  return WasmResult::Ok;
704
124
}
705

            
706
WasmResult Context::getHeaderMapValue(WasmHeaderMapType type, std::string_view key,
707
288
                                      std::string_view* value) {
708
288
  auto map = getConstMap(type);
709
288
  if (!map) {
710
53
    if (access_log_phase_) {
711
      // Maps might point to nullptr in the access log phase.
712
19
      if (envoyWasm()->abiVersion() == proxy_wasm::AbiVersion::ProxyWasm_0_1_0) {
713
        *value = "";
714
        return WasmResult::Ok;
715
19
      } else {
716
19
        return WasmResult::NotFound;
717
19
      }
718
19
    }
719
    // Requested map type is not currently available.
720
34
    return WasmResult::BadArgument;
721
53
  }
722
235
  const Http::LowerCaseString lower_key{std::string(key)};
723
235
  const auto entry = map->get(lower_key);
724
235
  if (entry.empty()) {
725
68
    if (envoyWasm()->abiVersion() == proxy_wasm::AbiVersion::ProxyWasm_0_1_0) {
726
1
      *value = "";
727
1
      return WasmResult::Ok;
728
68
    } else {
729
67
      return WasmResult::NotFound;
730
67
    }
731
68
  }
732
  // TODO(kyessenov, PiotrSikora): This needs to either return a concatenated list of values, or
733
  // the ABI needs to be changed to return multiple values. This is a potential security issue.
734
167
  *value = toStdStringView(entry[0]->value().getStringView());
735
167
  return WasmResult::Ok;
736
235
}
737

            
738
6
Pairs headerMapToPairs(const Http::HeaderMap* map) {
739
6
  if (!map) {
740
3
    return {};
741
3
  }
742
3
  Pairs pairs;
743
3
  pairs.reserve(map->size());
744
3
  map->iterate([&pairs](const Http::HeaderEntry& header) -> Http::HeaderMap::Iterate {
745
3
    pairs.push_back(std::make_pair(toStdStringView(header.key().getStringView()),
746
3
                                   toStdStringView(header.value().getStringView())));
747
3
    return Http::HeaderMap::Iterate::Continue;
748
3
  });
749
3
  return pairs;
750
6
}
751

            
752
6
WasmResult Context::getHeaderMapPairs(WasmHeaderMapType type, Pairs* result) {
753
6
  *result = headerMapToPairs(getConstMap(type));
754
6
  return WasmResult::Ok;
755
6
}
756

            
757
7
WasmResult Context::setHeaderMapPairs(WasmHeaderMapType type, const Pairs& pairs) {
758
7
  auto map = getMap(type);
759
7
  if (!map) {
760
4
    return WasmResult::BadArgument;
761
4
  }
762
3
  std::vector<std::string> keys;
763
6
  map->iterate([&keys](const Http::HeaderEntry& header) -> Http::HeaderMap::Iterate {
764
6
    keys.push_back(std::string(header.key().getStringView()));
765
6
    return Http::HeaderMap::Iterate::Continue;
766
6
  });
767
6
  for (auto& k : keys) {
768
6
    const Http::LowerCaseString lower_key{k};
769
6
    map->remove(lower_key);
770
6
  }
771
3
  for (auto& p : pairs) {
772
3
    const Http::LowerCaseString lower_key{std::string(p.first)};
773
3
    map->addCopy(lower_key, std::string(p.second));
774
3
  }
775
3
  onHeadersModified(type);
776
3
  return WasmResult::Ok;
777
7
}
778

            
779
31
WasmResult Context::removeHeaderMapValue(WasmHeaderMapType type, std::string_view key) {
780
31
  auto map = getMap(type);
781
31
  if (!map) {
782
26
    return WasmResult::BadArgument;
783
26
  }
784
5
  const Http::LowerCaseString lower_key{std::string(key)};
785
5
  map->remove(lower_key);
786
5
  onHeadersModified(type);
787
5
  return WasmResult::Ok;
788
31
}
789

            
790
WasmResult Context::replaceHeaderMapValue(WasmHeaderMapType type, std::string_view key,
791
83
                                          std::string_view value) {
792
83
  auto map = getMap(type);
793
83
  if (!map) {
794
2
    return WasmResult::BadArgument;
795
2
  }
796
81
  const Http::LowerCaseString lower_key{std::string(key)};
797
81
  map->setCopy(lower_key, toAbslStringView(value));
798
81
  onHeadersModified(type);
799
81
  return WasmResult::Ok;
800
83
}
801

            
802
52
WasmResult Context::getHeaderMapSize(WasmHeaderMapType type, uint32_t* result) {
803
52
  auto map = getMap(type);
804
52
  if (!map) {
805
26
    return WasmResult::BadArgument;
806
26
  }
807
26
  *result = map->byteSize();
808
26
  return WasmResult::Ok;
809
52
}
810

            
811
// Buffer
812

            
813
839
BufferInterface* Context::getBuffer(WasmBufferType type) {
814
839
  Envoy::Http::ResponseMessagePtr* response = nullptr;
815
839
  switch (type) {
816
8
  case WasmBufferType::CallData:
817
    // Set before the call.
818
8
    return &buffer_;
819
173
  case WasmBufferType::VmConfiguration:
820
173
    return buffer_.set(envoyWasm()->vm_configuration());
821
42
  case WasmBufferType::PluginConfiguration:
822
42
    if (temp_plugin_) {
823
32
      return buffer_.set(temp_plugin_->plugin_configuration_);
824
32
    }
825
10
    return nullptr;
826
371
  case WasmBufferType::HttpRequestBody:
827
371
    if (buffering_request_body_ && decoder_callbacks_) {
828
      // We need the mutable version, so capture it using a callback.
829
      // TODO: consider adding a mutableDecodingBuffer() interface.
830
118
      ::Envoy::Buffer::Instance* buffer_instance{};
831
118
      decoder_callbacks_->modifyDecodingBuffer(
832
118
          [&buffer_instance](::Envoy::Buffer::Instance& buffer) { buffer_instance = &buffer; });
833
118
      return buffer_.set(buffer_instance);
834
118
    }
835
253
    return buffer_.set(request_body_buffer_);
836
206
  case WasmBufferType::HttpResponseBody:
837
206
    if (buffering_response_body_ && encoder_callbacks_) {
838
      // TODO: consider adding a mutableDecodingBuffer() interface.
839
76
      ::Envoy::Buffer::Instance* buffer_instance{};
840
76
      encoder_callbacks_->modifyEncodingBuffer(
841
76
          [&buffer_instance](::Envoy::Buffer::Instance& buffer) { buffer_instance = &buffer; });
842
76
      return buffer_.set(buffer_instance);
843
76
    }
844
130
    return buffer_.set(response_body_buffer_);
845
6
  case WasmBufferType::NetworkDownstreamData:
846
6
    return buffer_.set(network_downstream_data_buffer_);
847
3
  case WasmBufferType::NetworkUpstreamData:
848
3
    return buffer_.set(network_upstream_data_buffer_);
849
5
  case WasmBufferType::HttpCallResponseBody:
850
5
    response = rootContext()->http_call_response_;
851
5
    if (response) {
852
3
      auto& body = (*response)->body();
853
3
      return buffer_.set(
854
3
          std::string_view(static_cast<const char*>(body.linearize(body.length())), body.length()));
855
3
    }
856
2
    return nullptr;
857
25
  case WasmBufferType::GrpcReceiveBuffer:
858
25
    return buffer_.set(rootContext()->grpc_receive_buffer_.get());
859
  default:
860
    return nullptr;
861
839
  }
862
839
}
863

            
864
27
void Context::onDownstreamConnectionClose(CloseType close_type) {
865
27
  ContextBase::onDownstreamConnectionClose(close_type);
866
27
  downstream_closed_ = true;
867
27
  onCloseTCP();
868
27
}
869

            
870
3
void Context::onUpstreamConnectionClose(CloseType close_type) {
871
3
  ContextBase::onUpstreamConnectionClose(close_type);
872
3
  upstream_closed_ = true;
873
3
  if (downstream_closed_) {
874
    onCloseTCP();
875
  }
876
3
}
877

            
878
// Async call via HTTP
879
WasmResult Context::httpCall(std::string_view cluster, const Pairs& request_headers,
880
                             std::string_view request_body, const Pairs& request_trailers,
881
58
                             int timeout_milliseconds, uint32_t* token_ptr) {
882
58
  if (timeout_milliseconds < 0) {
883
11
    return WasmResult::BadArgument;
884
11
  }
885
47
  auto cluster_string = std::string(cluster);
886
47
  const auto thread_local_cluster = clusterManager().getThreadLocalCluster(cluster_string);
887
47
  if (thread_local_cluster == nullptr) {
888
11
    return WasmResult::BadArgument;
889
11
  }
890

            
891
36
  Http::RequestMessagePtr message(
892
36
      new Http::RequestMessageImpl(buildRequestHeaderMapFromPairs(request_headers)));
893

            
894
  // Check that we were provided certain headers.
895
36
  if (message->headers().Path() == nullptr || message->headers().Method() == nullptr ||
896
36
      message->headers().Host() == nullptr) {
897
11
    return WasmResult::BadArgument;
898
11
  }
899

            
900
25
  if (!request_body.empty()) {
901
24
    message->body().add(toAbslStringView(request_body));
902
24
    message->headers().setContentLength(request_body.size());
903
24
  }
904

            
905
25
  if (!request_trailers.empty()) {
906
15
    message->trailers(buildRequestTrailerMapFromPairs(request_trailers));
907
15
  }
908

            
909
25
  absl::optional<std::chrono::milliseconds> timeout;
910
25
  if (timeout_milliseconds > 0) {
911
25
    timeout = std::chrono::milliseconds(timeout_milliseconds);
912
25
  }
913

            
914
25
  uint32_t token = envoyWasm()->nextHttpCallId();
915
25
  auto& handler = http_request_[token];
916
25
  handler.context_ = this;
917
25
  handler.token_ = token;
918

            
919
  // set default hash policy to be based on :authority to enable consistent hash
920
25
  Http::AsyncClient::RequestOptions options;
921
25
  options.setTimeout(timeout);
922
25
  Protobuf::RepeatedPtrField<HashPolicy> hash_policy;
923
25
  hash_policy.Add()->mutable_header()->set_header_name(Http::Headers::get().Host.get());
924
25
  options.setHashPolicy(hash_policy);
925
25
  options.setSendXff(false);
926
25
  auto http_request =
927
25
      thread_local_cluster->httpAsyncClient().send(std::move(message), handler, options);
928
25
  if (!http_request) {
929
6
    http_request_.erase(token);
930
6
    return WasmResult::InternalFailure;
931
6
  }
932
19
  handler.request_ = http_request;
933
19
  *token_ptr = token;
934
19
  return WasmResult::Ok;
935
25
}
936

            
937
WasmResult Context::grpcCall(std::string_view grpc_service, std::string_view service_name,
938
                             std::string_view method_name, const Pairs& initial_metadata,
939
                             std::string_view request, std::chrono::milliseconds timeout,
940
60
                             uint32_t* token_ptr) {
941
60
  GrpcService service_proto;
942
60
  if (!service_proto.ParseFromString(grpc_service)) {
943
48
    auto cluster_name = std::string(grpc_service.substr(0, grpc_service.size()));
944
48
    const auto thread_local_cluster = clusterManager().getThreadLocalCluster(cluster_name);
945
48
    if (thread_local_cluster == nullptr) {
946
      // TODO(shikugawa): The reason to keep return status as `BadArgument` is not to force
947
      // callers to change their own codebase with ABI 0.1.x. We should treat this failure as
948
      // `BadArgument` after ABI 0.2.x will have released.
949
30
      return WasmResult::ParseFailure;
950
30
    }
951
18
    service_proto.mutable_envoy_grpc()->set_cluster_name(cluster_name);
952
18
  }
953
30
  uint32_t token = envoyWasm()->nextGrpcCallId();
954
30
  auto& handler = grpc_call_request_[token];
955
30
  handler.context_ = this;
956
30
  handler.token_ = token;
957
30
  auto client_or_error = clusterManager().grpcAsyncClientManager().getOrCreateRawAsyncClient(
958
30
      service_proto, *envoyWasm()->scope_, true /* skip_cluster_check */);
959
30
  if (!client_or_error.status().ok()) {
960
    return WasmResult::BadArgument;
961
  }
962
30
  auto grpc_client = client_or_error.value();
963
30
  grpc_initial_metadata_ = buildRequestHeaderMapFromPairs(initial_metadata);
964

            
965
  // set default hash policy to be based on :authority to enable consistent hash
966
30
  Http::AsyncClient::RequestOptions options;
967
30
  options.setTimeout(timeout);
968
30
  Protobuf::RepeatedPtrField<HashPolicy> hash_policy;
969
30
  hash_policy.Add()->mutable_header()->set_header_name(Http::Headers::get().Host.get());
970
30
  options.setHashPolicy(hash_policy);
971
30
  options.setSendXff(false);
972

            
973
30
  auto grpc_request =
974
30
      grpc_client->sendRaw(toAbslStringView(service_name), toAbslStringView(method_name),
975
30
                           std::make_unique<::Envoy::Buffer::OwnedImpl>(toAbslStringView(request)),
976
30
                           handler, Tracing::NullSpan::instance(), options);
977
30
  if (!grpc_request) {
978
5
    grpc_call_request_.erase(token);
979
5
    return WasmResult::InternalFailure;
980
5
  }
981
25
  handler.client_ = std::move(grpc_client);
982
25
  handler.request_ = grpc_request;
983
25
  *token_ptr = token;
984
25
  return WasmResult::Ok;
985
30
}
986

            
987
WasmResult Context::grpcStream(std::string_view grpc_service, std::string_view service_name,
988
                               std::string_view method_name, const Pairs& initial_metadata,
989
75
                               uint32_t* token_ptr) {
990
75
  GrpcService service_proto;
991
75
  if (!service_proto.ParseFromString(grpc_service)) {
992
55
    auto cluster_name = std::string(grpc_service.substr(0, grpc_service.size()));
993
55
    const auto thread_local_cluster = clusterManager().getThreadLocalCluster(cluster_name);
994
55
    if (thread_local_cluster == nullptr) {
995
      // TODO(shikugawa): The reason to keep return status as `BadArgument` is not to force
996
      // callers to change their own codebase with ABI 0.1.x. We should treat this failure as
997
      // `BadArgument` after ABI 0.2.x will have released.
998
25
      return WasmResult::ParseFailure;
999
25
    }
30
    service_proto.mutable_envoy_grpc()->set_cluster_name(cluster_name);
30
  }
50
  uint32_t token = envoyWasm()->nextGrpcStreamId();
50
  auto& handler = grpc_stream_[token];
50
  handler.context_ = this;
50
  handler.token_ = token;
50
  auto client_or_error = clusterManager().grpcAsyncClientManager().getOrCreateRawAsyncClient(
50
      service_proto, *envoyWasm()->scope_, true /* skip_cluster_check */);
50
  if (!client_or_error.status().ok()) {
    return WasmResult::BadArgument;
  }
50
  auto grpc_client = client_or_error.value();
50
  grpc_initial_metadata_ = buildRequestHeaderMapFromPairs(initial_metadata);
  // set default hash policy to be based on :authority to enable consistent hash
50
  Http::AsyncClient::StreamOptions options;
50
  Protobuf::RepeatedPtrField<HashPolicy> hash_policy;
50
  hash_policy.Add()->mutable_header()->set_header_name(Http::Headers::get().Host.get());
50
  options.setHashPolicy(hash_policy);
50
  options.setSendXff(false);
50
  auto grpc_stream = grpc_client->startRaw(toAbslStringView(service_name),
50
                                           toAbslStringView(method_name), handler, options);
50
  if (!grpc_stream) {
25
    grpc_stream_.erase(token);
25
    return WasmResult::InternalFailure;
25
  }
25
  handler.client_ = std::move(grpc_client);
25
  handler.stream_ = grpc_stream;
25
  *token_ptr = token;
25
  return WasmResult::Ok;
50
}
// NB: this is currently called inline, so the token is known to be that of the currently
// executing grpcCall or grpcStream.
void Context::onGrpcCreateInitialMetadata(uint32_t /* token */,
30
                                          Http::RequestHeaderMap& initial_metadata) {
30
  if (grpc_initial_metadata_) {
30
    Http::HeaderMapImpl::copyFrom(initial_metadata, *grpc_initial_metadata_);
30
    grpc_initial_metadata_.reset();
30
  }
30
}
// StreamInfo
745
const StreamInfo::StreamInfo* Context::getConstRequestStreamInfo() const {
745
  if (encoder_callbacks_) {
88
    return &encoder_callbacks_->streamInfo();
657
  } else if (decoder_callbacks_) {
1
    return &decoder_callbacks_->streamInfo();
656
  } else if (access_log_stream_info_) {
1
    return access_log_stream_info_;
655
  } else if (network_read_filter_callbacks_) {
1
    return &network_read_filter_callbacks_->connection().streamInfo();
654
  } else if (network_write_filter_callbacks_) {
1
    return &network_write_filter_callbacks_->connection().streamInfo();
1
  }
653
  return nullptr;
745
}
25
StreamInfo::StreamInfo* Context::getRequestStreamInfo() const {
25
  if (encoder_callbacks_) {
16
    return &encoder_callbacks_->streamInfo();
22
  } else if (decoder_callbacks_) {
1
    return &decoder_callbacks_->streamInfo();
8
  } else if (network_read_filter_callbacks_) {
4
    return &network_read_filter_callbacks_->connection().streamInfo();
7
  } else if (network_write_filter_callbacks_) {
1
    return &network_write_filter_callbacks_->connection().streamInfo();
1
  }
3
  return nullptr;
25
}
11
const Network::Connection* Context::getConnection() const {
11
  if (encoder_callbacks_) {
3
    return encoder_callbacks_->connection().ptr();
8
  } else if (decoder_callbacks_) {
1
    return decoder_callbacks_->connection().ptr();
7
  } else if (network_read_filter_callbacks_) {
1
    return &network_read_filter_callbacks_->connection();
6
  } else if (network_write_filter_callbacks_) {
1
    return &network_write_filter_callbacks_->connection();
1
  }
5
  return nullptr;
11
}
17
WasmResult Context::setProperty(std::string_view path, std::string_view value) {
17
  auto* stream_info = getRequestStreamInfo();
17
  if (!stream_info) {
2
    return WasmResult::NotFound;
2
  }
15
  std::string key;
15
  absl::StrAppend(&key, CelStateKeyPrefix, toAbslStringView(path));
15
  CelState* state = stream_info->filterState()->getDataMutable<CelState>(key);
15
  if (state == nullptr) {
13
    const auto& it = rootContext()->state_prototypes_.find(toAbslStringView(path));
13
    const CelStatePrototype& prototype =
13
        it == rootContext()->state_prototypes_.end()
13
            ? Filters::Common::Expr::DefaultCelStatePrototype::get()
13
            : *it->second.get(); // NOLINT
13
    auto state_ptr = std::make_unique<CelState>(prototype);
13
    state = state_ptr.get();
13
    stream_info->filterState()->setData(key, std::move(state_ptr),
13
                                        StreamInfo::FilterState::StateType::Mutable,
13
                                        prototype.life_span_);
13
  }
15
  if (!state->setValue(toAbslStringView(value))) {
2
    return WasmResult::BadArgument;
2
  }
13
  return WasmResult::Ok;
15
}
WasmResult Context::setEnvoyFilterState(std::string_view path, std::string_view value,
3
                                        StreamInfo::FilterState::LifeSpan life_span) {
3
  auto* factory =
3
      Registry::FactoryRegistry<StreamInfo::FilterState::ObjectFactory>::getFactory(path);
3
  if (!factory) {
1
    return WasmResult::NotFound;
1
  }
2
  auto object = factory->createFromBytes(value);
2
  if (!object) {
    return WasmResult::BadArgument;
  }
2
  auto* stream_info = getRequestStreamInfo();
2
  if (!stream_info) {
    return WasmResult::NotFound;
  }
2
  stream_info->filterState()->setData(path, std::move(object),
2
                                      StreamInfo::FilterState::StateType::Mutable, life_span);
2
  return WasmResult::Ok;
2
}
WasmResult
Context::declareProperty(std::string_view path,
10
                         Filters::Common::Expr::CelStatePrototypeConstPtr state_prototype) {
  // Do not delete existing schema since it can be referenced by state objects.
10
  if (state_prototypes_.find(toAbslStringView(path)) == state_prototypes_.end()) {
8
    state_prototypes_[toAbslStringView(path)] = std::move(state_prototype);
8
    return WasmResult::Ok;
8
  }
2
  return WasmResult::BadArgument;
10
}
192
WasmResult Context::log(uint32_t level, std::string_view message) {
192
  switch (static_cast<spdlog::level::level_enum>(level)) {
7
  case spdlog::level::trace:
7
    ENVOY_LOG(trace, "wasm log{}: {}", log_prefix(), message);
7
    return WasmResult::Ok;
26
  case spdlog::level::debug:
26
    ENVOY_LOG(debug, "wasm log{}: {}", log_prefix(), message);
26
    return WasmResult::Ok;
52
  case spdlog::level::info:
52
    ENVOY_LOG(info, "wasm log{}: {}", log_prefix(), message);
52
    return WasmResult::Ok;
37
  case spdlog::level::warn:
37
    ENVOY_LOG(warn, "wasm log{}: {}", log_prefix(), message);
37
    return WasmResult::Ok;
70
  case spdlog::level::err:
70
    ENVOY_LOG(error, "wasm log{}: {}", log_prefix(), message);
70
    return WasmResult::Ok;
  case spdlog::level::critical:
    ENVOY_LOG(critical, "wasm log{}: {}", log_prefix(), message);
    return WasmResult::Ok;
  case spdlog::level::off:
    PANIC("not implemented");
  case spdlog::level::n_levels:
    PANIC("not implemented");
192
  }
  PANIC_DUE_TO_CORRUPT_ENUM;
}
2
uint32_t Context::getLogLevel() {
  // Like the "log" call above, assume that spdlog level as an int
  // matches the enum in the SDK
2
  return static_cast<uint32_t>(ENVOY_LOGGER().level());
2
}
//
// Calls into the Wasm code.
//
bool Context::validateConfiguration(std::string_view configuration,
3
                                    const std::shared_ptr<PluginBase>& plugin_base) {
3
  auto plugin = std::static_pointer_cast<Plugin>(plugin_base);
3
  if (!envoyWasm()->validate_configuration_) {
1
    return true;
1
  }
2
  temp_plugin_ = plugin_base;
2
  auto result =
2
      envoyWasm()
2
          ->validate_configuration_(this, id_, static_cast<uint32_t>(configuration.size()))
2
          .u64_ != 0;
2
  temp_plugin_.reset();
2
  return result;
3
}
2
std::string_view Context::getConfiguration() {
2
  if (temp_plugin_) {
    return temp_plugin_->plugin_configuration_;
2
  } else {
2
    return envoyWasm()->vm_configuration();
2
  }
2
};
22
std::pair<uint32_t, std::string_view> Context::getStatus() {
22
  return std::make_pair(status_code_, toStdStringView(status_message_));
22
}
25
void Context::onGrpcReceiveInitialMetadataWrapper(uint32_t token, Http::HeaderMapPtr&& metadata) {
25
  grpc_receive_initial_metadata_ = std::move(metadata);
25
  onGrpcReceiveInitialMetadata(token, headerSize(grpc_receive_initial_metadata_));
25
  grpc_receive_initial_metadata_ = nullptr;
25
}
10
void Context::onGrpcReceiveTrailingMetadataWrapper(uint32_t token, Http::HeaderMapPtr&& metadata) {
10
  grpc_receive_trailing_metadata_ = std::move(metadata);
10
  onGrpcReceiveTrailingMetadata(token, headerSize(grpc_receive_trailing_metadata_));
10
  grpc_receive_trailing_metadata_ = nullptr;
10
}
WasmResult Context::defineMetric(uint32_t metric_type, std::string_view name,
34
                                 uint32_t* metric_id_ptr) {
34
  if (metric_type > static_cast<uint32_t>(MetricType::Max)) {
2
    return WasmResult::BadArgument;
2
  }
32
  auto type = static_cast<MetricType>(metric_type);
  // TODO: Consider rethinking the scoping policy as it does not help in this case.
32
  Stats::StatNameManagedStorage storage(toAbslStringView(name), envoyWasm()->scope_->symbolTable());
32
  Stats::StatName stat_name = storage.statName();
  // We prefix the given name with custom_stat_name_ so that these user-defined
  // custom metrics can be distinguished from native Envoy metrics.
32
  if (type == MetricType::Counter) {
12
    auto id = envoyWasm()->nextCounterMetricId();
12
    Stats::Counter* c = &Stats::Utility::counterFromElements(
12
        *envoyWasm()->scope_, {envoyWasm()->custom_stat_namespace_, stat_name});
12
    envoyWasm()->counters_.emplace(id, c);
12
    *metric_id_ptr = id;
12
    return WasmResult::Ok;
12
  }
20
  if (type == MetricType::Gauge) {
10
    auto id = envoyWasm()->nextGaugeMetricId();
10
    Stats::Gauge* g = &Stats::Utility::gaugeFromStatNames(
10
        *envoyWasm()->scope_, {envoyWasm()->custom_stat_namespace_, stat_name},
10
        Stats::Gauge::ImportMode::Accumulate);
10
    envoyWasm()->gauges_.emplace(id, g);
10
    *metric_id_ptr = id;
10
    return WasmResult::Ok;
10
  }
  // (type == MetricType::Histogram) {
10
  auto id = envoyWasm()->nextHistogramMetricId();
10
  Stats::Histogram* h = &Stats::Utility::histogramFromStatNames(
10
      *envoyWasm()->scope_, {envoyWasm()->custom_stat_namespace_, stat_name},
10
      Stats::Histogram::Unit::Unspecified);
10
  envoyWasm()->histograms_.emplace(id, h);
10
  *metric_id_ptr = id;
10
  return WasmResult::Ok;
20
}
34
WasmResult Context::incrementMetric(uint32_t metric_id, int64_t offset) {
34
  auto type = static_cast<MetricType>(metric_id & Wasm::kMetricTypeMask);
34
  if (type == MetricType::Counter) {
22
    auto it = envoyWasm()->counters_.find(metric_id);
22
    if (it != envoyWasm()->counters_.end()) {
20
      if (offset > 0) {
18
        it->second->add(offset);
18
        return WasmResult::Ok;
18
      } else {
2
        return WasmResult::BadArgument;
2
      }
20
    }
2
    return WasmResult::NotFound;
26
  } else if (type == MetricType::Gauge) {
6
    auto it = envoyWasm()->gauges_.find(metric_id);
6
    if (it != envoyWasm()->gauges_.end()) {
4
      if (offset > 0) {
2
        it->second->add(offset);
2
        return WasmResult::Ok;
2
      } else {
2
        it->second->sub(-offset);
2
        return WasmResult::Ok;
2
      }
4
    }
2
    return WasmResult::NotFound;
6
  }
6
  return WasmResult::BadArgument;
34
}
36
WasmResult Context::recordMetric(uint32_t metric_id, uint64_t value) {
36
  auto type = static_cast<MetricType>(metric_id & Wasm::kMetricTypeMask);
36
  if (type == MetricType::Counter) {
10
    auto it = envoyWasm()->counters_.find(metric_id);
10
    if (it != envoyWasm()->counters_.end()) {
8
      it->second->add(value);
8
      return WasmResult::Ok;
8
    }
26
  } else if (type == MetricType::Gauge) {
12
    auto it = envoyWasm()->gauges_.find(metric_id);
12
    if (it != envoyWasm()->gauges_.end()) {
10
      it->second->set(value);
10
      return WasmResult::Ok;
10
    }
14
  } else if (type == MetricType::Histogram) {
12
    auto it = envoyWasm()->histograms_.find(metric_id);
12
    if (it != envoyWasm()->histograms_.end()) {
10
      it->second->recordValue(value);
10
      return WasmResult::Ok;
10
    }
12
  }
8
  return WasmResult::NotFound;
36
}
48
WasmResult Context::getMetric(uint32_t metric_id, uint64_t* result_uint64_ptr) {
48
  auto type = static_cast<MetricType>(metric_id & Wasm::kMetricTypeMask);
48
  if (type == MetricType::Counter) {
28
    auto it = envoyWasm()->counters_.find(metric_id);
28
    if (it != envoyWasm()->counters_.end()) {
26
      *result_uint64_ptr = it->second->value();
26
      return WasmResult::Ok;
26
    }
2
    return WasmResult::NotFound;
30
  } else if (type == MetricType::Gauge) {
12
    auto it = envoyWasm()->gauges_.find(metric_id);
12
    if (it != envoyWasm()->gauges_.end()) {
10
      *result_uint64_ptr = it->second->value();
10
      return WasmResult::Ok;
10
    }
2
    return WasmResult::NotFound;
12
  }
8
  return WasmResult::BadArgument;
48
}
2139
Context::~Context() {
  // Cancel any outstanding requests.
2139
  for (auto& p : http_request_) {
3
    if (p.second.request_ != nullptr) {
3
      p.second.request_->cancel();
3
    }
3
  }
2139
  for (auto& p : grpc_call_request_) {
1
    if (p.second.request_ != nullptr) {
1
      p.second.request_->cancel();
1
    }
1
  }
2139
  for (auto& p : grpc_stream_) {
5
    if (p.second.stream_ != nullptr) {
5
      p.second.stream_->resetStream();
5
    }
5
  }
2139
}
53
Network::FilterStatus convertNetworkFilterStatus(proxy_wasm::FilterStatus status) {
53
  switch (status) {
  default:
37
  case proxy_wasm::FilterStatus::Continue:
37
    return Network::FilterStatus::Continue;
16
  case proxy_wasm::FilterStatus::StopIteration:
16
    return Network::FilterStatus::StopIteration;
53
  }
53
};
226
Http::FilterHeadersStatus convertFilterHeadersStatus(proxy_wasm::FilterHeadersStatus status) {
226
  switch (status) {
  default:
138
  case proxy_wasm::FilterHeadersStatus::Continue:
138
    return Http::FilterHeadersStatus::Continue;
2
  case proxy_wasm::FilterHeadersStatus::StopIteration:
2
    return Http::FilterHeadersStatus::StopIteration;
2
  case proxy_wasm::FilterHeadersStatus::StopAllIterationAndBuffer:
2
    return Http::FilterHeadersStatus::StopAllIterationAndBuffer;
84
  case proxy_wasm::FilterHeadersStatus::StopAllIterationAndWatermark:
84
    return Http::FilterHeadersStatus::StopAllIterationAndWatermark;
226
  }
226
};
6
Http::FilterTrailersStatus convertFilterTrailersStatus(proxy_wasm::FilterTrailersStatus status) {
6
  switch (status) {
  default:
3
  case proxy_wasm::FilterTrailersStatus::Continue:
3
    return Http::FilterTrailersStatus::Continue;
3
  case proxy_wasm::FilterTrailersStatus::StopIteration:
3
    return Http::FilterTrailersStatus::StopIteration;
6
  }
6
};
6
Http::FilterMetadataStatus convertFilterMetadataStatus(proxy_wasm::FilterMetadataStatus status) {
6
  switch (status) {
  default:
6
  case proxy_wasm::FilterMetadataStatus::Continue:
6
    return Http::FilterMetadataStatus::Continue;
6
  }
6
};
195
Http::FilterDataStatus convertFilterDataStatus(proxy_wasm::FilterDataStatus status) {
195
  switch (status) {
  default:
109
  case proxy_wasm::FilterDataStatus::Continue:
109
    return Http::FilterDataStatus::Continue;
60
  case proxy_wasm::FilterDataStatus::StopIterationAndBuffer:
60
    return Http::FilterDataStatus::StopIterationAndBuffer;
4
  case proxy_wasm::FilterDataStatus::StopIterationAndWatermark:
4
    return Http::FilterDataStatus::StopIterationAndWatermark;
22
  case proxy_wasm::FilterDataStatus::StopIterationNoBuffer:
22
    return Http::FilterDataStatus::StopIterationNoBuffer;
195
  }
195
};
30
Network::FilterStatus Context::onNewConnection() {
30
  onCreate();
30
  return convertNetworkFilterStatus(onNetworkNewConnection());
30
};
15
Network::FilterStatus Context::onData(::Envoy::Buffer::Instance& data, bool end_stream) {
15
  if (!in_vm_context_created_) {
3
    return Network::FilterStatus::Continue;
3
  }
12
  network_downstream_data_buffer_ = &data;
12
  end_of_stream_ = end_stream;
12
  auto result = convertNetworkFilterStatus(onDownstreamData(data.length(), end_stream));
12
  if (result == Network::FilterStatus::Continue) {
7
    network_downstream_data_buffer_ = nullptr;
7
  }
12
  return result;
15
}
14
Network::FilterStatus Context::onWrite(::Envoy::Buffer::Instance& data, bool end_stream) {
14
  if (!in_vm_context_created_) {
3
    return Network::FilterStatus::Continue;
3
  }
11
  network_upstream_data_buffer_ = &data;
11
  end_of_stream_ = end_stream;
11
  auto result = convertNetworkFilterStatus(onUpstreamData(data.length(), end_stream));
11
  if (result == Network::FilterStatus::Continue) {
6
    network_upstream_data_buffer_ = nullptr;
6
  }
11
  if (end_stream) {
    // This is called when seeing end_stream=true and not on an upstream connection event,
    // because registering for latter requires replicating the whole TCP proxy extension.
3
    onUpstreamConnectionClose(CloseType::Unknown);
3
  }
11
  return result;
14
}
33
void Context::onEvent(Network::ConnectionEvent event) {
33
  if (!in_vm_context_created_) {
3
    return;
3
  }
30
  switch (event) {
24
  case Network::ConnectionEvent::LocalClose:
24
    onDownstreamConnectionClose(CloseType::Local);
24
    break;
3
  case Network::ConnectionEvent::RemoteClose:
3
    onDownstreamConnectionClose(CloseType::Remote);
3
    break;
3
  default:
3
    break;
30
  }
30
}
41
void Context::initializeReadFilterCallbacks(Network::ReadFilterCallbacks& callbacks) {
41
  network_read_filter_callbacks_ = &callbacks;
41
  network_read_filter_callbacks_->connection().addConnectionCallbacks(*this);
41
}
40
void Context::initializeWriteFilterCallbacks(Network::WriteFilterCallbacks& callbacks) {
40
  network_write_filter_callbacks_ = &callbacks;
40
}
void Context::log(const Formatter::Context& log_context,
36
                  const StreamInfo::StreamInfo& stream_info) {
  // `log` may be called multiple times due to mid-request logging -- we only want to run on the
  // last call.
36
  if (!stream_info.requestComplete().has_value()) {
3
    return;
3
  }
33
  if (!in_vm_context_created_) {
    // If the request is invalid then onRequestHeaders() will not be called and neither will
    // onCreate() in cases like sendLocalReply who short-circuits envoy
    // lifecycle. This is because Envoy does not have a well defined lifetime for the combined
    // HTTP
    // + AccessLog filter. Thus, to log these scenarios, we call onCreate() in log function below.
5
    onCreate();
5
  }
33
  access_log_phase_ = true;
33
  access_log_request_headers_ = log_context.requestHeaders().ptr();
  // ? request_trailers  ?
33
  access_log_response_headers_ = log_context.responseHeaders().ptr();
33
  access_log_response_trailers_ = log_context.responseTrailers().ptr();
33
  access_log_stream_info_ = &stream_info;
33
  onLog();
33
  access_log_phase_ = false;
33
  access_log_request_headers_ = nullptr;
  // ? request_trailers  ?
33
  access_log_response_headers_ = nullptr;
33
  access_log_response_trailers_ = nullptr;
33
  access_log_stream_info_ = nullptr;
33
}
113
void Context::onDestroy() {
113
  if (destroyed_ || !in_vm_context_created_) {
6
    return;
6
  }
107
  destroyed_ = true;
107
  onDone();
107
  onDelete();
107
}
22
WasmResult Context::continueStream(WasmStreamType stream_type) {
22
  switch (stream_type) {
11
  case WasmStreamType::Request:
11
    if (decoder_callbacks_) {
      // We are in a reentrant call, so defer.
11
      envoyWasm()->addAfterVmCallAction([this] { decoder_callbacks_->continueDecoding(); });
11
    }
11
    break;
3
  case WasmStreamType::Response:
3
    if (encoder_callbacks_) {
      // We are in a reentrant call, so defer.
3
      envoyWasm()->addAfterVmCallAction([this] { encoder_callbacks_->continueEncoding(); });
3
    }
3
    break;
3
  case WasmStreamType::Downstream:
3
    if (network_read_filter_callbacks_) {
      // We are in a reentrant call, so defer.
3
      envoyWasm()->addAfterVmCallAction(
3
          [this] { network_read_filter_callbacks_->continueReading(); });
3
    }
3
    return WasmResult::Ok;
2
  case WasmStreamType::Upstream:
2
    return WasmResult::Unimplemented;
3
  default:
3
    return WasmResult::BadArgument;
22
  }
14
  request_headers_ = nullptr;
14
  request_body_buffer_ = nullptr;
14
  request_trailers_ = nullptr;
14
  request_metadata_ = nullptr;
14
  return WasmResult::Ok;
22
}
constexpr absl::string_view CloseStreamResponseDetails = "wasm_close_stream";
18
WasmResult Context::closeStream(WasmStreamType stream_type) {
18
  switch (stream_type) {
3
  case WasmStreamType::Request:
3
    if (decoder_callbacks_) {
3
      if (!decoder_callbacks_->streamInfo().responseCodeDetails().has_value()) {
3
        decoder_callbacks_->streamInfo().setResponseCodeDetails(CloseStreamResponseDetails);
3
      }
      // We are in a reentrant call, so defer.
3
      envoyWasm()->addAfterVmCallAction([this] { decoder_callbacks_->resetStream(); });
3
    }
3
    return WasmResult::Ok;
6
  case WasmStreamType::Response:
6
    if (encoder_callbacks_) {
6
      if (!encoder_callbacks_->streamInfo().responseCodeDetails().has_value()) {
6
        encoder_callbacks_->streamInfo().setResponseCodeDetails(CloseStreamResponseDetails);
6
      }
      // We are in a reentrant call, so defer.
6
      envoyWasm()->addAfterVmCallAction([this] { encoder_callbacks_->resetStream(); });
6
    }
6
    return WasmResult::Ok;
3
  case WasmStreamType::Downstream:
3
    if (network_read_filter_callbacks_) {
      // We are in a reentrant call, so defer.
3
      envoyWasm()->addAfterVmCallAction([this] {
3
        network_read_filter_callbacks_->connection().close(
3
            Envoy::Network::ConnectionCloseType::FlushWrite, "wasm_downstream_close");
3
      });
3
    }
3
    return WasmResult::Ok;
3
  case WasmStreamType::Upstream:
3
    if (network_write_filter_callbacks_) {
      // We are in a reentrant call, so defer.
3
      envoyWasm()->addAfterVmCallAction([this] {
3
        network_write_filter_callbacks_->connection().close(
3
            Envoy::Network::ConnectionCloseType::FlushWrite, "wasm_upstream_close");
3
      });
3
    }
3
    return WasmResult::Ok;
18
  }
3
  return WasmResult::BadArgument;
18
}
constexpr absl::string_view FailStreamResponseDetails = "wasm_fail_stream";
92
void Context::failStream(WasmStreamType stream_type) {
92
  switch (stream_type) {
36
  case WasmStreamType::Request:
36
    if (decoder_callbacks_ && !failure_local_reply_sent_) {
36
      decoder_callbacks_->sendLocalReply(Envoy::Http::Code::ServiceUnavailable, "", nullptr,
36
                                         Grpc::Status::WellKnownGrpcStatus::Unavailable,
36
                                         FailStreamResponseDetails);
36
      failure_local_reply_sent_ = true;
36
    }
36
    break;
36
  case WasmStreamType::Response:
36
    if (encoder_callbacks_ && !failure_local_reply_sent_) {
      encoder_callbacks_->sendLocalReply(Envoy::Http::Code::ServiceUnavailable, "", nullptr,
                                         Grpc::Status::WellKnownGrpcStatus::Unavailable,
                                         FailStreamResponseDetails);
      failure_local_reply_sent_ = true;
    }
36
    break;
10
  case WasmStreamType::Downstream:
10
    if (network_read_filter_callbacks_) {
10
      network_read_filter_callbacks_->connection().close(
10
          Envoy::Network::ConnectionCloseType::FlushWrite);
10
    }
10
    break;
10
  case WasmStreamType::Upstream:
10
    if (network_write_filter_callbacks_) {
10
      network_write_filter_callbacks_->connection().close(
10
          Envoy::Network::ConnectionCloseType::FlushWrite);
10
    }
10
    break;
92
  }
92
}
WasmResult Context::sendLocalResponse(uint32_t response_code, std::string_view body_text,
                                      Pairs additional_headers, uint32_t grpc_status,
15
                                      std::string_view details) {
  // "additional_headers" is a collection of string_views. These will no longer
  // be valid when "modify_headers" is finally called below, so we must
  // make copies of all the headers.
15
  std::vector<std::pair<Http::LowerCaseString, std::string>> additional_headers_copy;
15
  for (auto& p : additional_headers) {
2
    const Http::LowerCaseString lower_key{std::string(p.first)};
2
    additional_headers_copy.emplace_back(lower_key, std::string(p.second));
2
  }
15
  auto modify_headers = [additional_headers_copy](Http::HeaderMap& headers) {
13
    for (auto& p : additional_headers_copy) {
2
      headers.addCopy(p.first, p.second);
2
    }
13
  };
15
  if (decoder_callbacks_) {
    // This is a bit subtle because proxy_on_delete() does call DeferAfterCallActions(),
    // so in theory it could call this and the Context in the VM would be invalid,
    // but because it only gets called after the connections have drained, the call to
    // sendLocalReply() will fail. Net net, this is safe.
15
    envoyWasm()->addAfterVmCallAction([this, response_code, body_text = std::string(body_text),
15
                                       modify_headers = std::move(modify_headers), grpc_status,
15
                                       details = StringUtil::replaceAllEmptySpace(
15
                                           absl::string_view(details.data(), details.size()))] {
      // C++, Rust and other SDKs use -1 (InvalidCode) as the default value if gRPC code is not set,
      // which should be mapped to nullopt in Envoy to prevent it from sending a grpc-status trailer
      // at all.
15
      absl::optional<Grpc::Status::GrpcStatus> grpc_status_code = absl::nullopt;
15
      if (grpc_status >= Grpc::Status::WellKnownGrpcStatus::Ok &&
15
          grpc_status <= Grpc::Status::WellKnownGrpcStatus::MaximumKnown) {
2
        grpc_status_code = Grpc::Status::WellKnownGrpcStatus(grpc_status);
2
      }
15
      decoder_callbacks_->sendLocalReply(static_cast<Envoy::Http::Code>(response_code), body_text,
15
                                         modify_headers, grpc_status_code, details);
15
    });
15
  }
15
  return WasmResult::Ok;
15
}
188
Http::FilterHeadersStatus Context::decodeHeaders(Http::RequestHeaderMap& headers, bool end_stream) {
188
  onCreate();
188
  request_headers_ = &headers;
188
  end_of_stream_ = end_stream;
188
  auto result = convertFilterHeadersStatus(onRequestHeaders(headerSize(&headers), end_stream));
188
  if (result == Http::FilterHeadersStatus::Continue) {
103
    request_headers_ = nullptr;
103
  }
188
  return result;
188
}
133
Http::FilterDataStatus Context::decodeData(::Envoy::Buffer::Instance& data, bool end_stream) {
133
  if (!in_vm_context_created_) {
3
    return Http::FilterDataStatus::Continue;
3
  }
130
  if (buffering_request_body_) {
38
    decoder_callbacks_->addDecodedData(data, false);
38
    if (destroyed_) {
      // The data adding have triggered a local reply (413) and we needn't to continue to
      // call the VM.
      // Note this is not perfect way. If the local reply processing is stopped by other
      // filters, this filter will still try to call the VM. But at least we can ensure
      // the VM has valid context.
3
      return Http::FilterDataStatus::StopIterationAndBuffer;
3
    }
38
  }
127
  request_body_buffer_ = &data;
127
  end_of_stream_ = end_stream;
127
  const auto buffer = getBuffer(WasmBufferType::HttpRequestBody);
127
  const auto buffer_size = (buffer == nullptr) ? 0 : buffer->size();
127
  auto result = convertFilterDataStatus(onRequestBody(buffer_size, end_stream));
127
  buffering_request_body_ = false;
127
  switch (result) {
65
  case Http::FilterDataStatus::Continue:
65
    request_body_buffer_ = nullptr;
65
    break;
40
  case Http::FilterDataStatus::StopIterationAndBuffer:
40
    buffering_request_body_ = true;
40
    break;
2
  case Http::FilterDataStatus::StopIterationAndWatermark:
22
  case Http::FilterDataStatus::StopIterationNoBuffer:
22
    break;
127
  }
127
  return result;
127
}
6
Http::FilterTrailersStatus Context::decodeTrailers(Http::RequestTrailerMap& trailers) {
6
  if (!in_vm_context_created_) {
3
    return Http::FilterTrailersStatus::Continue;
3
  }
3
  request_trailers_ = &trailers;
3
  auto result = convertFilterTrailersStatus(onRequestTrailers(headerSize(&trailers)));
3
  if (result == Http::FilterTrailersStatus::Continue) {
3
    request_trailers_ = nullptr;
3
  }
3
  return result;
6
}
6
Http::FilterMetadataStatus Context::decodeMetadata(Http::MetadataMap& request_metadata) {
6
  if (!in_vm_context_created_) {
3
    return Http::FilterMetadataStatus::Continue;
3
  }
3
  request_metadata_ = &request_metadata;
3
  auto result = convertFilterMetadataStatus(onRequestMetadata(headerSize(&request_metadata)));
3
  if (result == Http::FilterMetadataStatus::Continue) {
3
    request_metadata_ = nullptr;
3
  }
3
  return result;
6
}
266
void Context::setDecoderFilterCallbacks(Envoy::Http::StreamDecoderFilterCallbacks& callbacks) {
266
  decoder_callbacks_ = &callbacks;
266
}
3
Http::Filter1xxHeadersStatus Context::encode1xxHeaders(Http::ResponseHeaderMap&) {
3
  return Http::Filter1xxHeadersStatus::Continue;
3
}
Http::FilterHeadersStatus Context::encodeHeaders(Http::ResponseHeaderMap& headers,
43
                                                 bool end_stream) {
  // If the vm context is not created or the stream has failed and the local reply has been sent,
  // we should not continue to call the VM.
43
  if (!in_vm_context_created_ || failure_local_reply_sent_) {
5
    return Http::FilterHeadersStatus::Continue;
5
  }
38
  response_headers_ = &headers;
38
  end_of_stream_ = end_stream;
38
  auto result = convertFilterHeadersStatus(onResponseHeaders(headerSize(&headers), end_stream));
38
  if (result == Http::FilterHeadersStatus::Continue) {
35
    response_headers_ = nullptr;
35
  }
38
  return result;
43
}
72
Http::FilterDataStatus Context::encodeData(::Envoy::Buffer::Instance& data, bool end_stream) {
  // If the vm context is not created or the stream has failed and the local reply has been sent,
  // we should not continue to call the VM.
72
  if (!in_vm_context_created_ || failure_local_reply_sent_) {
4
    return Http::FilterDataStatus::Continue;
4
  }
68
  if (buffering_response_body_) {
20
    encoder_callbacks_->addEncodedData(data, false);
20
    if (destroyed_) {
      // The data adding have triggered a local reply (413) and we needn't to continue to
      // call the VM.
      // Note this is not perfect way. If the local reply processing is stopped by other
      // filters, this filter will still try to call the VM. But at least we can ensure
      // the VM has valid context.
      return Http::FilterDataStatus::StopIterationAndBuffer;
    }
20
  }
68
  response_body_buffer_ = &data;
68
  end_of_stream_ = end_stream;
68
  const auto buffer = getBuffer(WasmBufferType::HttpResponseBody);
68
  const auto buffer_size = (buffer == nullptr) ? 0 : buffer->size();
68
  auto result = convertFilterDataStatus(onResponseBody(buffer_size, end_stream));
68
  buffering_response_body_ = false;
68
  switch (result) {
44
  case Http::FilterDataStatus::Continue:
44
    response_body_buffer_ = nullptr;
44
    break;
20
  case Http::FilterDataStatus::StopIterationAndBuffer:
20
    buffering_response_body_ = true;
20
    break;
2
  case Http::FilterDataStatus::StopIterationAndWatermark:
4
  case Http::FilterDataStatus::StopIterationNoBuffer:
4
    break;
68
  }
68
  return result;
68
}
6
Http::FilterTrailersStatus Context::encodeTrailers(Http::ResponseTrailerMap& trailers) {
  // If the vm context is not created or the stream has failed and the local reply has been sent,
  // we should not continue to call the VM.
6
  if (!in_vm_context_created_ || failure_local_reply_sent_) {
3
    return Http::FilterTrailersStatus::Continue;
3
  }
3
  response_trailers_ = &trailers;
3
  auto result = convertFilterTrailersStatus(onResponseTrailers(headerSize(&trailers)));
3
  if (result == Http::FilterTrailersStatus::Continue) {
    response_trailers_ = nullptr;
  }
3
  return result;
6
}
6
Http::FilterMetadataStatus Context::encodeMetadata(Http::MetadataMap& response_metadata) {
  // If the vm context is not created or the stream has failed and the local reply has been sent,
  // we should not continue to call the VM.
6
  if (!in_vm_context_created_ || failure_local_reply_sent_) {
3
    return Http::FilterMetadataStatus::Continue;
3
  }
3
  response_metadata_ = &response_metadata;
3
  auto result = convertFilterMetadataStatus(onResponseMetadata(headerSize(&response_metadata)));
3
  if (result == Http::FilterMetadataStatus::Continue) {
3
    response_metadata_ = nullptr;
3
  }
3
  return result;
6
}
//  Http::FilterMetadataStatus::Continue;
255
void Context::setEncoderFilterCallbacks(Envoy::Http::StreamEncoderFilterCallbacks& callbacks) {
255
  encoder_callbacks_ = &callbacks;
255
}
31
void Context::onHttpCallSuccess(uint32_t token, Envoy::Http::ResponseMessagePtr&& response) {
  // TODO: convert this into a function in proxy-wasm-cpp-host and use here.
31
  if (proxy_wasm::current_context_ != nullptr) {
    // We are in a reentrant call, so defer.
15
    envoyWasm()->addAfterVmCallAction([this, token, response = response.release()] {
15
      onHttpCallSuccess(token, std::unique_ptr<Envoy::Http::ResponseMessage>(response));
15
    });
15
    return;
15
  }
16
  auto handler = http_request_.find(token);
16
  if (handler == http_request_.end()) {
3
    return;
3
  }
13
  http_call_response_ = &response;
13
  uint32_t body_size = response->body().length();
  // Deferred "after VM call" actions are going to be executed upon returning from
  // ContextBase::*, which might include deleting Context object via proxy_done().
13
  envoyWasm()->addAfterVmCallAction([this, handler] {
13
    http_call_response_ = nullptr;
13
    http_request_.erase(handler);
13
  });
13
  ContextBase::onHttpCallResponse(token, response->headers().size(), body_size,
13
                                  headerSize(response->trailers()));
13
}
6
void Context::onHttpCallFailure(uint32_t token, Http::AsyncClient::FailureReason reason) {
6
  if (proxy_wasm::current_context_ != nullptr) {
    // We are in a reentrant call, so defer.
3
    envoyWasm()->addAfterVmCallAction([this, token, reason] { onHttpCallFailure(token, reason); });
3
    return;
3
  }
3
  auto handler = http_request_.find(token);
3
  if (handler == http_request_.end()) {
    return;
  }
3
  status_code_ = static_cast<uint32_t>(WasmResult::BrokenConnection);
  // TODO(botengyao): handle different failure reasons.
3
  ASSERT(reason == Http::AsyncClient::FailureReason::Reset ||
3
         reason == Http::AsyncClient::FailureReason::ExceedResponseBufferLimit);
3
  status_message_ = "reset";
  // Deferred "after VM call" actions are going to be executed upon returning from
  // ContextBase::*, which might include deleting Context object via proxy_done().
3
  envoyWasm()->addAfterVmCallAction([this, handler] {
3
    status_message_ = "";
3
    http_request_.erase(handler);
3
  });
3
  ContextBase::onHttpCallResponse(token, 0, 0, 0);
3
}
27
void Context::onGrpcReceiveWrapper(uint32_t token, ::Envoy::Buffer::InstancePtr response) {
27
  ASSERT(proxy_wasm::current_context_ == nullptr); // Non-reentrant.
27
  auto cleanup = [this, token] {
27
    if (envoyWasm()->isGrpcCallId(token)) {
7
      grpc_call_request_.erase(token);
7
    }
27
  };
27
  if (envoyWasm()->on_grpc_receive_) {
27
    grpc_receive_buffer_ = std::move(response);
27
    uint32_t response_size = grpc_receive_buffer_->length();
    // Deferred "after VM call" actions are going to be executed upon returning from
    // ContextBase::*, which might include deleting Context object via proxy_done().
27
    envoyWasm()->addAfterVmCallAction([this, cleanup] {
27
      grpc_receive_buffer_.reset();
27
      cleanup();
27
    });
27
    ContextBase::onGrpcReceive(token, response_size);
27
  } else {
    cleanup();
  }
27
}
void Context::onGrpcCloseWrapper(uint32_t token, const Grpc::Status::GrpcStatus& status,
22
                                 const std::string_view message) {
22
  if (proxy_wasm::current_context_ != nullptr) {
    // We are in a reentrant call, so defer.
    envoyWasm()->addAfterVmCallAction([this, token, status, message = std::string(message)] {
      onGrpcCloseWrapper(token, status, message);
    });
    return;
  }
22
  auto cleanup = [this, token] {
22
    if (envoyWasm()->isGrpcCallId(token)) {
7
      grpc_call_request_.erase(token);
22
    } else if (envoyWasm()->isGrpcStreamId(token)) {
15
      auto it = grpc_stream_.find(token);
15
      if (it != grpc_stream_.end()) {
5
        if (it->second.local_closed_) {
5
          grpc_stream_.erase(token);
5
        }
5
      }
15
    }
22
  };
22
  if (envoyWasm()->on_grpc_close_) {
22
    status_code_ = static_cast<uint32_t>(status);
22
    status_message_ = toAbslStringView(message);
    // Deferred "after VM call" actions are going to be executed upon returning from
    // ContextBase::*, which might include deleting Context object via proxy_done().
22
    envoyWasm()->addAfterVmCallAction([this, cleanup] {
22
      status_message_ = "";
22
      cleanup();
22
    });
22
    ContextBase::onGrpcClose(token, status_code_);
22
  } else {
    cleanup();
  }
22
}
30
WasmResult Context::grpcSend(uint32_t token, std::string_view message, bool end_stream) {
30
  if (!envoyWasm()->isGrpcStreamId(token)) {
10
    return WasmResult::BadArgument;
10
  }
20
  auto it = grpc_stream_.find(token);
20
  if (it == grpc_stream_.end()) {
5
    return WasmResult::NotFound;
5
  }
15
  if (it->second.stream_) {
15
    it->second.stream_->sendMessageRaw(::Envoy::Buffer::InstancePtr(new ::Envoy::Buffer::OwnedImpl(
15
                                           message.data(), message.size())),
15
                                       end_stream);
15
  }
15
  return WasmResult::Ok;
20
}
29
WasmResult Context::grpcClose(uint32_t token) {
29
  if (envoyWasm()->isGrpcCallId(token)) {
9
    auto it = grpc_call_request_.find(token);
9
    if (it == grpc_call_request_.end()) {
5
      return WasmResult::NotFound;
5
    }
4
    if (it->second.request_) {
4
      it->second.request_->cancel();
4
    }
4
    grpc_call_request_.erase(token);
4
    return WasmResult::Ok;
24
  } else if (envoyWasm()->isGrpcStreamId(token)) {
15
    auto it = grpc_stream_.find(token);
15
    if (it == grpc_stream_.end()) {
5
      return WasmResult::NotFound;
5
    }
10
    if (it->second.stream_) {
5
      it->second.stream_->closeStream();
5
    }
10
    if (it->second.remote_closed_) {
5
      grpc_stream_.erase(token);
10
    } else {
5
      it->second.local_closed_ = true;
5
    }
10
    return WasmResult::Ok;
15
  }
5
  return WasmResult::BadArgument;
29
}
35
WasmResult Context::grpcCancel(uint32_t token) {
35
  if (envoyWasm()->isGrpcCallId(token)) {
15
    auto it = grpc_call_request_.find(token);
15
    if (it == grpc_call_request_.end()) {
5
      return WasmResult::NotFound;
5
    }
10
    if (it->second.request_) {
6
      it->second.request_->cancel();
6
    }
10
    grpc_call_request_.erase(token);
10
    return WasmResult::Ok;
30
  } else if (envoyWasm()->isGrpcStreamId(token)) {
15
    auto it = grpc_stream_.find(token);
15
    if (it == grpc_stream_.end()) {
5
      return WasmResult::NotFound;
5
    }
10
    if (it->second.stream_) {
5
      it->second.stream_->resetStream();
5
    }
10
    grpc_stream_.erase(token);
10
    return WasmResult::Ok;
15
  }
5
  return WasmResult::BadArgument;
35
}
} // namespace Wasm
} // namespace Common
} // namespace Extensions
} // namespace Envoy