Line data Source code
1 : #include "source/extensions/common/wasm/context.h"
2 :
3 : #include <algorithm>
4 : #include <cctype>
5 : #include <cstring>
6 : #include <ctime>
7 : #include <limits>
8 : #include <memory>
9 : #include <string>
10 :
11 : #include "envoy/common/exception.h"
12 : #include "envoy/extensions/wasm/v3/wasm.pb.validate.h"
13 : #include "envoy/grpc/status.h"
14 : #include "envoy/http/codes.h"
15 : #include "envoy/local_info/local_info.h"
16 : #include "envoy/network/filter.h"
17 : #include "envoy/stats/sink.h"
18 : #include "envoy/thread_local/thread_local.h"
19 :
20 : #include "source/common/buffer/buffer_impl.h"
21 : #include "source/common/common/assert.h"
22 : #include "source/common/common/empty_string.h"
23 : #include "source/common/common/enum_to_int.h"
24 : #include "source/common/common/logger.h"
25 : #include "source/common/common/safe_memcpy.h"
26 : #include "source/common/http/header_map_impl.h"
27 : #include "source/common/http/message_impl.h"
28 : #include "source/common/http/utility.h"
29 : #include "source/common/tracing/http_tracer_impl.h"
30 : #include "source/extensions/common/wasm/plugin.h"
31 : #include "source/extensions/common/wasm/wasm.h"
32 : #include "source/extensions/filters/common/expr/context.h"
33 :
34 : #include "absl/base/casts.h"
35 : #include "absl/container/flat_hash_map.h"
36 : #include "absl/container/node_hash_map.h"
37 : #include "absl/strings/str_cat.h"
38 : #include "absl/synchronization/mutex.h"
39 : #include "eval/public/cel_value.h"
40 : #include "eval/public/containers/field_access.h"
41 : #include "eval/public/containers/field_backed_list_impl.h"
42 : #include "eval/public/containers/field_backed_map_impl.h"
43 : #include "eval/public/structs/cel_proto_wrapper.h"
44 : #include "include/proxy-wasm/pairs_util.h"
45 : #include "openssl/bytestring.h"
46 : #include "openssl/hmac.h"
47 : #include "openssl/sha.h"
48 :
49 : using proxy_wasm::MetricType;
50 : using proxy_wasm::Word;
51 :
52 : namespace Envoy {
53 : namespace Extensions {
54 : namespace Common {
55 : namespace Wasm {
56 :
57 : namespace {
58 :
59 : // FilterState prefix for CelState values.
60 : constexpr absl::string_view CelStateKeyPrefix = "wasm.";
61 :
62 : using HashPolicy = envoy::config::route::v3::RouteAction::HashPolicy;
63 : using CelState = Filters::Common::Expr::CelState;
64 : using CelStatePrototype = Filters::Common::Expr::CelStatePrototype;
65 :
66 0 : Http::RequestTrailerMapPtr buildRequestTrailerMapFromPairs(const Pairs& pairs) {
67 0 : auto map = Http::RequestTrailerMapImpl::create();
68 0 : for (auto& p : pairs) {
69 : // Note: because of the lack of a string_view interface for addCopy and
70 : // the lack of an interface to add an entry with an empty value and return
71 : // the entry, there is no efficient way to prevent either a double copy
72 : // of the value or a double lookup of the entry.
73 0 : map->addCopy(Http::LowerCaseString(std::string(p.first)), std::string(p.second));
74 0 : }
75 0 : return map;
76 0 : }
77 :
78 0 : Http::RequestHeaderMapPtr buildRequestHeaderMapFromPairs(const Pairs& pairs) {
79 0 : auto map = Http::RequestHeaderMapImpl::create();
80 0 : for (auto& p : pairs) {
81 : // Note: because of the lack of a string_view interface for addCopy and
82 : // the lack of an interface to add an entry with an empty value and return
83 : // the entry, there is no efficient way to prevent either a double copy
84 : // of the value or a double lookup of the entry.
85 0 : map->addCopy(Http::LowerCaseString(std::string(p.first)), std::string(p.second));
86 0 : }
87 0 : return map;
88 0 : }
89 :
90 0 : template <typename P> static uint32_t headerSize(const P& p) { return p ? p->size() : 0; }
91 :
92 0 : Upstream::HostDescriptionConstSharedPtr getHost(const StreamInfo::StreamInfo* info) {
93 0 : if (info && info->upstreamInfo() && info->upstreamInfo().value().get().upstreamHost()) {
94 0 : return info->upstreamInfo().value().get().upstreamHost();
95 0 : }
96 0 : return nullptr;
97 0 : }
98 :
99 : } // namespace
100 :
101 : // Test support.
102 :
103 0 : size_t Buffer::size() const {
104 0 : if (const_buffer_instance_) {
105 0 : return const_buffer_instance_->length();
106 0 : }
107 0 : return proxy_wasm::BufferBase::size();
108 0 : }
109 :
110 : WasmResult Buffer::copyTo(WasmBase* wasm, size_t start, size_t length, uint64_t ptr_ptr,
111 0 : uint64_t size_ptr) const {
112 0 : if (const_buffer_instance_) {
113 0 : uint64_t pointer;
114 0 : auto p = wasm->allocMemory(length, &pointer);
115 0 : if (!p) {
116 0 : return WasmResult::InvalidMemoryAccess;
117 0 : }
118 0 : const_buffer_instance_->copyOut(start, length, p);
119 0 : if (!wasm->wasm_vm()->setWord(ptr_ptr, Word(pointer))) {
120 0 : return WasmResult::InvalidMemoryAccess;
121 0 : }
122 0 : if (!wasm->wasm_vm()->setWord(size_ptr, Word(length))) {
123 0 : return WasmResult::InvalidMemoryAccess;
124 0 : }
125 0 : return WasmResult::Ok;
126 0 : }
127 0 : return proxy_wasm::BufferBase::copyTo(wasm, start, length, ptr_ptr, size_ptr);
128 0 : }
129 :
130 0 : WasmResult Buffer::copyFrom(size_t start, size_t length, std::string_view data) {
131 0 : if (buffer_instance_) {
132 0 : if (start == 0) {
133 0 : if (length != 0) {
134 0 : buffer_instance_->drain(length);
135 0 : }
136 0 : buffer_instance_->prepend(toAbslStringView(data));
137 0 : return WasmResult::Ok;
138 0 : } else if (start >= buffer_instance_->length()) {
139 0 : buffer_instance_->add(toAbslStringView(data));
140 0 : return WasmResult::Ok;
141 0 : } else {
142 0 : return WasmResult::BadArgument;
143 0 : }
144 0 : }
145 0 : if (const_buffer_instance_) { // This buffer is immutable.
146 0 : return WasmResult::BadArgument;
147 0 : }
148 0 : return proxy_wasm::BufferBase::copyFrom(start, length, data);
149 0 : }
150 :
151 0 : Context::Context() = default;
152 0 : Context::Context(Wasm* wasm) : ContextBase(wasm) {}
153 0 : Context::Context(Wasm* wasm, const PluginSharedPtr& plugin) : ContextBase(wasm, plugin) {
154 0 : root_local_info_ = &std::static_pointer_cast<Plugin>(plugin)->localInfo();
155 0 : }
156 : Context::Context(Wasm* wasm, uint32_t root_context_id, PluginHandleSharedPtr plugin_handle)
157 0 : : ContextBase(wasm, root_context_id, plugin_handle), plugin_handle_(plugin_handle) {}
158 :
159 0 : Wasm* Context::wasm() const { return static_cast<Wasm*>(wasm_); }
160 0 : Plugin* Context::plugin() const { return static_cast<Plugin*>(plugin_.get()); }
161 0 : Context* Context::rootContext() const { return static_cast<Context*>(root_context()); }
162 0 : Upstream::ClusterManager& Context::clusterManager() const { return wasm()->clusterManager(); }
163 :
164 0 : void Context::error(std::string_view message) { ENVOY_LOG(trace, message); }
165 :
166 0 : uint64_t Context::getCurrentTimeNanoseconds() {
167 0 : return std::chrono::duration_cast<std::chrono::nanoseconds>(
168 0 : wasm()->time_source_.systemTime().time_since_epoch())
169 0 : .count();
170 0 : }
171 :
172 0 : uint64_t Context::getMonotonicTimeNanoseconds() {
173 0 : return std::chrono::duration_cast<std::chrono::nanoseconds>(
174 0 : wasm()->time_source_.monotonicTime().time_since_epoch())
175 0 : .count();
176 0 : }
177 :
178 0 : void Context::onCloseTCP() {
179 0 : if (tcp_connection_closed_ || !in_vm_context_created_) {
180 0 : return;
181 0 : }
182 0 : tcp_connection_closed_ = true;
183 0 : onDone();
184 0 : onLog();
185 0 : onDelete();
186 0 : }
187 :
188 : void Context::onResolveDns(uint32_t token, Envoy::Network::DnsResolver::ResolutionStatus status,
189 0 : std::list<Envoy::Network::DnsResponse>&& response) {
190 0 : proxy_wasm::DeferAfterCallActions actions(this);
191 0 : if (wasm()->isFailed() || !wasm()->on_resolve_dns_) {
192 0 : return;
193 0 : }
194 0 : if (status != Network::DnsResolver::ResolutionStatus::Success) {
195 0 : buffer_.set("");
196 0 : wasm()->on_resolve_dns_(this, id_, token, 0);
197 0 : return;
198 0 : }
199 : // buffer format:
200 : // 4 bytes number of entries = N
201 : // N * 4 bytes TTL for each entry
202 : // N * null-terminated addresses
203 0 : uint32_t s = 4; // length
204 0 : for (auto& e : response) {
205 0 : s += 4; // for TTL
206 0 : s += e.addrInfo().address_->asStringView().size() + 1; // null terminated.
207 0 : }
208 0 : auto buffer = std::unique_ptr<char[]>(new char[s]);
209 0 : char* b = buffer.get();
210 0 : uint32_t n = response.size();
211 0 : safeMemcpyUnsafeDst(b, &n);
212 0 : b += sizeof(uint32_t);
213 0 : for (auto& e : response) {
214 0 : uint32_t ttl = e.addrInfo().ttl_.count();
215 0 : safeMemcpyUnsafeDst(b, &ttl);
216 0 : b += sizeof(uint32_t);
217 0 : };
218 0 : for (auto& e : response) {
219 0 : memcpy(b, e.addrInfo().address_->asStringView().data(), // NOLINT(safe-memcpy)
220 0 : e.addrInfo().address_->asStringView().size());
221 0 : b += e.addrInfo().address_->asStringView().size();
222 0 : *b++ = 0;
223 0 : };
224 0 : buffer_.set(std::move(buffer), s);
225 0 : wasm()->on_resolve_dns_(this, id_, token, s);
226 0 : }
227 :
228 0 : template <typename I> inline uint32_t align(uint32_t i) {
229 0 : return (i + sizeof(I) - 1) & ~(sizeof(I) - 1);
230 0 : }
231 :
232 0 : template <typename I> inline char* align(char* p) {
233 0 : return reinterpret_cast<char*>((reinterpret_cast<uintptr_t>(p) + sizeof(I) - 1) &
234 0 : ~(sizeof(I) - 1));
235 0 : }
236 :
237 0 : void Context::onStatsUpdate(Envoy::Stats::MetricSnapshot& snapshot) {
238 0 : proxy_wasm::DeferAfterCallActions actions(this);
239 0 : if (wasm()->isFailed() || !wasm()->on_stats_update_) {
240 0 : return;
241 0 : }
242 : // buffer format:
243 : // uint32 size of block of this type
244 : // uint32 type
245 : // uint32 count
246 : // uint32 length of name
247 : // name
248 : // 8 byte alignment padding
249 : // 8 bytes of absolute value
250 : // 8 bytes of delta (if appropriate, e.g. for counters)
251 : // uint32 size of block of this type
252 :
253 0 : uint32_t counter_block_size = 3 * sizeof(uint32_t); // type of stat
254 0 : uint32_t num_counters = snapshot.counters().size();
255 0 : uint32_t counter_type = 1;
256 :
257 0 : uint32_t gauge_block_size = 3 * sizeof(uint32_t); // type of stat
258 0 : uint32_t num_gauges = snapshot.gauges().size();
259 0 : uint32_t gauge_type = 2;
260 :
261 0 : uint32_t n = 0;
262 0 : uint64_t v = 0;
263 :
264 0 : for (const auto& counter : snapshot.counters()) {
265 0 : if (counter.counter_.get().used()) {
266 0 : counter_block_size += sizeof(uint32_t) + counter.counter_.get().name().size();
267 0 : counter_block_size = align<uint64_t>(counter_block_size + 2 * sizeof(uint64_t));
268 0 : }
269 0 : }
270 :
271 0 : for (const auto& gauge : snapshot.gauges()) {
272 0 : if (gauge.get().used()) {
273 0 : gauge_block_size += sizeof(uint32_t) + gauge.get().name().size();
274 0 : gauge_block_size += align<uint64_t>(gauge_block_size + sizeof(uint64_t));
275 0 : }
276 0 : }
277 :
278 0 : auto buffer = std::unique_ptr<char[]>(new char[counter_block_size + gauge_block_size]);
279 0 : char* b = buffer.get();
280 :
281 0 : safeMemcpyUnsafeDst(b, &counter_block_size);
282 0 : b += sizeof(uint32_t);
283 0 : safeMemcpyUnsafeDst(b, &counter_type);
284 0 : b += sizeof(uint32_t);
285 0 : safeMemcpyUnsafeDst(b, &num_counters);
286 0 : b += sizeof(uint32_t);
287 :
288 0 : for (const auto& counter : snapshot.counters()) {
289 0 : if (counter.counter_.get().used()) {
290 0 : n = counter.counter_.get().name().size();
291 0 : safeMemcpyUnsafeDst(b, &n);
292 0 : b += sizeof(uint32_t);
293 0 : memcpy(b, counter.counter_.get().name().data(), // NOLINT(safe-memcpy)
294 0 : counter.counter_.get().name().size());
295 0 : b = align<uint64_t>(b + counter.counter_.get().name().size());
296 0 : v = counter.counter_.get().value();
297 0 : safeMemcpyUnsafeDst(b, &v);
298 0 : b += sizeof(uint64_t);
299 0 : v = counter.delta_;
300 0 : safeMemcpyUnsafeDst(b, &v);
301 0 : b += sizeof(uint64_t);
302 0 : }
303 0 : }
304 :
305 0 : safeMemcpyUnsafeDst(b, &gauge_block_size);
306 0 : b += sizeof(uint32_t);
307 0 : safeMemcpyUnsafeDst(b, &gauge_type);
308 0 : b += sizeof(uint32_t);
309 0 : safeMemcpyUnsafeDst(b, &num_gauges);
310 0 : b += sizeof(uint32_t);
311 :
312 0 : for (const auto& gauge : snapshot.gauges()) {
313 0 : if (gauge.get().used()) {
314 0 : n = gauge.get().name().size();
315 0 : safeMemcpyUnsafeDst(b, &n);
316 0 : b += sizeof(uint32_t);
317 0 : memcpy(b, gauge.get().name().data(), gauge.get().name().size()); // NOLINT(safe-memcpy)
318 0 : b = align<uint64_t>(b + gauge.get().name().size());
319 0 : v = gauge.get().value();
320 0 : safeMemcpyUnsafeDst(b, &v);
321 0 : b += sizeof(uint64_t);
322 0 : }
323 0 : }
324 0 : buffer_.set(std::move(buffer), counter_block_size + gauge_block_size);
325 0 : wasm()->on_stats_update_(this, id_, counter_block_size + gauge_block_size);
326 0 : }
327 :
328 : // Native serializer carrying over bit representation from CEL value to the extension.
329 : // This implementation assumes that the value type is static and known to the consumer.
330 0 : WasmResult serializeValue(Filters::Common::Expr::CelValue value, std::string* result) {
331 0 : using Filters::Common::Expr::CelValue;
332 0 : int64_t out_int64;
333 0 : uint64_t out_uint64;
334 0 : double out_double;
335 0 : bool out_bool;
336 0 : const Protobuf::Message* out_message;
337 0 : switch (value.type()) {
338 0 : case CelValue::Type::kString:
339 0 : result->assign(value.StringOrDie().value().data(), value.StringOrDie().value().size());
340 0 : return WasmResult::Ok;
341 0 : case CelValue::Type::kBytes:
342 0 : result->assign(value.BytesOrDie().value().data(), value.BytesOrDie().value().size());
343 0 : return WasmResult::Ok;
344 0 : case CelValue::Type::kInt64:
345 0 : out_int64 = value.Int64OrDie();
346 0 : result->assign(reinterpret_cast<const char*>(&out_int64), sizeof(int64_t));
347 0 : return WasmResult::Ok;
348 0 : case CelValue::Type::kUint64:
349 0 : out_uint64 = value.Uint64OrDie();
350 0 : result->assign(reinterpret_cast<const char*>(&out_uint64), sizeof(uint64_t));
351 0 : return WasmResult::Ok;
352 0 : case CelValue::Type::kDouble:
353 0 : out_double = value.DoubleOrDie();
354 0 : result->assign(reinterpret_cast<const char*>(&out_double), sizeof(double));
355 0 : return WasmResult::Ok;
356 0 : case CelValue::Type::kBool:
357 0 : out_bool = value.BoolOrDie();
358 0 : result->assign(reinterpret_cast<const char*>(&out_bool), sizeof(bool));
359 0 : return WasmResult::Ok;
360 0 : case CelValue::Type::kDuration:
361 : // Warning: loss of precision to nanoseconds
362 0 : out_int64 = absl::ToInt64Nanoseconds(value.DurationOrDie());
363 0 : result->assign(reinterpret_cast<const char*>(&out_int64), sizeof(int64_t));
364 0 : return WasmResult::Ok;
365 0 : case CelValue::Type::kTimestamp:
366 : // Warning: loss of precision to nanoseconds
367 0 : out_int64 = absl::ToUnixNanos(value.TimestampOrDie());
368 0 : result->assign(reinterpret_cast<const char*>(&out_int64), sizeof(int64_t));
369 0 : return WasmResult::Ok;
370 0 : case CelValue::Type::kMessage:
371 0 : out_message = value.MessageOrDie();
372 0 : result->clear();
373 0 : if (!out_message || out_message->SerializeToString(result)) {
374 0 : return WasmResult::Ok;
375 0 : }
376 0 : return WasmResult::SerializationFailure;
377 0 : case CelValue::Type::kMap: {
378 0 : const auto& map = *value.MapOrDie();
379 0 : auto keys_list = map.ListKeys();
380 0 : if (!keys_list.ok()) {
381 0 : return WasmResult::SerializationFailure;
382 0 : }
383 0 : const auto& keys = *keys_list.value();
384 0 : std::vector<std::pair<std::string, std::string>> pairs(map.size(), std::make_pair("", ""));
385 0 : for (auto i = 0; i < map.size(); i++) {
386 0 : if (serializeValue(keys[i], &pairs[i].first) != WasmResult::Ok) {
387 0 : return WasmResult::SerializationFailure;
388 0 : }
389 0 : if (serializeValue(map[keys[i]].value(), &pairs[i].second) != WasmResult::Ok) {
390 0 : return WasmResult::SerializationFailure;
391 0 : }
392 0 : }
393 0 : auto size = proxy_wasm::PairsUtil::pairsSize(pairs);
394 : // prevent string inlining which violates byte alignment
395 0 : result->resize(std::max(size, static_cast<size_t>(30)));
396 0 : if (!proxy_wasm::PairsUtil::marshalPairs(pairs, result->data(), size)) {
397 0 : return WasmResult::SerializationFailure;
398 0 : }
399 0 : result->resize(size);
400 0 : return WasmResult::Ok;
401 0 : }
402 0 : case CelValue::Type::kList: {
403 0 : const auto& list = *value.ListOrDie();
404 0 : std::vector<std::pair<std::string, std::string>> pairs(list.size(), std::make_pair("", ""));
405 0 : for (auto i = 0; i < list.size(); i++) {
406 0 : if (serializeValue(list[i], &pairs[i].first) != WasmResult::Ok) {
407 0 : return WasmResult::SerializationFailure;
408 0 : }
409 0 : }
410 0 : auto size = proxy_wasm::PairsUtil::pairsSize(pairs);
411 : // prevent string inlining which violates byte alignment
412 0 : if (size < 30) {
413 0 : result->reserve(30);
414 0 : }
415 0 : result->resize(size);
416 0 : if (!proxy_wasm::PairsUtil::marshalPairs(pairs, result->data(), size)) {
417 0 : return WasmResult::SerializationFailure;
418 0 : }
419 0 : return WasmResult::Ok;
420 0 : }
421 0 : default:
422 0 : break;
423 0 : }
424 0 : return WasmResult::SerializationFailure;
425 0 : }
426 :
427 : #define PROPERTY_TOKENS(_f) \
428 : _f(NODE) _f(LISTENER_DIRECTION) _f(LISTENER_METADATA) _f(CLUSTER_NAME) _f(CLUSTER_METADATA) \
429 : _f(ROUTE_NAME) _f(ROUTE_METADATA) _f(PLUGIN_NAME) _f(UPSTREAM_HOST_METADATA) \
430 : _f(PLUGIN_ROOT_ID) _f(PLUGIN_VM_ID) _f(CONNECTION_ID)
431 :
432 36 : static inline std::string downCase(std::string s) {
433 489 : std::transform(s.begin(), s.end(), s.begin(), [](unsigned char c) { return std::tolower(c); });
434 36 : return s;
435 36 : }
436 :
437 : #define _DECLARE(_t) _t,
438 : enum class PropertyToken { PROPERTY_TOKENS(_DECLARE) };
439 : #undef _DECLARE
440 :
441 : #define _PAIR(_t) {downCase(#_t), PropertyToken::_t},
442 : static absl::flat_hash_map<std::string, PropertyToken> property_tokens = {PROPERTY_TOKENS(_PAIR)};
443 : #undef _PAIR
444 :
445 : absl::optional<google::api::expr::runtime::CelValue>
446 0 : Context::FindValue(absl::string_view name, Protobuf::Arena* arena) const {
447 0 : return findValue(name, arena, false);
448 0 : }
449 :
450 : absl::optional<google::api::expr::runtime::CelValue>
451 0 : Context::findValue(absl::string_view name, Protobuf::Arena* arena, bool last) const {
452 0 : using google::api::expr::runtime::CelProtoWrapper;
453 0 : using google::api::expr::runtime::CelValue;
454 :
455 0 : const StreamInfo::StreamInfo* info = getConstRequestStreamInfo();
456 : // In order to delegate to the StreamActivation method, we have to set the
457 : // context properties to match the Wasm context properties in all callbacks
458 : // (e.g. onLog or onEncodeHeaders) for the duration of the call.
459 0 : activation_info_ = info;
460 0 : activation_request_headers_ = request_headers_ ? request_headers_ : access_log_request_headers_;
461 0 : activation_response_headers_ =
462 0 : response_headers_ ? response_headers_ : access_log_response_headers_;
463 0 : activation_response_trailers_ =
464 0 : response_trailers_ ? response_trailers_ : access_log_response_trailers_;
465 0 : auto value = StreamActivation::FindValue(name, arena);
466 0 : resetActivation();
467 0 : if (value) {
468 0 : return value;
469 0 : }
470 :
471 : // Convert into a dense token to enable a jump table implementation.
472 0 : auto part_token = property_tokens.find(name);
473 0 : if (part_token == property_tokens.end()) {
474 0 : if (info) {
475 0 : std::string key = absl::StrCat(CelStateKeyPrefix, name);
476 0 : const CelState* state = info->filterState().getDataReadOnly<CelState>(key);
477 0 : if (state == nullptr) {
478 0 : if (info->upstreamInfo().has_value() &&
479 0 : info->upstreamInfo().value().get().upstreamFilterState() != nullptr) {
480 0 : state =
481 0 : info->upstreamInfo().value().get().upstreamFilterState()->getDataReadOnly<CelState>(
482 0 : key);
483 0 : }
484 0 : }
485 0 : if (state != nullptr) {
486 0 : return state->exprValue(arena, last);
487 0 : }
488 0 : }
489 0 : return {};
490 0 : }
491 :
492 0 : switch (part_token->second) {
493 0 : case PropertyToken::CONNECTION_ID: {
494 0 : auto conn = getConnection();
495 0 : if (conn) {
496 0 : return CelValue::CreateUint64(conn->id());
497 0 : }
498 0 : break;
499 0 : }
500 0 : case PropertyToken::NODE:
501 0 : if (root_local_info_) {
502 0 : return CelProtoWrapper::CreateMessage(&root_local_info_->node(), arena);
503 0 : } else if (plugin_) {
504 0 : return CelProtoWrapper::CreateMessage(&plugin()->localInfo().node(), arena);
505 0 : }
506 0 : break;
507 0 : case PropertyToken::LISTENER_DIRECTION:
508 0 : if (plugin_) {
509 0 : return CelValue::CreateInt64(plugin()->direction());
510 0 : }
511 0 : break;
512 0 : case PropertyToken::LISTENER_METADATA:
513 0 : if (plugin_) {
514 0 : return CelProtoWrapper::CreateMessage(plugin()->listenerMetadata(), arena);
515 0 : }
516 0 : break;
517 0 : case PropertyToken::CLUSTER_NAME:
518 0 : if (getHost(info)) {
519 0 : return CelValue::CreateString(&getHost(info)->cluster().name());
520 0 : } else if (info && info->route() && info->route()->routeEntry()) {
521 0 : return CelValue::CreateString(&info->route()->routeEntry()->clusterName());
522 0 : } else if (info && info->upstreamClusterInfo().has_value() &&
523 0 : info->upstreamClusterInfo().value()) {
524 0 : return CelValue::CreateString(&info->upstreamClusterInfo().value()->name());
525 0 : }
526 0 : break;
527 0 : case PropertyToken::CLUSTER_METADATA:
528 0 : if (getHost(info)) {
529 0 : return CelProtoWrapper::CreateMessage(&getHost(info)->cluster().metadata(), arena);
530 0 : } else if (info && info->upstreamClusterInfo().has_value() &&
531 0 : info->upstreamClusterInfo().value()) {
532 0 : return CelProtoWrapper::CreateMessage(&info->upstreamClusterInfo().value()->metadata(),
533 0 : arena);
534 0 : }
535 0 : break;
536 0 : case PropertyToken::UPSTREAM_HOST_METADATA:
537 0 : if (getHost(info)) {
538 0 : return CelProtoWrapper::CreateMessage(getHost(info)->metadata().get(), arena);
539 0 : }
540 0 : break;
541 0 : case PropertyToken::ROUTE_NAME:
542 0 : if (info) {
543 0 : return CelValue::CreateString(&info->getRouteName());
544 0 : }
545 0 : break;
546 0 : case PropertyToken::ROUTE_METADATA:
547 0 : if (info && info->route()) {
548 0 : return CelProtoWrapper::CreateMessage(&info->route()->metadata(), arena);
549 0 : }
550 0 : break;
551 0 : case PropertyToken::PLUGIN_NAME:
552 0 : if (plugin_) {
553 0 : return CelValue::CreateStringView(plugin()->name_);
554 0 : }
555 0 : break;
556 0 : case PropertyToken::PLUGIN_ROOT_ID:
557 0 : return CelValue::CreateStringView(toAbslStringView(root_id()));
558 0 : case PropertyToken::PLUGIN_VM_ID:
559 0 : return CelValue::CreateStringView(toAbslStringView(wasm()->vm_id()));
560 0 : }
561 0 : return {};
562 0 : }
563 :
564 0 : WasmResult Context::getProperty(std::string_view path, std::string* result) {
565 0 : using google::api::expr::runtime::CelValue;
566 :
567 0 : bool first = true;
568 0 : CelValue value;
569 0 : Protobuf::Arena arena;
570 :
571 0 : size_t start = 0;
572 0 : while (true) {
573 0 : if (start >= path.size()) {
574 0 : break;
575 0 : }
576 :
577 0 : size_t end = path.find('\0', start);
578 0 : if (end == absl::string_view::npos) {
579 0 : end = start + path.size();
580 0 : }
581 0 : auto part = path.substr(start, end - start);
582 0 : start = end + 1;
583 :
584 0 : if (first) {
585 : // top-level identifier
586 0 : first = false;
587 0 : auto top_value = findValue(toAbslStringView(part), &arena, start >= path.size());
588 0 : if (!top_value.has_value()) {
589 0 : return WasmResult::NotFound;
590 0 : }
591 0 : value = top_value.value();
592 0 : } else if (value.IsMap()) {
593 0 : auto& map = *value.MapOrDie();
594 0 : auto field = map[CelValue::CreateStringView(toAbslStringView(part))];
595 0 : if (!field.has_value()) {
596 0 : return WasmResult::NotFound;
597 0 : }
598 0 : value = field.value();
599 0 : } else if (value.IsMessage()) {
600 0 : auto msg = value.MessageOrDie();
601 0 : if (msg == nullptr) {
602 0 : return WasmResult::NotFound;
603 0 : }
604 0 : const Protobuf::Descriptor* desc = msg->GetDescriptor();
605 0 : const Protobuf::FieldDescriptor* field_desc = desc->FindFieldByName(std::string(part));
606 0 : if (field_desc == nullptr) {
607 0 : return WasmResult::NotFound;
608 0 : }
609 0 : if (field_desc->is_map()) {
610 0 : value = CelValue::CreateMap(
611 0 : Protobuf::Arena::Create<google::api::expr::runtime::FieldBackedMapImpl>(
612 0 : &arena, msg, field_desc, &arena));
613 0 : } else if (field_desc->is_repeated()) {
614 0 : value = CelValue::CreateList(
615 0 : Protobuf::Arena::Create<google::api::expr::runtime::FieldBackedListImpl>(
616 0 : &arena, msg, field_desc, &arena));
617 0 : } else {
618 0 : auto status =
619 0 : google::api::expr::runtime::CreateValueFromSingleField(msg, field_desc, &arena, &value);
620 0 : if (!status.ok()) {
621 0 : return WasmResult::InternalFailure;
622 0 : }
623 0 : }
624 0 : } else if (value.IsList()) {
625 0 : auto& list = *value.ListOrDie();
626 0 : int idx = 0;
627 0 : if (!absl::SimpleAtoi(toAbslStringView(part), &idx)) {
628 0 : return WasmResult::NotFound;
629 0 : }
630 0 : if (idx < 0 || idx >= list.size()) {
631 0 : return WasmResult::NotFound;
632 0 : }
633 0 : value = list[idx];
634 0 : } else {
635 0 : return WasmResult::NotFound;
636 0 : }
637 0 : }
638 :
639 0 : return serializeValue(value, result);
640 0 : }
641 :
642 : // Header/Trailer/Metadata Maps.
643 0 : Http::HeaderMap* Context::getMap(WasmHeaderMapType type) {
644 0 : switch (type) {
645 0 : case WasmHeaderMapType::RequestHeaders:
646 0 : return request_headers_;
647 0 : case WasmHeaderMapType::RequestTrailers:
648 0 : if (request_trailers_ == nullptr && request_body_buffer_ && end_of_stream_ &&
649 0 : decoder_callbacks_) {
650 0 : request_trailers_ = &decoder_callbacks_->addDecodedTrailers();
651 0 : }
652 0 : return request_trailers_;
653 0 : case WasmHeaderMapType::ResponseHeaders:
654 0 : return response_headers_;
655 0 : case WasmHeaderMapType::ResponseTrailers:
656 0 : if (response_trailers_ == nullptr && response_body_buffer_ && end_of_stream_ &&
657 0 : encoder_callbacks_) {
658 0 : response_trailers_ = &encoder_callbacks_->addEncodedTrailers();
659 0 : }
660 0 : return response_trailers_;
661 0 : default:
662 0 : return nullptr;
663 0 : }
664 0 : }
665 :
666 0 : const Http::HeaderMap* Context::getConstMap(WasmHeaderMapType type) {
667 0 : switch (type) {
668 0 : case WasmHeaderMapType::RequestHeaders:
669 0 : if (access_log_phase_) {
670 0 : return access_log_request_headers_;
671 0 : }
672 0 : return request_headers_;
673 0 : case WasmHeaderMapType::RequestTrailers:
674 0 : if (access_log_phase_) {
675 0 : return nullptr;
676 0 : }
677 0 : return request_trailers_;
678 0 : case WasmHeaderMapType::ResponseHeaders:
679 0 : if (access_log_phase_) {
680 0 : return access_log_response_headers_;
681 0 : }
682 0 : return response_headers_;
683 0 : case WasmHeaderMapType::ResponseTrailers:
684 0 : if (access_log_phase_) {
685 0 : return access_log_response_trailers_;
686 0 : }
687 0 : return response_trailers_;
688 0 : case WasmHeaderMapType::GrpcReceiveInitialMetadata:
689 0 : return rootContext()->grpc_receive_initial_metadata_.get();
690 0 : case WasmHeaderMapType::GrpcReceiveTrailingMetadata:
691 0 : return rootContext()->grpc_receive_trailing_metadata_.get();
692 0 : case WasmHeaderMapType::HttpCallResponseHeaders: {
693 0 : Envoy::Http::ResponseMessagePtr* response = rootContext()->http_call_response_;
694 0 : if (response) {
695 0 : return &(*response)->headers();
696 0 : }
697 0 : return nullptr;
698 0 : }
699 0 : case WasmHeaderMapType::HttpCallResponseTrailers: {
700 0 : Envoy::Http::ResponseMessagePtr* response = rootContext()->http_call_response_;
701 0 : if (response) {
702 0 : return (*response)->trailers();
703 0 : }
704 0 : return nullptr;
705 0 : }
706 0 : }
707 0 : IS_ENVOY_BUG("unexpected");
708 0 : return nullptr;
709 0 : }
710 :
711 : WasmResult Context::addHeaderMapValue(WasmHeaderMapType type, std::string_view key,
712 0 : std::string_view value) {
713 0 : auto map = getMap(type);
714 0 : if (!map) {
715 0 : return WasmResult::BadArgument;
716 0 : }
717 0 : const Http::LowerCaseString lower_key{std::string(key)};
718 0 : map->addCopy(lower_key, std::string(value));
719 0 : if (type == WasmHeaderMapType::RequestHeaders && decoder_callbacks_) {
720 0 : decoder_callbacks_->downstreamCallbacks()->clearRouteCache();
721 0 : }
722 0 : return WasmResult::Ok;
723 0 : }
724 :
725 : WasmResult Context::getHeaderMapValue(WasmHeaderMapType type, std::string_view key,
726 0 : std::string_view* value) {
727 0 : auto map = getConstMap(type);
728 0 : if (!map) {
729 0 : if (access_log_phase_) {
730 : // Maps might point to nullptr in the access log phase.
731 0 : if (wasm()->abiVersion() == proxy_wasm::AbiVersion::ProxyWasm_0_1_0) {
732 0 : *value = "";
733 0 : return WasmResult::Ok;
734 0 : } else {
735 0 : return WasmResult::NotFound;
736 0 : }
737 0 : }
738 : // Requested map type is not currently available.
739 0 : return WasmResult::BadArgument;
740 0 : }
741 0 : const Http::LowerCaseString lower_key{std::string(key)};
742 0 : const auto entry = map->get(lower_key);
743 0 : if (entry.empty()) {
744 0 : if (wasm()->abiVersion() == proxy_wasm::AbiVersion::ProxyWasm_0_1_0) {
745 0 : *value = "";
746 0 : return WasmResult::Ok;
747 0 : } else {
748 0 : return WasmResult::NotFound;
749 0 : }
750 0 : }
751 : // TODO(kyessenov, PiotrSikora): This needs to either return a concatenated list of values, or
752 : // the ABI needs to be changed to return multiple values. This is a potential security issue.
753 0 : *value = toStdStringView(entry[0]->value().getStringView());
754 0 : return WasmResult::Ok;
755 0 : }
756 :
757 0 : Pairs headerMapToPairs(const Http::HeaderMap* map) {
758 0 : if (!map) {
759 0 : return {};
760 0 : }
761 0 : Pairs pairs;
762 0 : pairs.reserve(map->size());
763 0 : map->iterate([&pairs](const Http::HeaderEntry& header) -> Http::HeaderMap::Iterate {
764 0 : pairs.push_back(std::make_pair(toStdStringView(header.key().getStringView()),
765 0 : toStdStringView(header.value().getStringView())));
766 0 : return Http::HeaderMap::Iterate::Continue;
767 0 : });
768 0 : return pairs;
769 0 : }
770 :
771 0 : WasmResult Context::getHeaderMapPairs(WasmHeaderMapType type, Pairs* result) {
772 0 : *result = headerMapToPairs(getConstMap(type));
773 0 : return WasmResult::Ok;
774 0 : }
775 :
776 0 : WasmResult Context::setHeaderMapPairs(WasmHeaderMapType type, const Pairs& pairs) {
777 0 : auto map = getMap(type);
778 0 : if (!map) {
779 0 : return WasmResult::BadArgument;
780 0 : }
781 0 : std::vector<std::string> keys;
782 0 : map->iterate([&keys](const Http::HeaderEntry& header) -> Http::HeaderMap::Iterate {
783 0 : keys.push_back(std::string(header.key().getStringView()));
784 0 : return Http::HeaderMap::Iterate::Continue;
785 0 : });
786 0 : for (auto& k : keys) {
787 0 : const Http::LowerCaseString lower_key{k};
788 0 : map->remove(lower_key);
789 0 : }
790 0 : for (auto& p : pairs) {
791 0 : const Http::LowerCaseString lower_key{std::string(p.first)};
792 0 : map->addCopy(lower_key, std::string(p.second));
793 0 : }
794 0 : if (type == WasmHeaderMapType::RequestHeaders && decoder_callbacks_) {
795 0 : decoder_callbacks_->downstreamCallbacks()->clearRouteCache();
796 0 : }
797 0 : return WasmResult::Ok;
798 0 : }
799 :
800 0 : WasmResult Context::removeHeaderMapValue(WasmHeaderMapType type, std::string_view key) {
801 0 : auto map = getMap(type);
802 0 : if (!map) {
803 0 : return WasmResult::BadArgument;
804 0 : }
805 0 : const Http::LowerCaseString lower_key{std::string(key)};
806 0 : map->remove(lower_key);
807 0 : if (type == WasmHeaderMapType::RequestHeaders && decoder_callbacks_) {
808 0 : decoder_callbacks_->downstreamCallbacks()->clearRouteCache();
809 0 : }
810 0 : return WasmResult::Ok;
811 0 : }
812 :
813 : WasmResult Context::replaceHeaderMapValue(WasmHeaderMapType type, std::string_view key,
814 0 : std::string_view value) {
815 0 : auto map = getMap(type);
816 0 : if (!map) {
817 0 : return WasmResult::BadArgument;
818 0 : }
819 0 : const Http::LowerCaseString lower_key{std::string(key)};
820 0 : map->setCopy(lower_key, toAbslStringView(value));
821 0 : if (type == WasmHeaderMapType::RequestHeaders && decoder_callbacks_) {
822 0 : decoder_callbacks_->downstreamCallbacks()->clearRouteCache();
823 0 : }
824 0 : return WasmResult::Ok;
825 0 : }
826 :
827 0 : WasmResult Context::getHeaderMapSize(WasmHeaderMapType type, uint32_t* result) {
828 0 : auto map = getMap(type);
829 0 : if (!map) {
830 0 : return WasmResult::BadArgument;
831 0 : }
832 0 : *result = map->byteSize();
833 0 : return WasmResult::Ok;
834 0 : }
835 :
836 : // Buffer
837 :
838 0 : BufferInterface* Context::getBuffer(WasmBufferType type) {
839 0 : Envoy::Http::ResponseMessagePtr* response = nullptr;
840 0 : switch (type) {
841 0 : case WasmBufferType::CallData:
842 : // Set before the call.
843 0 : return &buffer_;
844 0 : case WasmBufferType::VmConfiguration:
845 0 : return buffer_.set(wasm()->vm_configuration());
846 0 : case WasmBufferType::PluginConfiguration:
847 0 : if (temp_plugin_) {
848 0 : return buffer_.set(temp_plugin_->plugin_configuration_);
849 0 : }
850 0 : return nullptr;
851 0 : case WasmBufferType::HttpRequestBody:
852 0 : if (buffering_request_body_ && decoder_callbacks_) {
853 : // We need the mutable version, so capture it using a callback.
854 : // TODO: consider adding a mutableDecodingBuffer() interface.
855 0 : ::Envoy::Buffer::Instance* buffer_instance{};
856 0 : decoder_callbacks_->modifyDecodingBuffer(
857 0 : [&buffer_instance](::Envoy::Buffer::Instance& buffer) { buffer_instance = &buffer; });
858 0 : return buffer_.set(buffer_instance);
859 0 : }
860 0 : return buffer_.set(request_body_buffer_);
861 0 : case WasmBufferType::HttpResponseBody:
862 0 : if (buffering_response_body_ && encoder_callbacks_) {
863 : // TODO: consider adding a mutableDecodingBuffer() interface.
864 0 : ::Envoy::Buffer::Instance* buffer_instance{};
865 0 : encoder_callbacks_->modifyEncodingBuffer(
866 0 : [&buffer_instance](::Envoy::Buffer::Instance& buffer) { buffer_instance = &buffer; });
867 0 : return buffer_.set(buffer_instance);
868 0 : }
869 0 : return buffer_.set(response_body_buffer_);
870 0 : case WasmBufferType::NetworkDownstreamData:
871 0 : return buffer_.set(network_downstream_data_buffer_);
872 0 : case WasmBufferType::NetworkUpstreamData:
873 0 : return buffer_.set(network_upstream_data_buffer_);
874 0 : case WasmBufferType::HttpCallResponseBody:
875 0 : response = rootContext()->http_call_response_;
876 0 : if (response) {
877 0 : auto& body = (*response)->body();
878 0 : return buffer_.set(
879 0 : std::string_view(static_cast<const char*>(body.linearize(body.length())), body.length()));
880 0 : }
881 0 : return nullptr;
882 0 : case WasmBufferType::GrpcReceiveBuffer:
883 0 : return buffer_.set(rootContext()->grpc_receive_buffer_.get());
884 0 : default:
885 0 : return nullptr;
886 0 : }
887 0 : }
888 :
889 0 : void Context::onDownstreamConnectionClose(CloseType close_type) {
890 0 : ContextBase::onDownstreamConnectionClose(close_type);
891 0 : downstream_closed_ = true;
892 0 : onCloseTCP();
893 0 : }
894 :
895 0 : void Context::onUpstreamConnectionClose(CloseType close_type) {
896 0 : ContextBase::onUpstreamConnectionClose(close_type);
897 0 : upstream_closed_ = true;
898 0 : if (downstream_closed_) {
899 0 : onCloseTCP();
900 0 : }
901 0 : }
902 :
903 : // Async call via HTTP
904 : WasmResult Context::httpCall(std::string_view cluster, const Pairs& request_headers,
905 : std::string_view request_body, const Pairs& request_trailers,
906 0 : int timeout_milliseconds, uint32_t* token_ptr) {
907 0 : if (timeout_milliseconds < 0) {
908 0 : return WasmResult::BadArgument;
909 0 : }
910 0 : auto cluster_string = std::string(cluster);
911 0 : const auto thread_local_cluster = clusterManager().getThreadLocalCluster(cluster_string);
912 0 : if (thread_local_cluster == nullptr) {
913 0 : return WasmResult::BadArgument;
914 0 : }
915 :
916 0 : Http::RequestMessagePtr message(
917 0 : new Http::RequestMessageImpl(buildRequestHeaderMapFromPairs(request_headers)));
918 :
919 : // Check that we were provided certain headers.
920 0 : if (message->headers().Path() == nullptr || message->headers().Method() == nullptr ||
921 0 : message->headers().Host() == nullptr) {
922 0 : return WasmResult::BadArgument;
923 0 : }
924 :
925 0 : if (!request_body.empty()) {
926 0 : message->body().add(toAbslStringView(request_body));
927 0 : message->headers().setContentLength(request_body.size());
928 0 : }
929 :
930 0 : if (!request_trailers.empty()) {
931 0 : message->trailers(buildRequestTrailerMapFromPairs(request_trailers));
932 0 : }
933 :
934 0 : absl::optional<std::chrono::milliseconds> timeout;
935 0 : if (timeout_milliseconds > 0) {
936 0 : timeout = std::chrono::milliseconds(timeout_milliseconds);
937 0 : }
938 :
939 0 : uint32_t token = wasm()->nextHttpCallId();
940 0 : auto& handler = http_request_[token];
941 0 : handler.context_ = this;
942 0 : handler.token_ = token;
943 :
944 : // set default hash policy to be based on :authority to enable consistent hash
945 0 : Http::AsyncClient::RequestOptions options;
946 0 : options.setTimeout(timeout);
947 0 : Protobuf::RepeatedPtrField<HashPolicy> hash_policy;
948 0 : hash_policy.Add()->mutable_header()->set_header_name(Http::Headers::get().Host.get());
949 0 : options.setHashPolicy(hash_policy);
950 0 : options.setSendXff(false);
951 0 : auto http_request =
952 0 : thread_local_cluster->httpAsyncClient().send(std::move(message), handler, options);
953 0 : if (!http_request) {
954 0 : http_request_.erase(token);
955 0 : return WasmResult::InternalFailure;
956 0 : }
957 0 : handler.request_ = http_request;
958 0 : *token_ptr = token;
959 0 : return WasmResult::Ok;
960 0 : }
961 :
962 : WasmResult Context::grpcCall(std::string_view grpc_service, std::string_view service_name,
963 : std::string_view method_name, const Pairs& initial_metadata,
964 : std::string_view request, std::chrono::milliseconds timeout,
965 0 : uint32_t* token_ptr) {
966 0 : GrpcService service_proto;
967 0 : if (!service_proto.ParseFromArray(grpc_service.data(), grpc_service.size())) {
968 0 : auto cluster_name = std::string(grpc_service.substr(0, grpc_service.size()));
969 0 : const auto thread_local_cluster = clusterManager().getThreadLocalCluster(cluster_name);
970 0 : if (thread_local_cluster == nullptr) {
971 : // TODO(shikugawa): The reason to keep return status as `BadArgument` is not to force
972 : // callers to change their own codebase with ABI 0.1.x. We should treat this failure as
973 : // `BadArgument` after ABI 0.2.x will have released.
974 0 : return WasmResult::ParseFailure;
975 0 : }
976 0 : service_proto.mutable_envoy_grpc()->set_cluster_name(cluster_name);
977 0 : }
978 0 : uint32_t token = wasm()->nextGrpcCallId();
979 0 : auto& handler = grpc_call_request_[token];
980 0 : handler.context_ = this;
981 0 : handler.token_ = token;
982 0 : auto grpc_client = clusterManager().grpcAsyncClientManager().getOrCreateRawAsyncClient(
983 0 : service_proto, *wasm()->scope_, true /* skip_cluster_check */);
984 0 : grpc_initial_metadata_ = buildRequestHeaderMapFromPairs(initial_metadata);
985 :
986 : // set default hash policy to be based on :authority to enable consistent hash
987 0 : Http::AsyncClient::RequestOptions options;
988 0 : options.setTimeout(timeout);
989 0 : Protobuf::RepeatedPtrField<HashPolicy> hash_policy;
990 0 : hash_policy.Add()->mutable_header()->set_header_name(Http::Headers::get().Host.get());
991 0 : options.setHashPolicy(hash_policy);
992 0 : options.setSendXff(false);
993 :
994 0 : auto grpc_request =
995 0 : grpc_client->sendRaw(toAbslStringView(service_name), toAbslStringView(method_name),
996 0 : std::make_unique<::Envoy::Buffer::OwnedImpl>(toAbslStringView(request)),
997 0 : handler, Tracing::NullSpan::instance(), options);
998 0 : if (!grpc_request) {
999 0 : grpc_call_request_.erase(token);
1000 0 : return WasmResult::InternalFailure;
1001 0 : }
1002 0 : handler.client_ = std::move(grpc_client);
1003 0 : handler.request_ = grpc_request;
1004 0 : *token_ptr = token;
1005 0 : return WasmResult::Ok;
1006 0 : }
1007 :
1008 : WasmResult Context::grpcStream(std::string_view grpc_service, std::string_view service_name,
1009 : std::string_view method_name, const Pairs& initial_metadata,
1010 0 : uint32_t* token_ptr) {
1011 0 : GrpcService service_proto;
1012 0 : if (!service_proto.ParseFromArray(grpc_service.data(), grpc_service.size())) {
1013 0 : auto cluster_name = std::string(grpc_service.substr(0, grpc_service.size()));
1014 0 : const auto thread_local_cluster = clusterManager().getThreadLocalCluster(cluster_name);
1015 0 : if (thread_local_cluster == nullptr) {
1016 : // TODO(shikugawa): The reason to keep return status as `BadArgument` is not to force
1017 : // callers to change their own codebase with ABI 0.1.x. We should treat this failure as
1018 : // `BadArgument` after ABI 0.2.x will have released.
1019 0 : return WasmResult::ParseFailure;
1020 0 : }
1021 0 : service_proto.mutable_envoy_grpc()->set_cluster_name(cluster_name);
1022 0 : }
1023 0 : uint32_t token = wasm()->nextGrpcStreamId();
1024 0 : auto& handler = grpc_stream_[token];
1025 0 : handler.context_ = this;
1026 0 : handler.token_ = token;
1027 0 : auto grpc_client = clusterManager().grpcAsyncClientManager().getOrCreateRawAsyncClient(
1028 0 : service_proto, *wasm()->scope_, true /* skip_cluster_check */);
1029 0 : grpc_initial_metadata_ = buildRequestHeaderMapFromPairs(initial_metadata);
1030 :
1031 : // set default hash policy to be based on :authority to enable consistent hash
1032 0 : Http::AsyncClient::StreamOptions options;
1033 0 : Protobuf::RepeatedPtrField<HashPolicy> hash_policy;
1034 0 : hash_policy.Add()->mutable_header()->set_header_name(Http::Headers::get().Host.get());
1035 0 : options.setHashPolicy(hash_policy);
1036 0 : options.setSendXff(false);
1037 :
1038 0 : auto grpc_stream = grpc_client->startRaw(toAbslStringView(service_name),
1039 0 : toAbslStringView(method_name), handler, options);
1040 0 : if (!grpc_stream) {
1041 0 : grpc_stream_.erase(token);
1042 0 : return WasmResult::InternalFailure;
1043 0 : }
1044 0 : handler.client_ = std::move(grpc_client);
1045 0 : handler.stream_ = grpc_stream;
1046 0 : *token_ptr = token;
1047 0 : return WasmResult::Ok;
1048 0 : }
1049 :
1050 : // NB: this is currently called inline, so the token is known to be that of the currently
1051 : // executing grpcCall or grpcStream.
1052 : void Context::onGrpcCreateInitialMetadata(uint32_t /* token */,
1053 0 : Http::RequestHeaderMap& initial_metadata) {
1054 0 : if (grpc_initial_metadata_) {
1055 0 : Http::HeaderMapImpl::copyFrom(initial_metadata, *grpc_initial_metadata_);
1056 0 : grpc_initial_metadata_.reset();
1057 0 : }
1058 0 : }
1059 :
1060 : // StreamInfo
1061 0 : const StreamInfo::StreamInfo* Context::getConstRequestStreamInfo() const {
1062 0 : if (encoder_callbacks_) {
1063 0 : return &encoder_callbacks_->streamInfo();
1064 0 : } else if (decoder_callbacks_) {
1065 0 : return &decoder_callbacks_->streamInfo();
1066 0 : } else if (access_log_stream_info_) {
1067 0 : return access_log_stream_info_;
1068 0 : } else if (network_read_filter_callbacks_) {
1069 0 : return &network_read_filter_callbacks_->connection().streamInfo();
1070 0 : } else if (network_write_filter_callbacks_) {
1071 0 : return &network_write_filter_callbacks_->connection().streamInfo();
1072 0 : }
1073 0 : return nullptr;
1074 0 : }
1075 :
1076 0 : StreamInfo::StreamInfo* Context::getRequestStreamInfo() const {
1077 0 : if (encoder_callbacks_) {
1078 0 : return &encoder_callbacks_->streamInfo();
1079 0 : } else if (decoder_callbacks_) {
1080 0 : return &decoder_callbacks_->streamInfo();
1081 0 : } else if (network_read_filter_callbacks_) {
1082 0 : return &network_read_filter_callbacks_->connection().streamInfo();
1083 0 : } else if (network_write_filter_callbacks_) {
1084 0 : return &network_write_filter_callbacks_->connection().streamInfo();
1085 0 : }
1086 0 : return nullptr;
1087 0 : }
1088 :
1089 0 : const Network::Connection* Context::getConnection() const {
1090 0 : if (encoder_callbacks_) {
1091 0 : return encoder_callbacks_->connection().ptr();
1092 0 : } else if (decoder_callbacks_) {
1093 0 : return decoder_callbacks_->connection().ptr();
1094 0 : } else if (network_read_filter_callbacks_) {
1095 0 : return &network_read_filter_callbacks_->connection();
1096 0 : } else if (network_write_filter_callbacks_) {
1097 0 : return &network_write_filter_callbacks_->connection();
1098 0 : }
1099 0 : return nullptr;
1100 0 : }
1101 :
1102 0 : WasmResult Context::setProperty(std::string_view path, std::string_view value) {
1103 0 : auto* stream_info = getRequestStreamInfo();
1104 0 : if (!stream_info) {
1105 0 : return WasmResult::NotFound;
1106 0 : }
1107 0 : std::string key;
1108 0 : absl::StrAppend(&key, CelStateKeyPrefix, toAbslStringView(path));
1109 0 : CelState* state = stream_info->filterState()->getDataMutable<CelState>(key);
1110 0 : if (state == nullptr) {
1111 0 : const auto& it = rootContext()->state_prototypes_.find(toAbslStringView(path));
1112 0 : const CelStatePrototype& prototype =
1113 0 : it == rootContext()->state_prototypes_.end()
1114 0 : ? Filters::Common::Expr::DefaultCelStatePrototype::get()
1115 0 : : *it->second.get(); // NOLINT
1116 0 : auto state_ptr = std::make_unique<CelState>(prototype);
1117 0 : state = state_ptr.get();
1118 0 : stream_info->filterState()->setData(key, std::move(state_ptr),
1119 0 : StreamInfo::FilterState::StateType::Mutable,
1120 0 : prototype.life_span_);
1121 0 : }
1122 0 : if (!state->setValue(toAbslStringView(value))) {
1123 0 : return WasmResult::BadArgument;
1124 0 : }
1125 0 : return WasmResult::Ok;
1126 0 : }
1127 :
1128 : WasmResult Context::setEnvoyFilterState(std::string_view path, std::string_view value,
1129 0 : StreamInfo::FilterState::LifeSpan life_span) {
1130 0 : auto* factory =
1131 0 : Registry::FactoryRegistry<StreamInfo::FilterState::ObjectFactory>::getFactory(path);
1132 0 : if (!factory) {
1133 0 : return WasmResult::NotFound;
1134 0 : }
1135 :
1136 0 : auto object = factory->createFromBytes(value);
1137 0 : if (!object) {
1138 0 : return WasmResult::BadArgument;
1139 0 : }
1140 :
1141 0 : auto* stream_info = getRequestStreamInfo();
1142 0 : if (!stream_info) {
1143 0 : return WasmResult::NotFound;
1144 0 : }
1145 :
1146 0 : stream_info->filterState()->setData(path, std::move(object),
1147 0 : StreamInfo::FilterState::StateType::Mutable, life_span);
1148 0 : return WasmResult::Ok;
1149 0 : }
1150 :
1151 : WasmResult
1152 : Context::declareProperty(std::string_view path,
1153 0 : Filters::Common::Expr::CelStatePrototypeConstPtr state_prototype) {
1154 : // Do not delete existing schema since it can be referenced by state objects.
1155 0 : if (state_prototypes_.find(toAbslStringView(path)) == state_prototypes_.end()) {
1156 0 : state_prototypes_[toAbslStringView(path)] = std::move(state_prototype);
1157 0 : return WasmResult::Ok;
1158 0 : }
1159 0 : return WasmResult::BadArgument;
1160 0 : }
1161 :
1162 0 : WasmResult Context::log(uint32_t level, std::string_view message) {
1163 0 : switch (static_cast<spdlog::level::level_enum>(level)) {
1164 0 : case spdlog::level::trace:
1165 0 : ENVOY_LOG(trace, "wasm log{}: {}", log_prefix(), message);
1166 0 : return WasmResult::Ok;
1167 0 : case spdlog::level::debug:
1168 0 : ENVOY_LOG(debug, "wasm log{}: {}", log_prefix(), message);
1169 0 : return WasmResult::Ok;
1170 0 : case spdlog::level::info:
1171 0 : ENVOY_LOG(info, "wasm log{}: {}", log_prefix(), message);
1172 0 : return WasmResult::Ok;
1173 0 : case spdlog::level::warn:
1174 0 : ENVOY_LOG(warn, "wasm log{}: {}", log_prefix(), message);
1175 0 : return WasmResult::Ok;
1176 0 : case spdlog::level::err:
1177 0 : ENVOY_LOG(error, "wasm log{}: {}", log_prefix(), message);
1178 0 : return WasmResult::Ok;
1179 0 : case spdlog::level::critical:
1180 0 : ENVOY_LOG(critical, "wasm log{}: {}", log_prefix(), message);
1181 0 : return WasmResult::Ok;
1182 0 : case spdlog::level::off:
1183 0 : PANIC("not implemented");
1184 0 : case spdlog::level::n_levels:
1185 0 : PANIC("not implemented");
1186 0 : }
1187 0 : PANIC_DUE_TO_CORRUPT_ENUM;
1188 0 : }
1189 :
1190 0 : uint32_t Context::getLogLevel() {
1191 : // Like the "log" call above, assume that spdlog level as an int
1192 : // matches the enum in the SDK
1193 0 : return static_cast<uint32_t>(ENVOY_LOGGER().level());
1194 0 : }
1195 :
1196 : //
1197 : // Calls into the Wasm code.
1198 : //
1199 : bool Context::validateConfiguration(std::string_view configuration,
1200 0 : const std::shared_ptr<PluginBase>& plugin_base) {
1201 0 : auto plugin = std::static_pointer_cast<Plugin>(plugin_base);
1202 0 : if (!wasm()->validate_configuration_) {
1203 0 : return true;
1204 0 : }
1205 0 : temp_plugin_ = plugin_base;
1206 0 : auto result =
1207 0 : wasm()
1208 0 : ->validate_configuration_(this, id_, static_cast<uint32_t>(configuration.size()))
1209 0 : .u64_ != 0;
1210 0 : temp_plugin_.reset();
1211 0 : return result;
1212 0 : }
1213 :
1214 0 : std::string_view Context::getConfiguration() {
1215 0 : if (temp_plugin_) {
1216 0 : return temp_plugin_->plugin_configuration_;
1217 0 : } else {
1218 0 : return wasm()->vm_configuration();
1219 0 : }
1220 0 : };
1221 :
1222 0 : std::pair<uint32_t, std::string_view> Context::getStatus() {
1223 0 : return std::make_pair(status_code_, toStdStringView(status_message_));
1224 0 : }
1225 :
1226 0 : void Context::onGrpcReceiveInitialMetadataWrapper(uint32_t token, Http::HeaderMapPtr&& metadata) {
1227 0 : grpc_receive_initial_metadata_ = std::move(metadata);
1228 0 : onGrpcReceiveInitialMetadata(token, headerSize(grpc_receive_initial_metadata_));
1229 0 : grpc_receive_initial_metadata_ = nullptr;
1230 0 : }
1231 :
1232 0 : void Context::onGrpcReceiveTrailingMetadataWrapper(uint32_t token, Http::HeaderMapPtr&& metadata) {
1233 0 : grpc_receive_trailing_metadata_ = std::move(metadata);
1234 0 : onGrpcReceiveTrailingMetadata(token, headerSize(grpc_receive_trailing_metadata_));
1235 0 : grpc_receive_trailing_metadata_ = nullptr;
1236 0 : }
1237 :
1238 : WasmResult Context::defineMetric(uint32_t metric_type, std::string_view name,
1239 0 : uint32_t* metric_id_ptr) {
1240 0 : if (metric_type > static_cast<uint32_t>(MetricType::Max)) {
1241 0 : return WasmResult::BadArgument;
1242 0 : }
1243 0 : auto type = static_cast<MetricType>(metric_type);
1244 : // TODO: Consider rethinking the scoping policy as it does not help in this case.
1245 0 : Stats::StatNameManagedStorage storage(toAbslStringView(name), wasm()->scope_->symbolTable());
1246 0 : Stats::StatName stat_name = storage.statName();
1247 : // We prefix the given name with custom_stat_name_ so that these user-defined
1248 : // custom metrics can be distinguished from native Envoy metrics.
1249 0 : if (type == MetricType::Counter) {
1250 0 : auto id = wasm()->nextCounterMetricId();
1251 0 : Stats::Counter* c = &Stats::Utility::counterFromElements(
1252 0 : *wasm()->scope_, {wasm()->custom_stat_namespace_, stat_name});
1253 0 : wasm()->counters_.emplace(id, c);
1254 0 : *metric_id_ptr = id;
1255 0 : return WasmResult::Ok;
1256 0 : }
1257 0 : if (type == MetricType::Gauge) {
1258 0 : auto id = wasm()->nextGaugeMetricId();
1259 0 : Stats::Gauge* g = &Stats::Utility::gaugeFromStatNames(
1260 0 : *wasm()->scope_, {wasm()->custom_stat_namespace_, stat_name},
1261 0 : Stats::Gauge::ImportMode::Accumulate);
1262 0 : wasm()->gauges_.emplace(id, g);
1263 0 : *metric_id_ptr = id;
1264 0 : return WasmResult::Ok;
1265 0 : }
1266 : // (type == MetricType::Histogram) {
1267 0 : auto id = wasm()->nextHistogramMetricId();
1268 0 : Stats::Histogram* h = &Stats::Utility::histogramFromStatNames(
1269 0 : *wasm()->scope_, {wasm()->custom_stat_namespace_, stat_name},
1270 0 : Stats::Histogram::Unit::Unspecified);
1271 0 : wasm()->histograms_.emplace(id, h);
1272 0 : *metric_id_ptr = id;
1273 0 : return WasmResult::Ok;
1274 0 : }
1275 :
1276 0 : WasmResult Context::incrementMetric(uint32_t metric_id, int64_t offset) {
1277 0 : auto type = static_cast<MetricType>(metric_id & Wasm::kMetricTypeMask);
1278 0 : if (type == MetricType::Counter) {
1279 0 : auto it = wasm()->counters_.find(metric_id);
1280 0 : if (it != wasm()->counters_.end()) {
1281 0 : if (offset > 0) {
1282 0 : it->second->add(offset);
1283 0 : return WasmResult::Ok;
1284 0 : } else {
1285 0 : return WasmResult::BadArgument;
1286 0 : }
1287 0 : }
1288 0 : return WasmResult::NotFound;
1289 0 : } else if (type == MetricType::Gauge) {
1290 0 : auto it = wasm()->gauges_.find(metric_id);
1291 0 : if (it != wasm()->gauges_.end()) {
1292 0 : if (offset > 0) {
1293 0 : it->second->add(offset);
1294 0 : return WasmResult::Ok;
1295 0 : } else {
1296 0 : it->second->sub(-offset);
1297 0 : return WasmResult::Ok;
1298 0 : }
1299 0 : }
1300 0 : return WasmResult::NotFound;
1301 0 : }
1302 0 : return WasmResult::BadArgument;
1303 0 : }
1304 :
1305 0 : WasmResult Context::recordMetric(uint32_t metric_id, uint64_t value) {
1306 0 : auto type = static_cast<MetricType>(metric_id & Wasm::kMetricTypeMask);
1307 0 : if (type == MetricType::Counter) {
1308 0 : auto it = wasm()->counters_.find(metric_id);
1309 0 : if (it != wasm()->counters_.end()) {
1310 0 : it->second->add(value);
1311 0 : return WasmResult::Ok;
1312 0 : }
1313 0 : } else if (type == MetricType::Gauge) {
1314 0 : auto it = wasm()->gauges_.find(metric_id);
1315 0 : if (it != wasm()->gauges_.end()) {
1316 0 : it->second->set(value);
1317 0 : return WasmResult::Ok;
1318 0 : }
1319 0 : } else if (type == MetricType::Histogram) {
1320 0 : auto it = wasm()->histograms_.find(metric_id);
1321 0 : if (it != wasm()->histograms_.end()) {
1322 0 : it->second->recordValue(value);
1323 0 : return WasmResult::Ok;
1324 0 : }
1325 0 : }
1326 0 : return WasmResult::NotFound;
1327 0 : }
1328 :
1329 0 : WasmResult Context::getMetric(uint32_t metric_id, uint64_t* result_uint64_ptr) {
1330 0 : auto type = static_cast<MetricType>(metric_id & Wasm::kMetricTypeMask);
1331 0 : if (type == MetricType::Counter) {
1332 0 : auto it = wasm()->counters_.find(metric_id);
1333 0 : if (it != wasm()->counters_.end()) {
1334 0 : *result_uint64_ptr = it->second->value();
1335 0 : return WasmResult::Ok;
1336 0 : }
1337 0 : return WasmResult::NotFound;
1338 0 : } else if (type == MetricType::Gauge) {
1339 0 : auto it = wasm()->gauges_.find(metric_id);
1340 0 : if (it != wasm()->gauges_.end()) {
1341 0 : *result_uint64_ptr = it->second->value();
1342 0 : return WasmResult::Ok;
1343 0 : }
1344 0 : return WasmResult::NotFound;
1345 0 : }
1346 0 : return WasmResult::BadArgument;
1347 0 : }
1348 :
1349 0 : Context::~Context() {
1350 : // Cancel any outstanding requests.
1351 0 : for (auto& p : http_request_) {
1352 0 : p.second.request_->cancel();
1353 0 : }
1354 0 : for (auto& p : grpc_call_request_) {
1355 0 : p.second.request_->cancel();
1356 0 : }
1357 0 : for (auto& p : grpc_stream_) {
1358 0 : p.second.stream_->resetStream();
1359 0 : }
1360 0 : }
1361 :
1362 0 : Network::FilterStatus convertNetworkFilterStatus(proxy_wasm::FilterStatus status) {
1363 0 : switch (status) {
1364 0 : default:
1365 0 : case proxy_wasm::FilterStatus::Continue:
1366 0 : return Network::FilterStatus::Continue;
1367 0 : case proxy_wasm::FilterStatus::StopIteration:
1368 0 : return Network::FilterStatus::StopIteration;
1369 0 : }
1370 0 : };
1371 :
1372 0 : Http::FilterHeadersStatus convertFilterHeadersStatus(proxy_wasm::FilterHeadersStatus status) {
1373 0 : switch (status) {
1374 0 : default:
1375 0 : case proxy_wasm::FilterHeadersStatus::Continue:
1376 0 : return Http::FilterHeadersStatus::Continue;
1377 0 : case proxy_wasm::FilterHeadersStatus::StopIteration:
1378 0 : return Http::FilterHeadersStatus::StopIteration;
1379 0 : case proxy_wasm::FilterHeadersStatus::StopAllIterationAndBuffer:
1380 0 : return Http::FilterHeadersStatus::StopAllIterationAndBuffer;
1381 0 : case proxy_wasm::FilterHeadersStatus::StopAllIterationAndWatermark:
1382 0 : return Http::FilterHeadersStatus::StopAllIterationAndWatermark;
1383 0 : }
1384 0 : };
1385 :
1386 0 : Http::FilterTrailersStatus convertFilterTrailersStatus(proxy_wasm::FilterTrailersStatus status) {
1387 0 : switch (status) {
1388 0 : default:
1389 0 : case proxy_wasm::FilterTrailersStatus::Continue:
1390 0 : return Http::FilterTrailersStatus::Continue;
1391 0 : case proxy_wasm::FilterTrailersStatus::StopIteration:
1392 0 : return Http::FilterTrailersStatus::StopIteration;
1393 0 : }
1394 0 : };
1395 :
1396 0 : Http::FilterMetadataStatus convertFilterMetadataStatus(proxy_wasm::FilterMetadataStatus status) {
1397 0 : switch (status) {
1398 0 : default:
1399 0 : case proxy_wasm::FilterMetadataStatus::Continue:
1400 0 : return Http::FilterMetadataStatus::Continue;
1401 0 : }
1402 0 : };
1403 :
1404 0 : Http::FilterDataStatus convertFilterDataStatus(proxy_wasm::FilterDataStatus status) {
1405 0 : switch (status) {
1406 0 : default:
1407 0 : case proxy_wasm::FilterDataStatus::Continue:
1408 0 : return Http::FilterDataStatus::Continue;
1409 0 : case proxy_wasm::FilterDataStatus::StopIterationAndBuffer:
1410 0 : return Http::FilterDataStatus::StopIterationAndBuffer;
1411 0 : case proxy_wasm::FilterDataStatus::StopIterationAndWatermark:
1412 0 : return Http::FilterDataStatus::StopIterationAndWatermark;
1413 0 : case proxy_wasm::FilterDataStatus::StopIterationNoBuffer:
1414 0 : return Http::FilterDataStatus::StopIterationNoBuffer;
1415 0 : }
1416 0 : };
1417 :
1418 0 : Network::FilterStatus Context::onNewConnection() {
1419 0 : onCreate();
1420 0 : return convertNetworkFilterStatus(onNetworkNewConnection());
1421 0 : };
1422 :
1423 0 : Network::FilterStatus Context::onData(::Envoy::Buffer::Instance& data, bool end_stream) {
1424 0 : if (!in_vm_context_created_) {
1425 0 : return Network::FilterStatus::Continue;
1426 0 : }
1427 0 : network_downstream_data_buffer_ = &data;
1428 0 : end_of_stream_ = end_stream;
1429 0 : auto result = convertNetworkFilterStatus(onDownstreamData(data.length(), end_stream));
1430 0 : if (result == Network::FilterStatus::Continue) {
1431 0 : network_downstream_data_buffer_ = nullptr;
1432 0 : }
1433 0 : return result;
1434 0 : }
1435 :
1436 0 : Network::FilterStatus Context::onWrite(::Envoy::Buffer::Instance& data, bool end_stream) {
1437 0 : if (!in_vm_context_created_) {
1438 0 : return Network::FilterStatus::Continue;
1439 0 : }
1440 0 : network_upstream_data_buffer_ = &data;
1441 0 : end_of_stream_ = end_stream;
1442 0 : auto result = convertNetworkFilterStatus(onUpstreamData(data.length(), end_stream));
1443 0 : if (result == Network::FilterStatus::Continue) {
1444 0 : network_upstream_data_buffer_ = nullptr;
1445 0 : }
1446 0 : if (end_stream) {
1447 : // This is called when seeing end_stream=true and not on an upstream connection event,
1448 : // because registering for latter requires replicating the whole TCP proxy extension.
1449 0 : onUpstreamConnectionClose(CloseType::Unknown);
1450 0 : }
1451 0 : return result;
1452 0 : }
1453 :
1454 0 : void Context::onEvent(Network::ConnectionEvent event) {
1455 0 : if (!in_vm_context_created_) {
1456 0 : return;
1457 0 : }
1458 0 : switch (event) {
1459 0 : case Network::ConnectionEvent::LocalClose:
1460 0 : onDownstreamConnectionClose(CloseType::Local);
1461 0 : break;
1462 0 : case Network::ConnectionEvent::RemoteClose:
1463 0 : onDownstreamConnectionClose(CloseType::Remote);
1464 0 : break;
1465 0 : default:
1466 0 : break;
1467 0 : }
1468 0 : }
1469 :
1470 0 : void Context::initializeReadFilterCallbacks(Network::ReadFilterCallbacks& callbacks) {
1471 0 : network_read_filter_callbacks_ = &callbacks;
1472 0 : network_read_filter_callbacks_->connection().addConnectionCallbacks(*this);
1473 0 : }
1474 :
1475 0 : void Context::initializeWriteFilterCallbacks(Network::WriteFilterCallbacks& callbacks) {
1476 0 : network_write_filter_callbacks_ = &callbacks;
1477 0 : }
1478 :
1479 : void Context::log(const Formatter::HttpFormatterContext& log_context,
1480 0 : const StreamInfo::StreamInfo& stream_info) {
1481 : // `log` may be called multiple times due to mid-request logging -- we only want to run on the
1482 : // last call.
1483 0 : if (!stream_info.requestComplete().has_value()) {
1484 0 : return;
1485 0 : }
1486 0 : if (!in_vm_context_created_) {
1487 : // If the request is invalid then onRequestHeaders() will not be called and neither will
1488 : // onCreate() in cases like sendLocalReply who short-circuits envoy
1489 : // lifecycle. This is because Envoy does not have a well defined lifetime for the combined
1490 : // HTTP
1491 : // + AccessLog filter. Thus, to log these scenarios, we call onCreate() in log function below.
1492 0 : onCreate();
1493 0 : }
1494 :
1495 0 : access_log_phase_ = true;
1496 0 : access_log_request_headers_ = &log_context.requestHeaders();
1497 : // ? request_trailers ?
1498 0 : access_log_response_headers_ = &log_context.responseHeaders();
1499 0 : access_log_response_trailers_ = &log_context.responseTrailers();
1500 0 : access_log_stream_info_ = &stream_info;
1501 :
1502 0 : onLog();
1503 :
1504 0 : access_log_phase_ = false;
1505 0 : access_log_request_headers_ = nullptr;
1506 : // ? request_trailers ?
1507 0 : access_log_response_headers_ = nullptr;
1508 0 : access_log_response_trailers_ = nullptr;
1509 0 : access_log_stream_info_ = nullptr;
1510 0 : }
1511 :
1512 0 : void Context::onDestroy() {
1513 0 : if (destroyed_ || !in_vm_context_created_) {
1514 0 : return;
1515 0 : }
1516 0 : destroyed_ = true;
1517 0 : onDone();
1518 0 : onDelete();
1519 0 : }
1520 :
1521 0 : WasmResult Context::continueStream(WasmStreamType stream_type) {
1522 0 : switch (stream_type) {
1523 0 : case WasmStreamType::Request:
1524 0 : if (decoder_callbacks_) {
1525 : // We are in a reentrant call, so defer.
1526 0 : wasm()->addAfterVmCallAction([this] { decoder_callbacks_->continueDecoding(); });
1527 0 : }
1528 0 : break;
1529 0 : case WasmStreamType::Response:
1530 0 : if (encoder_callbacks_) {
1531 : // We are in a reentrant call, so defer.
1532 0 : wasm()->addAfterVmCallAction([this] { encoder_callbacks_->continueEncoding(); });
1533 0 : }
1534 0 : break;
1535 0 : case WasmStreamType::Downstream:
1536 0 : if (network_read_filter_callbacks_) {
1537 : // We are in a reentrant call, so defer.
1538 0 : wasm()->addAfterVmCallAction([this] { network_read_filter_callbacks_->continueReading(); });
1539 0 : }
1540 0 : return WasmResult::Ok;
1541 0 : case WasmStreamType::Upstream:
1542 0 : return WasmResult::Unimplemented;
1543 0 : default:
1544 0 : return WasmResult::BadArgument;
1545 0 : }
1546 0 : request_headers_ = nullptr;
1547 0 : request_body_buffer_ = nullptr;
1548 0 : request_trailers_ = nullptr;
1549 0 : request_metadata_ = nullptr;
1550 0 : return WasmResult::Ok;
1551 0 : }
1552 :
1553 : constexpr absl::string_view CloseStreamResponseDetails = "wasm_close_stream";
1554 :
1555 0 : WasmResult Context::closeStream(WasmStreamType stream_type) {
1556 0 : switch (stream_type) {
1557 0 : case WasmStreamType::Request:
1558 0 : if (decoder_callbacks_) {
1559 0 : if (!decoder_callbacks_->streamInfo().responseCodeDetails().has_value()) {
1560 0 : decoder_callbacks_->streamInfo().setResponseCodeDetails(CloseStreamResponseDetails);
1561 0 : }
1562 : // We are in a reentrant call, so defer.
1563 0 : wasm()->addAfterVmCallAction([this] { decoder_callbacks_->resetStream(); });
1564 0 : }
1565 0 : return WasmResult::Ok;
1566 0 : case WasmStreamType::Response:
1567 0 : if (encoder_callbacks_) {
1568 0 : if (!encoder_callbacks_->streamInfo().responseCodeDetails().has_value()) {
1569 0 : encoder_callbacks_->streamInfo().setResponseCodeDetails(CloseStreamResponseDetails);
1570 0 : }
1571 : // We are in a reentrant call, so defer.
1572 0 : wasm()->addAfterVmCallAction([this] { encoder_callbacks_->resetStream(); });
1573 0 : }
1574 0 : return WasmResult::Ok;
1575 0 : case WasmStreamType::Downstream:
1576 0 : if (network_read_filter_callbacks_) {
1577 : // We are in a reentrant call, so defer.
1578 0 : wasm()->addAfterVmCallAction([this] {
1579 0 : network_read_filter_callbacks_->connection().close(
1580 0 : Envoy::Network::ConnectionCloseType::FlushWrite, "wasm_downstream_close");
1581 0 : });
1582 0 : }
1583 0 : return WasmResult::Ok;
1584 0 : case WasmStreamType::Upstream:
1585 0 : if (network_write_filter_callbacks_) {
1586 : // We are in a reentrant call, so defer.
1587 0 : wasm()->addAfterVmCallAction([this] {
1588 0 : network_write_filter_callbacks_->connection().close(
1589 0 : Envoy::Network::ConnectionCloseType::FlushWrite, "wasm_upstream_close");
1590 0 : });
1591 0 : }
1592 0 : return WasmResult::Ok;
1593 0 : }
1594 0 : return WasmResult::BadArgument;
1595 0 : }
1596 :
1597 : constexpr absl::string_view FailStreamResponseDetails = "wasm_fail_stream";
1598 :
1599 0 : void Context::failStream(WasmStreamType stream_type) {
1600 0 : switch (stream_type) {
1601 0 : case WasmStreamType::Request:
1602 0 : if (decoder_callbacks_ && !local_reply_sent_) {
1603 0 : decoder_callbacks_->sendLocalReply(Envoy::Http::Code::ServiceUnavailable, "", nullptr,
1604 0 : Grpc::Status::WellKnownGrpcStatus::Unavailable,
1605 0 : FailStreamResponseDetails);
1606 0 : local_reply_sent_ = true;
1607 0 : }
1608 0 : break;
1609 0 : case WasmStreamType::Response:
1610 0 : if (encoder_callbacks_ && !local_reply_sent_) {
1611 0 : encoder_callbacks_->sendLocalReply(Envoy::Http::Code::ServiceUnavailable, "", nullptr,
1612 0 : Grpc::Status::WellKnownGrpcStatus::Unavailable,
1613 0 : FailStreamResponseDetails);
1614 0 : local_reply_sent_ = true;
1615 0 : }
1616 0 : break;
1617 0 : case WasmStreamType::Downstream:
1618 0 : if (network_read_filter_callbacks_) {
1619 0 : network_read_filter_callbacks_->connection().close(
1620 0 : Envoy::Network::ConnectionCloseType::FlushWrite);
1621 0 : }
1622 0 : break;
1623 0 : case WasmStreamType::Upstream:
1624 0 : if (network_write_filter_callbacks_) {
1625 0 : network_write_filter_callbacks_->connection().close(
1626 0 : Envoy::Network::ConnectionCloseType::FlushWrite);
1627 0 : }
1628 0 : break;
1629 0 : }
1630 0 : }
1631 :
1632 : WasmResult Context::sendLocalResponse(uint32_t response_code, std::string_view body_text,
1633 : Pairs additional_headers, uint32_t grpc_status,
1634 0 : std::string_view details) {
1635 : // This flag is used to avoid calling sendLocalReply() twice, even if wasm code has this
1636 : // logic. We can't reuse "local_reply_sent_" here because it can't avoid calling nested
1637 : // sendLocalReply() during encodeHeaders().
1638 0 : if (local_reply_hold_) {
1639 0 : return WasmResult::BadArgument;
1640 0 : }
1641 : // "additional_headers" is a collection of string_views. These will no longer
1642 : // be valid when "modify_headers" is finally called below, so we must
1643 : // make copies of all the headers.
1644 0 : std::vector<std::pair<Http::LowerCaseString, std::string>> additional_headers_copy;
1645 0 : for (auto& p : additional_headers) {
1646 0 : const Http::LowerCaseString lower_key{std::string(p.first)};
1647 0 : additional_headers_copy.emplace_back(lower_key, std::string(p.second));
1648 0 : }
1649 :
1650 0 : auto modify_headers = [additional_headers_copy](Http::HeaderMap& headers) {
1651 0 : for (auto& p : additional_headers_copy) {
1652 0 : headers.addCopy(p.first, p.second);
1653 0 : }
1654 0 : };
1655 :
1656 0 : if (decoder_callbacks_) {
1657 : // This is a bit subtle because proxy_on_delete() does call DeferAfterCallActions(),
1658 : // so in theory it could call this and the Context in the VM would be invalid,
1659 : // but because it only gets called after the connections have drained, the call to
1660 : // sendLocalReply() will fail. Net net, this is safe.
1661 0 : wasm()->addAfterVmCallAction([this, response_code, body_text = std::string(body_text),
1662 0 : modify_headers = std::move(modify_headers), grpc_status,
1663 0 : details = StringUtil::replaceAllEmptySpace(
1664 0 : absl::string_view(details.data(), details.size()))] {
1665 : // When the wasm vm fails, failStream() is called if the plugin is fail-closed, we need
1666 : // this flag to avoid calling sendLocalReply() twice.
1667 0 : if (local_reply_sent_) {
1668 0 : return;
1669 0 : }
1670 0 : decoder_callbacks_->sendLocalReply(static_cast<Envoy::Http::Code>(response_code), body_text,
1671 0 : modify_headers, grpc_status, details);
1672 0 : local_reply_sent_ = true;
1673 0 : });
1674 0 : }
1675 0 : local_reply_hold_ = true;
1676 0 : return WasmResult::Ok;
1677 0 : }
1678 :
1679 0 : Http::FilterHeadersStatus Context::decodeHeaders(Http::RequestHeaderMap& headers, bool end_stream) {
1680 0 : onCreate();
1681 0 : request_headers_ = &headers;
1682 0 : end_of_stream_ = end_stream;
1683 0 : auto result = convertFilterHeadersStatus(onRequestHeaders(headerSize(&headers), end_stream));
1684 0 : if (result == Http::FilterHeadersStatus::Continue) {
1685 0 : request_headers_ = nullptr;
1686 0 : }
1687 0 : return result;
1688 0 : }
1689 :
1690 0 : Http::FilterDataStatus Context::decodeData(::Envoy::Buffer::Instance& data, bool end_stream) {
1691 0 : if (!in_vm_context_created_) {
1692 0 : return Http::FilterDataStatus::Continue;
1693 0 : }
1694 0 : request_body_buffer_ = &data;
1695 0 : end_of_stream_ = end_stream;
1696 0 : const auto buffer = getBuffer(WasmBufferType::HttpRequestBody);
1697 0 : const auto buffer_size = (buffer == nullptr) ? 0 : buffer->size();
1698 0 : auto result = convertFilterDataStatus(onRequestBody(buffer_size, end_stream));
1699 0 : buffering_request_body_ = false;
1700 0 : switch (result) {
1701 0 : case Http::FilterDataStatus::Continue:
1702 0 : request_body_buffer_ = nullptr;
1703 0 : break;
1704 0 : case Http::FilterDataStatus::StopIterationAndBuffer:
1705 0 : buffering_request_body_ = true;
1706 0 : break;
1707 0 : case Http::FilterDataStatus::StopIterationAndWatermark:
1708 0 : case Http::FilterDataStatus::StopIterationNoBuffer:
1709 0 : break;
1710 0 : }
1711 0 : return result;
1712 0 : }
1713 :
1714 0 : Http::FilterTrailersStatus Context::decodeTrailers(Http::RequestTrailerMap& trailers) {
1715 0 : if (!in_vm_context_created_) {
1716 0 : return Http::FilterTrailersStatus::Continue;
1717 0 : }
1718 0 : request_trailers_ = &trailers;
1719 0 : auto result = convertFilterTrailersStatus(onRequestTrailers(headerSize(&trailers)));
1720 0 : if (result == Http::FilterTrailersStatus::Continue) {
1721 0 : request_trailers_ = nullptr;
1722 0 : }
1723 0 : return result;
1724 0 : }
1725 :
1726 0 : Http::FilterMetadataStatus Context::decodeMetadata(Http::MetadataMap& request_metadata) {
1727 0 : if (!in_vm_context_created_) {
1728 0 : return Http::FilterMetadataStatus::Continue;
1729 0 : }
1730 0 : request_metadata_ = &request_metadata;
1731 0 : auto result = convertFilterMetadataStatus(onRequestMetadata(headerSize(&request_metadata)));
1732 0 : if (result == Http::FilterMetadataStatus::Continue) {
1733 0 : request_metadata_ = nullptr;
1734 0 : }
1735 0 : return result;
1736 0 : }
1737 :
1738 0 : void Context::setDecoderFilterCallbacks(Envoy::Http::StreamDecoderFilterCallbacks& callbacks) {
1739 0 : decoder_callbacks_ = &callbacks;
1740 0 : }
1741 :
1742 0 : Http::Filter1xxHeadersStatus Context::encode1xxHeaders(Http::ResponseHeaderMap&) {
1743 0 : return Http::Filter1xxHeadersStatus::Continue;
1744 0 : }
1745 :
1746 : Http::FilterHeadersStatus Context::encodeHeaders(Http::ResponseHeaderMap& headers,
1747 0 : bool end_stream) {
1748 0 : if (!in_vm_context_created_) {
1749 0 : return Http::FilterHeadersStatus::Continue;
1750 0 : }
1751 0 : response_headers_ = &headers;
1752 0 : end_of_stream_ = end_stream;
1753 0 : auto result = convertFilterHeadersStatus(onResponseHeaders(headerSize(&headers), end_stream));
1754 0 : if (result == Http::FilterHeadersStatus::Continue) {
1755 0 : response_headers_ = nullptr;
1756 0 : }
1757 0 : return result;
1758 0 : }
1759 :
1760 0 : Http::FilterDataStatus Context::encodeData(::Envoy::Buffer::Instance& data, bool end_stream) {
1761 0 : if (!in_vm_context_created_) {
1762 0 : return Http::FilterDataStatus::Continue;
1763 0 : }
1764 0 : response_body_buffer_ = &data;
1765 0 : end_of_stream_ = end_stream;
1766 0 : const auto buffer = getBuffer(WasmBufferType::HttpResponseBody);
1767 0 : const auto buffer_size = (buffer == nullptr) ? 0 : buffer->size();
1768 0 : auto result = convertFilterDataStatus(onResponseBody(buffer_size, end_stream));
1769 0 : buffering_response_body_ = false;
1770 0 : switch (result) {
1771 0 : case Http::FilterDataStatus::Continue:
1772 0 : request_body_buffer_ = nullptr;
1773 0 : break;
1774 0 : case Http::FilterDataStatus::StopIterationAndBuffer:
1775 0 : buffering_response_body_ = true;
1776 0 : break;
1777 0 : case Http::FilterDataStatus::StopIterationAndWatermark:
1778 0 : case Http::FilterDataStatus::StopIterationNoBuffer:
1779 0 : break;
1780 0 : }
1781 0 : return result;
1782 0 : }
1783 :
1784 0 : Http::FilterTrailersStatus Context::encodeTrailers(Http::ResponseTrailerMap& trailers) {
1785 0 : if (!in_vm_context_created_) {
1786 0 : return Http::FilterTrailersStatus::Continue;
1787 0 : }
1788 0 : response_trailers_ = &trailers;
1789 0 : auto result = convertFilterTrailersStatus(onResponseTrailers(headerSize(&trailers)));
1790 0 : if (result == Http::FilterTrailersStatus::Continue) {
1791 0 : response_trailers_ = nullptr;
1792 0 : }
1793 0 : return result;
1794 0 : }
1795 :
1796 0 : Http::FilterMetadataStatus Context::encodeMetadata(Http::MetadataMap& response_metadata) {
1797 0 : if (!in_vm_context_created_) {
1798 0 : return Http::FilterMetadataStatus::Continue;
1799 0 : }
1800 0 : response_metadata_ = &response_metadata;
1801 0 : auto result = convertFilterMetadataStatus(onResponseMetadata(headerSize(&response_metadata)));
1802 0 : if (result == Http::FilterMetadataStatus::Continue) {
1803 0 : response_metadata_ = nullptr;
1804 0 : }
1805 0 : return result;
1806 0 : }
1807 :
1808 : // Http::FilterMetadataStatus::Continue;
1809 :
1810 0 : void Context::setEncoderFilterCallbacks(Envoy::Http::StreamEncoderFilterCallbacks& callbacks) {
1811 0 : encoder_callbacks_ = &callbacks;
1812 0 : }
1813 :
1814 0 : void Context::onHttpCallSuccess(uint32_t token, Envoy::Http::ResponseMessagePtr&& response) {
1815 : // TODO: convert this into a function in proxy-wasm-cpp-host and use here.
1816 0 : if (proxy_wasm::current_context_ != nullptr) {
1817 : // We are in a reentrant call, so defer.
1818 0 : wasm()->addAfterVmCallAction([this, token, response = response.release()] {
1819 0 : onHttpCallSuccess(token, std::unique_ptr<Envoy::Http::ResponseMessage>(response));
1820 0 : });
1821 0 : return;
1822 0 : }
1823 0 : auto handler = http_request_.find(token);
1824 0 : if (handler == http_request_.end()) {
1825 0 : return;
1826 0 : }
1827 0 : http_call_response_ = &response;
1828 0 : uint32_t body_size = response->body().length();
1829 : // Deferred "after VM call" actions are going to be executed upon returning from
1830 : // ContextBase::*, which might include deleting Context object via proxy_done().
1831 0 : wasm()->addAfterVmCallAction([this, handler] {
1832 0 : http_call_response_ = nullptr;
1833 0 : http_request_.erase(handler);
1834 0 : });
1835 0 : ContextBase::onHttpCallResponse(token, response->headers().size(), body_size,
1836 0 : headerSize(response->trailers()));
1837 0 : }
1838 :
1839 0 : void Context::onHttpCallFailure(uint32_t token, Http::AsyncClient::FailureReason reason) {
1840 0 : if (proxy_wasm::current_context_ != nullptr) {
1841 : // We are in a reentrant call, so defer.
1842 0 : wasm()->addAfterVmCallAction([this, token, reason] { onHttpCallFailure(token, reason); });
1843 0 : return;
1844 0 : }
1845 0 : auto handler = http_request_.find(token);
1846 0 : if (handler == http_request_.end()) {
1847 0 : return;
1848 0 : }
1849 0 : status_code_ = static_cast<uint32_t>(WasmResult::BrokenConnection);
1850 : // This is the only value currently.
1851 0 : ASSERT(reason == Http::AsyncClient::FailureReason::Reset);
1852 0 : status_message_ = "reset";
1853 : // Deferred "after VM call" actions are going to be executed upon returning from
1854 : // ContextBase::*, which might include deleting Context object via proxy_done().
1855 0 : wasm()->addAfterVmCallAction([this, handler] {
1856 0 : status_message_ = "";
1857 0 : http_request_.erase(handler);
1858 0 : });
1859 0 : ContextBase::onHttpCallResponse(token, 0, 0, 0);
1860 0 : }
1861 :
1862 0 : void Context::onGrpcReceiveWrapper(uint32_t token, ::Envoy::Buffer::InstancePtr response) {
1863 0 : ASSERT(proxy_wasm::current_context_ == nullptr); // Non-reentrant.
1864 0 : auto cleanup = [this, token] {
1865 0 : if (wasm()->isGrpcCallId(token)) {
1866 0 : grpc_call_request_.erase(token);
1867 0 : }
1868 0 : };
1869 0 : if (wasm()->on_grpc_receive_) {
1870 0 : grpc_receive_buffer_ = std::move(response);
1871 0 : uint32_t response_size = grpc_receive_buffer_->length();
1872 : // Deferred "after VM call" actions are going to be executed upon returning from
1873 : // ContextBase::*, which might include deleting Context object via proxy_done().
1874 0 : wasm()->addAfterVmCallAction([this, cleanup] {
1875 0 : grpc_receive_buffer_.reset();
1876 0 : cleanup();
1877 0 : });
1878 0 : ContextBase::onGrpcReceive(token, response_size);
1879 0 : } else {
1880 0 : cleanup();
1881 0 : }
1882 0 : }
1883 :
1884 : void Context::onGrpcCloseWrapper(uint32_t token, const Grpc::Status::GrpcStatus& status,
1885 0 : const std::string_view message) {
1886 0 : if (proxy_wasm::current_context_ != nullptr) {
1887 : // We are in a reentrant call, so defer.
1888 0 : wasm()->addAfterVmCallAction([this, token, status, message = std::string(message)] {
1889 0 : onGrpcCloseWrapper(token, status, message);
1890 0 : });
1891 0 : return;
1892 0 : }
1893 0 : auto cleanup = [this, token] {
1894 0 : if (wasm()->isGrpcCallId(token)) {
1895 0 : grpc_call_request_.erase(token);
1896 0 : } else if (wasm()->isGrpcStreamId(token)) {
1897 0 : auto it = grpc_stream_.find(token);
1898 0 : if (it != grpc_stream_.end()) {
1899 0 : if (it->second.local_closed_) {
1900 0 : grpc_stream_.erase(token);
1901 0 : }
1902 0 : }
1903 0 : }
1904 0 : };
1905 0 : if (wasm()->on_grpc_close_) {
1906 0 : status_code_ = static_cast<uint32_t>(status);
1907 0 : status_message_ = toAbslStringView(message);
1908 : // Deferred "after VM call" actions are going to be executed upon returning from
1909 : // ContextBase::*, which might include deleting Context object via proxy_done().
1910 0 : wasm()->addAfterVmCallAction([this, cleanup] {
1911 0 : status_message_ = "";
1912 0 : cleanup();
1913 0 : });
1914 0 : ContextBase::onGrpcClose(token, status_code_);
1915 0 : } else {
1916 0 : cleanup();
1917 0 : }
1918 0 : }
1919 :
1920 0 : WasmResult Context::grpcSend(uint32_t token, std::string_view message, bool end_stream) {
1921 0 : if (!wasm()->isGrpcStreamId(token)) {
1922 0 : return WasmResult::BadArgument;
1923 0 : }
1924 0 : auto it = grpc_stream_.find(token);
1925 0 : if (it == grpc_stream_.end()) {
1926 0 : return WasmResult::NotFound;
1927 0 : }
1928 0 : if (it->second.stream_) {
1929 0 : it->second.stream_->sendMessageRaw(::Envoy::Buffer::InstancePtr(new ::Envoy::Buffer::OwnedImpl(
1930 0 : message.data(), message.size())),
1931 0 : end_stream);
1932 0 : }
1933 0 : return WasmResult::Ok;
1934 0 : }
1935 :
1936 0 : WasmResult Context::grpcClose(uint32_t token) {
1937 0 : if (wasm()->isGrpcCallId(token)) {
1938 0 : auto it = grpc_call_request_.find(token);
1939 0 : if (it == grpc_call_request_.end()) {
1940 0 : return WasmResult::NotFound;
1941 0 : }
1942 0 : if (it->second.request_) {
1943 0 : it->second.request_->cancel();
1944 0 : }
1945 0 : grpc_call_request_.erase(token);
1946 0 : return WasmResult::Ok;
1947 0 : } else if (wasm()->isGrpcStreamId(token)) {
1948 0 : auto it = grpc_stream_.find(token);
1949 0 : if (it == grpc_stream_.end()) {
1950 0 : return WasmResult::NotFound;
1951 0 : }
1952 0 : if (it->second.stream_) {
1953 0 : it->second.stream_->closeStream();
1954 0 : }
1955 0 : if (it->second.remote_closed_) {
1956 0 : grpc_stream_.erase(token);
1957 0 : } else {
1958 0 : it->second.local_closed_ = true;
1959 0 : }
1960 0 : return WasmResult::Ok;
1961 0 : }
1962 0 : return WasmResult::BadArgument;
1963 0 : }
1964 :
1965 0 : WasmResult Context::grpcCancel(uint32_t token) {
1966 0 : if (wasm()->isGrpcCallId(token)) {
1967 0 : auto it = grpc_call_request_.find(token);
1968 0 : if (it == grpc_call_request_.end()) {
1969 0 : return WasmResult::NotFound;
1970 0 : }
1971 0 : if (it->second.request_) {
1972 0 : it->second.request_->cancel();
1973 0 : }
1974 0 : grpc_call_request_.erase(token);
1975 0 : return WasmResult::Ok;
1976 0 : } else if (wasm()->isGrpcStreamId(token)) {
1977 0 : auto it = grpc_stream_.find(token);
1978 0 : if (it == grpc_stream_.end()) {
1979 0 : return WasmResult::NotFound;
1980 0 : }
1981 0 : if (it->second.stream_) {
1982 0 : it->second.stream_->resetStream();
1983 0 : }
1984 0 : grpc_stream_.erase(token);
1985 0 : return WasmResult::Ok;
1986 0 : }
1987 0 : return WasmResult::BadArgument;
1988 0 : }
1989 :
1990 : } // namespace Wasm
1991 : } // namespace Common
1992 : } // namespace Extensions
1993 : } // namespace Envoy
|