Coverage Report

Created: 2024-09-19 09:45

/proc/self/cwd/source/extensions/common/wasm/context.cc
Line
Count
Source (jump to first uncovered line)
1
#include "source/extensions/common/wasm/context.h"
2
3
#include <algorithm>
4
#include <cctype>
5
#include <cstring>
6
#include <ctime>
7
#include <limits>
8
#include <memory>
9
#include <string>
10
11
#include "envoy/common/exception.h"
12
#include "envoy/extensions/wasm/v3/wasm.pb.validate.h"
13
#include "envoy/grpc/status.h"
14
#include "envoy/http/codes.h"
15
#include "envoy/local_info/local_info.h"
16
#include "envoy/network/filter.h"
17
#include "envoy/stats/sink.h"
18
#include "envoy/thread_local/thread_local.h"
19
20
#include "source/common/buffer/buffer_impl.h"
21
#include "source/common/common/assert.h"
22
#include "source/common/common/empty_string.h"
23
#include "source/common/common/enum_to_int.h"
24
#include "source/common/common/logger.h"
25
#include "source/common/common/safe_memcpy.h"
26
#include "source/common/http/header_map_impl.h"
27
#include "source/common/http/message_impl.h"
28
#include "source/common/http/utility.h"
29
#include "source/common/tracing/http_tracer_impl.h"
30
#include "source/extensions/common/wasm/plugin.h"
31
#include "source/extensions/common/wasm/wasm.h"
32
#include "source/extensions/filters/common/expr/context.h"
33
34
#include "absl/base/casts.h"
35
#include "absl/container/flat_hash_map.h"
36
#include "absl/container/node_hash_map.h"
37
#include "absl/strings/str_cat.h"
38
#include "absl/synchronization/mutex.h"
39
40
#if defined(__GNUC__)
41
#pragma GCC diagnostic push
42
#pragma GCC diagnostic ignored "-Woverloaded-virtual"
43
#endif
44
45
#include "eval/public/cel_value.h"
46
#include "eval/public/containers/field_access.h"
47
#include "eval/public/containers/field_backed_list_impl.h"
48
#include "eval/public/containers/field_backed_map_impl.h"
49
#include "eval/public/structs/cel_proto_wrapper.h"
50
51
#if defined(__GNUC__)
52
#pragma GCC diagnostic pop
53
#endif
54
55
#include "include/proxy-wasm/pairs_util.h"
56
#include "openssl/bytestring.h"
57
#include "openssl/hmac.h"
58
#include "openssl/sha.h"
59
60
using proxy_wasm::MetricType;
61
using proxy_wasm::Word;
62
63
namespace Envoy {
64
namespace Extensions {
65
namespace Common {
66
namespace Wasm {
67
68
namespace {
69
70
// FilterState prefix for CelState values.
71
constexpr absl::string_view CelStateKeyPrefix = "wasm.";
72
73
using HashPolicy = envoy::config::route::v3::RouteAction::HashPolicy;
74
using CelState = Filters::Common::Expr::CelState;
75
using CelStatePrototype = Filters::Common::Expr::CelStatePrototype;
76
77
0
Http::RequestTrailerMapPtr buildRequestTrailerMapFromPairs(const Pairs& pairs) {
78
0
  auto map = Http::RequestTrailerMapImpl::create();
79
0
  for (auto& p : pairs) {
80
    // Note: because of the lack of a string_view interface for addCopy and
81
    // the lack of an interface to add an entry with an empty value and return
82
    // the entry, there is no efficient way to prevent either a double copy
83
    // of the value or a double lookup of the entry.
84
0
    map->addCopy(Http::LowerCaseString(std::string(p.first)), std::string(p.second));
85
0
  }
86
0
  return map;
87
0
}
88
89
0
Http::RequestHeaderMapPtr buildRequestHeaderMapFromPairs(const Pairs& pairs) {
90
0
  auto map = Http::RequestHeaderMapImpl::create();
91
0
  for (auto& p : pairs) {
92
    // Note: because of the lack of a string_view interface for addCopy and
93
    // the lack of an interface to add an entry with an empty value and return
94
    // the entry, there is no efficient way to prevent either a double copy
95
    // of the value or a double lookup of the entry.
96
0
    map->addCopy(Http::LowerCaseString(std::string(p.first)), std::string(p.second));
97
0
  }
98
0
  return map;
99
0
}
100
101
0
template <typename P> static uint32_t headerSize(const P& p) { return p ? p->size() : 0; }
Unexecuted instantiation: context.cc:unsigned int Envoy::Extensions::Common::Wasm::(anonymous namespace)::headerSize<std::__1::unique_ptr<Envoy::Http::HeaderMap, std::__1::default_delete<Envoy::Http::HeaderMap> > >(std::__1::unique_ptr<Envoy::Http::HeaderMap, std::__1::default_delete<Envoy::Http::HeaderMap> > const&)
Unexecuted instantiation: context.cc:unsigned int Envoy::Extensions::Common::Wasm::(anonymous namespace)::headerSize<Envoy::Http::RequestHeaderMap*>(Envoy::Http::RequestHeaderMap* const&)
Unexecuted instantiation: context.cc:unsigned int Envoy::Extensions::Common::Wasm::(anonymous namespace)::headerSize<Envoy::Http::RequestTrailerMap*>(Envoy::Http::RequestTrailerMap* const&)
Unexecuted instantiation: context.cc:unsigned int Envoy::Extensions::Common::Wasm::(anonymous namespace)::headerSize<Envoy::Http::MetadataMap*>(Envoy::Http::MetadataMap* const&)
Unexecuted instantiation: context.cc:unsigned int Envoy::Extensions::Common::Wasm::(anonymous namespace)::headerSize<Envoy::Http::ResponseHeaderMap*>(Envoy::Http::ResponseHeaderMap* const&)
Unexecuted instantiation: context.cc:unsigned int Envoy::Extensions::Common::Wasm::(anonymous namespace)::headerSize<Envoy::Http::ResponseTrailerMap*>(Envoy::Http::ResponseTrailerMap* const&)
102
103
0
Upstream::HostDescriptionConstSharedPtr getHost(const StreamInfo::StreamInfo* info) {
104
0
  if (info && info->upstreamInfo() && info->upstreamInfo().value().get().upstreamHost()) {
105
0
    return info->upstreamInfo().value().get().upstreamHost();
106
0
  }
107
0
  return nullptr;
108
0
}
109
110
} // namespace
111
112
// Test support.
113
114
0
size_t Buffer::size() const {
115
0
  if (const_buffer_instance_) {
116
0
    return const_buffer_instance_->length();
117
0
  }
118
0
  return proxy_wasm::BufferBase::size();
119
0
}
120
121
WasmResult Buffer::copyTo(WasmBase* wasm, size_t start, size_t length, uint64_t ptr_ptr,
122
0
                          uint64_t size_ptr) const {
123
0
  if (const_buffer_instance_) {
124
0
    uint64_t pointer;
125
0
    auto p = wasm->allocMemory(length, &pointer);
126
0
    if (!p) {
127
0
      return WasmResult::InvalidMemoryAccess;
128
0
    }
129
0
    const_buffer_instance_->copyOut(start, length, p);
130
0
    if (!wasm->wasm_vm()->setWord(ptr_ptr, Word(pointer))) {
131
0
      return WasmResult::InvalidMemoryAccess;
132
0
    }
133
0
    if (!wasm->wasm_vm()->setWord(size_ptr, Word(length))) {
134
0
      return WasmResult::InvalidMemoryAccess;
135
0
    }
136
0
    return WasmResult::Ok;
137
0
  }
138
0
  return proxy_wasm::BufferBase::copyTo(wasm, start, length, ptr_ptr, size_ptr);
139
0
}
140
141
0
WasmResult Buffer::copyFrom(size_t start, size_t length, std::string_view data) {
142
0
  if (buffer_instance_) {
143
0
    if (start == 0) {
144
0
      if (length != 0) {
145
0
        buffer_instance_->drain(length);
146
0
      }
147
0
      buffer_instance_->prepend(toAbslStringView(data));
148
0
      return WasmResult::Ok;
149
0
    } else if (start >= buffer_instance_->length()) {
150
0
      buffer_instance_->add(toAbslStringView(data));
151
0
      return WasmResult::Ok;
152
0
    } else {
153
0
      return WasmResult::BadArgument;
154
0
    }
155
0
  }
156
0
  if (const_buffer_instance_) { // This buffer is immutable.
157
0
    return WasmResult::BadArgument;
158
0
  }
159
0
  return proxy_wasm::BufferBase::copyFrom(start, length, data);
160
0
}
161
162
0
Context::Context() = default;
Unexecuted instantiation: Envoy::Extensions::Common::Wasm::Context::Context()
Unexecuted instantiation: Envoy::Extensions::Common::Wasm::Context::Context()
163
0
Context::Context(Wasm* wasm) : ContextBase(wasm) {}
Unexecuted instantiation: Envoy::Extensions::Common::Wasm::Context::Context(Envoy::Extensions::Common::Wasm::Wasm*)
Unexecuted instantiation: Envoy::Extensions::Common::Wasm::Context::Context(Envoy::Extensions::Common::Wasm::Wasm*)
164
0
Context::Context(Wasm* wasm, const PluginSharedPtr& plugin) : ContextBase(wasm, plugin) {
165
0
  root_local_info_ = &std::static_pointer_cast<Plugin>(plugin)->localInfo();
166
0
}
Unexecuted instantiation: Envoy::Extensions::Common::Wasm::Context::Context(Envoy::Extensions::Common::Wasm::Wasm*, std::__1::shared_ptr<Envoy::Extensions::Common::Wasm::Plugin> const&)
Unexecuted instantiation: Envoy::Extensions::Common::Wasm::Context::Context(Envoy::Extensions::Common::Wasm::Wasm*, std::__1::shared_ptr<Envoy::Extensions::Common::Wasm::Plugin> const&)
167
Context::Context(Wasm* wasm, uint32_t root_context_id, PluginHandleSharedPtr plugin_handle)
168
0
    : ContextBase(wasm, root_context_id, plugin_handle), plugin_handle_(plugin_handle) {}
Unexecuted instantiation: Envoy::Extensions::Common::Wasm::Context::Context(Envoy::Extensions::Common::Wasm::Wasm*, unsigned int, std::__1::shared_ptr<Envoy::Extensions::Common::Wasm::PluginHandle>)
Unexecuted instantiation: Envoy::Extensions::Common::Wasm::Context::Context(Envoy::Extensions::Common::Wasm::Wasm*, unsigned int, std::__1::shared_ptr<Envoy::Extensions::Common::Wasm::PluginHandle>)
169
170
0
Wasm* Context::wasm() const { return static_cast<Wasm*>(wasm_); }
171
0
Plugin* Context::plugin() const { return static_cast<Plugin*>(plugin_.get()); }
172
0
Context* Context::rootContext() const { return static_cast<Context*>(root_context()); }
173
0
Upstream::ClusterManager& Context::clusterManager() const { return wasm()->clusterManager(); }
174
175
0
void Context::error(std::string_view message) { ENVOY_LOG(trace, message); }
176
177
0
uint64_t Context::getCurrentTimeNanoseconds() {
178
0
  return std::chrono::duration_cast<std::chrono::nanoseconds>(
179
0
             wasm()->time_source_.systemTime().time_since_epoch())
180
0
      .count();
181
0
}
182
183
0
uint64_t Context::getMonotonicTimeNanoseconds() {
184
0
  return std::chrono::duration_cast<std::chrono::nanoseconds>(
185
0
             wasm()->time_source_.monotonicTime().time_since_epoch())
186
0
      .count();
187
0
}
188
189
0
void Context::onCloseTCP() {
190
0
  if (tcp_connection_closed_ || !in_vm_context_created_) {
191
0
    return;
192
0
  }
193
0
  tcp_connection_closed_ = true;
194
0
  onDone();
195
0
  onLog();
196
0
  onDelete();
197
0
}
198
199
void Context::onResolveDns(uint32_t token, Envoy::Network::DnsResolver::ResolutionStatus status,
200
0
                           std::list<Envoy::Network::DnsResponse>&& response) {
201
0
  proxy_wasm::DeferAfterCallActions actions(this);
202
0
  if (wasm()->isFailed() || !wasm()->on_resolve_dns_) {
203
0
    return;
204
0
  }
205
0
  if (status != Network::DnsResolver::ResolutionStatus::Success) {
206
0
    buffer_.set("");
207
0
    wasm()->on_resolve_dns_(this, id_, token, 0);
208
0
    return;
209
0
  }
210
  // buffer format:
211
  //    4 bytes number of entries = N
212
  //    N * 4 bytes TTL for each entry
213
  //    N * null-terminated addresses
214
0
  uint32_t s = 4; // length
215
0
  for (auto& e : response) {
216
0
    s += 4;                                                // for TTL
217
0
    s += e.addrInfo().address_->asStringView().size() + 1; // null terminated.
218
0
  }
219
0
  auto buffer = std::unique_ptr<char[]>(new char[s]);
220
0
  char* b = buffer.get();
221
0
  uint32_t n = response.size();
222
0
  safeMemcpyUnsafeDst(b, &n);
223
0
  b += sizeof(uint32_t);
224
0
  for (auto& e : response) {
225
0
    uint32_t ttl = e.addrInfo().ttl_.count();
226
0
    safeMemcpyUnsafeDst(b, &ttl);
227
0
    b += sizeof(uint32_t);
228
0
  };
229
0
  for (auto& e : response) {
230
0
    memcpy(b, e.addrInfo().address_->asStringView().data(), // NOLINT(safe-memcpy)
231
0
           e.addrInfo().address_->asStringView().size());
232
0
    b += e.addrInfo().address_->asStringView().size();
233
0
    *b++ = 0;
234
0
  };
235
0
  buffer_.set(std::move(buffer), s);
236
0
  wasm()->on_resolve_dns_(this, id_, token, s);
237
0
}
238
239
0
template <typename I> inline uint32_t align(uint32_t i) {
240
0
  return (i + sizeof(I) - 1) & ~(sizeof(I) - 1);
241
0
}
242
243
0
template <typename I> inline char* align(char* p) {
244
0
  return reinterpret_cast<char*>((reinterpret_cast<uintptr_t>(p) + sizeof(I) - 1) &
245
0
                                 ~(sizeof(I) - 1));
246
0
}
247
248
0
void Context::onStatsUpdate(Envoy::Stats::MetricSnapshot& snapshot) {
249
0
  proxy_wasm::DeferAfterCallActions actions(this);
250
0
  if (wasm()->isFailed() || !wasm()->on_stats_update_) {
251
0
    return;
252
0
  }
253
  // buffer format:
254
  //  uint32 size of block of this type
255
  //  uint32 type
256
  //  uint32 count
257
  //    uint32 length of name
258
  //    name
259
  //    8 byte alignment padding
260
  //    8 bytes of absolute value
261
  //    8 bytes of delta  (if appropriate, e.g. for counters)
262
  //  uint32 size of block of this type
263
264
0
  uint32_t counter_block_size = 3 * sizeof(uint32_t); // type of stat
265
0
  uint32_t num_counters = snapshot.counters().size();
266
0
  uint32_t counter_type = 1;
267
268
0
  uint32_t gauge_block_size = 3 * sizeof(uint32_t); // type of stat
269
0
  uint32_t num_gauges = snapshot.gauges().size();
270
0
  uint32_t gauge_type = 2;
271
272
0
  uint32_t n = 0;
273
0
  uint64_t v = 0;
274
275
0
  for (const auto& counter : snapshot.counters()) {
276
0
    if (counter.counter_.get().used()) {
277
0
      counter_block_size += sizeof(uint32_t) + counter.counter_.get().name().size();
278
0
      counter_block_size = align<uint64_t>(counter_block_size + 2 * sizeof(uint64_t));
279
0
    }
280
0
  }
281
282
0
  for (const auto& gauge : snapshot.gauges()) {
283
0
    if (gauge.get().used()) {
284
0
      gauge_block_size += sizeof(uint32_t) + gauge.get().name().size();
285
0
      gauge_block_size += align<uint64_t>(gauge_block_size + sizeof(uint64_t));
286
0
    }
287
0
  }
288
289
0
  auto buffer = std::unique_ptr<char[]>(new char[counter_block_size + gauge_block_size]);
290
0
  char* b = buffer.get();
291
292
0
  safeMemcpyUnsafeDst(b, &counter_block_size);
293
0
  b += sizeof(uint32_t);
294
0
  safeMemcpyUnsafeDst(b, &counter_type);
295
0
  b += sizeof(uint32_t);
296
0
  safeMemcpyUnsafeDst(b, &num_counters);
297
0
  b += sizeof(uint32_t);
298
299
0
  for (const auto& counter : snapshot.counters()) {
300
0
    if (counter.counter_.get().used()) {
301
0
      n = counter.counter_.get().name().size();
302
0
      safeMemcpyUnsafeDst(b, &n);
303
0
      b += sizeof(uint32_t);
304
0
      memcpy(b, counter.counter_.get().name().data(), // NOLINT(safe-memcpy)
305
0
             counter.counter_.get().name().size());
306
0
      b = align<uint64_t>(b + counter.counter_.get().name().size());
307
0
      v = counter.counter_.get().value();
308
0
      safeMemcpyUnsafeDst(b, &v);
309
0
      b += sizeof(uint64_t);
310
0
      v = counter.delta_;
311
0
      safeMemcpyUnsafeDst(b, &v);
312
0
      b += sizeof(uint64_t);
313
0
    }
314
0
  }
315
316
0
  safeMemcpyUnsafeDst(b, &gauge_block_size);
317
0
  b += sizeof(uint32_t);
318
0
  safeMemcpyUnsafeDst(b, &gauge_type);
319
0
  b += sizeof(uint32_t);
320
0
  safeMemcpyUnsafeDst(b, &num_gauges);
321
0
  b += sizeof(uint32_t);
322
323
0
  for (const auto& gauge : snapshot.gauges()) {
324
0
    if (gauge.get().used()) {
325
0
      n = gauge.get().name().size();
326
0
      safeMemcpyUnsafeDst(b, &n);
327
0
      b += sizeof(uint32_t);
328
0
      memcpy(b, gauge.get().name().data(), gauge.get().name().size()); // NOLINT(safe-memcpy)
329
0
      b = align<uint64_t>(b + gauge.get().name().size());
330
0
      v = gauge.get().value();
331
0
      safeMemcpyUnsafeDst(b, &v);
332
0
      b += sizeof(uint64_t);
333
0
    }
334
0
  }
335
0
  buffer_.set(std::move(buffer), counter_block_size + gauge_block_size);
336
0
  wasm()->on_stats_update_(this, id_, counter_block_size + gauge_block_size);
337
0
}
338
339
// Native serializer carrying over bit representation from CEL value to the extension.
340
// This implementation assumes that the value type is static and known to the consumer.
341
0
WasmResult serializeValue(Filters::Common::Expr::CelValue value, std::string* result) {
342
0
  using Filters::Common::Expr::CelValue;
343
0
  int64_t out_int64;
344
0
  uint64_t out_uint64;
345
0
  double out_double;
346
0
  bool out_bool;
347
0
  const Protobuf::Message* out_message;
348
0
  switch (value.type()) {
349
0
  case CelValue::Type::kString:
350
0
    result->assign(value.StringOrDie().value().data(), value.StringOrDie().value().size());
351
0
    return WasmResult::Ok;
352
0
  case CelValue::Type::kBytes:
353
0
    result->assign(value.BytesOrDie().value().data(), value.BytesOrDie().value().size());
354
0
    return WasmResult::Ok;
355
0
  case CelValue::Type::kInt64:
356
0
    out_int64 = value.Int64OrDie();
357
0
    result->assign(reinterpret_cast<const char*>(&out_int64), sizeof(int64_t));
358
0
    return WasmResult::Ok;
359
0
  case CelValue::Type::kUint64:
360
0
    out_uint64 = value.Uint64OrDie();
361
0
    result->assign(reinterpret_cast<const char*>(&out_uint64), sizeof(uint64_t));
362
0
    return WasmResult::Ok;
363
0
  case CelValue::Type::kDouble:
364
0
    out_double = value.DoubleOrDie();
365
0
    result->assign(reinterpret_cast<const char*>(&out_double), sizeof(double));
366
0
    return WasmResult::Ok;
367
0
  case CelValue::Type::kBool:
368
0
    out_bool = value.BoolOrDie();
369
0
    result->assign(reinterpret_cast<const char*>(&out_bool), sizeof(bool));
370
0
    return WasmResult::Ok;
371
0
  case CelValue::Type::kDuration:
372
    // Warning: loss of precision to nanoseconds
373
0
    out_int64 = absl::ToInt64Nanoseconds(value.DurationOrDie());
374
0
    result->assign(reinterpret_cast<const char*>(&out_int64), sizeof(int64_t));
375
0
    return WasmResult::Ok;
376
0
  case CelValue::Type::kTimestamp:
377
    // Warning: loss of precision to nanoseconds
378
0
    out_int64 = absl::ToUnixNanos(value.TimestampOrDie());
379
0
    result->assign(reinterpret_cast<const char*>(&out_int64), sizeof(int64_t));
380
0
    return WasmResult::Ok;
381
0
  case CelValue::Type::kMessage:
382
0
    out_message = value.MessageOrDie();
383
0
    result->clear();
384
0
    if (!out_message || out_message->SerializeToString(result)) {
385
0
      return WasmResult::Ok;
386
0
    }
387
0
    return WasmResult::SerializationFailure;
388
0
  case CelValue::Type::kMap: {
389
0
    const auto& map = *value.MapOrDie();
390
0
    auto keys_list = map.ListKeys();
391
0
    if (!keys_list.ok()) {
392
0
      return WasmResult::SerializationFailure;
393
0
    }
394
0
    const auto& keys = *keys_list.value();
395
0
    std::vector<std::pair<std::string, std::string>> pairs(map.size(), std::make_pair("", ""));
396
0
    for (auto i = 0; i < map.size(); i++) {
397
0
      if (serializeValue(keys[i], &pairs[i].first) != WasmResult::Ok) {
398
0
        return WasmResult::SerializationFailure;
399
0
      }
400
0
      if (serializeValue(map[keys[i]].value(), &pairs[i].second) != WasmResult::Ok) {
401
0
        return WasmResult::SerializationFailure;
402
0
      }
403
0
    }
404
0
    auto size = proxy_wasm::PairsUtil::pairsSize(pairs);
405
    // prevent string inlining which violates byte alignment
406
0
    result->resize(std::max(size, static_cast<size_t>(30)));
407
0
    if (!proxy_wasm::PairsUtil::marshalPairs(pairs, result->data(), size)) {
408
0
      return WasmResult::SerializationFailure;
409
0
    }
410
0
    result->resize(size);
411
0
    return WasmResult::Ok;
412
0
  }
413
0
  case CelValue::Type::kList: {
414
0
    const auto& list = *value.ListOrDie();
415
0
    std::vector<std::pair<std::string, std::string>> pairs(list.size(), std::make_pair("", ""));
416
0
    for (auto i = 0; i < list.size(); i++) {
417
0
      if (serializeValue(list[i], &pairs[i].first) != WasmResult::Ok) {
418
0
        return WasmResult::SerializationFailure;
419
0
      }
420
0
    }
421
0
    auto size = proxy_wasm::PairsUtil::pairsSize(pairs);
422
    // prevent string inlining which violates byte alignment
423
0
    if (size < 30) {
424
0
      result->reserve(30);
425
0
    }
426
0
    result->resize(size);
427
0
    if (!proxy_wasm::PairsUtil::marshalPairs(pairs, result->data(), size)) {
428
0
      return WasmResult::SerializationFailure;
429
0
    }
430
0
    return WasmResult::Ok;
431
0
  }
432
0
  default:
433
0
    break;
434
0
  }
435
0
  return WasmResult::SerializationFailure;
436
0
}
437
438
#define PROPERTY_TOKENS(_f)                                                                        \
439
  _f(NODE) _f(LISTENER_DIRECTION) _f(LISTENER_METADATA) _f(CLUSTER_NAME) _f(CLUSTER_METADATA)      \
440
      _f(ROUTE_NAME) _f(ROUTE_METADATA) _f(PLUGIN_NAME) _f(UPSTREAM_HOST_METADATA)                 \
441
          _f(PLUGIN_ROOT_ID) _f(PLUGIN_VM_ID) _f(CONNECTION_ID)
442
443
72
static inline std::string downCase(std::string s) {
444
978
  std::transform(s.begin(), s.end(), s.begin(), [](unsigned char c) { return std::tolower(c); });
445
72
  return s;
446
72
}
447
448
#define _DECLARE(_t) _t,
449
enum class PropertyToken { PROPERTY_TOKENS(_DECLARE) };
450
#undef _DECLARE
451
452
#define _PAIR(_t) {downCase(#_t), PropertyToken::_t},
453
static absl::flat_hash_map<std::string, PropertyToken> property_tokens = {PROPERTY_TOKENS(_PAIR)};
454
#undef _PAIR
455
456
absl::optional<google::api::expr::runtime::CelValue>
457
0
Context::FindValue(absl::string_view name, Protobuf::Arena* arena) const {
458
0
  return findValue(name, arena, false);
459
0
}
460
461
absl::optional<google::api::expr::runtime::CelValue>
462
0
Context::findValue(absl::string_view name, Protobuf::Arena* arena, bool last) const {
463
0
  using google::api::expr::runtime::CelProtoWrapper;
464
0
  using google::api::expr::runtime::CelValue;
465
466
0
  const StreamInfo::StreamInfo* info = getConstRequestStreamInfo();
467
  // In order to delegate to the StreamActivation method, we have to set the
468
  // context properties to match the Wasm context properties in all callbacks
469
  // (e.g. onLog or onEncodeHeaders) for the duration of the call.
470
0
  if (root_local_info_) {
471
0
    local_info_ = root_local_info_;
472
0
  } else if (plugin_) {
473
0
    local_info_ = &plugin()->localInfo();
474
0
  }
475
0
  activation_info_ = info;
476
0
  activation_request_headers_ = request_headers_ ? request_headers_ : access_log_request_headers_;
477
0
  activation_response_headers_ =
478
0
      response_headers_ ? response_headers_ : access_log_response_headers_;
479
0
  activation_response_trailers_ =
480
0
      response_trailers_ ? response_trailers_ : access_log_response_trailers_;
481
0
  auto value = StreamActivation::FindValue(name, arena);
482
0
  resetActivation();
483
0
  if (value) {
484
0
    return value;
485
0
  }
486
487
  // Convert into a dense token to enable a jump table implementation.
488
0
  auto part_token = property_tokens.find(name);
489
0
  if (part_token == property_tokens.end()) {
490
0
    if (info) {
491
0
      std::string key = absl::StrCat(CelStateKeyPrefix, name);
492
0
      const CelState* state = info->filterState().getDataReadOnly<CelState>(key);
493
0
      if (state == nullptr) {
494
0
        if (info->upstreamInfo().has_value() &&
495
0
            info->upstreamInfo().value().get().upstreamFilterState() != nullptr) {
496
0
          state =
497
0
              info->upstreamInfo().value().get().upstreamFilterState()->getDataReadOnly<CelState>(
498
0
                  key);
499
0
        }
500
0
      }
501
0
      if (state != nullptr) {
502
0
        return state->exprValue(arena, last);
503
0
      }
504
0
    }
505
0
    return {};
506
0
  }
507
508
0
  switch (part_token->second) {
509
0
  case PropertyToken::CONNECTION_ID: {
510
0
    auto conn = getConnection();
511
0
    if (conn) {
512
0
      return CelValue::CreateUint64(conn->id());
513
0
    }
514
0
    break;
515
0
  }
516
0
  case PropertyToken::NODE:
517
0
    if (root_local_info_) {
518
0
      return CelProtoWrapper::CreateMessage(&root_local_info_->node(), arena);
519
0
    } else if (plugin_) {
520
0
      return CelProtoWrapper::CreateMessage(&plugin()->localInfo().node(), arena);
521
0
    }
522
0
    break;
523
0
  case PropertyToken::LISTENER_DIRECTION:
524
0
    if (plugin_) {
525
0
      return CelValue::CreateInt64(plugin()->direction());
526
0
    }
527
0
    break;
528
0
  case PropertyToken::LISTENER_METADATA:
529
0
    if (plugin_) {
530
0
      return CelProtoWrapper::CreateMessage(plugin()->listenerMetadata(), arena);
531
0
    }
532
0
    break;
533
0
  case PropertyToken::CLUSTER_NAME:
534
0
    if (getHost(info)) {
535
0
      return CelValue::CreateString(&getHost(info)->cluster().name());
536
0
    } else if (info && info->route() && info->route()->routeEntry()) {
537
0
      return CelValue::CreateString(&info->route()->routeEntry()->clusterName());
538
0
    } else if (info && info->upstreamClusterInfo().has_value() &&
539
0
               info->upstreamClusterInfo().value()) {
540
0
      return CelValue::CreateString(&info->upstreamClusterInfo().value()->name());
541
0
    }
542
0
    break;
543
0
  case PropertyToken::CLUSTER_METADATA:
544
0
    if (getHost(info)) {
545
0
      return CelProtoWrapper::CreateMessage(&getHost(info)->cluster().metadata(), arena);
546
0
    } else if (info && info->upstreamClusterInfo().has_value() &&
547
0
               info->upstreamClusterInfo().value()) {
548
0
      return CelProtoWrapper::CreateMessage(&info->upstreamClusterInfo().value()->metadata(),
549
0
                                            arena);
550
0
    }
551
0
    break;
552
0
  case PropertyToken::UPSTREAM_HOST_METADATA:
553
0
    if (getHost(info)) {
554
0
      return CelProtoWrapper::CreateMessage(getHost(info)->metadata().get(), arena);
555
0
    }
556
0
    break;
557
0
  case PropertyToken::ROUTE_NAME:
558
0
    if (info) {
559
0
      return CelValue::CreateString(&info->getRouteName());
560
0
    }
561
0
    break;
562
0
  case PropertyToken::ROUTE_METADATA:
563
0
    if (info && info->route()) {
564
0
      return CelProtoWrapper::CreateMessage(&info->route()->metadata(), arena);
565
0
    }
566
0
    break;
567
0
  case PropertyToken::PLUGIN_NAME:
568
0
    if (plugin_) {
569
0
      return CelValue::CreateStringView(plugin()->name_);
570
0
    }
571
0
    break;
572
0
  case PropertyToken::PLUGIN_ROOT_ID:
573
0
    return CelValue::CreateStringView(toAbslStringView(root_id()));
574
0
  case PropertyToken::PLUGIN_VM_ID:
575
0
    return CelValue::CreateStringView(toAbslStringView(wasm()->vm_id()));
576
0
  }
577
0
  return {};
578
0
}
579
580
0
WasmResult Context::getProperty(std::string_view path, std::string* result) {
581
0
  using google::api::expr::runtime::CelValue;
582
583
0
  bool first = true;
584
0
  CelValue value;
585
0
  Protobuf::Arena arena;
586
587
0
  size_t start = 0;
588
0
  while (true) {
589
0
    if (start >= path.size()) {
590
0
      break;
591
0
    }
592
593
0
    size_t end = path.find('\0', start);
594
0
    if (end == absl::string_view::npos) {
595
0
      end = start + path.size();
596
0
    }
597
0
    auto part = path.substr(start, end - start);
598
0
    start = end + 1;
599
600
0
    if (first) {
601
      // top-level identifier
602
0
      first = false;
603
0
      auto top_value = findValue(toAbslStringView(part), &arena, start >= path.size());
604
0
      if (!top_value.has_value()) {
605
0
        return WasmResult::NotFound;
606
0
      }
607
0
      value = top_value.value();
608
0
    } else if (value.IsMap()) {
609
0
      auto& map = *value.MapOrDie();
610
0
      auto field = map[CelValue::CreateStringView(toAbslStringView(part))];
611
0
      if (!field.has_value()) {
612
0
        return WasmResult::NotFound;
613
0
      }
614
0
      value = field.value();
615
0
    } else if (value.IsMessage()) {
616
0
      auto msg = value.MessageOrDie();
617
0
      if (msg == nullptr) {
618
0
        return WasmResult::NotFound;
619
0
      }
620
0
      const Protobuf::Descriptor* desc = msg->GetDescriptor();
621
0
      const Protobuf::FieldDescriptor* field_desc = desc->FindFieldByName(std::string(part));
622
0
      if (field_desc == nullptr) {
623
0
        return WasmResult::NotFound;
624
0
      }
625
0
      if (field_desc->is_map()) {
626
0
        value = CelValue::CreateMap(
627
0
            Protobuf::Arena::Create<google::api::expr::runtime::FieldBackedMapImpl>(
628
0
                &arena, msg, field_desc, &arena));
629
0
      } else if (field_desc->is_repeated()) {
630
0
        value = CelValue::CreateList(
631
0
            Protobuf::Arena::Create<google::api::expr::runtime::FieldBackedListImpl>(
632
0
                &arena, msg, field_desc, &arena));
633
0
      } else {
634
0
        auto status =
635
0
            google::api::expr::runtime::CreateValueFromSingleField(msg, field_desc, &arena, &value);
636
0
        if (!status.ok()) {
637
0
          return WasmResult::InternalFailure;
638
0
        }
639
0
      }
640
0
    } else if (value.IsList()) {
641
0
      auto& list = *value.ListOrDie();
642
0
      int idx = 0;
643
0
      if (!absl::SimpleAtoi(toAbslStringView(part), &idx)) {
644
0
        return WasmResult::NotFound;
645
0
      }
646
0
      if (idx < 0 || idx >= list.size()) {
647
0
        return WasmResult::NotFound;
648
0
      }
649
0
      value = list[idx];
650
0
    } else {
651
0
      return WasmResult::NotFound;
652
0
    }
653
0
  }
654
655
0
  return serializeValue(value, result);
656
0
}
657
658
// Header/Trailer/Metadata Maps.
659
0
Http::HeaderMap* Context::getMap(WasmHeaderMapType type) {
660
0
  switch (type) {
661
0
  case WasmHeaderMapType::RequestHeaders:
662
0
    return request_headers_;
663
0
  case WasmHeaderMapType::RequestTrailers:
664
0
    if (request_trailers_ == nullptr && request_body_buffer_ && end_of_stream_ &&
665
0
        decoder_callbacks_) {
666
0
      request_trailers_ = &decoder_callbacks_->addDecodedTrailers();
667
0
    }
668
0
    return request_trailers_;
669
0
  case WasmHeaderMapType::ResponseHeaders:
670
0
    return response_headers_;
671
0
  case WasmHeaderMapType::ResponseTrailers:
672
0
    if (response_trailers_ == nullptr && response_body_buffer_ && end_of_stream_ &&
673
0
        encoder_callbacks_) {
674
0
      response_trailers_ = &encoder_callbacks_->addEncodedTrailers();
675
0
    }
676
0
    return response_trailers_;
677
0
  default:
678
0
    return nullptr;
679
0
  }
680
0
}
681
682
0
const Http::HeaderMap* Context::getConstMap(WasmHeaderMapType type) {
683
0
  switch (type) {
684
0
  case WasmHeaderMapType::RequestHeaders:
685
0
    if (access_log_phase_) {
686
0
      return access_log_request_headers_;
687
0
    }
688
0
    return request_headers_;
689
0
  case WasmHeaderMapType::RequestTrailers:
690
0
    if (access_log_phase_) {
691
0
      return nullptr;
692
0
    }
693
0
    return request_trailers_;
694
0
  case WasmHeaderMapType::ResponseHeaders:
695
0
    if (access_log_phase_) {
696
0
      return access_log_response_headers_;
697
0
    }
698
0
    return response_headers_;
699
0
  case WasmHeaderMapType::ResponseTrailers:
700
0
    if (access_log_phase_) {
701
0
      return access_log_response_trailers_;
702
0
    }
703
0
    return response_trailers_;
704
0
  case WasmHeaderMapType::GrpcReceiveInitialMetadata:
705
0
    return rootContext()->grpc_receive_initial_metadata_.get();
706
0
  case WasmHeaderMapType::GrpcReceiveTrailingMetadata:
707
0
    return rootContext()->grpc_receive_trailing_metadata_.get();
708
0
  case WasmHeaderMapType::HttpCallResponseHeaders: {
709
0
    Envoy::Http::ResponseMessagePtr* response = rootContext()->http_call_response_;
710
0
    if (response) {
711
0
      return &(*response)->headers();
712
0
    }
713
0
    return nullptr;
714
0
  }
715
0
  case WasmHeaderMapType::HttpCallResponseTrailers: {
716
0
    Envoy::Http::ResponseMessagePtr* response = rootContext()->http_call_response_;
717
0
    if (response) {
718
0
      return (*response)->trailers();
719
0
    }
720
0
    return nullptr;
721
0
  }
722
0
  }
723
0
  IS_ENVOY_BUG("unexpected");
724
0
  return nullptr;
725
0
}
726
727
WasmResult Context::addHeaderMapValue(WasmHeaderMapType type, std::string_view key,
728
0
                                      std::string_view value) {
729
0
  auto map = getMap(type);
730
0
  if (!map) {
731
0
    return WasmResult::BadArgument;
732
0
  }
733
0
  const Http::LowerCaseString lower_key{std::string(key)};
734
0
  map->addCopy(lower_key, std::string(value));
735
0
  if (type == WasmHeaderMapType::RequestHeaders) {
736
0
    clearRouteCache();
737
0
  }
738
0
  return WasmResult::Ok;
739
0
}
740
741
WasmResult Context::getHeaderMapValue(WasmHeaderMapType type, std::string_view key,
742
0
                                      std::string_view* value) {
743
0
  auto map = getConstMap(type);
744
0
  if (!map) {
745
0
    if (access_log_phase_) {
746
      // Maps might point to nullptr in the access log phase.
747
0
      if (wasm()->abiVersion() == proxy_wasm::AbiVersion::ProxyWasm_0_1_0) {
748
0
        *value = "";
749
0
        return WasmResult::Ok;
750
0
      } else {
751
0
        return WasmResult::NotFound;
752
0
      }
753
0
    }
754
    // Requested map type is not currently available.
755
0
    return WasmResult::BadArgument;
756
0
  }
757
0
  const Http::LowerCaseString lower_key{std::string(key)};
758
0
  const auto entry = map->get(lower_key);
759
0
  if (entry.empty()) {
760
0
    if (wasm()->abiVersion() == proxy_wasm::AbiVersion::ProxyWasm_0_1_0) {
761
0
      *value = "";
762
0
      return WasmResult::Ok;
763
0
    } else {
764
0
      return WasmResult::NotFound;
765
0
    }
766
0
  }
767
  // TODO(kyessenov, PiotrSikora): This needs to either return a concatenated list of values, or
768
  // the ABI needs to be changed to return multiple values. This is a potential security issue.
769
0
  *value = toStdStringView(entry[0]->value().getStringView());
770
0
  return WasmResult::Ok;
771
0
}
772
773
0
Pairs headerMapToPairs(const Http::HeaderMap* map) {
774
0
  if (!map) {
775
0
    return {};
776
0
  }
777
0
  Pairs pairs;
778
0
  pairs.reserve(map->size());
779
0
  map->iterate([&pairs](const Http::HeaderEntry& header) -> Http::HeaderMap::Iterate {
780
0
    pairs.push_back(std::make_pair(toStdStringView(header.key().getStringView()),
781
0
                                   toStdStringView(header.value().getStringView())));
782
0
    return Http::HeaderMap::Iterate::Continue;
783
0
  });
784
0
  return pairs;
785
0
}
786
787
0
WasmResult Context::getHeaderMapPairs(WasmHeaderMapType type, Pairs* result) {
788
0
  *result = headerMapToPairs(getConstMap(type));
789
0
  return WasmResult::Ok;
790
0
}
791
792
0
WasmResult Context::setHeaderMapPairs(WasmHeaderMapType type, const Pairs& pairs) {
793
0
  auto map = getMap(type);
794
0
  if (!map) {
795
0
    return WasmResult::BadArgument;
796
0
  }
797
0
  std::vector<std::string> keys;
798
0
  map->iterate([&keys](const Http::HeaderEntry& header) -> Http::HeaderMap::Iterate {
799
0
    keys.push_back(std::string(header.key().getStringView()));
800
0
    return Http::HeaderMap::Iterate::Continue;
801
0
  });
802
0
  for (auto& k : keys) {
803
0
    const Http::LowerCaseString lower_key{k};
804
0
    map->remove(lower_key);
805
0
  }
806
0
  for (auto& p : pairs) {
807
0
    const Http::LowerCaseString lower_key{std::string(p.first)};
808
0
    map->addCopy(lower_key, std::string(p.second));
809
0
  }
810
0
  if (type == WasmHeaderMapType::RequestHeaders) {
811
0
    clearRouteCache();
812
0
  }
813
0
  return WasmResult::Ok;
814
0
}
815
816
0
WasmResult Context::removeHeaderMapValue(WasmHeaderMapType type, std::string_view key) {
817
0
  auto map = getMap(type);
818
0
  if (!map) {
819
0
    return WasmResult::BadArgument;
820
0
  }
821
0
  const Http::LowerCaseString lower_key{std::string(key)};
822
0
  map->remove(lower_key);
823
0
  if (type == WasmHeaderMapType::RequestHeaders) {
824
0
    clearRouteCache();
825
0
  }
826
0
  return WasmResult::Ok;
827
0
}
828
829
WasmResult Context::replaceHeaderMapValue(WasmHeaderMapType type, std::string_view key,
830
0
                                          std::string_view value) {
831
0
  auto map = getMap(type);
832
0
  if (!map) {
833
0
    return WasmResult::BadArgument;
834
0
  }
835
0
  const Http::LowerCaseString lower_key{std::string(key)};
836
0
  map->setCopy(lower_key, toAbslStringView(value));
837
0
  if (type == WasmHeaderMapType::RequestHeaders) {
838
0
    clearRouteCache();
839
0
  }
840
0
  return WasmResult::Ok;
841
0
}
842
843
0
WasmResult Context::getHeaderMapSize(WasmHeaderMapType type, uint32_t* result) {
844
0
  auto map = getMap(type);
845
0
  if (!map) {
846
0
    return WasmResult::BadArgument;
847
0
  }
848
0
  *result = map->byteSize();
849
0
  return WasmResult::Ok;
850
0
}
851
852
// Buffer
853
854
0
BufferInterface* Context::getBuffer(WasmBufferType type) {
855
0
  Envoy::Http::ResponseMessagePtr* response = nullptr;
856
0
  switch (type) {
857
0
  case WasmBufferType::CallData:
858
    // Set before the call.
859
0
    return &buffer_;
860
0
  case WasmBufferType::VmConfiguration:
861
0
    return buffer_.set(wasm()->vm_configuration());
862
0
  case WasmBufferType::PluginConfiguration:
863
0
    if (temp_plugin_) {
864
0
      return buffer_.set(temp_plugin_->plugin_configuration_);
865
0
    }
866
0
    return nullptr;
867
0
  case WasmBufferType::HttpRequestBody:
868
0
    if (buffering_request_body_ && decoder_callbacks_) {
869
      // We need the mutable version, so capture it using a callback.
870
      // TODO: consider adding a mutableDecodingBuffer() interface.
871
0
      ::Envoy::Buffer::Instance* buffer_instance{};
872
0
      decoder_callbacks_->modifyDecodingBuffer(
873
0
          [&buffer_instance](::Envoy::Buffer::Instance& buffer) { buffer_instance = &buffer; });
874
0
      return buffer_.set(buffer_instance);
875
0
    }
876
0
    return buffer_.set(request_body_buffer_);
877
0
  case WasmBufferType::HttpResponseBody:
878
0
    if (buffering_response_body_ && encoder_callbacks_) {
879
      // TODO: consider adding a mutableDecodingBuffer() interface.
880
0
      ::Envoy::Buffer::Instance* buffer_instance{};
881
0
      encoder_callbacks_->modifyEncodingBuffer(
882
0
          [&buffer_instance](::Envoy::Buffer::Instance& buffer) { buffer_instance = &buffer; });
883
0
      return buffer_.set(buffer_instance);
884
0
    }
885
0
    return buffer_.set(response_body_buffer_);
886
0
  case WasmBufferType::NetworkDownstreamData:
887
0
    return buffer_.set(network_downstream_data_buffer_);
888
0
  case WasmBufferType::NetworkUpstreamData:
889
0
    return buffer_.set(network_upstream_data_buffer_);
890
0
  case WasmBufferType::HttpCallResponseBody:
891
0
    response = rootContext()->http_call_response_;
892
0
    if (response) {
893
0
      auto& body = (*response)->body();
894
0
      return buffer_.set(
895
0
          std::string_view(static_cast<const char*>(body.linearize(body.length())), body.length()));
896
0
    }
897
0
    return nullptr;
898
0
  case WasmBufferType::GrpcReceiveBuffer:
899
0
    return buffer_.set(rootContext()->grpc_receive_buffer_.get());
900
0
  default:
901
0
    return nullptr;
902
0
  }
903
0
}
904
905
0
void Context::onDownstreamConnectionClose(CloseType close_type) {
906
0
  ContextBase::onDownstreamConnectionClose(close_type);
907
0
  downstream_closed_ = true;
908
0
  onCloseTCP();
909
0
}
910
911
0
void Context::onUpstreamConnectionClose(CloseType close_type) {
912
0
  ContextBase::onUpstreamConnectionClose(close_type);
913
0
  upstream_closed_ = true;
914
0
  if (downstream_closed_) {
915
0
    onCloseTCP();
916
0
  }
917
0
}
918
919
// Async call via HTTP
920
WasmResult Context::httpCall(std::string_view cluster, const Pairs& request_headers,
921
                             std::string_view request_body, const Pairs& request_trailers,
922
0
                             int timeout_milliseconds, uint32_t* token_ptr) {
923
0
  if (timeout_milliseconds < 0) {
924
0
    return WasmResult::BadArgument;
925
0
  }
926
0
  auto cluster_string = std::string(cluster);
927
0
  const auto thread_local_cluster = clusterManager().getThreadLocalCluster(cluster_string);
928
0
  if (thread_local_cluster == nullptr) {
929
0
    return WasmResult::BadArgument;
930
0
  }
931
932
0
  Http::RequestMessagePtr message(
933
0
      new Http::RequestMessageImpl(buildRequestHeaderMapFromPairs(request_headers)));
934
935
  // Check that we were provided certain headers.
936
0
  if (message->headers().Path() == nullptr || message->headers().Method() == nullptr ||
937
0
      message->headers().Host() == nullptr) {
938
0
    return WasmResult::BadArgument;
939
0
  }
940
941
0
  if (!request_body.empty()) {
942
0
    message->body().add(toAbslStringView(request_body));
943
0
    message->headers().setContentLength(request_body.size());
944
0
  }
945
946
0
  if (!request_trailers.empty()) {
947
0
    message->trailers(buildRequestTrailerMapFromPairs(request_trailers));
948
0
  }
949
950
0
  absl::optional<std::chrono::milliseconds> timeout;
951
0
  if (timeout_milliseconds > 0) {
952
0
    timeout = std::chrono::milliseconds(timeout_milliseconds);
953
0
  }
954
955
0
  uint32_t token = wasm()->nextHttpCallId();
956
0
  auto& handler = http_request_[token];
957
0
  handler.context_ = this;
958
0
  handler.token_ = token;
959
960
  // set default hash policy to be based on :authority to enable consistent hash
961
0
  Http::AsyncClient::RequestOptions options;
962
0
  options.setTimeout(timeout);
963
0
  Protobuf::RepeatedPtrField<HashPolicy> hash_policy;
964
0
  hash_policy.Add()->mutable_header()->set_header_name(Http::Headers::get().Host.get());
965
0
  options.setHashPolicy(hash_policy);
966
0
  options.setSendXff(false);
967
0
  auto http_request =
968
0
      thread_local_cluster->httpAsyncClient().send(std::move(message), handler, options);
969
0
  if (!http_request) {
970
0
    http_request_.erase(token);
971
0
    return WasmResult::InternalFailure;
972
0
  }
973
0
  handler.request_ = http_request;
974
0
  *token_ptr = token;
975
0
  return WasmResult::Ok;
976
0
}
977
978
WasmResult Context::grpcCall(std::string_view grpc_service, std::string_view service_name,
979
                             std::string_view method_name, const Pairs& initial_metadata,
980
                             std::string_view request, std::chrono::milliseconds timeout,
981
0
                             uint32_t* token_ptr) {
982
0
  GrpcService service_proto;
983
0
  if (!service_proto.ParseFromArray(grpc_service.data(), grpc_service.size())) {
984
0
    auto cluster_name = std::string(grpc_service.substr(0, grpc_service.size()));
985
0
    const auto thread_local_cluster = clusterManager().getThreadLocalCluster(cluster_name);
986
0
    if (thread_local_cluster == nullptr) {
987
      // TODO(shikugawa): The reason to keep return status as `BadArgument` is not to force
988
      // callers to change their own codebase with ABI 0.1.x. We should treat this failure as
989
      // `BadArgument` after ABI 0.2.x will have released.
990
0
      return WasmResult::ParseFailure;
991
0
    }
992
0
    service_proto.mutable_envoy_grpc()->set_cluster_name(cluster_name);
993
0
  }
994
0
  uint32_t token = wasm()->nextGrpcCallId();
995
0
  auto& handler = grpc_call_request_[token];
996
0
  handler.context_ = this;
997
0
  handler.token_ = token;
998
0
  auto client_or_error = clusterManager().grpcAsyncClientManager().getOrCreateRawAsyncClient(
999
0
      service_proto, *wasm()->scope_, true /* skip_cluster_check */);
1000
0
  if (!client_or_error.status().ok()) {
1001
0
    return WasmResult::BadArgument;
1002
0
  }
1003
0
  auto grpc_client = client_or_error.value();
1004
0
  grpc_initial_metadata_ = buildRequestHeaderMapFromPairs(initial_metadata);
1005
1006
  // set default hash policy to be based on :authority to enable consistent hash
1007
0
  Http::AsyncClient::RequestOptions options;
1008
0
  options.setTimeout(timeout);
1009
0
  Protobuf::RepeatedPtrField<HashPolicy> hash_policy;
1010
0
  hash_policy.Add()->mutable_header()->set_header_name(Http::Headers::get().Host.get());
1011
0
  options.setHashPolicy(hash_policy);
1012
0
  options.setSendXff(false);
1013
1014
0
  auto grpc_request =
1015
0
      grpc_client->sendRaw(toAbslStringView(service_name), toAbslStringView(method_name),
1016
0
                           std::make_unique<::Envoy::Buffer::OwnedImpl>(toAbslStringView(request)),
1017
0
                           handler, Tracing::NullSpan::instance(), options);
1018
0
  if (!grpc_request) {
1019
0
    grpc_call_request_.erase(token);
1020
0
    return WasmResult::InternalFailure;
1021
0
  }
1022
0
  handler.client_ = std::move(grpc_client);
1023
0
  handler.request_ = grpc_request;
1024
0
  *token_ptr = token;
1025
0
  return WasmResult::Ok;
1026
0
}
1027
1028
WasmResult Context::grpcStream(std::string_view grpc_service, std::string_view service_name,
1029
                               std::string_view method_name, const Pairs& initial_metadata,
1030
0
                               uint32_t* token_ptr) {
1031
0
  GrpcService service_proto;
1032
0
  if (!service_proto.ParseFromArray(grpc_service.data(), grpc_service.size())) {
1033
0
    auto cluster_name = std::string(grpc_service.substr(0, grpc_service.size()));
1034
0
    const auto thread_local_cluster = clusterManager().getThreadLocalCluster(cluster_name);
1035
0
    if (thread_local_cluster == nullptr) {
1036
      // TODO(shikugawa): The reason to keep return status as `BadArgument` is not to force
1037
      // callers to change their own codebase with ABI 0.1.x. We should treat this failure as
1038
      // `BadArgument` after ABI 0.2.x will have released.
1039
0
      return WasmResult::ParseFailure;
1040
0
    }
1041
0
    service_proto.mutable_envoy_grpc()->set_cluster_name(cluster_name);
1042
0
  }
1043
0
  uint32_t token = wasm()->nextGrpcStreamId();
1044
0
  auto& handler = grpc_stream_[token];
1045
0
  handler.context_ = this;
1046
0
  handler.token_ = token;
1047
0
  auto client_or_error = clusterManager().grpcAsyncClientManager().getOrCreateRawAsyncClient(
1048
0
      service_proto, *wasm()->scope_, true /* skip_cluster_check */);
1049
0
  if (!client_or_error.status().ok()) {
1050
0
    return WasmResult::BadArgument;
1051
0
  }
1052
0
  auto grpc_client = client_or_error.value();
1053
0
  grpc_initial_metadata_ = buildRequestHeaderMapFromPairs(initial_metadata);
1054
1055
  // set default hash policy to be based on :authority to enable consistent hash
1056
0
  Http::AsyncClient::StreamOptions options;
1057
0
  Protobuf::RepeatedPtrField<HashPolicy> hash_policy;
1058
0
  hash_policy.Add()->mutable_header()->set_header_name(Http::Headers::get().Host.get());
1059
0
  options.setHashPolicy(hash_policy);
1060
0
  options.setSendXff(false);
1061
1062
0
  auto grpc_stream = grpc_client->startRaw(toAbslStringView(service_name),
1063
0
                                           toAbslStringView(method_name), handler, options);
1064
0
  if (!grpc_stream) {
1065
0
    grpc_stream_.erase(token);
1066
0
    return WasmResult::InternalFailure;
1067
0
  }
1068
0
  handler.client_ = std::move(grpc_client);
1069
0
  handler.stream_ = grpc_stream;
1070
0
  *token_ptr = token;
1071
0
  return WasmResult::Ok;
1072
0
}
1073
1074
// NB: this is currently called inline, so the token is known to be that of the currently
1075
// executing grpcCall or grpcStream.
1076
void Context::onGrpcCreateInitialMetadata(uint32_t /* token */,
1077
0
                                          Http::RequestHeaderMap& initial_metadata) {
1078
0
  if (grpc_initial_metadata_) {
1079
0
    Http::HeaderMapImpl::copyFrom(initial_metadata, *grpc_initial_metadata_);
1080
0
    grpc_initial_metadata_.reset();
1081
0
  }
1082
0
}
1083
1084
// StreamInfo
1085
0
const StreamInfo::StreamInfo* Context::getConstRequestStreamInfo() const {
1086
0
  if (encoder_callbacks_) {
1087
0
    return &encoder_callbacks_->streamInfo();
1088
0
  } else if (decoder_callbacks_) {
1089
0
    return &decoder_callbacks_->streamInfo();
1090
0
  } else if (access_log_stream_info_) {
1091
0
    return access_log_stream_info_;
1092
0
  } else if (network_read_filter_callbacks_) {
1093
0
    return &network_read_filter_callbacks_->connection().streamInfo();
1094
0
  } else if (network_write_filter_callbacks_) {
1095
0
    return &network_write_filter_callbacks_->connection().streamInfo();
1096
0
  }
1097
0
  return nullptr;
1098
0
}
1099
1100
0
StreamInfo::StreamInfo* Context::getRequestStreamInfo() const {
1101
0
  if (encoder_callbacks_) {
1102
0
    return &encoder_callbacks_->streamInfo();
1103
0
  } else if (decoder_callbacks_) {
1104
0
    return &decoder_callbacks_->streamInfo();
1105
0
  } else if (network_read_filter_callbacks_) {
1106
0
    return &network_read_filter_callbacks_->connection().streamInfo();
1107
0
  } else if (network_write_filter_callbacks_) {
1108
0
    return &network_write_filter_callbacks_->connection().streamInfo();
1109
0
  }
1110
0
  return nullptr;
1111
0
}
1112
1113
0
const Network::Connection* Context::getConnection() const {
1114
0
  if (encoder_callbacks_) {
1115
0
    return encoder_callbacks_->connection().ptr();
1116
0
  } else if (decoder_callbacks_) {
1117
0
    return decoder_callbacks_->connection().ptr();
1118
0
  } else if (network_read_filter_callbacks_) {
1119
0
    return &network_read_filter_callbacks_->connection();
1120
0
  } else if (network_write_filter_callbacks_) {
1121
0
    return &network_write_filter_callbacks_->connection();
1122
0
  }
1123
0
  return nullptr;
1124
0
}
1125
1126
0
WasmResult Context::setProperty(std::string_view path, std::string_view value) {
1127
0
  auto* stream_info = getRequestStreamInfo();
1128
0
  if (!stream_info) {
1129
0
    return WasmResult::NotFound;
1130
0
  }
1131
0
  std::string key;
1132
0
  absl::StrAppend(&key, CelStateKeyPrefix, toAbslStringView(path));
1133
0
  CelState* state = stream_info->filterState()->getDataMutable<CelState>(key);
1134
0
  if (state == nullptr) {
1135
0
    const auto& it = rootContext()->state_prototypes_.find(toAbslStringView(path));
1136
0
    const CelStatePrototype& prototype =
1137
0
        it == rootContext()->state_prototypes_.end()
1138
0
            ? Filters::Common::Expr::DefaultCelStatePrototype::get()
1139
0
            : *it->second.get(); // NOLINT
1140
0
    auto state_ptr = std::make_unique<CelState>(prototype);
1141
0
    state = state_ptr.get();
1142
0
    stream_info->filterState()->setData(key, std::move(state_ptr),
1143
0
                                        StreamInfo::FilterState::StateType::Mutable,
1144
0
                                        prototype.life_span_);
1145
0
  }
1146
0
  if (!state->setValue(toAbslStringView(value))) {
1147
0
    return WasmResult::BadArgument;
1148
0
  }
1149
0
  return WasmResult::Ok;
1150
0
}
1151
1152
WasmResult Context::setEnvoyFilterState(std::string_view path, std::string_view value,
1153
0
                                        StreamInfo::FilterState::LifeSpan life_span) {
1154
0
  auto* factory =
1155
0
      Registry::FactoryRegistry<StreamInfo::FilterState::ObjectFactory>::getFactory(path);
1156
0
  if (!factory) {
1157
0
    return WasmResult::NotFound;
1158
0
  }
1159
1160
0
  auto object = factory->createFromBytes(value);
1161
0
  if (!object) {
1162
0
    return WasmResult::BadArgument;
1163
0
  }
1164
1165
0
  auto* stream_info = getRequestStreamInfo();
1166
0
  if (!stream_info) {
1167
0
    return WasmResult::NotFound;
1168
0
  }
1169
1170
0
  stream_info->filterState()->setData(path, std::move(object),
1171
0
                                      StreamInfo::FilterState::StateType::Mutable, life_span);
1172
0
  return WasmResult::Ok;
1173
0
}
1174
1175
WasmResult
1176
Context::declareProperty(std::string_view path,
1177
0
                         Filters::Common::Expr::CelStatePrototypeConstPtr state_prototype) {
1178
  // Do not delete existing schema since it can be referenced by state objects.
1179
0
  if (state_prototypes_.find(toAbslStringView(path)) == state_prototypes_.end()) {
1180
0
    state_prototypes_[toAbslStringView(path)] = std::move(state_prototype);
1181
0
    return WasmResult::Ok;
1182
0
  }
1183
0
  return WasmResult::BadArgument;
1184
0
}
1185
1186
0
WasmResult Context::log(uint32_t level, std::string_view message) {
1187
0
  switch (static_cast<spdlog::level::level_enum>(level)) {
1188
0
  case spdlog::level::trace:
1189
0
    ENVOY_LOG(trace, "wasm log{}: {}", log_prefix(), message);
1190
0
    return WasmResult::Ok;
1191
0
  case spdlog::level::debug:
1192
0
    ENVOY_LOG(debug, "wasm log{}: {}", log_prefix(), message);
1193
0
    return WasmResult::Ok;
1194
0
  case spdlog::level::info:
1195
0
    ENVOY_LOG(info, "wasm log{}: {}", log_prefix(), message);
1196
0
    return WasmResult::Ok;
1197
0
  case spdlog::level::warn:
1198
0
    ENVOY_LOG(warn, "wasm log{}: {}", log_prefix(), message);
1199
0
    return WasmResult::Ok;
1200
0
  case spdlog::level::err:
1201
0
    ENVOY_LOG(error, "wasm log{}: {}", log_prefix(), message);
1202
0
    return WasmResult::Ok;
1203
0
  case spdlog::level::critical:
1204
0
    ENVOY_LOG(critical, "wasm log{}: {}", log_prefix(), message);
1205
0
    return WasmResult::Ok;
1206
0
  case spdlog::level::off:
1207
0
    PANIC("not implemented");
1208
0
  case spdlog::level::n_levels:
1209
0
    PANIC("not implemented");
1210
0
  }
1211
0
  PANIC_DUE_TO_CORRUPT_ENUM;
1212
0
}
1213
1214
0
uint32_t Context::getLogLevel() {
1215
  // Like the "log" call above, assume that spdlog level as an int
1216
  // matches the enum in the SDK
1217
0
  return static_cast<uint32_t>(ENVOY_LOGGER().level());
1218
0
}
1219
1220
//
1221
// Calls into the Wasm code.
1222
//
1223
bool Context::validateConfiguration(std::string_view configuration,
1224
0
                                    const std::shared_ptr<PluginBase>& plugin_base) {
1225
0
  auto plugin = std::static_pointer_cast<Plugin>(plugin_base);
1226
0
  if (!wasm()->validate_configuration_) {
1227
0
    return true;
1228
0
  }
1229
0
  temp_plugin_ = plugin_base;
1230
0
  auto result =
1231
0
      wasm()
1232
0
          ->validate_configuration_(this, id_, static_cast<uint32_t>(configuration.size()))
1233
0
          .u64_ != 0;
1234
0
  temp_plugin_.reset();
1235
0
  return result;
1236
0
}
1237
1238
0
std::string_view Context::getConfiguration() {
1239
0
  if (temp_plugin_) {
1240
0
    return temp_plugin_->plugin_configuration_;
1241
0
  } else {
1242
0
    return wasm()->vm_configuration();
1243
0
  }
1244
0
};
1245
1246
0
std::pair<uint32_t, std::string_view> Context::getStatus() {
1247
0
  return std::make_pair(status_code_, toStdStringView(status_message_));
1248
0
}
1249
1250
0
void Context::onGrpcReceiveInitialMetadataWrapper(uint32_t token, Http::HeaderMapPtr&& metadata) {
1251
0
  grpc_receive_initial_metadata_ = std::move(metadata);
1252
0
  onGrpcReceiveInitialMetadata(token, headerSize(grpc_receive_initial_metadata_));
1253
0
  grpc_receive_initial_metadata_ = nullptr;
1254
0
}
1255
1256
0
void Context::onGrpcReceiveTrailingMetadataWrapper(uint32_t token, Http::HeaderMapPtr&& metadata) {
1257
0
  grpc_receive_trailing_metadata_ = std::move(metadata);
1258
0
  onGrpcReceiveTrailingMetadata(token, headerSize(grpc_receive_trailing_metadata_));
1259
0
  grpc_receive_trailing_metadata_ = nullptr;
1260
0
}
1261
1262
WasmResult Context::defineMetric(uint32_t metric_type, std::string_view name,
1263
0
                                 uint32_t* metric_id_ptr) {
1264
0
  if (metric_type > static_cast<uint32_t>(MetricType::Max)) {
1265
0
    return WasmResult::BadArgument;
1266
0
  }
1267
0
  auto type = static_cast<MetricType>(metric_type);
1268
  // TODO: Consider rethinking the scoping policy as it does not help in this case.
1269
0
  Stats::StatNameManagedStorage storage(toAbslStringView(name), wasm()->scope_->symbolTable());
1270
0
  Stats::StatName stat_name = storage.statName();
1271
  // We prefix the given name with custom_stat_name_ so that these user-defined
1272
  // custom metrics can be distinguished from native Envoy metrics.
1273
0
  if (type == MetricType::Counter) {
1274
0
    auto id = wasm()->nextCounterMetricId();
1275
0
    Stats::Counter* c = &Stats::Utility::counterFromElements(
1276
0
        *wasm()->scope_, {wasm()->custom_stat_namespace_, stat_name});
1277
0
    wasm()->counters_.emplace(id, c);
1278
0
    *metric_id_ptr = id;
1279
0
    return WasmResult::Ok;
1280
0
  }
1281
0
  if (type == MetricType::Gauge) {
1282
0
    auto id = wasm()->nextGaugeMetricId();
1283
0
    Stats::Gauge* g = &Stats::Utility::gaugeFromStatNames(
1284
0
        *wasm()->scope_, {wasm()->custom_stat_namespace_, stat_name},
1285
0
        Stats::Gauge::ImportMode::Accumulate);
1286
0
    wasm()->gauges_.emplace(id, g);
1287
0
    *metric_id_ptr = id;
1288
0
    return WasmResult::Ok;
1289
0
  }
1290
  // (type == MetricType::Histogram) {
1291
0
  auto id = wasm()->nextHistogramMetricId();
1292
0
  Stats::Histogram* h = &Stats::Utility::histogramFromStatNames(
1293
0
      *wasm()->scope_, {wasm()->custom_stat_namespace_, stat_name},
1294
0
      Stats::Histogram::Unit::Unspecified);
1295
0
  wasm()->histograms_.emplace(id, h);
1296
0
  *metric_id_ptr = id;
1297
0
  return WasmResult::Ok;
1298
0
}
1299
1300
0
WasmResult Context::incrementMetric(uint32_t metric_id, int64_t offset) {
1301
0
  auto type = static_cast<MetricType>(metric_id & Wasm::kMetricTypeMask);
1302
0
  if (type == MetricType::Counter) {
1303
0
    auto it = wasm()->counters_.find(metric_id);
1304
0
    if (it != wasm()->counters_.end()) {
1305
0
      if (offset > 0) {
1306
0
        it->second->add(offset);
1307
0
        return WasmResult::Ok;
1308
0
      } else {
1309
0
        return WasmResult::BadArgument;
1310
0
      }
1311
0
    }
1312
0
    return WasmResult::NotFound;
1313
0
  } else if (type == MetricType::Gauge) {
1314
0
    auto it = wasm()->gauges_.find(metric_id);
1315
0
    if (it != wasm()->gauges_.end()) {
1316
0
      if (offset > 0) {
1317
0
        it->second->add(offset);
1318
0
        return WasmResult::Ok;
1319
0
      } else {
1320
0
        it->second->sub(-offset);
1321
0
        return WasmResult::Ok;
1322
0
      }
1323
0
    }
1324
0
    return WasmResult::NotFound;
1325
0
  }
1326
0
  return WasmResult::BadArgument;
1327
0
}
1328
1329
0
WasmResult Context::recordMetric(uint32_t metric_id, uint64_t value) {
1330
0
  auto type = static_cast<MetricType>(metric_id & Wasm::kMetricTypeMask);
1331
0
  if (type == MetricType::Counter) {
1332
0
    auto it = wasm()->counters_.find(metric_id);
1333
0
    if (it != wasm()->counters_.end()) {
1334
0
      it->second->add(value);
1335
0
      return WasmResult::Ok;
1336
0
    }
1337
0
  } else if (type == MetricType::Gauge) {
1338
0
    auto it = wasm()->gauges_.find(metric_id);
1339
0
    if (it != wasm()->gauges_.end()) {
1340
0
      it->second->set(value);
1341
0
      return WasmResult::Ok;
1342
0
    }
1343
0
  } else if (type == MetricType::Histogram) {
1344
0
    auto it = wasm()->histograms_.find(metric_id);
1345
0
    if (it != wasm()->histograms_.end()) {
1346
0
      it->second->recordValue(value);
1347
0
      return WasmResult::Ok;
1348
0
    }
1349
0
  }
1350
0
  return WasmResult::NotFound;
1351
0
}
1352
1353
0
WasmResult Context::getMetric(uint32_t metric_id, uint64_t* result_uint64_ptr) {
1354
0
  auto type = static_cast<MetricType>(metric_id & Wasm::kMetricTypeMask);
1355
0
  if (type == MetricType::Counter) {
1356
0
    auto it = wasm()->counters_.find(metric_id);
1357
0
    if (it != wasm()->counters_.end()) {
1358
0
      *result_uint64_ptr = it->second->value();
1359
0
      return WasmResult::Ok;
1360
0
    }
1361
0
    return WasmResult::NotFound;
1362
0
  } else if (type == MetricType::Gauge) {
1363
0
    auto it = wasm()->gauges_.find(metric_id);
1364
0
    if (it != wasm()->gauges_.end()) {
1365
0
      *result_uint64_ptr = it->second->value();
1366
0
      return WasmResult::Ok;
1367
0
    }
1368
0
    return WasmResult::NotFound;
1369
0
  }
1370
0
  return WasmResult::BadArgument;
1371
0
}
1372
1373
0
Context::~Context() {
1374
  // Cancel any outstanding requests.
1375
0
  for (auto& p : http_request_) {
1376
0
    p.second.request_->cancel();
1377
0
  }
1378
0
  for (auto& p : grpc_call_request_) {
1379
0
    p.second.request_->cancel();
1380
0
  }
1381
0
  for (auto& p : grpc_stream_) {
1382
0
    p.second.stream_->resetStream();
1383
0
  }
1384
0
}
1385
1386
0
Network::FilterStatus convertNetworkFilterStatus(proxy_wasm::FilterStatus status) {
1387
0
  switch (status) {
1388
0
  default:
1389
0
  case proxy_wasm::FilterStatus::Continue:
1390
0
    return Network::FilterStatus::Continue;
1391
0
  case proxy_wasm::FilterStatus::StopIteration:
1392
0
    return Network::FilterStatus::StopIteration;
1393
0
  }
1394
0
};
1395
1396
0
Http::FilterHeadersStatus convertFilterHeadersStatus(proxy_wasm::FilterHeadersStatus status) {
1397
0
  switch (status) {
1398
0
  default:
1399
0
  case proxy_wasm::FilterHeadersStatus::Continue:
1400
0
    return Http::FilterHeadersStatus::Continue;
1401
0
  case proxy_wasm::FilterHeadersStatus::StopIteration:
1402
0
    return Http::FilterHeadersStatus::StopIteration;
1403
0
  case proxy_wasm::FilterHeadersStatus::StopAllIterationAndBuffer:
1404
0
    return Http::FilterHeadersStatus::StopAllIterationAndBuffer;
1405
0
  case proxy_wasm::FilterHeadersStatus::StopAllIterationAndWatermark:
1406
0
    return Http::FilterHeadersStatus::StopAllIterationAndWatermark;
1407
0
  }
1408
0
};
1409
1410
0
Http::FilterTrailersStatus convertFilterTrailersStatus(proxy_wasm::FilterTrailersStatus status) {
1411
0
  switch (status) {
1412
0
  default:
1413
0
  case proxy_wasm::FilterTrailersStatus::Continue:
1414
0
    return Http::FilterTrailersStatus::Continue;
1415
0
  case proxy_wasm::FilterTrailersStatus::StopIteration:
1416
0
    return Http::FilterTrailersStatus::StopIteration;
1417
0
  }
1418
0
};
1419
1420
0
Http::FilterMetadataStatus convertFilterMetadataStatus(proxy_wasm::FilterMetadataStatus status) {
1421
0
  switch (status) {
1422
0
  default:
1423
0
  case proxy_wasm::FilterMetadataStatus::Continue:
1424
0
    return Http::FilterMetadataStatus::Continue;
1425
0
  }
1426
0
};
1427
1428
0
Http::FilterDataStatus convertFilterDataStatus(proxy_wasm::FilterDataStatus status) {
1429
0
  switch (status) {
1430
0
  default:
1431
0
  case proxy_wasm::FilterDataStatus::Continue:
1432
0
    return Http::FilterDataStatus::Continue;
1433
0
  case proxy_wasm::FilterDataStatus::StopIterationAndBuffer:
1434
0
    return Http::FilterDataStatus::StopIterationAndBuffer;
1435
0
  case proxy_wasm::FilterDataStatus::StopIterationAndWatermark:
1436
0
    return Http::FilterDataStatus::StopIterationAndWatermark;
1437
0
  case proxy_wasm::FilterDataStatus::StopIterationNoBuffer:
1438
0
    return Http::FilterDataStatus::StopIterationNoBuffer;
1439
0
  }
1440
0
};
1441
1442
0
Network::FilterStatus Context::onNewConnection() {
1443
0
  onCreate();
1444
0
  return convertNetworkFilterStatus(onNetworkNewConnection());
1445
0
};
1446
1447
0
Network::FilterStatus Context::onData(::Envoy::Buffer::Instance& data, bool end_stream) {
1448
0
  if (!in_vm_context_created_) {
1449
0
    return Network::FilterStatus::Continue;
1450
0
  }
1451
0
  network_downstream_data_buffer_ = &data;
1452
0
  end_of_stream_ = end_stream;
1453
0
  auto result = convertNetworkFilterStatus(onDownstreamData(data.length(), end_stream));
1454
0
  if (result == Network::FilterStatus::Continue) {
1455
0
    network_downstream_data_buffer_ = nullptr;
1456
0
  }
1457
0
  return result;
1458
0
}
1459
1460
0
Network::FilterStatus Context::onWrite(::Envoy::Buffer::Instance& data, bool end_stream) {
1461
0
  if (!in_vm_context_created_) {
1462
0
    return Network::FilterStatus::Continue;
1463
0
  }
1464
0
  network_upstream_data_buffer_ = &data;
1465
0
  end_of_stream_ = end_stream;
1466
0
  auto result = convertNetworkFilterStatus(onUpstreamData(data.length(), end_stream));
1467
0
  if (result == Network::FilterStatus::Continue) {
1468
0
    network_upstream_data_buffer_ = nullptr;
1469
0
  }
1470
0
  if (end_stream) {
1471
    // This is called when seeing end_stream=true and not on an upstream connection event,
1472
    // because registering for latter requires replicating the whole TCP proxy extension.
1473
0
    onUpstreamConnectionClose(CloseType::Unknown);
1474
0
  }
1475
0
  return result;
1476
0
}
1477
1478
0
void Context::onEvent(Network::ConnectionEvent event) {
1479
0
  if (!in_vm_context_created_) {
1480
0
    return;
1481
0
  }
1482
0
  switch (event) {
1483
0
  case Network::ConnectionEvent::LocalClose:
1484
0
    onDownstreamConnectionClose(CloseType::Local);
1485
0
    break;
1486
0
  case Network::ConnectionEvent::RemoteClose:
1487
0
    onDownstreamConnectionClose(CloseType::Remote);
1488
0
    break;
1489
0
  default:
1490
0
    break;
1491
0
  }
1492
0
}
1493
1494
0
void Context::initializeReadFilterCallbacks(Network::ReadFilterCallbacks& callbacks) {
1495
0
  network_read_filter_callbacks_ = &callbacks;
1496
0
  network_read_filter_callbacks_->connection().addConnectionCallbacks(*this);
1497
0
}
1498
1499
0
void Context::initializeWriteFilterCallbacks(Network::WriteFilterCallbacks& callbacks) {
1500
0
  network_write_filter_callbacks_ = &callbacks;
1501
0
}
1502
1503
void Context::log(const Formatter::HttpFormatterContext& log_context,
1504
0
                  const StreamInfo::StreamInfo& stream_info) {
1505
  // `log` may be called multiple times due to mid-request logging -- we only want to run on the
1506
  // last call.
1507
0
  if (!stream_info.requestComplete().has_value()) {
1508
0
    return;
1509
0
  }
1510
0
  if (!in_vm_context_created_) {
1511
    // If the request is invalid then onRequestHeaders() will not be called and neither will
1512
    // onCreate() in cases like sendLocalReply who short-circuits envoy
1513
    // lifecycle. This is because Envoy does not have a well defined lifetime for the combined
1514
    // HTTP
1515
    // + AccessLog filter. Thus, to log these scenarios, we call onCreate() in log function below.
1516
0
    onCreate();
1517
0
  }
1518
1519
0
  access_log_phase_ = true;
1520
0
  access_log_request_headers_ = &log_context.requestHeaders();
1521
  // ? request_trailers  ?
1522
0
  access_log_response_headers_ = &log_context.responseHeaders();
1523
0
  access_log_response_trailers_ = &log_context.responseTrailers();
1524
0
  access_log_stream_info_ = &stream_info;
1525
1526
0
  onLog();
1527
1528
0
  access_log_phase_ = false;
1529
0
  access_log_request_headers_ = nullptr;
1530
  // ? request_trailers  ?
1531
0
  access_log_response_headers_ = nullptr;
1532
0
  access_log_response_trailers_ = nullptr;
1533
0
  access_log_stream_info_ = nullptr;
1534
0
}
1535
1536
0
void Context::onDestroy() {
1537
0
  if (destroyed_ || !in_vm_context_created_) {
1538
0
    return;
1539
0
  }
1540
0
  destroyed_ = true;
1541
0
  onDone();
1542
0
  onDelete();
1543
0
}
1544
1545
0
WasmResult Context::continueStream(WasmStreamType stream_type) {
1546
0
  switch (stream_type) {
1547
0
  case WasmStreamType::Request:
1548
0
    if (decoder_callbacks_) {
1549
      // We are in a reentrant call, so defer.
1550
0
      wasm()->addAfterVmCallAction([this] { decoder_callbacks_->continueDecoding(); });
1551
0
    }
1552
0
    break;
1553
0
  case WasmStreamType::Response:
1554
0
    if (encoder_callbacks_) {
1555
      // We are in a reentrant call, so defer.
1556
0
      wasm()->addAfterVmCallAction([this] { encoder_callbacks_->continueEncoding(); });
1557
0
    }
1558
0
    break;
1559
0
  case WasmStreamType::Downstream:
1560
0
    if (network_read_filter_callbacks_) {
1561
      // We are in a reentrant call, so defer.
1562
0
      wasm()->addAfterVmCallAction([this] { network_read_filter_callbacks_->continueReading(); });
1563
0
    }
1564
0
    return WasmResult::Ok;
1565
0
  case WasmStreamType::Upstream:
1566
0
    return WasmResult::Unimplemented;
1567
0
  default:
1568
0
    return WasmResult::BadArgument;
1569
0
  }
1570
0
  request_headers_ = nullptr;
1571
0
  request_body_buffer_ = nullptr;
1572
0
  request_trailers_ = nullptr;
1573
0
  request_metadata_ = nullptr;
1574
0
  return WasmResult::Ok;
1575
0
}
1576
1577
constexpr absl::string_view CloseStreamResponseDetails = "wasm_close_stream";
1578
1579
0
WasmResult Context::closeStream(WasmStreamType stream_type) {
1580
0
  switch (stream_type) {
1581
0
  case WasmStreamType::Request:
1582
0
    if (decoder_callbacks_) {
1583
0
      if (!decoder_callbacks_->streamInfo().responseCodeDetails().has_value()) {
1584
0
        decoder_callbacks_->streamInfo().setResponseCodeDetails(CloseStreamResponseDetails);
1585
0
      }
1586
      // We are in a reentrant call, so defer.
1587
0
      wasm()->addAfterVmCallAction([this] { decoder_callbacks_->resetStream(); });
1588
0
    }
1589
0
    return WasmResult::Ok;
1590
0
  case WasmStreamType::Response:
1591
0
    if (encoder_callbacks_) {
1592
0
      if (!encoder_callbacks_->streamInfo().responseCodeDetails().has_value()) {
1593
0
        encoder_callbacks_->streamInfo().setResponseCodeDetails(CloseStreamResponseDetails);
1594
0
      }
1595
      // We are in a reentrant call, so defer.
1596
0
      wasm()->addAfterVmCallAction([this] { encoder_callbacks_->resetStream(); });
1597
0
    }
1598
0
    return WasmResult::Ok;
1599
0
  case WasmStreamType::Downstream:
1600
0
    if (network_read_filter_callbacks_) {
1601
      // We are in a reentrant call, so defer.
1602
0
      wasm()->addAfterVmCallAction([this] {
1603
0
        network_read_filter_callbacks_->connection().close(
1604
0
            Envoy::Network::ConnectionCloseType::FlushWrite, "wasm_downstream_close");
1605
0
      });
1606
0
    }
1607
0
    return WasmResult::Ok;
1608
0
  case WasmStreamType::Upstream:
1609
0
    if (network_write_filter_callbacks_) {
1610
      // We are in a reentrant call, so defer.
1611
0
      wasm()->addAfterVmCallAction([this] {
1612
0
        network_write_filter_callbacks_->connection().close(
1613
0
            Envoy::Network::ConnectionCloseType::FlushWrite, "wasm_upstream_close");
1614
0
      });
1615
0
    }
1616
0
    return WasmResult::Ok;
1617
0
  }
1618
0
  return WasmResult::BadArgument;
1619
0
}
1620
1621
constexpr absl::string_view FailStreamResponseDetails = "wasm_fail_stream";
1622
1623
0
void Context::failStream(WasmStreamType stream_type) {
1624
0
  switch (stream_type) {
1625
0
  case WasmStreamType::Request:
1626
0
    if (decoder_callbacks_ && !local_reply_sent_) {
1627
0
      decoder_callbacks_->sendLocalReply(Envoy::Http::Code::ServiceUnavailable, "", nullptr,
1628
0
                                         Grpc::Status::WellKnownGrpcStatus::Unavailable,
1629
0
                                         FailStreamResponseDetails);
1630
0
      local_reply_sent_ = true;
1631
0
    }
1632
0
    break;
1633
0
  case WasmStreamType::Response:
1634
0
    if (encoder_callbacks_ && !local_reply_sent_) {
1635
0
      encoder_callbacks_->sendLocalReply(Envoy::Http::Code::ServiceUnavailable, "", nullptr,
1636
0
                                         Grpc::Status::WellKnownGrpcStatus::Unavailable,
1637
0
                                         FailStreamResponseDetails);
1638
0
      local_reply_sent_ = true;
1639
0
    }
1640
0
    break;
1641
0
  case WasmStreamType::Downstream:
1642
0
    if (network_read_filter_callbacks_) {
1643
0
      network_read_filter_callbacks_->connection().close(
1644
0
          Envoy::Network::ConnectionCloseType::FlushWrite);
1645
0
    }
1646
0
    break;
1647
0
  case WasmStreamType::Upstream:
1648
0
    if (network_write_filter_callbacks_) {
1649
0
      network_write_filter_callbacks_->connection().close(
1650
0
          Envoy::Network::ConnectionCloseType::FlushWrite);
1651
0
    }
1652
0
    break;
1653
0
  }
1654
0
}
1655
1656
WasmResult Context::sendLocalResponse(uint32_t response_code, std::string_view body_text,
1657
                                      Pairs additional_headers, uint32_t grpc_status,
1658
0
                                      std::string_view details) {
1659
  // This flag is used to avoid calling sendLocalReply() twice, even if wasm code has this
1660
  // logic. We can't reuse "local_reply_sent_" here because it can't avoid calling nested
1661
  // sendLocalReply() during encodeHeaders().
1662
0
  if (local_reply_hold_) {
1663
0
    return WasmResult::BadArgument;
1664
0
  }
1665
  // "additional_headers" is a collection of string_views. These will no longer
1666
  // be valid when "modify_headers" is finally called below, so we must
1667
  // make copies of all the headers.
1668
0
  std::vector<std::pair<Http::LowerCaseString, std::string>> additional_headers_copy;
1669
0
  for (auto& p : additional_headers) {
1670
0
    const Http::LowerCaseString lower_key{std::string(p.first)};
1671
0
    additional_headers_copy.emplace_back(lower_key, std::string(p.second));
1672
0
  }
1673
1674
0
  auto modify_headers = [additional_headers_copy](Http::HeaderMap& headers) {
1675
0
    for (auto& p : additional_headers_copy) {
1676
0
      headers.addCopy(p.first, p.second);
1677
0
    }
1678
0
  };
1679
1680
0
  if (decoder_callbacks_) {
1681
    // This is a bit subtle because proxy_on_delete() does call DeferAfterCallActions(),
1682
    // so in theory it could call this and the Context in the VM would be invalid,
1683
    // but because it only gets called after the connections have drained, the call to
1684
    // sendLocalReply() will fail. Net net, this is safe.
1685
0
    wasm()->addAfterVmCallAction([this, response_code, body_text = std::string(body_text),
1686
0
                                  modify_headers = std::move(modify_headers), grpc_status,
1687
0
                                  details = StringUtil::replaceAllEmptySpace(
1688
0
                                      absl::string_view(details.data(), details.size()))] {
1689
      // When the wasm vm fails, failStream() is called if the plugin is fail-closed, we need
1690
      // this flag to avoid calling sendLocalReply() twice.
1691
0
      if (local_reply_sent_) {
1692
0
        return;
1693
0
      }
1694
      // C++, Rust and other SDKs use -1 (InvalidCode) as the default value if gRPC code is not set,
1695
      // which should be mapped to nullopt in Envoy to prevent it from sending a grpc-status trailer
1696
      // at all.
1697
0
      absl::optional<Grpc::Status::GrpcStatus> grpc_status_code = absl::nullopt;
1698
0
      if (grpc_status >= Grpc::Status::WellKnownGrpcStatus::Ok &&
1699
0
          grpc_status <= Grpc::Status::WellKnownGrpcStatus::MaximumKnown) {
1700
0
        grpc_status_code = Grpc::Status::WellKnownGrpcStatus(grpc_status);
1701
0
      }
1702
0
      decoder_callbacks_->sendLocalReply(static_cast<Envoy::Http::Code>(response_code), body_text,
1703
0
                                         modify_headers, grpc_status_code, details);
1704
0
      local_reply_sent_ = true;
1705
0
    });
1706
0
  }
1707
0
  local_reply_hold_ = true;
1708
0
  return WasmResult::Ok;
1709
0
}
1710
1711
0
Http::FilterHeadersStatus Context::decodeHeaders(Http::RequestHeaderMap& headers, bool end_stream) {
1712
0
  onCreate();
1713
0
  request_headers_ = &headers;
1714
0
  end_of_stream_ = end_stream;
1715
0
  auto result = convertFilterHeadersStatus(onRequestHeaders(headerSize(&headers), end_stream));
1716
0
  if (result == Http::FilterHeadersStatus::Continue) {
1717
0
    request_headers_ = nullptr;
1718
0
  }
1719
0
  return result;
1720
0
}
1721
1722
0
Http::FilterDataStatus Context::decodeData(::Envoy::Buffer::Instance& data, bool end_stream) {
1723
0
  if (!in_vm_context_created_) {
1724
0
    return Http::FilterDataStatus::Continue;
1725
0
  }
1726
0
  request_body_buffer_ = &data;
1727
0
  end_of_stream_ = end_stream;
1728
0
  const auto buffer = getBuffer(WasmBufferType::HttpRequestBody);
1729
0
  const auto buffer_size = (buffer == nullptr) ? 0 : buffer->size();
1730
0
  auto result = convertFilterDataStatus(onRequestBody(buffer_size, end_stream));
1731
0
  buffering_request_body_ = false;
1732
0
  switch (result) {
1733
0
  case Http::FilterDataStatus::Continue:
1734
0
    request_body_buffer_ = nullptr;
1735
0
    break;
1736
0
  case Http::FilterDataStatus::StopIterationAndBuffer:
1737
0
    buffering_request_body_ = true;
1738
0
    break;
1739
0
  case Http::FilterDataStatus::StopIterationAndWatermark:
1740
0
  case Http::FilterDataStatus::StopIterationNoBuffer:
1741
0
    break;
1742
0
  }
1743
0
  return result;
1744
0
}
1745
1746
0
Http::FilterTrailersStatus Context::decodeTrailers(Http::RequestTrailerMap& trailers) {
1747
0
  if (!in_vm_context_created_) {
1748
0
    return Http::FilterTrailersStatus::Continue;
1749
0
  }
1750
0
  request_trailers_ = &trailers;
1751
0
  auto result = convertFilterTrailersStatus(onRequestTrailers(headerSize(&trailers)));
1752
0
  if (result == Http::FilterTrailersStatus::Continue) {
1753
0
    request_trailers_ = nullptr;
1754
0
  }
1755
0
  return result;
1756
0
}
1757
1758
0
Http::FilterMetadataStatus Context::decodeMetadata(Http::MetadataMap& request_metadata) {
1759
0
  if (!in_vm_context_created_) {
1760
0
    return Http::FilterMetadataStatus::Continue;
1761
0
  }
1762
0
  request_metadata_ = &request_metadata;
1763
0
  auto result = convertFilterMetadataStatus(onRequestMetadata(headerSize(&request_metadata)));
1764
0
  if (result == Http::FilterMetadataStatus::Continue) {
1765
0
    request_metadata_ = nullptr;
1766
0
  }
1767
0
  return result;
1768
0
}
1769
1770
0
void Context::setDecoderFilterCallbacks(Envoy::Http::StreamDecoderFilterCallbacks& callbacks) {
1771
0
  decoder_callbacks_ = &callbacks;
1772
0
}
1773
1774
0
Http::Filter1xxHeadersStatus Context::encode1xxHeaders(Http::ResponseHeaderMap&) {
1775
0
  return Http::Filter1xxHeadersStatus::Continue;
1776
0
}
1777
1778
Http::FilterHeadersStatus Context::encodeHeaders(Http::ResponseHeaderMap& headers,
1779
0
                                                 bool end_stream) {
1780
0
  if (!in_vm_context_created_) {
1781
0
    return Http::FilterHeadersStatus::Continue;
1782
0
  }
1783
0
  response_headers_ = &headers;
1784
0
  end_of_stream_ = end_stream;
1785
0
  auto result = convertFilterHeadersStatus(onResponseHeaders(headerSize(&headers), end_stream));
1786
0
  if (result == Http::FilterHeadersStatus::Continue) {
1787
0
    response_headers_ = nullptr;
1788
0
  }
1789
0
  return result;
1790
0
}
1791
1792
0
Http::FilterDataStatus Context::encodeData(::Envoy::Buffer::Instance& data, bool end_stream) {
1793
0
  if (!in_vm_context_created_) {
1794
0
    return Http::FilterDataStatus::Continue;
1795
0
  }
1796
0
  response_body_buffer_ = &data;
1797
0
  end_of_stream_ = end_stream;
1798
0
  const auto buffer = getBuffer(WasmBufferType::HttpResponseBody);
1799
0
  const auto buffer_size = (buffer == nullptr) ? 0 : buffer->size();
1800
0
  auto result = convertFilterDataStatus(onResponseBody(buffer_size, end_stream));
1801
0
  buffering_response_body_ = false;
1802
0
  switch (result) {
1803
0
  case Http::FilterDataStatus::Continue:
1804
0
    request_body_buffer_ = nullptr;
1805
0
    break;
1806
0
  case Http::FilterDataStatus::StopIterationAndBuffer:
1807
0
    buffering_response_body_ = true;
1808
0
    break;
1809
0
  case Http::FilterDataStatus::StopIterationAndWatermark:
1810
0
  case Http::FilterDataStatus::StopIterationNoBuffer:
1811
0
    break;
1812
0
  }
1813
0
  return result;
1814
0
}
1815
1816
0
Http::FilterTrailersStatus Context::encodeTrailers(Http::ResponseTrailerMap& trailers) {
1817
0
  if (!in_vm_context_created_) {
1818
0
    return Http::FilterTrailersStatus::Continue;
1819
0
  }
1820
0
  response_trailers_ = &trailers;
1821
0
  auto result = convertFilterTrailersStatus(onResponseTrailers(headerSize(&trailers)));
1822
0
  if (result == Http::FilterTrailersStatus::Continue) {
1823
0
    response_trailers_ = nullptr;
1824
0
  }
1825
0
  return result;
1826
0
}
1827
1828
0
Http::FilterMetadataStatus Context::encodeMetadata(Http::MetadataMap& response_metadata) {
1829
0
  if (!in_vm_context_created_) {
1830
0
    return Http::FilterMetadataStatus::Continue;
1831
0
  }
1832
0
  response_metadata_ = &response_metadata;
1833
0
  auto result = convertFilterMetadataStatus(onResponseMetadata(headerSize(&response_metadata)));
1834
0
  if (result == Http::FilterMetadataStatus::Continue) {
1835
0
    response_metadata_ = nullptr;
1836
0
  }
1837
0
  return result;
1838
0
}
1839
1840
//  Http::FilterMetadataStatus::Continue;
1841
1842
0
void Context::setEncoderFilterCallbacks(Envoy::Http::StreamEncoderFilterCallbacks& callbacks) {
1843
0
  encoder_callbacks_ = &callbacks;
1844
0
}
1845
1846
0
void Context::onHttpCallSuccess(uint32_t token, Envoy::Http::ResponseMessagePtr&& response) {
1847
  // TODO: convert this into a function in proxy-wasm-cpp-host and use here.
1848
0
  if (proxy_wasm::current_context_ != nullptr) {
1849
    // We are in a reentrant call, so defer.
1850
0
    wasm()->addAfterVmCallAction([this, token, response = response.release()] {
1851
0
      onHttpCallSuccess(token, std::unique_ptr<Envoy::Http::ResponseMessage>(response));
1852
0
    });
1853
0
    return;
1854
0
  }
1855
0
  auto handler = http_request_.find(token);
1856
0
  if (handler == http_request_.end()) {
1857
0
    return;
1858
0
  }
1859
0
  http_call_response_ = &response;
1860
0
  uint32_t body_size = response->body().length();
1861
  // Deferred "after VM call" actions are going to be executed upon returning from
1862
  // ContextBase::*, which might include deleting Context object via proxy_done().
1863
0
  wasm()->addAfterVmCallAction([this, handler] {
1864
0
    http_call_response_ = nullptr;
1865
0
    http_request_.erase(handler);
1866
0
  });
1867
0
  ContextBase::onHttpCallResponse(token, response->headers().size(), body_size,
1868
0
                                  headerSize(response->trailers()));
1869
0
}
1870
1871
0
void Context::onHttpCallFailure(uint32_t token, Http::AsyncClient::FailureReason reason) {
1872
0
  if (proxy_wasm::current_context_ != nullptr) {
1873
    // We are in a reentrant call, so defer.
1874
0
    wasm()->addAfterVmCallAction([this, token, reason] { onHttpCallFailure(token, reason); });
1875
0
    return;
1876
0
  }
1877
0
  auto handler = http_request_.find(token);
1878
0
  if (handler == http_request_.end()) {
1879
0
    return;
1880
0
  }
1881
0
  status_code_ = static_cast<uint32_t>(WasmResult::BrokenConnection);
1882
  // TODO(botengyao): handle different failure reasons.
1883
0
  ASSERT(reason == Http::AsyncClient::FailureReason::Reset ||
1884
0
         reason == Http::AsyncClient::FailureReason::ExceedResponseBufferLimit);
1885
0
  status_message_ = "reset";
1886
  // Deferred "after VM call" actions are going to be executed upon returning from
1887
  // ContextBase::*, which might include deleting Context object via proxy_done().
1888
0
  wasm()->addAfterVmCallAction([this, handler] {
1889
0
    status_message_ = "";
1890
0
    http_request_.erase(handler);
1891
0
  });
1892
0
  ContextBase::onHttpCallResponse(token, 0, 0, 0);
1893
0
}
1894
1895
0
void Context::onGrpcReceiveWrapper(uint32_t token, ::Envoy::Buffer::InstancePtr response) {
1896
0
  ASSERT(proxy_wasm::current_context_ == nullptr); // Non-reentrant.
1897
0
  auto cleanup = [this, token] {
1898
0
    if (wasm()->isGrpcCallId(token)) {
1899
0
      grpc_call_request_.erase(token);
1900
0
    }
1901
0
  };
1902
0
  if (wasm()->on_grpc_receive_) {
1903
0
    grpc_receive_buffer_ = std::move(response);
1904
0
    uint32_t response_size = grpc_receive_buffer_->length();
1905
    // Deferred "after VM call" actions are going to be executed upon returning from
1906
    // ContextBase::*, which might include deleting Context object via proxy_done().
1907
0
    wasm()->addAfterVmCallAction([this, cleanup] {
1908
0
      grpc_receive_buffer_.reset();
1909
0
      cleanup();
1910
0
    });
1911
0
    ContextBase::onGrpcReceive(token, response_size);
1912
0
  } else {
1913
0
    cleanup();
1914
0
  }
1915
0
}
1916
1917
void Context::onGrpcCloseWrapper(uint32_t token, const Grpc::Status::GrpcStatus& status,
1918
0
                                 const std::string_view message) {
1919
0
  if (proxy_wasm::current_context_ != nullptr) {
1920
    // We are in a reentrant call, so defer.
1921
0
    wasm()->addAfterVmCallAction([this, token, status, message = std::string(message)] {
1922
0
      onGrpcCloseWrapper(token, status, message);
1923
0
    });
1924
0
    return;
1925
0
  }
1926
0
  auto cleanup = [this, token] {
1927
0
    if (wasm()->isGrpcCallId(token)) {
1928
0
      grpc_call_request_.erase(token);
1929
0
    } else if (wasm()->isGrpcStreamId(token)) {
1930
0
      auto it = grpc_stream_.find(token);
1931
0
      if (it != grpc_stream_.end()) {
1932
0
        if (it->second.local_closed_) {
1933
0
          grpc_stream_.erase(token);
1934
0
        }
1935
0
      }
1936
0
    }
1937
0
  };
1938
0
  if (wasm()->on_grpc_close_) {
1939
0
    status_code_ = static_cast<uint32_t>(status);
1940
0
    status_message_ = toAbslStringView(message);
1941
    // Deferred "after VM call" actions are going to be executed upon returning from
1942
    // ContextBase::*, which might include deleting Context object via proxy_done().
1943
0
    wasm()->addAfterVmCallAction([this, cleanup] {
1944
0
      status_message_ = "";
1945
0
      cleanup();
1946
0
    });
1947
0
    ContextBase::onGrpcClose(token, status_code_);
1948
0
  } else {
1949
0
    cleanup();
1950
0
  }
1951
0
}
1952
1953
0
WasmResult Context::grpcSend(uint32_t token, std::string_view message, bool end_stream) {
1954
0
  if (!wasm()->isGrpcStreamId(token)) {
1955
0
    return WasmResult::BadArgument;
1956
0
  }
1957
0
  auto it = grpc_stream_.find(token);
1958
0
  if (it == grpc_stream_.end()) {
1959
0
    return WasmResult::NotFound;
1960
0
  }
1961
0
  if (it->second.stream_) {
1962
0
    it->second.stream_->sendMessageRaw(::Envoy::Buffer::InstancePtr(new ::Envoy::Buffer::OwnedImpl(
1963
0
                                           message.data(), message.size())),
1964
0
                                       end_stream);
1965
0
  }
1966
0
  return WasmResult::Ok;
1967
0
}
1968
1969
0
WasmResult Context::grpcClose(uint32_t token) {
1970
0
  if (wasm()->isGrpcCallId(token)) {
1971
0
    auto it = grpc_call_request_.find(token);
1972
0
    if (it == grpc_call_request_.end()) {
1973
0
      return WasmResult::NotFound;
1974
0
    }
1975
0
    if (it->second.request_) {
1976
0
      it->second.request_->cancel();
1977
0
    }
1978
0
    grpc_call_request_.erase(token);
1979
0
    return WasmResult::Ok;
1980
0
  } else if (wasm()->isGrpcStreamId(token)) {
1981
0
    auto it = grpc_stream_.find(token);
1982
0
    if (it == grpc_stream_.end()) {
1983
0
      return WasmResult::NotFound;
1984
0
    }
1985
0
    if (it->second.stream_) {
1986
0
      it->second.stream_->closeStream();
1987
0
    }
1988
0
    if (it->second.remote_closed_) {
1989
0
      grpc_stream_.erase(token);
1990
0
    } else {
1991
0
      it->second.local_closed_ = true;
1992
0
    }
1993
0
    return WasmResult::Ok;
1994
0
  }
1995
0
  return WasmResult::BadArgument;
1996
0
}
1997
1998
0
WasmResult Context::grpcCancel(uint32_t token) {
1999
0
  if (wasm()->isGrpcCallId(token)) {
2000
0
    auto it = grpc_call_request_.find(token);
2001
0
    if (it == grpc_call_request_.end()) {
2002
0
      return WasmResult::NotFound;
2003
0
    }
2004
0
    if (it->second.request_) {
2005
0
      it->second.request_->cancel();
2006
0
    }
2007
0
    grpc_call_request_.erase(token);
2008
0
    return WasmResult::Ok;
2009
0
  } else if (wasm()->isGrpcStreamId(token)) {
2010
0
    auto it = grpc_stream_.find(token);
2011
0
    if (it == grpc_stream_.end()) {
2012
0
      return WasmResult::NotFound;
2013
0
    }
2014
0
    if (it->second.stream_) {
2015
0
      it->second.stream_->resetStream();
2016
0
    }
2017
0
    grpc_stream_.erase(token);
2018
0
    return WasmResult::Ok;
2019
0
  }
2020
0
  return WasmResult::BadArgument;
2021
0
}
2022
2023
} // namespace Wasm
2024
} // namespace Common
2025
} // namespace Extensions
2026
} // namespace Envoy