/proc/self/cwd/eval/compiler/flat_expr_builder.cc
Line | Count | Source |
1 | | /* |
2 | | * Copyright 2021 Google LLC |
3 | | * |
4 | | * Licensed under the Apache License, Version 2.0 (the "License"); |
5 | | * you may not use this file except in compliance with the License. |
6 | | * You may obtain a copy of the License at |
7 | | * |
8 | | * https://www.apache.org/licenses/LICENSE-2.0 |
9 | | * |
10 | | * Unless required by applicable law or agreed to in writing, software |
11 | | * distributed under the License is distributed on an "AS IS" BASIS, |
12 | | * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. |
13 | | * See the License for the specific language governing permissions and |
14 | | * limitations under the License. |
15 | | */ |
16 | | |
17 | | #include "eval/compiler/flat_expr_builder.h" |
18 | | |
19 | | #include <algorithm> |
20 | | #include <cstddef> |
21 | | #include <cstdint> |
22 | | #include <deque> |
23 | | #include <iterator> |
24 | | #include <limits> |
25 | | #include <memory> |
26 | | #include <stack> |
27 | | #include <string> |
28 | | #include <type_traits> |
29 | | #include <utility> |
30 | | #include <vector> |
31 | | |
32 | | #include "absl/algorithm/container.h" |
33 | | #include "absl/base/attributes.h" |
34 | | #include "absl/base/optimization.h" |
35 | | #include "absl/container/flat_hash_map.h" |
36 | | #include "absl/container/flat_hash_set.h" |
37 | | #include "absl/container/node_hash_map.h" |
38 | | #include "absl/functional/any_invocable.h" |
39 | | #include "absl/log/absl_check.h" |
40 | | #include "absl/log/check.h" |
41 | | #include "absl/status/status.h" |
42 | | #include "absl/status/statusor.h" |
43 | | #include "absl/strings/match.h" |
44 | | #include "absl/strings/numbers.h" |
45 | | #include "absl/strings/str_cat.h" |
46 | | #include "absl/strings/string_view.h" |
47 | | #include "absl/strings/strip.h" |
48 | | #include "absl/types/optional.h" |
49 | | #include "absl/types/span.h" |
50 | | #include "absl/types/variant.h" |
51 | | #include "base/ast.h" |
52 | | #include "base/builtins.h" |
53 | | #include "base/type_provider.h" |
54 | | #include "common/allocator.h" |
55 | | #include "common/ast.h" |
56 | | #include "common/ast_traverse.h" |
57 | | #include "common/ast_visitor.h" |
58 | | #include "common/constant.h" |
59 | | #include "common/expr.h" |
60 | | #include "common/kind.h" |
61 | | #include "common/type.h" |
62 | | #include "common/value.h" |
63 | | #include "eval/compiler/check_ast_extensions.h" |
64 | | #include "eval/compiler/flat_expr_builder_extensions.h" |
65 | | #include "eval/compiler/resolver.h" |
66 | | #include "eval/eval/comprehension_step.h" |
67 | | #include "eval/eval/const_value_step.h" |
68 | | #include "eval/eval/container_access_step.h" |
69 | | #include "eval/eval/create_list_step.h" |
70 | | #include "eval/eval/create_map_step.h" |
71 | | #include "eval/eval/create_struct_step.h" |
72 | | #include "eval/eval/direct_expression_step.h" |
73 | | #include "eval/eval/equality_steps.h" |
74 | | #include "eval/eval/evaluator_core.h" |
75 | | #include "eval/eval/function_step.h" |
76 | | #include "eval/eval/ident_step.h" |
77 | | #include "eval/eval/jump_step.h" |
78 | | #include "eval/eval/lazy_init_step.h" |
79 | | #include "eval/eval/logic_step.h" |
80 | | #include "eval/eval/optional_or_step.h" |
81 | | #include "eval/eval/select_step.h" |
82 | | #include "eval/eval/shadowable_value_step.h" |
83 | | #include "eval/eval/ternary_step.h" |
84 | | #include "eval/eval/trace_step.h" |
85 | | #include "internal/status_macros.h" |
86 | | #include "runtime/internal/convert_constant.h" |
87 | | #include "runtime/internal/issue_collector.h" |
88 | | #include "runtime/runtime_issue.h" |
89 | | #include "runtime/runtime_options.h" |
90 | | #include "runtime/type_registry.h" |
91 | | #include "google/protobuf/arena.h" |
92 | | |
93 | | namespace google::api::expr::runtime { |
94 | | |
95 | | namespace { |
96 | | |
97 | | using ::cel::Ast; |
98 | | using ::cel::AstTraverse; |
99 | | using ::cel::RuntimeIssue; |
100 | | using ::cel::StringValue; |
101 | | using ::cel::Value; |
102 | | using ::cel::runtime_internal::ConvertConstant; |
103 | | using ::cel::runtime_internal::GetLegacyRuntimeTypeProvider; |
104 | | using ::cel::runtime_internal::GetRuntimeTypeProvider; |
105 | | using ::cel::runtime_internal::IssueCollector; |
106 | | |
107 | | constexpr absl::string_view kOptionalOrFn = "or"; |
108 | | constexpr absl::string_view kOptionalOrValueFn = "orValue"; |
109 | | constexpr absl::string_view kBlock = "cel.@block"; |
110 | | |
111 | | // Forward declare to resolve circular dependency for short_circuiting visitors. |
112 | | class FlatExprVisitor; |
113 | | |
114 | | // Error code for failed recursive program building. Generally indicates an |
115 | | // optimization doesn't support recursive programs. |
116 | 0 | absl::Status FailedRecursivePlanning() { |
117 | 0 | return absl::InternalError( |
118 | 0 | "failed to build recursive program. check for unsupported optimizations"); |
119 | 0 | } |
120 | | |
121 | | // Helper for bookkeeping variables mapped to indexes. |
122 | | class IndexManager { |
123 | | public: |
124 | 10.3k | IndexManager() : next_free_slot_(0), max_slot_count_(0) {} |
125 | | |
126 | 0 | size_t ReserveSlots(size_t n) { |
127 | 0 | size_t result = next_free_slot_; |
128 | 0 | next_free_slot_ += n; |
129 | 0 | if (next_free_slot_ > max_slot_count_) { |
130 | 0 | max_slot_count_ = next_free_slot_; |
131 | 0 | } |
132 | 0 | return result; |
133 | 0 | } |
134 | | |
135 | 0 | size_t ReleaseSlots(size_t n) { |
136 | 0 | next_free_slot_ -= n; |
137 | 0 | return next_free_slot_; |
138 | 0 | } |
139 | | |
140 | 9.92k | size_t max_slot_count() const { return max_slot_count_; } |
141 | | |
142 | | private: |
143 | | size_t next_free_slot_; |
144 | | size_t max_slot_count_; |
145 | | }; |
146 | | |
147 | | // Helper for computing jump offsets. |
148 | | // |
149 | | // Jumps should be self-contained to a single expression node -- jumping |
150 | | // outside that range is a bug. |
151 | | struct ProgramStepIndex { |
152 | | int index; |
153 | | ProgramBuilder::Subexpression* subexpression; |
154 | | }; |
155 | | |
156 | | // A convenience wrapper for offset-calculating logic. |
157 | | class Jump { |
158 | | public: |
159 | | // Default constructor for empty jump. |
160 | | // |
161 | | // Users must check that jump is non-empty before calling member functions. |
162 | 12.7k | explicit Jump() : self_index_{-1, nullptr}, jump_step_(nullptr) {} |
163 | | Jump(ProgramStepIndex self_index, JumpStepBase* jump_step) |
164 | 12.7k | : self_index_(self_index), jump_step_(jump_step) {} |
165 | | |
166 | | static absl::StatusOr<int> CalculateOffset(ProgramStepIndex base, |
167 | 12.4k | ProgramStepIndex target) { |
168 | 12.4k | if (target.subexpression != base.subexpression) { |
169 | 0 | return absl::InternalError( |
170 | 0 | "Jump target must be contained in the parent" |
171 | 0 | "subexpression"); |
172 | 0 | } |
173 | | |
174 | 12.4k | int offset = base.subexpression->CalculateOffset(base.index, target.index); |
175 | 12.4k | return offset; |
176 | 12.4k | } |
177 | | |
178 | 12.4k | absl::Status set_target(ProgramStepIndex target) { |
179 | 12.4k | CEL_ASSIGN_OR_RETURN(int offset, CalculateOffset(self_index_, target)); |
180 | | |
181 | 12.4k | jump_step_->set_jump_offset(offset); |
182 | 12.4k | return absl::OkStatus(); |
183 | 12.4k | } |
184 | | |
185 | 2.08k | bool exists() { return jump_step_ != nullptr; } |
186 | | |
187 | | private: |
188 | | ProgramStepIndex self_index_; |
189 | | JumpStepBase* jump_step_; |
190 | | }; |
191 | | |
192 | | class CondVisitor { |
193 | | public: |
194 | 11.1k | virtual ~CondVisitor() = default; |
195 | | virtual void PreVisit(const cel::Expr* expr) = 0; |
196 | | virtual void PostVisitArg(int arg_num, const cel::Expr* expr) = 0; |
197 | | virtual void PostVisit(const cel::Expr* expr) = 0; |
198 | 0 | virtual void PostVisitTarget(const cel::Expr* expr) {} |
199 | | }; |
200 | | |
201 | | enum class BinaryCond { |
202 | | kAnd = 0, |
203 | | kOr, |
204 | | kOptionalOr, |
205 | | kOptionalOrValue, |
206 | | }; |
207 | | |
208 | | // Visitor managing the "&&" and "||" operatiions. |
209 | | // Implements short-circuiting if enabled. |
210 | | // |
211 | | // With short-circuiting enabled, generates a program like: |
212 | | // +-------------+------------------------+-----------------------+ |
213 | | // | PC | Step | Stack | |
214 | | // +-------------+------------------------+-----------------------+ |
215 | | // | i + 0 | <Arg1> | arg1 | |
216 | | // | i + 1 | ConditionalJump i + 4 | arg1 | |
217 | | // | i + 2 | <Arg2> | arg1, arg2 | |
218 | | // | i + 3 | BooleanOperator | Op(arg1, arg2) | |
219 | | // | i + 4 | <rest of program> | arg1 | Op(arg1, arg2) | |
220 | | // +-------------+------------------------+------------------------+ |
221 | | class BinaryCondVisitor : public CondVisitor { |
222 | | public: |
223 | | explicit BinaryCondVisitor(FlatExprVisitor* visitor, BinaryCond cond, |
224 | | bool short_circuiting) |
225 | 10.4k | : visitor_(visitor), cond_(cond), short_circuiting_(short_circuiting) {} |
226 | | |
227 | | void PreVisit(const cel::Expr* expr) override; |
228 | | void PostVisitArg(int arg_num, const cel::Expr* expr) override; |
229 | | void PostVisit(const cel::Expr* expr) override; |
230 | | void PostVisitTarget(const cel::Expr* expr) override; |
231 | | |
232 | | private: |
233 | | FlatExprVisitor* visitor_; |
234 | | const BinaryCond cond_; |
235 | | Jump jump_step_; |
236 | | bool short_circuiting_; |
237 | | }; |
238 | | |
239 | | class TernaryCondVisitor : public CondVisitor { |
240 | | public: |
241 | 783 | explicit TernaryCondVisitor(FlatExprVisitor* visitor) : visitor_(visitor) {} |
242 | | |
243 | | void PreVisit(const cel::Expr* expr) override; |
244 | | void PostVisitArg(int arg_num, const cel::Expr* expr) override; |
245 | | void PostVisit(const cel::Expr* expr) override; |
246 | | |
247 | | private: |
248 | | FlatExprVisitor* visitor_; |
249 | | Jump jump_to_second_; |
250 | | Jump error_jump_; |
251 | | Jump jump_after_first_; |
252 | | }; |
253 | | |
254 | | class ExhaustiveTernaryCondVisitor : public CondVisitor { |
255 | | public: |
256 | | explicit ExhaustiveTernaryCondVisitor(FlatExprVisitor* visitor) |
257 | 0 | : visitor_(visitor) {} |
258 | | |
259 | | void PreVisit(const cel::Expr* expr) override; |
260 | 0 | void PostVisitArg(int arg_num, const cel::Expr* expr) override {} |
261 | | void PostVisit(const cel::Expr* expr) override; |
262 | | |
263 | | private: |
264 | | FlatExprVisitor* visitor_; |
265 | | }; |
266 | | |
267 | | // Returns a hint for the number of program nodes (steps or subexpressions) that |
268 | | // will be created for this expr. |
269 | 379k | size_t SizeHint(const cel::Expr& expr) { |
270 | 379k | switch (expr.kind_case()) { |
271 | 117k | case cel::ExprKindCase::kConstant: |
272 | 117k | return 1; |
273 | 59.6k | case cel::ExprKindCase::kIdentExpr: |
274 | 59.6k | return 1; |
275 | 16.8k | case cel::ExprKindCase::kSelectExpr: |
276 | 16.8k | return 2; |
277 | 166k | case cel::ExprKindCase::kCallExpr: |
278 | 166k | return expr.call_expr().args().size() + |
279 | 166k | (expr.call_expr().has_target() ? 2 : 1); |
280 | 8.92k | case cel::ExprKindCase::kListExpr: |
281 | 8.92k | return expr.list_expr().elements().size() + 1; |
282 | 2.18k | case cel::ExprKindCase::kStructExpr: |
283 | 2.18k | return expr.struct_expr().fields().size() + 1; |
284 | 8.59k | case cel::ExprKindCase::kMapExpr: |
285 | 8.59k | return 2 * expr.struct_expr().fields().size() + 1; |
286 | 0 | default: |
287 | 0 | return 1; |
288 | 379k | } |
289 | 0 | return 0; |
290 | 379k | } |
291 | | |
292 | | // Returns whether this comprehension appears to be a standard map/filter |
293 | | // macro implementation. It is not exhaustive, so it is unsafe to use with |
294 | | // custom comprehensions outside of the standard macros or hand crafted ASTs. |
295 | | bool IsOptimizableListAppend(const cel::ComprehensionExpr* comprehension, |
296 | 0 | bool enable_comprehension_list_append) { |
297 | 0 | if (!enable_comprehension_list_append) { |
298 | 0 | return false; |
299 | 0 | } |
300 | 0 | absl::string_view accu_var = comprehension->accu_var(); |
301 | 0 | if (accu_var.empty() || |
302 | 0 | comprehension->result().ident_expr().name() != accu_var) { |
303 | 0 | return false; |
304 | 0 | } |
305 | 0 | if (!comprehension->accu_init().has_list_expr() || |
306 | 0 | !comprehension->accu_init().list_expr().elements().empty()) { |
307 | 0 | return false; |
308 | 0 | } |
309 | | |
310 | 0 | if (!comprehension->loop_step().has_call_expr()) { |
311 | 0 | return false; |
312 | 0 | } |
313 | | |
314 | | // Macro loop_step for a filter() will contain a ternary: |
315 | | // filter ? accu_var + [elem] : accu_var |
316 | | // Macro loop_step for a map() will contain a list concat operation: |
317 | | // accu_var + [elem] |
318 | 0 | const auto* call_expr = &comprehension->loop_step().call_expr(); |
319 | |
|
320 | 0 | if (call_expr->function() == cel::builtin::kTernary && |
321 | 0 | call_expr->args().size() == 3) { |
322 | 0 | if (!call_expr->args()[1].has_call_expr()) { |
323 | 0 | return false; |
324 | 0 | } |
325 | 0 | call_expr = &(call_expr->args()[1].call_expr()); |
326 | 0 | } |
327 | | |
328 | 0 | return call_expr->function() == cel::builtin::kAdd && |
329 | 0 | call_expr->args().size() == 2 && |
330 | 0 | call_expr->args()[0].has_ident_expr() && |
331 | 0 | call_expr->args()[0].ident_expr().name() == accu_var && |
332 | 0 | call_expr->args()[1].has_list_expr() && |
333 | 0 | call_expr->args()[1].list_expr().elements().size() == 1; |
334 | 0 | } |
335 | | |
336 | | // Assuming `IsOptimizableListAppend()` return true, return a pointer to the |
337 | | // call `accu_var + [elem]`. |
338 | | const cel::CallExpr* GetOptimizableListAppendCall( |
339 | 0 | const cel::ComprehensionExpr* comprehension) { |
340 | 0 | ABSL_DCHECK(IsOptimizableListAppend( |
341 | 0 | comprehension, /*enable_comprehension_list_append=*/true)); |
342 | | |
343 | | // Macro loop_step for a filter() will contain a ternary: |
344 | | // filter ? accu_var + [elem] : accu_var |
345 | | // Macro loop_step for a map() will contain a list concat operation: |
346 | | // accu_var + [elem] |
347 | 0 | const auto* call_expr = &comprehension->loop_step().call_expr(); |
348 | |
|
349 | 0 | if (call_expr->function() == cel::builtin::kTernary && |
350 | 0 | call_expr->args().size() == 3) { |
351 | 0 | call_expr = &(call_expr->args()[1].call_expr()); |
352 | 0 | } |
353 | 0 | return call_expr; |
354 | 0 | } |
355 | | |
356 | | // Assuming `IsOptimizableListAppend()` return true, return a pointer to the |
357 | | // node `[elem]`. |
358 | | const cel::Expr* GetOptimizableListAppendOperand( |
359 | 0 | const cel::ComprehensionExpr* comprehension) { |
360 | 0 | return &GetOptimizableListAppendCall(comprehension)->args()[1]; |
361 | 0 | } |
362 | | |
363 | | // Returns whether this comprehension appears to be a macro implementation for |
364 | | // map transformations. It is not exhaustive, so it is unsafe to use with custom |
365 | | // comprehensions outside of the standard macros or hand crafted ASTs. |
366 | | bool IsOptimizableMapInsert(const cel::ComprehensionExpr* comprehension, |
367 | 0 | bool enable_comprehension_mutable_map) { |
368 | 0 | if (!enable_comprehension_mutable_map) { |
369 | 0 | return false; |
370 | 0 | } |
371 | 0 | if (comprehension->iter_var().empty() || comprehension->iter_var2().empty()) { |
372 | 0 | return false; |
373 | 0 | } |
374 | 0 | absl::string_view accu_var = comprehension->accu_var(); |
375 | 0 | if (accu_var.empty() || !comprehension->has_result() || |
376 | 0 | !comprehension->result().has_ident_expr() || |
377 | 0 | comprehension->result().ident_expr().name() != accu_var) { |
378 | 0 | return false; |
379 | 0 | } |
380 | 0 | if (!comprehension->accu_init().has_map_expr()) { |
381 | 0 | return false; |
382 | 0 | } |
383 | 0 | if (!comprehension->loop_step().has_call_expr()) { |
384 | 0 | return false; |
385 | 0 | } |
386 | 0 | const auto* call_expr = &comprehension->loop_step().call_expr(); |
387 | |
|
388 | 0 | if (call_expr->function() == cel::builtin::kTernary && |
389 | 0 | call_expr->args().size() == 3) { |
390 | 0 | if (!call_expr->args()[1].has_call_expr()) { |
391 | 0 | return false; |
392 | 0 | } |
393 | 0 | call_expr = &(call_expr->args()[1].call_expr()); |
394 | 0 | } |
395 | 0 | return call_expr->function() == "cel.@mapInsert" && |
396 | 0 | (call_expr->args().size() == 2 || call_expr->args().size() == 3) && |
397 | 0 | call_expr->args()[0].has_ident_expr() && |
398 | 0 | call_expr->args()[0].ident_expr().name() == accu_var; |
399 | 0 | } |
400 | | |
401 | 0 | bool IsBind(const cel::ComprehensionExpr* comprehension) { |
402 | 0 | static constexpr absl::string_view kUnusedIterVar = "#unused"; |
403 | |
|
404 | 0 | return comprehension->loop_condition().const_expr().has_bool_value() && |
405 | 0 | comprehension->loop_condition().const_expr().bool_value() == false && |
406 | 0 | comprehension->iter_var() == kUnusedIterVar && |
407 | 0 | comprehension->iter_var2().empty() && |
408 | 0 | comprehension->iter_range().has_list_expr() && |
409 | 0 | comprehension->iter_range().list_expr().elements().empty(); |
410 | 0 | } |
411 | | |
412 | 154k | bool IsBlock(const cel::CallExpr* call) { return call->function() == kBlock; } |
413 | | |
414 | | // Visitor for Comprehension expressions. |
415 | | class ComprehensionVisitor { |
416 | | public: |
417 | | explicit ComprehensionVisitor(FlatExprVisitor* visitor, bool short_circuiting, |
418 | | bool is_trivial, size_t iter_slot, |
419 | | size_t iter2_slot, size_t accu_slot) |
420 | 0 | : visitor_(visitor), |
421 | 0 | next_step_(nullptr), |
422 | 0 | cond_step_(nullptr), |
423 | 0 | short_circuiting_(short_circuiting), |
424 | 0 | is_trivial_(is_trivial), |
425 | 0 | accu_init_extracted_(false), |
426 | 0 | iter_slot_(iter_slot), |
427 | 0 | iter2_slot_(iter2_slot), |
428 | 0 | accu_slot_(accu_slot) {} |
429 | | |
430 | | void PreVisit(const cel::Expr* expr); |
431 | | absl::Status PostVisitArg(cel::ComprehensionArg arg_num, |
432 | 0 | const cel::Expr* comprehension_expr) { |
433 | 0 | if (is_trivial_) { |
434 | 0 | PostVisitArgTrivial(arg_num, comprehension_expr); |
435 | 0 | return absl::OkStatus(); |
436 | 0 | } else { |
437 | 0 | return PostVisitArgDefault(arg_num, comprehension_expr); |
438 | 0 | } |
439 | 0 | } |
440 | | void PostVisit(const cel::Expr* expr); |
441 | | |
442 | 0 | void MarkAccuInitExtracted() { accu_init_extracted_ = true; } |
443 | | |
444 | | private: |
445 | | void PostVisitArgTrivial(cel::ComprehensionArg arg_num, |
446 | | const cel::Expr* comprehension_expr); |
447 | | |
448 | | absl::Status PostVisitArgDefault(cel::ComprehensionArg arg_num, |
449 | | const cel::Expr* comprehension_expr); |
450 | | |
451 | | FlatExprVisitor* visitor_; |
452 | | ComprehensionInitStep* init_step_; |
453 | | ComprehensionNextStep* next_step_; |
454 | | ComprehensionCondStep* cond_step_; |
455 | | ProgramStepIndex init_step_pos_; |
456 | | ProgramStepIndex next_step_pos_; |
457 | | ProgramStepIndex cond_step_pos_; |
458 | | bool short_circuiting_; |
459 | | bool is_trivial_; |
460 | | bool accu_init_extracted_; |
461 | | size_t iter_slot_; |
462 | | size_t iter2_slot_; |
463 | | size_t accu_slot_; |
464 | | }; |
465 | | |
466 | | absl::flat_hash_set<int32_t> MakeOptionalIndicesSet( |
467 | 0 | const cel::ListExpr& create_list_expr) { |
468 | 0 | absl::flat_hash_set<int32_t> optional_indices; |
469 | 0 | for (size_t i = 0; i < create_list_expr.elements().size(); ++i) { |
470 | 0 | if (create_list_expr.elements()[i].optional()) { |
471 | 0 | optional_indices.insert(static_cast<int32_t>(i)); |
472 | 0 | } |
473 | 0 | } |
474 | 0 | return optional_indices; |
475 | 0 | } |
476 | | |
477 | | absl::flat_hash_set<int32_t> MakeOptionalIndicesSet( |
478 | 1.89k | const cel::StructExpr& create_struct_expr) { |
479 | 1.89k | absl::flat_hash_set<int32_t> optional_indices; |
480 | 1.89k | for (size_t i = 0; i < create_struct_expr.fields().size(); ++i) { |
481 | 0 | if (create_struct_expr.fields()[i].optional()) { |
482 | 0 | optional_indices.insert(static_cast<int32_t>(i)); |
483 | 0 | } |
484 | 0 | } |
485 | 1.89k | return optional_indices; |
486 | 1.89k | } |
487 | | |
488 | | absl::flat_hash_set<int32_t> MakeOptionalIndicesSet( |
489 | 8.70k | const cel::MapExpr& map_expr) { |
490 | 8.70k | absl::flat_hash_set<int32_t> optional_indices; |
491 | 12.5k | for (size_t i = 0; i < map_expr.entries().size(); ++i) { |
492 | 3.86k | if (map_expr.entries()[i].optional()) { |
493 | 0 | optional_indices.insert(static_cast<int32_t>(i)); |
494 | 0 | } |
495 | 3.86k | } |
496 | 8.70k | return optional_indices; |
497 | 8.70k | } |
498 | | |
499 | | class FlatExprVisitor : public cel::AstVisitor { |
500 | | public: |
501 | | enum class CallHandlerResult { |
502 | | // The call was intercepted, no additional processing is needed. |
503 | | kIntercepted, |
504 | | // The call was not intercepted, continue with the default processing. |
505 | | kNotIntercepted, |
506 | | }; |
507 | | |
508 | | // Handler for functions with builtin implementations. |
509 | | // This is used to replace the usual dispatcher step that applies |
510 | | // the arguments to a candidate function from the function registry. |
511 | | using CallHandler = absl::AnyInvocable<CallHandlerResult( |
512 | | const cel::Expr&, const cel::CallExpr&)>; |
513 | | |
514 | | FlatExprVisitor( |
515 | | const Resolver& resolver, const cel::RuntimeOptions& options, |
516 | | std::vector<std::unique_ptr<ProgramOptimizer>> program_optimizers, |
517 | | const absl::flat_hash_map<int64_t, cel::Reference>& reference_map, |
518 | | const cel::TypeProvider& type_provider, IssueCollector& issue_collector, |
519 | | ProgramBuilder& program_builder, PlannerContext& extension_context, |
520 | | bool enable_optional_types) |
521 | 10.3k | : resolver_(resolver), |
522 | 10.3k | type_provider_(type_provider), |
523 | 10.3k | progress_status_(absl::OkStatus()), |
524 | 10.3k | resolved_select_expr_(nullptr), |
525 | 10.3k | options_(options), |
526 | 10.3k | program_optimizers_(std::move(program_optimizers)), |
527 | 10.3k | issue_collector_(issue_collector), |
528 | 10.3k | program_builder_(program_builder), |
529 | 10.3k | extension_context_(extension_context), |
530 | 10.3k | enable_optional_types_(enable_optional_types) { |
531 | 10.3k | constexpr size_t kCallHandlerSizeHint = 11; |
532 | 10.3k | call_handlers_.reserve(kCallHandlerSizeHint); |
533 | 10.3k | call_handlers_[cel::builtin::kIndex] = [this](const cel::Expr& expr, |
534 | 10.3k | const cel::CallExpr& call) { |
535 | 6.18k | return HandleIndex(expr, call); |
536 | 6.18k | }; |
537 | 10.3k | call_handlers_[kBlock] = [this](const cel::Expr& expr, |
538 | 10.3k | const cel::CallExpr& call) { |
539 | 0 | return HandleBlock(expr, call); |
540 | 0 | }; |
541 | 10.3k | call_handlers_[cel::builtin::kAdd] = [this](const cel::Expr& expr, |
542 | 10.3k | const cel::CallExpr& call) { |
543 | 9.21k | return HandleListAppend(expr, call); |
544 | 9.21k | }; |
545 | 10.3k | if (options_.enable_fast_builtins) { |
546 | 10.3k | call_handlers_[cel::builtin::kNotStrictlyFalse] = |
547 | 10.3k | [this](const cel::Expr& expr, const cel::CallExpr& call) { |
548 | 0 | return HandleNotStrictlyFalse(expr, call); |
549 | 0 | }; |
550 | 10.3k | call_handlers_[cel::builtin::kNotStrictlyFalseDeprecated] = |
551 | 10.3k | [this](const cel::Expr& expr, const cel::CallExpr& call) { |
552 | 0 | return HandleNotStrictlyFalse(expr, call); |
553 | 0 | }; |
554 | 10.3k | call_handlers_[cel::builtin::kNot] = [this](const cel::Expr& expr, |
555 | 10.3k | const cel::CallExpr& call) { |
556 | 829 | return HandleNot(expr, call); |
557 | 829 | }; |
558 | 10.3k | if (options_.enable_heterogeneous_equality) { |
559 | 10.3k | for (const auto& in_op : |
560 | 10.3k | {cel::builtin::kIn, cel::builtin::kInDeprecated, |
561 | 31.0k | cel::builtin::kInFunction}) { |
562 | 31.0k | call_handlers_[in_op] = [this](const cel::Expr& expr, |
563 | 31.0k | const cel::CallExpr& call) { |
564 | 2.75k | return HandleHeterogeneousEqualityIn(expr, call); |
565 | 2.75k | }; |
566 | 31.0k | } |
567 | | // Try to detect if the environment is setup with a custom equality |
568 | | // implementation. |
569 | 10.3k | if (resolver_ |
570 | 10.3k | .FindOverloads(cel::builtin::kEqual, |
571 | 10.3k | /*receiver_style=*/false, |
572 | 10.3k | {cel::Kind::kAny, cel::Kind::kAny}) |
573 | 10.3k | .empty()) { |
574 | 10.3k | call_handlers_[cel::builtin::kEqual] = |
575 | 10.3k | [this](const cel::Expr& expr, const cel::CallExpr& call) { |
576 | 3.94k | return HandleHeterogeneousEquality(expr, call, |
577 | 3.94k | /*inequality=*/false); |
578 | 3.94k | }; |
579 | 10.3k | call_handlers_[cel::builtin::kInequal] = |
580 | 10.3k | [this](const cel::Expr& expr, const cel::CallExpr& call) { |
581 | 557 | return HandleHeterogeneousEquality(expr, call, |
582 | 557 | /*inequality=*/true); |
583 | 557 | }; |
584 | 10.3k | } |
585 | 10.3k | } |
586 | 10.3k | } |
587 | 10.3k | } |
588 | | |
589 | 0 | void SetMaxRecursionDepth(int max_recursion_depth) { |
590 | 0 | max_recursion_depth_ = max_recursion_depth; |
591 | 0 | } |
592 | | |
593 | 222k | bool PlanRecursiveProgram() const { return max_recursion_depth_ > 0; } |
594 | | |
595 | 385k | void PreVisitExpr(const cel::Expr& expr) override { |
596 | 385k | ValidateOrError(!absl::holds_alternative<cel::UnspecifiedExpr>(expr.kind()), |
597 | 385k | "Invalid empty expression"); |
598 | 385k | if (!progress_status_.ok()) { |
599 | 5.78k | return; |
600 | 5.78k | } |
601 | 379k | if (resume_from_suppressed_branch_ == nullptr && |
602 | 379k | suppressed_branches_.find(&expr) != suppressed_branches_.end()) { |
603 | 0 | resume_from_suppressed_branch_ = &expr; |
604 | 0 | } |
605 | | |
606 | 379k | if (block_.has_value()) { |
607 | 0 | BlockInfo& block = *block_; |
608 | 0 | if (block.in && block.bindings_set.contains(&expr)) { |
609 | 0 | block.current_binding = &expr; |
610 | 0 | } |
611 | 0 | } |
612 | | |
613 | 379k | auto* subexpression = |
614 | 379k | program_builder_.EnterSubexpression(&expr, SizeHint(expr)); |
615 | 379k | if (subexpression == nullptr) { |
616 | 0 | progress_status_.Update( |
617 | 0 | absl::InternalError("same CEL expr visited twice")); |
618 | 0 | return; |
619 | 0 | } |
620 | | |
621 | 379k | for (const std::unique_ptr<ProgramOptimizer>& optimizer : |
622 | 379k | program_optimizers_) { |
623 | 0 | absl::Status status = optimizer->OnPreVisit(extension_context_, expr); |
624 | 0 | if (!status.ok()) { |
625 | 0 | SetProgressStatusError(status); |
626 | 0 | } |
627 | 0 | } |
628 | 379k | } |
629 | | |
630 | 385k | void PostVisitExpr(const cel::Expr& expr) override { |
631 | 385k | if (!progress_status_.ok()) { |
632 | 9.10k | return; |
633 | 9.10k | } |
634 | 376k | if (&expr == resume_from_suppressed_branch_) { |
635 | 0 | resume_from_suppressed_branch_ = nullptr; |
636 | 0 | } |
637 | | |
638 | 376k | for (const std::unique_ptr<ProgramOptimizer>& optimizer : |
639 | 376k | program_optimizers_) { |
640 | 0 | absl::Status status = optimizer->OnPostVisit(extension_context_, expr); |
641 | 0 | if (!status.ok()) { |
642 | 0 | SetProgressStatusError(status); |
643 | 0 | return; |
644 | 0 | } |
645 | 0 | } |
646 | | |
647 | 376k | auto* subexpression = program_builder_.current(); |
648 | 376k | if (subexpression != nullptr && options_.enable_recursive_tracing && |
649 | 0 | subexpression->IsRecursive()) { |
650 | 0 | auto program = subexpression->ExtractRecursiveProgram(); |
651 | 0 | subexpression->set_recursive_program( |
652 | 0 | std::make_unique<TraceStep>(std::move(program.step)), program.depth); |
653 | 0 | } |
654 | | |
655 | 376k | program_builder_.ExitSubexpression(&expr); |
656 | | |
657 | 376k | if (!comprehension_stack_.empty() && |
658 | 0 | comprehension_stack_.back().is_optimizable_bind && |
659 | 0 | (&comprehension_stack_.back().comprehension->accu_init() == &expr)) { |
660 | 0 | SetProgressStatusError( |
661 | 0 | MaybeExtractSubexpression(&expr, comprehension_stack_.back())); |
662 | 0 | } |
663 | | |
664 | 376k | if (block_.has_value()) { |
665 | 0 | BlockInfo& block = *block_; |
666 | 0 | if (block.current_binding == &expr) { |
667 | 0 | int index = program_builder_.ExtractSubexpression(&expr); |
668 | 0 | if (index == -1) { |
669 | 0 | SetProgressStatusError( |
670 | 0 | absl::InvalidArgumentError("failed to extract subexpression")); |
671 | 0 | return; |
672 | 0 | } |
673 | 0 | block.subexpressions[block.current_index++] = index; |
674 | 0 | block.current_binding = nullptr; |
675 | 0 | } |
676 | 0 | } |
677 | 376k | } |
678 | | |
679 | | void PostVisitConst(const cel::Expr& expr, |
680 | 118k | const cel::Constant& const_expr) override { |
681 | 118k | if (!progress_status_.ok()) { |
682 | 803 | return; |
683 | 803 | } |
684 | | |
685 | 117k | absl::StatusOr<cel::Value> converted_value = |
686 | 117k | ConvertConstant(const_expr, cel::NewDeleteAllocator()); |
687 | | |
688 | 117k | if (!converted_value.ok()) { |
689 | 0 | SetProgressStatusError(converted_value.status()); |
690 | 0 | return; |
691 | 0 | } |
692 | | |
693 | 117k | if (options_.max_recursion_depth > 0 || options_.max_recursion_depth < 0) { |
694 | 0 | SetRecursiveStep(CreateConstValueDirectStep( |
695 | 0 | std::move(converted_value).value(), expr.id()), |
696 | 0 | 1); |
697 | 0 | return; |
698 | 0 | } |
699 | | |
700 | 117k | AddStep( |
701 | 117k | CreateConstValueStep(std::move(converted_value).value(), expr.id())); |
702 | 117k | } |
703 | | |
704 | | struct SlotLookupResult { |
705 | | int slot; |
706 | | int subexpression; |
707 | | }; |
708 | | |
709 | | // Helper to lookup a variable mapped to a slot. |
710 | | // |
711 | | // If lazy evaluation enabled and ided as a lazy expression, |
712 | | // subexpression and slot will be set. |
713 | 59.6k | SlotLookupResult LookupSlot(absl::string_view path) { |
714 | | // If there's a leading dot, it cannot resolve to a local variable. |
715 | 59.6k | if (absl::StartsWith(path, ".")) { |
716 | 773 | return {-1, -1}; |
717 | 773 | } |
718 | 58.8k | if (block_.has_value()) { |
719 | 0 | const BlockInfo& block = *block_; |
720 | 0 | if (block.in) { |
721 | 0 | absl::string_view index_suffix = path; |
722 | 0 | if (absl::ConsumePrefix(&index_suffix, "@index")) { |
723 | 0 | size_t index; |
724 | 0 | if (!absl::SimpleAtoi(index_suffix, &index)) { |
725 | 0 | SetProgressStatusError( |
726 | 0 | issue_collector_.AddIssue(RuntimeIssue::CreateError( |
727 | 0 | absl::InvalidArgumentError("bad @index")))); |
728 | 0 | return {-1, -1}; |
729 | 0 | } |
730 | 0 | if (index >= block.size) { |
731 | 0 | SetProgressStatusError( |
732 | 0 | issue_collector_.AddIssue(RuntimeIssue::CreateError( |
733 | 0 | absl::InvalidArgumentError(absl::StrCat( |
734 | 0 | "invalid @index greater than number of bindings: ", |
735 | 0 | index, " >= ", block.size))))); |
736 | 0 | return {-1, -1}; |
737 | 0 | } |
738 | 0 | if (index >= block.current_index) { |
739 | 0 | SetProgressStatusError( |
740 | 0 | issue_collector_.AddIssue(RuntimeIssue::CreateError( |
741 | 0 | absl::InvalidArgumentError(absl::StrCat( |
742 | 0 | "@index references current or future binding: ", index, |
743 | 0 | " >= ", block.current_index))))); |
744 | 0 | return {-1, -1}; |
745 | 0 | } |
746 | 0 | return {static_cast<int>(block.index + index), |
747 | 0 | block.subexpressions[index]}; |
748 | 0 | } |
749 | 0 | } |
750 | 0 | } |
751 | 58.8k | if (!comprehension_stack_.empty()) { |
752 | 0 | for (int i = comprehension_stack_.size() - 1; i >= 0; i--) { |
753 | 0 | const ComprehensionStackRecord& record = comprehension_stack_[i]; |
754 | 0 | if (record.iter_var_in_scope && |
755 | 0 | record.comprehension->iter_var() == path) { |
756 | 0 | if (record.is_optimizable_bind) { |
757 | 0 | SetProgressStatusError(issue_collector_.AddIssue( |
758 | 0 | RuntimeIssue::CreateWarning(absl::InvalidArgumentError( |
759 | 0 | "Unexpected iter_var access in trivial comprehension")))); |
760 | 0 | return {-1, -1}; |
761 | 0 | } |
762 | 0 | return {static_cast<int>(record.iter_slot), -1}; |
763 | 0 | } |
764 | 0 | if (record.iter_var2_in_scope && |
765 | 0 | record.comprehension->iter_var2() == path) { |
766 | 0 | return {static_cast<int>(record.iter2_slot), -1}; |
767 | 0 | } |
768 | 0 | if (record.accu_var_in_scope && |
769 | 0 | record.comprehension->accu_var() == path) { |
770 | 0 | int slot = record.accu_slot; |
771 | 0 | int subexpression = -1; |
772 | 0 | if (record.is_optimizable_bind) { |
773 | 0 | subexpression = record.subexpression; |
774 | 0 | } |
775 | 0 | return {slot, subexpression}; |
776 | 0 | } |
777 | 0 | } |
778 | 0 | } |
779 | 58.8k | if (absl::StartsWith(path, "@it:") || absl::StartsWith(path, "@it2:") || |
780 | 58.8k | absl::StartsWith(path, "@ac:")) { |
781 | | // If we see a CSE generated comprehension variable that was not |
782 | | // resolvable through the normal comprehension scope resolution, reject it |
783 | | // now rather than surfacing errors at activation time. |
784 | 0 | SetProgressStatusError( |
785 | 0 | issue_collector_.AddIssue(RuntimeIssue::CreateError( |
786 | 0 | absl::InvalidArgumentError("out of scope reference to CSE " |
787 | 0 | "generated comprehension variable")))); |
788 | 0 | } |
789 | 58.8k | return {-1, -1}; |
790 | 58.8k | } |
791 | | |
792 | | // Ident node handler. |
793 | | // Invoked after child nodes are processed. |
794 | | void PostVisitIdent(const cel::Expr& expr, |
795 | 61.6k | const cel::IdentExpr& ident_expr) override { |
796 | 61.6k | if (!progress_status_.ok()) { |
797 | 2.03k | return; |
798 | 2.03k | } |
799 | 59.6k | absl::string_view path = ident_expr.name(); |
800 | 59.6k | if (!ValidateOrError( |
801 | 59.6k | !path.empty(), |
802 | 59.6k | "Invalid expression: identifier 'name' must not be empty")) { |
803 | 0 | return; |
804 | 0 | } |
805 | | |
806 | | // Check if this is a local variable first (since it should shadow most |
807 | | // other interpretations). |
808 | 59.6k | SlotLookupResult slot = LookupSlot(path); |
809 | | |
810 | 59.6k | if (slot.subexpression >= 0) { |
811 | 0 | auto* subexpression = |
812 | 0 | program_builder_.GetExtractedSubexpression(slot.subexpression); |
813 | 0 | if (subexpression == nullptr) { |
814 | 0 | SetProgressStatusError( |
815 | 0 | absl::InternalError("bad subexpression reference")); |
816 | 0 | return; |
817 | 0 | } |
818 | 0 | if (subexpression->IsRecursive()) { |
819 | 0 | const auto& program = subexpression->recursive_program(); |
820 | 0 | SetRecursiveStep( |
821 | 0 | CreateDirectLazyInitStep(slot.slot, program.step.get(), expr.id()), |
822 | 0 | program.depth + 1); |
823 | 0 | } else { |
824 | | // Off by one since mainline expression will be index 0. |
825 | 0 | AddStep( |
826 | 0 | CreateLazyInitStep(slot.slot, slot.subexpression + 1, expr.id())); |
827 | 0 | } |
828 | 0 | return; |
829 | 59.6k | } else if (slot.slot >= 0) { |
830 | 0 | if (options_.max_recursion_depth != 0) { |
831 | 0 | SetRecursiveStep( |
832 | 0 | CreateDirectSlotIdentStep(ident_expr.name(), slot.slot, expr.id()), |
833 | 0 | 1); |
834 | 0 | } else { |
835 | 0 | AddStep( |
836 | 0 | CreateIdentStepForSlot(ident_expr.name(), slot.slot, expr.id())); |
837 | 0 | } |
838 | 0 | return; |
839 | 0 | } |
840 | | |
841 | | // Attempt to resolve a select expression as a namespaced identifier for an |
842 | | // enum or type constant value. |
843 | 59.6k | std::optional<cel::Value> const_value; |
844 | 59.6k | int64_t select_root_id = -1; |
845 | 59.6k | std::string path_candidate; |
846 | | |
847 | 68.6k | while (!namespace_stack_.empty()) { |
848 | 9.04k | const auto& select_node = namespace_stack_.front(); |
849 | | // Generate path in format "<ident>.<field 0>.<field 1>...". |
850 | 9.04k | const cel::Expr* select_expr = select_node.first; |
851 | 9.04k | path_candidate = absl::StrCat(path, ".", select_node.second); |
852 | | |
853 | | // Attempt to find a constant enum or type value which matches the |
854 | | // qualified path present in the expression. Whether the identifier |
855 | | // can be resolved to a type instance depends on whether the option to |
856 | | // 'enable_qualified_type_identifiers' is set to true. |
857 | 9.04k | const_value = resolver_.FindConstant(path_candidate, select_expr->id()); |
858 | 9.04k | if (const_value) { |
859 | 48 | resolved_select_expr_ = select_expr; |
860 | 48 | select_root_id = select_expr->id(); |
861 | 48 | path = path_candidate; |
862 | 48 | namespace_stack_.clear(); |
863 | 48 | break; |
864 | 48 | } |
865 | 8.99k | namespace_stack_.pop_front(); |
866 | 8.99k | } |
867 | | |
868 | 59.6k | if (!const_value) { |
869 | | // Attempt to resolve a simple identifier as an enum or type constant |
870 | | // value. |
871 | 59.5k | const_value = resolver_.FindConstant(path, expr.id()); |
872 | 59.5k | select_root_id = expr.id(); |
873 | 59.5k | } |
874 | | |
875 | | // TODO(issues/97): Need to add support for resolving packaged names at |
876 | | // runtime if Parse-only. For checked, checker should have reported the |
877 | | // expected interpretation. |
878 | 59.6k | if (const_value) { |
879 | | // If the path starts with a dot, strip it. |
880 | 765 | absl::string_view name = absl::StripPrefix(path, "."); |
881 | 765 | if (options_.max_recursion_depth != 0) { |
882 | 0 | SetRecursiveStep( |
883 | 0 | CreateDirectShadowableValueStep( |
884 | 0 | name, std::move(const_value).value(), select_root_id), |
885 | 0 | 1); |
886 | 0 | return; |
887 | 0 | } |
888 | 765 | AddStep(CreateShadowableValueStep(name, std::move(const_value).value(), |
889 | 765 | select_root_id)); |
890 | 765 | return; |
891 | 765 | } |
892 | | |
893 | 58.8k | absl::string_view ident_name = absl::StripPrefix(ident_expr.name(), "."); |
894 | 58.8k | if (options_.max_recursion_depth != 0) { |
895 | 0 | SetRecursiveStep(CreateDirectIdentStep(ident_name, expr.id()), 1); |
896 | 58.8k | } else { |
897 | 58.8k | AddStep(CreateIdentStep(ident_name, expr.id())); |
898 | 58.8k | } |
899 | 58.8k | } |
900 | | |
901 | | void PreVisitSelect(const cel::Expr& expr, |
902 | 17.2k | const cel::SelectExpr& select_expr) override { |
903 | 17.2k | if (!progress_status_.ok()) { |
904 | 445 | return; |
905 | 445 | } |
906 | 16.8k | if (!ValidateOrError( |
907 | 16.8k | !select_expr.field().empty(), |
908 | 16.8k | "invalid expression: select 'field' must not be empty")) { |
909 | 0 | return; |
910 | 0 | } |
911 | 16.8k | if (!ValidateOrError( |
912 | 16.8k | select_expr.has_operand() && |
913 | 16.8k | select_expr.operand().kind_case() != |
914 | 16.8k | cel::ExprKindCase::kUnspecifiedExpr, |
915 | 16.8k | "invalid expression: select must specify an operand")) { |
916 | 0 | return; |
917 | 0 | } |
918 | | |
919 | | // Not exactly the cleanest solution - we peek into child of |
920 | | // select_expr. |
921 | | // Chain of multiple SELECT ending with IDENT can represent namespaced |
922 | | // entity. |
923 | 16.8k | if (!select_expr.test_only() && (select_expr.operand().has_ident_expr() || |
924 | 15.1k | select_expr.operand().has_select_expr())) { |
925 | | // select expressions are pushed in reverse order: |
926 | | // google.type.Expr is pushed as: |
927 | | // - field: 'Expr' |
928 | | // - field: 'type' |
929 | | // - id: 'google' |
930 | | // |
931 | | // The search order though is as follows: |
932 | | // - id: 'google.type.Expr' |
933 | | // - id: 'google.type', field: 'Expr' |
934 | | // - id: 'google', field: 'type', field: 'Expr' |
935 | 209k | for (size_t i = 0; i < namespace_stack_.size(); i++) { |
936 | 197k | auto ns = namespace_stack_[i]; |
937 | 197k | namespace_stack_[i] = { |
938 | 197k | ns.first, absl::StrCat(select_expr.field(), ".", ns.second)}; |
939 | 197k | } |
940 | 11.8k | namespace_stack_.push_back({&expr, select_expr.field()}); |
941 | 11.8k | } else { |
942 | 4.96k | namespace_stack_.clear(); |
943 | 4.96k | } |
944 | 16.8k | } |
945 | | |
946 | | // Select node handler. |
947 | | // Invoked after child nodes are processed. |
948 | | void PostVisitSelect(const cel::Expr& expr, |
949 | 17.2k | const cel::SelectExpr& select_expr) override { |
950 | 17.2k | if (!progress_status_.ok()) { |
951 | 670 | return; |
952 | 670 | } |
953 | | |
954 | | // Check if we are "in the middle" of namespaced name. |
955 | | // This is currently enum specific. Constant expression that corresponds |
956 | | // to resolved enum value has been already created, thus preceding chain |
957 | | // of selects is no longer relevant. |
958 | 16.6k | if (resolved_select_expr_) { |
959 | 144 | if (&expr == resolved_select_expr_) { |
960 | 48 | resolved_select_expr_ = nullptr; |
961 | 48 | } |
962 | 144 | return; |
963 | 144 | } |
964 | | |
965 | 16.4k | if (auto depth = RecursionEligible(); depth.has_value()) { |
966 | 0 | auto deps = ExtractRecursiveDependencies(); |
967 | 0 | if (deps.size() != 1) { |
968 | 0 | SetProgressStatusError(absl::InternalError( |
969 | 0 | "unexpected number of dependencies for select operation.")); |
970 | 0 | return; |
971 | 0 | } |
972 | 0 | StringValue field = cel::StringValue(select_expr.field()); |
973 | |
|
974 | 0 | SetRecursiveStep( |
975 | 0 | CreateDirectSelectStep(std::move(deps[0]), std::move(field), |
976 | 0 | select_expr.test_only(), expr.id(), |
977 | 0 | options_.enable_empty_wrapper_null_unboxing, |
978 | 0 | enable_optional_types_), |
979 | 0 | *depth + 1); |
980 | 0 | return; |
981 | 0 | } |
982 | | |
983 | 16.4k | AddStep(CreateSelectStep(select_expr, expr.id(), |
984 | 16.4k | options_.enable_empty_wrapper_null_unboxing, |
985 | 16.4k | enable_optional_types_)); |
986 | 16.4k | } |
987 | | |
988 | | // Call node handler group. |
989 | | // We provide finer granularity for Call node callbacks to allow special |
990 | | // handling for short-circuiting |
991 | | // PreVisitCall is invoked before child nodes are processed. |
992 | | void PreVisitCall(const cel::Expr& expr, |
993 | 167k | const cel::CallExpr& call_expr) override { |
994 | 167k | if (!progress_status_.ok()) { |
995 | 1.83k | return; |
996 | 1.83k | } |
997 | | |
998 | 166k | std::unique_ptr<CondVisitor> cond_visitor; |
999 | 166k | if (call_expr.function() == cel::builtin::kAnd) { |
1000 | 6.58k | cond_visitor = std::make_unique<BinaryCondVisitor>( |
1001 | 6.58k | this, BinaryCond::kAnd, options_.short_circuiting); |
1002 | 159k | } else if (call_expr.function() == cel::builtin::kOr) { |
1003 | 3.82k | cond_visitor = std::make_unique<BinaryCondVisitor>( |
1004 | 3.82k | this, BinaryCond::kOr, options_.short_circuiting); |
1005 | 155k | } else if (call_expr.function() == cel::builtin::kTernary) { |
1006 | 783 | if (options_.short_circuiting) { |
1007 | 783 | cond_visitor = std::make_unique<TernaryCondVisitor>(this); |
1008 | 783 | } else { |
1009 | 0 | cond_visitor = std::make_unique<ExhaustiveTernaryCondVisitor>(this); |
1010 | 0 | } |
1011 | 154k | } else if (enable_optional_types_ && |
1012 | 0 | call_expr.function() == kOptionalOrFn && |
1013 | 0 | call_expr.has_target() && call_expr.args().size() == 1) { |
1014 | 0 | cond_visitor = std::make_unique<BinaryCondVisitor>( |
1015 | 0 | this, BinaryCond::kOptionalOr, options_.short_circuiting); |
1016 | 154k | } else if (enable_optional_types_ && |
1017 | 0 | call_expr.function() == kOptionalOrValueFn && |
1018 | 0 | call_expr.has_target() && call_expr.args().size() == 1) { |
1019 | 0 | cond_visitor = std::make_unique<BinaryCondVisitor>( |
1020 | 0 | this, BinaryCond::kOptionalOrValue, options_.short_circuiting); |
1021 | 154k | } else if (IsBlock(&call_expr)) { |
1022 | | // cel.@block |
1023 | 0 | if (block_.has_value()) { |
1024 | | // There can only be one for now. |
1025 | 0 | SetProgressStatusError( |
1026 | 0 | absl::InvalidArgumentError("multiple cel.@block are not allowed")); |
1027 | 0 | return; |
1028 | 0 | } |
1029 | 0 | block_ = BlockInfo(); |
1030 | 0 | BlockInfo& block = *block_; |
1031 | 0 | block.in = true; |
1032 | 0 | if (call_expr.args().empty()) { |
1033 | 0 | SetProgressStatusError(absl::InvalidArgumentError( |
1034 | 0 | "malformed cel.@block: missing list of bound expressions")); |
1035 | 0 | return; |
1036 | 0 | } |
1037 | 0 | if (call_expr.args().size() != 2) { |
1038 | 0 | SetProgressStatusError(absl::InvalidArgumentError( |
1039 | 0 | "malformed cel.@block: missing bound expression")); |
1040 | 0 | return; |
1041 | 0 | } |
1042 | 0 | if (!call_expr.args()[0].has_list_expr()) { |
1043 | 0 | SetProgressStatusError( |
1044 | 0 | absl::InvalidArgumentError("malformed cel.@block: first argument " |
1045 | 0 | "is not a list of bound expressions")); |
1046 | 0 | return; |
1047 | 0 | } |
1048 | 0 | const auto& list_expr = call_expr.args().front().list_expr(); |
1049 | 0 | block.size = list_expr.elements().size(); |
1050 | 0 | if (block.size == 0) { |
1051 | 0 | SetProgressStatusError(absl::InvalidArgumentError( |
1052 | 0 | "malformed cel.@block: list of bound expressions is empty")); |
1053 | 0 | return; |
1054 | 0 | } |
1055 | 0 | block.bindings_set.reserve(block.size); |
1056 | 0 | for (const auto& list_expr_element : list_expr.elements()) { |
1057 | 0 | if (list_expr_element.optional()) { |
1058 | 0 | SetProgressStatusError( |
1059 | 0 | absl::InvalidArgumentError("malformed cel.@block: list of bound " |
1060 | 0 | "expressions contains an optional")); |
1061 | 0 | return; |
1062 | 0 | } |
1063 | 0 | block.bindings_set.insert(&list_expr_element.expr()); |
1064 | 0 | } |
1065 | 0 | block.index = index_manager().ReserveSlots(block.size); |
1066 | 0 | block.slot_count = block.size; |
1067 | 0 | block.expr = &expr; |
1068 | 0 | block.bindings = &call_expr.args()[0]; |
1069 | 0 | block.bound = &call_expr.args()[1]; |
1070 | 0 | block.subexpressions.resize(block.size, -1); |
1071 | 154k | } else { |
1072 | 154k | return; |
1073 | 154k | } |
1074 | | |
1075 | 11.1k | if (cond_visitor) { |
1076 | 11.1k | cond_visitor->PreVisit(&expr); |
1077 | 11.1k | cond_visitor_stack_.push({&expr, std::move(cond_visitor)}); |
1078 | 11.1k | } |
1079 | 11.1k | } |
1080 | | |
1081 | | // Returns the maximum recursion depth of the current program if it is |
1082 | | // eligible for recursion, or nullopt if it is not. |
1083 | 188k | std::optional<int> RecursionEligible() { |
1084 | 188k | if (!PlanRecursiveProgram() || program_builder_.current() == nullptr) { |
1085 | 188k | return absl::nullopt; |
1086 | 188k | } |
1087 | 0 | return program_builder_.current()->RecursiveDependencyDepth(); |
1088 | 188k | } |
1089 | | |
1090 | | std::vector<std::unique_ptr<DirectExpressionStep>> |
1091 | 0 | ExtractRecursiveDependencies() { |
1092 | | // Must check recursion eligibility before calling. |
1093 | 0 | ABSL_DCHECK(program_builder_.current() != nullptr); |
1094 | |
|
1095 | 0 | return program_builder_.current()->ExtractRecursiveDependencies(); |
1096 | 0 | } |
1097 | | |
1098 | 0 | void MakeTernaryRecursive(const cel::Expr* expr) { |
1099 | 0 | if (expr->call_expr().args().size() != 3) { |
1100 | 0 | SetProgressStatusError(absl::InvalidArgumentError( |
1101 | 0 | "unexpected number of args for builtin ternary")); |
1102 | 0 | return; |
1103 | 0 | } |
1104 | | |
1105 | 0 | const cel::Expr* condition_expr = &expr->call_expr().args()[0]; |
1106 | 0 | const cel::Expr* left_expr = &expr->call_expr().args()[1]; |
1107 | 0 | const cel::Expr* right_expr = &expr->call_expr().args()[2]; |
1108 | |
|
1109 | 0 | auto* condition_plan = program_builder_.GetSubexpression(condition_expr); |
1110 | 0 | auto* left_plan = program_builder_.GetSubexpression(left_expr); |
1111 | 0 | auto* right_plan = program_builder_.GetSubexpression(right_expr); |
1112 | |
|
1113 | 0 | if (condition_plan == nullptr || !condition_plan->IsRecursive() || |
1114 | 0 | left_plan == nullptr || !left_plan->IsRecursive() || |
1115 | 0 | right_plan == nullptr || !right_plan->IsRecursive()) { |
1116 | 0 | SetProgressStatusError(FailedRecursivePlanning()); |
1117 | 0 | return; |
1118 | 0 | } |
1119 | | |
1120 | 0 | int max_depth = std::max({0, condition_plan->recursive_program().depth, |
1121 | 0 | left_plan->recursive_program().depth, |
1122 | 0 | right_plan->recursive_program().depth}); |
1123 | |
|
1124 | 0 | SetRecursiveStep( |
1125 | 0 | CreateDirectTernaryStep(condition_plan->ExtractRecursiveProgram().step, |
1126 | 0 | left_plan->ExtractRecursiveProgram().step, |
1127 | 0 | right_plan->ExtractRecursiveProgram().step, |
1128 | 0 | expr->id(), options_.short_circuiting), |
1129 | 0 | max_depth + 1); |
1130 | 0 | } |
1131 | | |
1132 | 0 | void MakeShortcircuitRecursive(const cel::Expr* expr, bool is_or) { |
1133 | 0 | if (expr->call_expr().args().size() != 2) { |
1134 | 0 | SetProgressStatusError(absl::InvalidArgumentError( |
1135 | 0 | "unexpected number of args for builtin boolean operator &&/||")); |
1136 | 0 | return; |
1137 | 0 | } |
1138 | 0 | const cel::Expr* left_expr = &expr->call_expr().args()[0]; |
1139 | 0 | const cel::Expr* right_expr = &expr->call_expr().args()[1]; |
1140 | |
|
1141 | 0 | auto* left_plan = program_builder_.GetSubexpression(left_expr); |
1142 | 0 | auto* right_plan = program_builder_.GetSubexpression(right_expr); |
1143 | |
|
1144 | 0 | if (left_plan == nullptr || !left_plan->IsRecursive() || |
1145 | 0 | right_plan == nullptr || !right_plan->IsRecursive()) { |
1146 | 0 | SetProgressStatusError(FailedRecursivePlanning()); |
1147 | 0 | return; |
1148 | 0 | } |
1149 | | |
1150 | 0 | int max_depth = std::max({0, left_plan->recursive_program().depth, |
1151 | 0 | right_plan->recursive_program().depth}); |
1152 | |
|
1153 | 0 | if (is_or) { |
1154 | 0 | SetRecursiveStep( |
1155 | 0 | CreateDirectOrStep(left_plan->ExtractRecursiveProgram().step, |
1156 | 0 | right_plan->ExtractRecursiveProgram().step, |
1157 | 0 | expr->id(), options_.short_circuiting), |
1158 | 0 | max_depth + 1); |
1159 | 0 | } else { |
1160 | 0 | SetRecursiveStep( |
1161 | 0 | CreateDirectAndStep(left_plan->ExtractRecursiveProgram().step, |
1162 | 0 | right_plan->ExtractRecursiveProgram().step, |
1163 | 0 | expr->id(), options_.short_circuiting), |
1164 | 0 | max_depth + 1); |
1165 | 0 | } |
1166 | 0 | } |
1167 | | |
1168 | 0 | void MakeOptionalShortcircuit(const cel::Expr* expr, bool is_or_value) { |
1169 | 0 | if (!expr->call_expr().has_target() || |
1170 | 0 | expr->call_expr().args().size() != 1) { |
1171 | 0 | SetProgressStatusError(absl::InvalidArgumentError( |
1172 | 0 | "unexpected number of args for optional.or{Value}")); |
1173 | 0 | return; |
1174 | 0 | } |
1175 | 0 | const cel::Expr* left_expr = &expr->call_expr().target(); |
1176 | 0 | const cel::Expr* right_expr = &expr->call_expr().args()[0]; |
1177 | |
|
1178 | 0 | auto* left_plan = program_builder_.GetSubexpression(left_expr); |
1179 | 0 | auto* right_plan = program_builder_.GetSubexpression(right_expr); |
1180 | |
|
1181 | 0 | if (left_plan == nullptr || !left_plan->IsRecursive() || |
1182 | 0 | right_plan == nullptr || !right_plan->IsRecursive()) { |
1183 | 0 | SetProgressStatusError(FailedRecursivePlanning()); |
1184 | 0 | return; |
1185 | 0 | } |
1186 | 0 | int max_depth = std::max({0, left_plan->recursive_program().depth, |
1187 | 0 | right_plan->recursive_program().depth}); |
1188 | |
|
1189 | 0 | SetRecursiveStep(CreateDirectOptionalOrStep( |
1190 | 0 | expr->id(), left_plan->ExtractRecursiveProgram().step, |
1191 | 0 | right_plan->ExtractRecursiveProgram().step, |
1192 | 0 | is_or_value, options_.short_circuiting), |
1193 | 0 | max_depth + 1); |
1194 | 0 | } |
1195 | | |
1196 | | void MaybeMakeBindRecursive(const cel::Expr* expr, |
1197 | | const cel::ComprehensionExpr* comprehension, |
1198 | 0 | size_t accu_slot) { |
1199 | 0 | if (!PlanRecursiveProgram()) { |
1200 | 0 | return; |
1201 | 0 | } |
1202 | | |
1203 | 0 | auto* result_plan = |
1204 | 0 | program_builder_.GetSubexpression(&comprehension->result()); |
1205 | |
|
1206 | 0 | if (result_plan == nullptr || !result_plan->IsRecursive()) { |
1207 | 0 | SetProgressStatusError(FailedRecursivePlanning()); |
1208 | 0 | return; |
1209 | 0 | } |
1210 | | |
1211 | 0 | int result_depth = result_plan->recursive_program().depth; |
1212 | |
|
1213 | 0 | auto program = result_plan->ExtractRecursiveProgram(); |
1214 | 0 | SetRecursiveStep( |
1215 | 0 | CreateDirectBindStep(accu_slot, std::move(program.step), expr->id()), |
1216 | 0 | result_depth + 1); |
1217 | 0 | } |
1218 | | |
1219 | | void MaybeMakeComprehensionRecursive( |
1220 | | const cel::Expr* expr, const cel::ComprehensionExpr* comprehension, |
1221 | 0 | size_t iter_slot, size_t iter2_slot, size_t accu_slot) { |
1222 | 0 | if (!PlanRecursiveProgram()) { |
1223 | 0 | return; |
1224 | 0 | } |
1225 | | |
1226 | 0 | auto* accu_plan = |
1227 | 0 | program_builder_.GetSubexpression(&comprehension->accu_init()); |
1228 | 0 | auto* range_plan = |
1229 | 0 | program_builder_.GetSubexpression(&comprehension->iter_range()); |
1230 | 0 | auto* loop_plan = |
1231 | 0 | program_builder_.GetSubexpression(&comprehension->loop_step()); |
1232 | 0 | auto* condition_plan = |
1233 | 0 | program_builder_.GetSubexpression(&comprehension->loop_condition()); |
1234 | 0 | auto* result_plan = |
1235 | 0 | program_builder_.GetSubexpression(&comprehension->result()); |
1236 | 0 | if (accu_plan == nullptr || !accu_plan->IsRecursive() || |
1237 | 0 | range_plan == nullptr || !range_plan->IsRecursive() || |
1238 | 0 | loop_plan == nullptr || !loop_plan->IsRecursive() || |
1239 | 0 | condition_plan == nullptr || !condition_plan->IsRecursive() || |
1240 | 0 | result_plan == nullptr || !result_plan->IsRecursive()) { |
1241 | 0 | SetProgressStatusError(FailedRecursivePlanning()); |
1242 | 0 | return; |
1243 | 0 | } |
1244 | | |
1245 | 0 | int max_depth = 0; |
1246 | 0 | max_depth = std::max(max_depth, accu_plan->recursive_program().depth); |
1247 | 0 | max_depth = std::max(max_depth, range_plan->recursive_program().depth); |
1248 | 0 | max_depth = std::max(max_depth, loop_plan->recursive_program().depth); |
1249 | 0 | max_depth = std::max(max_depth, condition_plan->recursive_program().depth); |
1250 | 0 | max_depth = std::max(max_depth, result_plan->recursive_program().depth); |
1251 | |
|
1252 | 0 | auto step = CreateDirectComprehensionStep( |
1253 | 0 | iter_slot, iter2_slot, accu_slot, |
1254 | 0 | range_plan->ExtractRecursiveProgram().step, |
1255 | 0 | accu_plan->ExtractRecursiveProgram().step, |
1256 | 0 | loop_plan->ExtractRecursiveProgram().step, |
1257 | 0 | condition_plan->ExtractRecursiveProgram().step, |
1258 | 0 | result_plan->ExtractRecursiveProgram().step, options_.short_circuiting, |
1259 | 0 | expr->id()); |
1260 | |
|
1261 | 0 | SetRecursiveStep(std::move(step), max_depth + 1); |
1262 | 0 | } |
1263 | | |
1264 | | // Invoked after all child nodes are processed. |
1265 | | void PostVisitCall(const cel::Expr& expr, |
1266 | 167k | const cel::CallExpr& call_expr) override { |
1267 | 167k | if (!progress_status_.ok()) { |
1268 | 4.49k | return; |
1269 | 4.49k | } |
1270 | | |
1271 | 163k | auto cond_visitor = FindCondVisitor(&expr); |
1272 | 163k | if (cond_visitor) { |
1273 | 11.0k | cond_visitor->PostVisit(&expr); |
1274 | 11.0k | cond_visitor_stack_.pop(); |
1275 | 11.0k | return; |
1276 | 11.0k | } |
1277 | | |
1278 | | // Check if the call is intercepted by a custom handler. |
1279 | 152k | if (auto handler = call_handlers_.find(call_expr.function()); |
1280 | 152k | handler != call_handlers_.end()) { |
1281 | 23.4k | CallHandlerResult result = handler->second(expr, call_expr); |
1282 | 23.4k | if (result == CallHandlerResult::kIntercepted) { |
1283 | 14.2k | return; |
1284 | 14.2k | } // otherwise, apply default function handling. |
1285 | 23.4k | } |
1286 | | |
1287 | 138k | AddResolvedFunctionStep(&call_expr, &expr, call_expr.function()); |
1288 | 138k | } |
1289 | | |
1290 | | void PreVisitComprehension( |
1291 | | const cel::Expr& expr, |
1292 | 0 | const cel::ComprehensionExpr& comprehension) override { |
1293 | 0 | if (!progress_status_.ok()) { |
1294 | 0 | return; |
1295 | 0 | } |
1296 | 0 | if (!ValidateOrError(options_.enable_comprehension, |
1297 | 0 | "Comprehension support is disabled")) { |
1298 | 0 | return; |
1299 | 0 | } |
1300 | 0 | const auto& accu_var = comprehension.accu_var(); |
1301 | 0 | const auto& iter_var = comprehension.iter_var(); |
1302 | 0 | const auto& iter_var2 = comprehension.iter_var2(); |
1303 | 0 | ValidateOrError(!accu_var.empty(), |
1304 | 0 | "Invalid comprehension: 'accu_var' must not be empty"); |
1305 | 0 | ValidateOrError(!iter_var.empty(), |
1306 | 0 | "Invalid comprehension: 'iter_var' must not be empty"); |
1307 | 0 | ValidateOrError( |
1308 | 0 | accu_var != iter_var, |
1309 | 0 | "Invalid comprehension: 'accu_var' must not be the same as 'iter_var'"); |
1310 | 0 | ValidateOrError(accu_var != iter_var2, |
1311 | 0 | "Invalid comprehension: 'accu_var' must not be the same as " |
1312 | 0 | "'iter_var2'"); |
1313 | 0 | ValidateOrError(iter_var2 != iter_var, |
1314 | 0 | "Invalid comprehension: 'iter_var2' must not be the same " |
1315 | 0 | "as 'iter_var'"); |
1316 | 0 | ValidateOrError(comprehension.has_accu_init(), |
1317 | 0 | "Invalid comprehension: 'accu_init' must be set"); |
1318 | 0 | ValidateOrError(comprehension.has_loop_condition(), |
1319 | 0 | "Invalid comprehension: 'loop_condition' must be set"); |
1320 | 0 | ValidateOrError(comprehension.has_loop_step(), |
1321 | 0 | "Invalid comprehension: 'loop_step' must be set"); |
1322 | 0 | ValidateOrError(comprehension.has_result(), |
1323 | 0 | "Invalid comprehension: 'result' must be set"); |
1324 | |
|
1325 | 0 | size_t iter_slot, iter2_slot, accu_slot, slot_count; |
1326 | 0 | bool is_bind = IsBind(&comprehension); |
1327 | |
|
1328 | 0 | if (is_bind) { |
1329 | 0 | accu_slot = iter_slot = iter2_slot = index_manager_.ReserveSlots(1); |
1330 | 0 | slot_count = 1; |
1331 | 0 | } else if (comprehension.iter_var2().empty()) { |
1332 | 0 | iter_slot = iter2_slot = index_manager_.ReserveSlots(2); |
1333 | 0 | accu_slot = iter_slot + 1; |
1334 | 0 | slot_count = 2; |
1335 | 0 | } else { |
1336 | 0 | iter_slot = index_manager_.ReserveSlots(3); |
1337 | 0 | iter2_slot = iter_slot + 1; |
1338 | 0 | accu_slot = iter2_slot + 1; |
1339 | 0 | slot_count = 3; |
1340 | 0 | } |
1341 | |
|
1342 | 0 | if (block_.has_value()) { |
1343 | 0 | BlockInfo& block = *block_; |
1344 | 0 | if (block.in) { |
1345 | 0 | block.slot_count += slot_count; |
1346 | 0 | slot_count = 0; |
1347 | 0 | } |
1348 | 0 | } |
1349 | | // If this is in the scope of an optimized bind accu-init, account the slots |
1350 | | // to the outermost bind-init scope. |
1351 | | // |
1352 | | // The init expression is effectively inlined at the first usage in the |
1353 | | // critical path (which is unknown at plan time), so the used slots need to |
1354 | | // be dedicated for the entire scope of that bind. |
1355 | 0 | for (ComprehensionStackRecord& record : comprehension_stack_) { |
1356 | 0 | if (record.in_accu_init && record.is_optimizable_bind) { |
1357 | 0 | record.slot_count += slot_count; |
1358 | 0 | slot_count = 0; |
1359 | 0 | break; |
1360 | 0 | } |
1361 | | // If no bind init subexpression, account normally. |
1362 | 0 | } |
1363 | |
|
1364 | 0 | comprehension_stack_.push_back( |
1365 | 0 | {&expr, &comprehension, iter_slot, iter2_slot, accu_slot, slot_count, |
1366 | 0 | /*subexpression=*/-1, |
1367 | | /*.is_optimizable_list_append=*/ |
1368 | 0 | IsOptimizableListAppend(&comprehension, |
1369 | 0 | options_.enable_comprehension_list_append), |
1370 | | /*.is_optimizable_map_insert=*/ |
1371 | 0 | IsOptimizableMapInsert(&comprehension, |
1372 | 0 | options_.enable_comprehension_mutable_map), |
1373 | 0 | /*.is_optimizable_bind=*/is_bind, |
1374 | 0 | /*.iter_var_in_scope=*/false, |
1375 | 0 | /*.iter_var2_in_scope=*/false, |
1376 | 0 | /*.accu_var_in_scope=*/false, |
1377 | 0 | /*.in_accu_init=*/false, |
1378 | 0 | std::make_unique<ComprehensionVisitor>(this, options_.short_circuiting, |
1379 | 0 | is_bind, iter_slot, iter2_slot, |
1380 | 0 | accu_slot)}); |
1381 | 0 | comprehension_stack_.back().visitor->PreVisit(&expr); |
1382 | 0 | } |
1383 | | |
1384 | | // Invoked after all child nodes are processed. |
1385 | | void PostVisitComprehension( |
1386 | | const cel::Expr& expr, |
1387 | 0 | const cel::ComprehensionExpr& comprehension_expr) override { |
1388 | 0 | if (!progress_status_.ok()) { |
1389 | 0 | return; |
1390 | 0 | } |
1391 | | |
1392 | 0 | ComprehensionStackRecord& record = comprehension_stack_.back(); |
1393 | 0 | if (comprehension_stack_.empty() || |
1394 | 0 | record.comprehension != &comprehension_expr) { |
1395 | 0 | return; |
1396 | 0 | } |
1397 | | |
1398 | 0 | record.visitor->PostVisit(&expr); |
1399 | |
|
1400 | 0 | index_manager_.ReleaseSlots(record.slot_count); |
1401 | 0 | comprehension_stack_.pop_back(); |
1402 | 0 | } |
1403 | | |
1404 | | void PreVisitComprehensionSubexpression( |
1405 | | const cel::Expr& expr, const cel::ComprehensionExpr& compr, |
1406 | 0 | cel::ComprehensionArg comprehension_arg) override { |
1407 | 0 | if (!progress_status_.ok()) { |
1408 | 0 | return; |
1409 | 0 | } |
1410 | | |
1411 | 0 | if (comprehension_stack_.empty() || |
1412 | 0 | comprehension_stack_.back().comprehension != &compr) { |
1413 | 0 | return; |
1414 | 0 | } |
1415 | | |
1416 | 0 | ComprehensionStackRecord& record = comprehension_stack_.back(); |
1417 | |
|
1418 | 0 | switch (comprehension_arg) { |
1419 | 0 | case cel::ITER_RANGE: { |
1420 | 0 | record.in_accu_init = false; |
1421 | 0 | record.iter_var_in_scope = false; |
1422 | 0 | record.iter_var2_in_scope = false; |
1423 | 0 | record.accu_var_in_scope = false; |
1424 | 0 | break; |
1425 | 0 | } |
1426 | 0 | case cel::ACCU_INIT: { |
1427 | 0 | record.in_accu_init = true; |
1428 | 0 | record.iter_var_in_scope = false; |
1429 | 0 | record.iter_var2_in_scope = false; |
1430 | 0 | record.accu_var_in_scope = false; |
1431 | 0 | break; |
1432 | 0 | } |
1433 | 0 | case cel::LOOP_CONDITION: { |
1434 | 0 | record.in_accu_init = false; |
1435 | 0 | record.iter_var_in_scope = true; |
1436 | 0 | record.iter_var2_in_scope = true; |
1437 | 0 | record.accu_var_in_scope = true; |
1438 | 0 | break; |
1439 | 0 | } |
1440 | 0 | case cel::LOOP_STEP: { |
1441 | 0 | record.in_accu_init = false; |
1442 | 0 | record.iter_var_in_scope = true; |
1443 | 0 | record.iter_var2_in_scope = true; |
1444 | 0 | record.accu_var_in_scope = true; |
1445 | 0 | break; |
1446 | 0 | } |
1447 | 0 | case cel::RESULT: { |
1448 | 0 | record.in_accu_init = false; |
1449 | 0 | record.iter_var_in_scope = false; |
1450 | 0 | record.iter_var2_in_scope = false; |
1451 | 0 | record.accu_var_in_scope = true; |
1452 | 0 | break; |
1453 | 0 | } |
1454 | 0 | } |
1455 | 0 | } |
1456 | | |
1457 | | void PostVisitComprehensionSubexpression( |
1458 | | const cel::Expr& expr, const cel::ComprehensionExpr& compr, |
1459 | 0 | cel::ComprehensionArg comprehension_arg) override { |
1460 | 0 | if (!progress_status_.ok()) { |
1461 | 0 | return; |
1462 | 0 | } |
1463 | | |
1464 | 0 | if (comprehension_stack_.empty() || |
1465 | 0 | comprehension_stack_.back().comprehension != &compr) { |
1466 | 0 | return; |
1467 | 0 | } |
1468 | | |
1469 | 0 | SetProgressStatusError(comprehension_stack_.back().visitor->PostVisitArg( |
1470 | 0 | comprehension_arg, comprehension_stack_.back().expr)); |
1471 | 0 | } |
1472 | | |
1473 | | // Invoked after each argument node processed. |
1474 | 317k | void PostVisitArg(const cel::Expr& expr, int arg_num) override { |
1475 | 317k | if (!progress_status_.ok()) { |
1476 | 7.14k | return; |
1477 | 7.14k | } |
1478 | 310k | auto cond_visitor = FindCondVisitor(&expr); |
1479 | 310k | if (cond_visitor) { |
1480 | 22.9k | cond_visitor->PostVisitArg(arg_num, &expr); |
1481 | 22.9k | } |
1482 | 310k | } |
1483 | | |
1484 | 782 | void PostVisitTarget(const cel::Expr& expr) override { |
1485 | 782 | if (!progress_status_.ok()) { |
1486 | 523 | return; |
1487 | 523 | } |
1488 | 259 | auto cond_visitor = FindCondVisitor(&expr); |
1489 | 259 | if (cond_visitor) { |
1490 | 0 | cond_visitor->PostVisitTarget(&expr); |
1491 | 0 | } |
1492 | 259 | } |
1493 | | |
1494 | | // CreateList node handler. |
1495 | | // Invoked after child nodes are processed. |
1496 | | void PostVisitList(const cel::Expr& expr, |
1497 | 9.20k | const cel::ListExpr& list_expr) override { |
1498 | 9.20k | if (!progress_status_.ok()) { |
1499 | 300 | return; |
1500 | 300 | } |
1501 | | |
1502 | 8.90k | if (block_.has_value()) { |
1503 | 0 | BlockInfo& block = *block_; |
1504 | 0 | if (block.bindings == &expr) { |
1505 | | // Do nothing, this is the cel.@block bindings list. |
1506 | 0 | return; |
1507 | 0 | } |
1508 | 0 | } |
1509 | | |
1510 | 8.90k | if (!comprehension_stack_.empty()) { |
1511 | 0 | const ComprehensionStackRecord& comprehension = |
1512 | 0 | comprehension_stack_.back(); |
1513 | 0 | if (comprehension.is_optimizable_list_append) { |
1514 | 0 | if (&(comprehension.comprehension->accu_init()) == &expr) { |
1515 | 0 | if (PlanRecursiveProgram()) { |
1516 | 0 | SetRecursiveStep(CreateDirectMutableListStep(expr.id()), 1); |
1517 | 0 | return; |
1518 | 0 | } |
1519 | 0 | AddStep(CreateMutableListStep(expr.id())); |
1520 | 0 | return; |
1521 | 0 | } |
1522 | 0 | if (GetOptimizableListAppendOperand(comprehension.comprehension) == |
1523 | 0 | &expr) { |
1524 | 0 | return; |
1525 | 0 | } |
1526 | 0 | } |
1527 | 0 | } |
1528 | 8.90k | if (std::optional<int> depth = RecursionEligible(); depth.has_value()) { |
1529 | 0 | auto deps = ExtractRecursiveDependencies(); |
1530 | 0 | if (deps.size() != list_expr.elements().size()) { |
1531 | 0 | SetProgressStatusError(absl::InternalError( |
1532 | 0 | "Unexpected number of plan elements for CreateList expr")); |
1533 | 0 | return; |
1534 | 0 | } |
1535 | 0 | auto step = CreateDirectListStep( |
1536 | 0 | std::move(deps), MakeOptionalIndicesSet(list_expr), expr.id()); |
1537 | 0 | SetRecursiveStep(std::move(step), *depth + 1); |
1538 | 0 | return; |
1539 | 0 | } |
1540 | 8.90k | AddStep(CreateCreateListStep(list_expr, expr.id())); |
1541 | 8.90k | } |
1542 | | |
1543 | | // CreateStruct node handler. |
1544 | | // Invoked after child nodes are processed. |
1545 | | void PostVisitStruct(const cel::Expr& expr, |
1546 | 2.45k | const cel::StructExpr& struct_expr) override { |
1547 | 2.45k | if (!progress_status_.ok()) { |
1548 | 288 | return; |
1549 | 288 | } |
1550 | | |
1551 | 2.16k | auto status_or_resolved_fields = |
1552 | 2.16k | ResolveCreateStructFields(struct_expr, expr.id()); |
1553 | 2.16k | if (!status_or_resolved_fields.ok()) { |
1554 | 275 | SetProgressStatusError(status_or_resolved_fields.status()); |
1555 | 275 | return; |
1556 | 275 | } |
1557 | 1.89k | std::string resolved_name = |
1558 | 1.89k | std::move(status_or_resolved_fields.value().first); |
1559 | 1.89k | std::vector<std::string> fields = |
1560 | 1.89k | std::move(status_or_resolved_fields.value().second); |
1561 | | |
1562 | 1.89k | if (auto depth = RecursionEligible(); depth.has_value()) { |
1563 | 0 | auto deps = ExtractRecursiveDependencies(); |
1564 | 0 | if (deps.size() != struct_expr.fields().size()) { |
1565 | 0 | SetProgressStatusError(absl::InternalError( |
1566 | 0 | "Unexpected number of plan elements for CreateStruct expr")); |
1567 | 0 | return; |
1568 | 0 | } |
1569 | 0 | auto step = CreateDirectCreateStructStep( |
1570 | 0 | std::move(resolved_name), std::move(fields), std::move(deps), |
1571 | 0 | MakeOptionalIndicesSet(struct_expr), expr.id()); |
1572 | 0 | SetRecursiveStep(std::move(step), *depth + 1); |
1573 | 0 | return; |
1574 | 0 | } |
1575 | | |
1576 | 1.89k | AddStep(CreateCreateStructStep(std::move(resolved_name), std::move(fields), |
1577 | 1.89k | MakeOptionalIndicesSet(struct_expr), |
1578 | 1.89k | expr.id())); |
1579 | 1.89k | } |
1580 | | |
1581 | | void PostVisitMap(const cel::Expr& expr, |
1582 | 8.70k | const cel::MapExpr& map_expr) override { |
1583 | 8.70k | for (const auto& entry : map_expr.entries()) { |
1584 | 3.86k | ValidateOrError(entry.has_key(), "Map entry missing key"); |
1585 | 3.86k | ValidateOrError(entry.has_value(), "Map entry missing value"); |
1586 | 3.86k | } |
1587 | | |
1588 | 8.70k | if (!comprehension_stack_.empty()) { |
1589 | 0 | const ComprehensionStackRecord& comprehension = |
1590 | 0 | comprehension_stack_.back(); |
1591 | 0 | if (comprehension.is_optimizable_map_insert) { |
1592 | 0 | if (&(comprehension.comprehension->accu_init()) == &expr) { |
1593 | 0 | if (PlanRecursiveProgram()) { |
1594 | 0 | SetRecursiveStep(CreateDirectMutableMapStep(expr.id()), 1); |
1595 | 0 | return; |
1596 | 0 | } |
1597 | 0 | AddStep(CreateMutableMapStep(expr.id())); |
1598 | 0 | return; |
1599 | 0 | } |
1600 | 0 | } |
1601 | 0 | } |
1602 | | |
1603 | 8.70k | if (auto depth = RecursionEligible(); depth.has_value()) { |
1604 | 0 | auto deps = ExtractRecursiveDependencies(); |
1605 | 0 | if (deps.size() != 2 * map_expr.entries().size()) { |
1606 | 0 | SetProgressStatusError(absl::InternalError( |
1607 | 0 | "Unexpected number of plan elements for CreateStruct expr")); |
1608 | 0 | return; |
1609 | 0 | } |
1610 | 0 | auto step = CreateDirectCreateMapStep( |
1611 | 0 | std::move(deps), MakeOptionalIndicesSet(map_expr), expr.id()); |
1612 | 0 | SetRecursiveStep(std::move(step), *depth + 1); |
1613 | 0 | return; |
1614 | 0 | } |
1615 | 8.70k | AddStep(CreateCreateStructStepForMap(map_expr.entries().size(), |
1616 | 8.70k | MakeOptionalIndicesSet(map_expr), |
1617 | 8.70k | expr.id())); |
1618 | 8.70k | } |
1619 | | |
1620 | 10.7k | absl::Status progress_status() const { return progress_status_; } |
1621 | | |
1622 | | // Mark a branch as suppressed. The visitor will continue as normal, but |
1623 | | // any emitted program steps are ignored. |
1624 | | // |
1625 | | // Only applies to branches that have not yet been visited (pre-order). |
1626 | 0 | void SuppressBranch(const cel::Expr* expr) { |
1627 | 0 | suppressed_branches_.insert(expr); |
1628 | 0 | } |
1629 | | |
1630 | | void AddResolvedFunctionStep(const cel::CallExpr* call_expr, |
1631 | | const cel::Expr* expr, |
1632 | 138k | absl::string_view function) { |
1633 | | // Establish the search criteria for a given function. |
1634 | 138k | bool receiver_style = call_expr->has_target(); |
1635 | 138k | size_t num_args = call_expr->args().size() + (receiver_style ? 1 : 0); |
1636 | | |
1637 | | // First, search for lazily defined function overloads. |
1638 | | // Lazy functions shadow eager functions with the same signature. |
1639 | 138k | auto lazy_overloads = resolver_.FindLazyOverloads( |
1640 | 138k | function, call_expr->has_target(), num_args, expr->id()); |
1641 | 138k | if (!lazy_overloads.empty()) { |
1642 | 0 | if (auto depth = RecursionEligible(); depth.has_value()) { |
1643 | 0 | auto args = program_builder_.current()->ExtractRecursiveDependencies(); |
1644 | 0 | SetRecursiveStep(CreateDirectLazyFunctionStep( |
1645 | 0 | expr->id(), *call_expr, std::move(args), |
1646 | 0 | std::move(lazy_overloads)), |
1647 | 0 | *depth + 1); |
1648 | 0 | return; |
1649 | 0 | } |
1650 | 0 | AddStep(CreateFunctionStep(*call_expr, expr->id(), |
1651 | 0 | std::move(lazy_overloads))); |
1652 | 0 | return; |
1653 | 0 | } |
1654 | | |
1655 | | // Second, search for eagerly defined function overloads. |
1656 | 138k | auto overloads = |
1657 | 138k | resolver_.FindOverloads(function, receiver_style, num_args, expr->id()); |
1658 | 138k | if (overloads.empty()) { |
1659 | | // Create a warning that the overload could not be found. Depending on the |
1660 | | // builder_warnings configuration, this could result in termination of the |
1661 | | // CelExpression creation or an inspectable warning for use within runtime |
1662 | | // logging. |
1663 | 122 | auto status = issue_collector_.AddIssue(RuntimeIssue::CreateWarning( |
1664 | 122 | absl::InvalidArgumentError( |
1665 | 122 | "No overloads provided for FunctionStep creation"), |
1666 | 122 | RuntimeIssue::ErrorCode::kNoMatchingOverload)); |
1667 | 122 | if (!status.ok()) { |
1668 | 122 | SetProgressStatusError(status); |
1669 | 122 | return; |
1670 | 122 | } |
1671 | 122 | } |
1672 | | |
1673 | 137k | if (auto recursion_depth = RecursionEligible(); |
1674 | 137k | recursion_depth.has_value()) { |
1675 | | // Nonnull while active -- nullptr indicates logic error elsewhere in the |
1676 | | // builder. |
1677 | 0 | ABSL_DCHECK(program_builder_.current() != nullptr); |
1678 | 0 | auto args = program_builder_.current()->ExtractRecursiveDependencies(); |
1679 | 0 | SetRecursiveStep( |
1680 | 0 | CreateDirectFunctionStep(expr->id(), *call_expr, std::move(args), |
1681 | 0 | std::move(overloads)), |
1682 | 0 | *recursion_depth + 1); |
1683 | 0 | return; |
1684 | 0 | } |
1685 | 137k | AddStep(CreateFunctionStep(*call_expr, expr->id(), std::move(overloads))); |
1686 | 137k | } |
1687 | | |
1688 | | // Add a step to the program, taking ownership. If successful, returns the |
1689 | | // pointer to the step. Otherwise, returns nullptr. |
1690 | | // |
1691 | | // Note: the pointer is only guaranteed to stay valid until the parent |
1692 | | // subexpression is finalized. Optimizers may modify the program plan which |
1693 | | // may free the step at that point. |
1694 | | ExpressionStep* AddStep( |
1695 | 365k | absl::StatusOr<std::unique_ptr<ExpressionStep>> step) { |
1696 | 365k | if (step.ok()) { |
1697 | 365k | return AddStep(*std::move(step)); |
1698 | 365k | } else { |
1699 | 0 | SetProgressStatusError(step.status()); |
1700 | 0 | } |
1701 | 0 | return nullptr; |
1702 | 365k | } |
1703 | | |
1704 | | template <typename T> |
1705 | | std::enable_if_t<std::is_base_of_v<ExpressionStep, T>, T*> AddStep( |
1706 | 388k | std::unique_ptr<T> step) { |
1707 | 388k | if (progress_status_.ok() && !PlanningSuppressed()) { |
1708 | 388k | return static_cast<T*>(program_builder_.AddStep(std::move(step))); |
1709 | 388k | } |
1710 | 110 | return nullptr; |
1711 | 388k | } flat_expr_builder.cc:_ZN6google3api4expr7runtime12_GLOBAL__N_115FlatExprVisitor7AddStepINS2_14ExpressionStepEEENSt3__19enable_ifIXsr3stdE12is_base_of_vIS6_T_EEPS9_E4typeENS7_10unique_ptrIS9_NS7_14default_deleteIS9_EEEE Line | Count | Source | 1706 | 375k | std::unique_ptr<T> step) { | 1707 | 375k | if (progress_status_.ok() && !PlanningSuppressed()) { | 1708 | 375k | return static_cast<T*>(program_builder_.AddStep(std::move(step))); | 1709 | 375k | } | 1710 | 110 | return nullptr; | 1711 | 375k | } |
flat_expr_builder.cc:_ZN6google3api4expr7runtime12_GLOBAL__N_115FlatExprVisitor7AddStepINS2_12JumpStepBaseEEENSt3__19enable_ifIXsr3stdE12is_base_of_vINS2_14ExpressionStepET_EEPSA_E4typeENS7_10unique_ptrISA_NS7_14default_deleteISA_EEEE Line | Count | Source | 1706 | 12.7k | std::unique_ptr<T> step) { | 1707 | 12.7k | if (progress_status_.ok() && !PlanningSuppressed()) { | 1708 | 12.7k | return static_cast<T*>(program_builder_.AddStep(std::move(step))); | 1709 | 12.7k | } | 1710 | 0 | return nullptr; | 1711 | 12.7k | } |
Unexecuted instantiation: flat_expr_builder.cc:_ZN6google3api4expr7runtime12_GLOBAL__N_115FlatExprVisitor7AddStepINS2_21ComprehensionInitStepEEENSt3__19enable_ifIXsr3stdE12is_base_of_vINS2_14ExpressionStepET_EEPSA_E4typeENS7_10unique_ptrISA_NS7_14default_deleteISA_EEEE Unexecuted instantiation: flat_expr_builder.cc:_ZN6google3api4expr7runtime12_GLOBAL__N_115FlatExprVisitor7AddStepINS2_21ComprehensionNextStepEEENSt3__19enable_ifIXsr3stdE12is_base_of_vINS2_14ExpressionStepET_EEPSA_E4typeENS7_10unique_ptrISA_NS7_14default_deleteISA_EEEE Unexecuted instantiation: flat_expr_builder.cc:_ZN6google3api4expr7runtime12_GLOBAL__N_115FlatExprVisitor7AddStepINS2_21ComprehensionCondStepEEENSt3__19enable_ifIXsr3stdE12is_base_of_vINS2_14ExpressionStepET_EEPSA_E4typeENS7_10unique_ptrISA_NS7_14default_deleteISA_EEEE |
1712 | | |
1713 | 0 | void SetRecursiveStep(std::unique_ptr<DirectExpressionStep> step, int depth) { |
1714 | 0 | if (!progress_status_.ok() || PlanningSuppressed()) { |
1715 | 0 | return; |
1716 | 0 | } |
1717 | 0 | if (program_builder_.current() == nullptr) { |
1718 | 0 | SetProgressStatusError(absl::InternalError( |
1719 | 0 | "CEL AST traversal out of order in flat_expr_builder.")); |
1720 | 0 | return; |
1721 | 0 | } |
1722 | 0 | program_builder_.current()->set_recursive_program(std::move(step), depth); |
1723 | 0 | if (depth > max_recursion_depth_) { |
1724 | 0 | SetProgressStatusError(absl::InvalidArgumentError( |
1725 | 0 | absl::StrCat("Maximum recursion depth of ", |
1726 | 0 | options_.max_recursion_depth, " exceeded"))); |
1727 | 0 | } |
1728 | 0 | } |
1729 | | |
1730 | 12.8k | void SetProgressStatusError(const absl::Status& status) { |
1731 | 12.8k | if (progress_status_.ok() && !status.ok()) { |
1732 | 410 | progress_status_ = status; |
1733 | 410 | } |
1734 | 12.8k | } |
1735 | | |
1736 | | // Index of the next step to be inserted, in terms of the current |
1737 | | // subexpression |
1738 | 25.1k | ProgramStepIndex GetCurrentIndex() const { |
1739 | | // Nonnull while active -- nullptr indicates logic error in the builder. |
1740 | 25.1k | ABSL_DCHECK(program_builder_.current() != nullptr); |
1741 | 25.1k | return {static_cast<int>(program_builder_.current()->elements().size()), |
1742 | 25.1k | program_builder_.current()}; |
1743 | 25.1k | } |
1744 | | |
1745 | 474k | CondVisitor* FindCondVisitor(const cel::Expr* expr) const { |
1746 | 474k | if (cond_visitor_stack_.empty()) { |
1747 | 293k | return nullptr; |
1748 | 293k | } |
1749 | | |
1750 | 180k | const auto& latest = cond_visitor_stack_.top(); |
1751 | | |
1752 | 180k | return (latest.first == expr) ? latest.second.get() : nullptr; |
1753 | 474k | } |
1754 | | |
1755 | 0 | IndexManager& index_manager() { return index_manager_; } |
1756 | | |
1757 | 9.92k | size_t slot_count() const { return index_manager_.max_slot_count(); } |
1758 | | |
1759 | 0 | void AddOptimizer(std::unique_ptr<ProgramOptimizer> optimizer) { |
1760 | 0 | program_optimizers_.push_back(std::move(optimizer)); |
1761 | 0 | } |
1762 | | |
1763 | | // Tests the boolean predicate, and if false produces an InvalidArgumentError |
1764 | | // which concatenates the error_message and any optional message_parts as the |
1765 | | // error status message. |
1766 | | template <typename... MP> |
1767 | | bool ValidateOrError(bool valid_expression, absl::string_view error_message, |
1768 | 514k | MP... message_parts) { |
1769 | 514k | if (valid_expression) { |
1770 | 514k | return true; |
1771 | 514k | } |
1772 | 13 | SetProgressStatusError(absl::InvalidArgumentError( |
1773 | 13 | absl::StrCat(error_message, message_parts...))); |
1774 | 13 | return false; |
1775 | 514k | } |
1776 | | |
1777 | | private: |
1778 | | struct ComprehensionStackRecord { |
1779 | | const cel::Expr* expr; |
1780 | | const cel::ComprehensionExpr* comprehension; |
1781 | | size_t iter_slot; |
1782 | | size_t iter2_slot; |
1783 | | size_t accu_slot; |
1784 | | size_t slot_count; |
1785 | | // -1 indicates this shouldn't be used. |
1786 | | int subexpression; |
1787 | | bool is_optimizable_list_append; |
1788 | | bool is_optimizable_map_insert; |
1789 | | bool is_optimizable_bind; |
1790 | | bool iter_var_in_scope; |
1791 | | bool iter_var2_in_scope; |
1792 | | bool accu_var_in_scope; |
1793 | | bool in_accu_init; |
1794 | | std::unique_ptr<ComprehensionVisitor> visitor; |
1795 | | }; |
1796 | | |
1797 | | struct BlockInfo { |
1798 | | // True if we are currently visiting the `cel.@block` node or any of its |
1799 | | // children. |
1800 | | bool in = false; |
1801 | | // Pointer to the `cel.@block` node. |
1802 | | const cel::Expr* expr = nullptr; |
1803 | | // Pointer to the `cel.@block` bindings, that is the first argument to the |
1804 | | // function. |
1805 | | const cel::Expr* bindings = nullptr; |
1806 | | // Set of pointers to the elements of `bindings` above. |
1807 | | absl::flat_hash_set<const cel::Expr*> bindings_set; |
1808 | | // Pointer to the `cel.@block` bound expression, that is the second argument |
1809 | | // to the function. |
1810 | | const cel::Expr* bound = nullptr; |
1811 | | // The number of entries in the `cel.@block`. |
1812 | | size_t size = 0; |
1813 | | // Starting slot index for `cel.@block`. We occupy he slot indices `index` |
1814 | | // through `index + size + (var_size * 2)`. |
1815 | | size_t index = 0; |
1816 | | // The total number of slots needed for evaluating the bound expressions. |
1817 | | size_t slot_count = 0; |
1818 | | // The current slot index we are processing, any index references must be |
1819 | | // less than this to be valid. |
1820 | | size_t current_index = 0; |
1821 | | // Pointer to the current `cel.@block` being processed, that is one of the |
1822 | | // elements within the first argument. |
1823 | | const cel::Expr* current_binding = nullptr; |
1824 | | // Mapping between block indices and their subexpressions, fixed size with |
1825 | | // exactly `size` elements. Unprocessed indices are set to `-1`. |
1826 | | std::vector<int> subexpressions; |
1827 | | }; |
1828 | | |
1829 | 388k | bool PlanningSuppressed() const { |
1830 | 388k | return resume_from_suppressed_branch_ != nullptr; |
1831 | 388k | } |
1832 | | |
1833 | | absl::Status MaybeExtractSubexpression(const cel::Expr* expr, |
1834 | 0 | ComprehensionStackRecord& record) { |
1835 | 0 | if (!record.is_optimizable_bind) { |
1836 | 0 | return absl::OkStatus(); |
1837 | 0 | } |
1838 | | |
1839 | 0 | int index = program_builder_.ExtractSubexpression(expr); |
1840 | 0 | if (index == -1) { |
1841 | 0 | return absl::InternalError("Failed to extract subexpression"); |
1842 | 0 | } |
1843 | | |
1844 | 0 | record.subexpression = index; |
1845 | |
|
1846 | 0 | record.visitor->MarkAccuInitExtracted(); |
1847 | |
|
1848 | 0 | return absl::OkStatus(); |
1849 | 0 | } |
1850 | | |
1851 | | // Resolve the name of the message type being created and the names of set |
1852 | | // fields. |
1853 | | absl::StatusOr<std::pair<std::string, std::vector<std::string>>> |
1854 | | ResolveCreateStructFields(const cel::StructExpr& create_struct_expr, |
1855 | 2.16k | int64_t expr_id) { |
1856 | 2.16k | absl::string_view ast_name = create_struct_expr.name(); |
1857 | | |
1858 | 2.16k | std::optional<std::pair<std::string, cel::Type>> type; |
1859 | 2.16k | CEL_ASSIGN_OR_RETURN(type, resolver_.FindType(ast_name, expr_id)); |
1860 | | |
1861 | 2.16k | if (!type.has_value()) { |
1862 | 234 | return absl::InvalidArgumentError(absl::StrCat( |
1863 | 234 | "Invalid struct creation: missing type info for '", ast_name, "'")); |
1864 | 234 | } |
1865 | | |
1866 | 1.93k | std::string resolved_name = std::move(type).value().first; |
1867 | | |
1868 | 1.93k | std::vector<std::string> fields; |
1869 | 1.93k | fields.reserve(create_struct_expr.fields().size()); |
1870 | 1.93k | for (const auto& entry : create_struct_expr.fields()) { |
1871 | 41 | if (entry.name().empty()) { |
1872 | 0 | return absl::InvalidArgumentError("Struct field missing name"); |
1873 | 0 | } |
1874 | 41 | if (!entry.has_value()) { |
1875 | 0 | return absl::InvalidArgumentError("Struct field missing value"); |
1876 | 0 | } |
1877 | 82 | CEL_ASSIGN_OR_RETURN(auto field, type_provider_.FindStructTypeFieldByName( |
1878 | 82 | resolved_name, entry.name())); |
1879 | 82 | if (!field.has_value()) { |
1880 | 41 | return absl::InvalidArgumentError( |
1881 | 41 | absl::StrCat("Invalid message creation: field '", entry.name(), |
1882 | 41 | "' not found in '", resolved_name, "'")); |
1883 | 41 | } |
1884 | 0 | fields.push_back(entry.name()); |
1885 | 0 | } |
1886 | | |
1887 | 1.89k | return std::make_pair(std::move(resolved_name), std::move(fields)); |
1888 | 1.93k | } |
1889 | | |
1890 | | CallHandlerResult HandleIndex(const cel::Expr& expr, |
1891 | | const cel::CallExpr& call); |
1892 | | CallHandlerResult HandleBlock(const cel::Expr& expr, |
1893 | | const cel::CallExpr& call); |
1894 | | CallHandlerResult HandleListAppend(const cel::Expr& expr, |
1895 | | const cel::CallExpr& call); |
1896 | | CallHandlerResult HandleNot(const cel::Expr& expr, const cel::CallExpr& call); |
1897 | | CallHandlerResult HandleNotStrictlyFalse(const cel::Expr& expr, |
1898 | | const cel::CallExpr& call); |
1899 | | |
1900 | | CallHandlerResult HandleHeterogeneousEquality(const cel::Expr& expr, |
1901 | | const cel::CallExpr& call, |
1902 | | bool inequality); |
1903 | | |
1904 | | CallHandlerResult HandleHeterogeneousEqualityIn(const cel::Expr& expr, |
1905 | | const cel::CallExpr& call); |
1906 | | |
1907 | | const Resolver& resolver_; |
1908 | | const cel::TypeProvider& type_provider_; |
1909 | | absl::Status progress_status_; |
1910 | | absl::flat_hash_map<std::string, CallHandler> call_handlers_; |
1911 | | |
1912 | | std::stack<std::pair<const cel::Expr*, std::unique_ptr<CondVisitor>>> |
1913 | | cond_visitor_stack_; |
1914 | | |
1915 | | // Tracks SELECT-...SELECT-IDENT chains. |
1916 | | std::deque<std::pair<const cel::Expr*, std::string>> namespace_stack_; |
1917 | | |
1918 | | // When multiple SELECT-...SELECT-IDENT chain is resolved as namespace, this |
1919 | | // field is used as marker suppressing CelExpression creation for SELECTs. |
1920 | | const cel::Expr* resolved_select_expr_; |
1921 | | |
1922 | | const cel::RuntimeOptions& options_; |
1923 | | |
1924 | | std::vector<ComprehensionStackRecord> comprehension_stack_; |
1925 | | absl::flat_hash_set<const cel::Expr*> suppressed_branches_; |
1926 | | const cel::Expr* resume_from_suppressed_branch_ = nullptr; |
1927 | | std::vector<std::unique_ptr<ProgramOptimizer>> program_optimizers_; |
1928 | | IssueCollector& issue_collector_; |
1929 | | |
1930 | | ProgramBuilder& program_builder_; |
1931 | | PlannerContext& extension_context_; |
1932 | | IndexManager index_manager_; |
1933 | | |
1934 | | bool enable_optional_types_; |
1935 | | std::optional<FlatExprVisitor::BlockInfo> block_; |
1936 | | int max_recursion_depth_ = 0; |
1937 | | }; |
1938 | | |
1939 | | FlatExprVisitor::CallHandlerResult FlatExprVisitor::HandleIndex( |
1940 | 6.18k | const cel::Expr& expr, const cel::CallExpr& call_expr) { |
1941 | 6.18k | ABSL_DCHECK(call_expr.function() == cel::builtin::kIndex); |
1942 | 6.18k | if (!ValidateOrError( |
1943 | 6.18k | (call_expr.args().size() == 2 && !call_expr.has_target()) || |
1944 | | // TODO(uncreated-issue/79): A few clients use the index operator with a |
1945 | | // target in custom ASTs. |
1946 | 0 | (call_expr.args().size() == 1 && call_expr.has_target()), |
1947 | 6.18k | "unexpected number of args for builtin index operator")) { |
1948 | 0 | return CallHandlerResult::kIntercepted; |
1949 | 0 | } |
1950 | | |
1951 | 6.18k | if (auto depth = RecursionEligible(); depth.has_value()) { |
1952 | 0 | auto args = ExtractRecursiveDependencies(); |
1953 | 0 | if (args.size() != 2) { |
1954 | 0 | SetProgressStatusError(absl::InvalidArgumentError( |
1955 | 0 | "unexpected number of args for builtin index operator")); |
1956 | 0 | return CallHandlerResult::kIntercepted; |
1957 | 0 | } |
1958 | 0 | SetRecursiveStep( |
1959 | 0 | CreateDirectContainerAccessStep(std::move(args[0]), std::move(args[1]), |
1960 | 0 | enable_optional_types_, expr.id()), |
1961 | 0 | *depth + 1); |
1962 | 0 | return CallHandlerResult::kIntercepted; |
1963 | 0 | } |
1964 | 6.18k | AddStep( |
1965 | 6.18k | CreateContainerAccessStep(call_expr, expr.id(), enable_optional_types_)); |
1966 | 6.18k | return CallHandlerResult::kIntercepted; |
1967 | 6.18k | } |
1968 | | |
1969 | | FlatExprVisitor::CallHandlerResult FlatExprVisitor::HandleNot( |
1970 | 829 | const cel::Expr& expr, const cel::CallExpr& call_expr) { |
1971 | 829 | ABSL_DCHECK(call_expr.function() == cel::builtin::kNot); |
1972 | | |
1973 | 829 | if (!ValidateOrError(call_expr.args().size() == 1 && !call_expr.has_target(), |
1974 | 829 | "unexpected number of args for builtin not operator")) { |
1975 | 0 | return CallHandlerResult::kIntercepted; |
1976 | 0 | } |
1977 | | |
1978 | 829 | if (auto depth = RecursionEligible(); depth.has_value()) { |
1979 | 0 | auto args = ExtractRecursiveDependencies(); |
1980 | 0 | if (args.size() != 1) { |
1981 | 0 | SetProgressStatusError(absl::InvalidArgumentError( |
1982 | 0 | "unexpected number of args for builtin not operator")); |
1983 | 0 | return CallHandlerResult::kIntercepted; |
1984 | 0 | } |
1985 | 0 | SetRecursiveStep(CreateDirectNotStep(std::move(args[0]), expr.id()), |
1986 | 0 | *depth + 1); |
1987 | 0 | return CallHandlerResult::kIntercepted; |
1988 | 0 | } |
1989 | 829 | AddStep(CreateNotStep(expr.id())); |
1990 | 829 | return CallHandlerResult::kIntercepted; |
1991 | 829 | } |
1992 | | |
1993 | | FlatExprVisitor::CallHandlerResult FlatExprVisitor::HandleNotStrictlyFalse( |
1994 | 0 | const cel::Expr& expr, const cel::CallExpr& call_expr) { |
1995 | 0 | if (!ValidateOrError(call_expr.args().size() == 1 && !call_expr.has_target(), |
1996 | 0 | "unexpected number of args for builtin " |
1997 | 0 | "not_strictly_false operator")) { |
1998 | 0 | return CallHandlerResult::kIntercepted; |
1999 | 0 | } |
2000 | | |
2001 | 0 | if (auto depth = RecursionEligible(); depth.has_value()) { |
2002 | 0 | auto args = ExtractRecursiveDependencies(); |
2003 | 0 | if (args.size() != 1) { |
2004 | 0 | SetProgressStatusError( |
2005 | 0 | absl::InvalidArgumentError("unexpected number of args for builtin " |
2006 | 0 | "@not_strictly_false operator")); |
2007 | 0 | return CallHandlerResult::kIntercepted; |
2008 | 0 | } |
2009 | 0 | SetRecursiveStep( |
2010 | 0 | CreateDirectNotStrictlyFalseStep(std::move(args[0]), expr.id()), |
2011 | 0 | *depth + 1); |
2012 | 0 | return CallHandlerResult::kIntercepted; |
2013 | 0 | } |
2014 | 0 | AddStep(CreateNotStrictlyFalseStep(expr.id())); |
2015 | 0 | return CallHandlerResult::kIntercepted; |
2016 | 0 | } |
2017 | | |
2018 | | FlatExprVisitor::CallHandlerResult FlatExprVisitor::HandleBlock( |
2019 | 0 | const cel::Expr& expr, const cel::CallExpr& call_expr) { |
2020 | 0 | ABSL_DCHECK(call_expr.function() == kBlock); |
2021 | 0 | if (!block_.has_value() || block_->expr != &expr || |
2022 | 0 | call_expr.args().size() != 2 || call_expr.has_target()) { |
2023 | 0 | SetProgressStatusError( |
2024 | 0 | absl::InvalidArgumentError("unexpected call to internal cel.@block")); |
2025 | 0 | return CallHandlerResult::kIntercepted; |
2026 | 0 | } |
2027 | | |
2028 | 0 | BlockInfo& block = *block_; |
2029 | 0 | block.in = false; |
2030 | 0 | index_manager().ReleaseSlots(block.slot_count); |
2031 | | |
2032 | | // Check if eligible for recursion and update the plan if so. |
2033 | | // |
2034 | | // The first argument to @block is the list of initializers. These don't |
2035 | | // generate a plan in the main program (they are tracked separately to support |
2036 | | // lazy evaluation) so we only need to extract the second argument -- the body |
2037 | | // of the block that uses the initializers. |
2038 | 0 | ProgramBuilder::Subexpression* body_subexpression = |
2039 | 0 | program_builder_.GetSubexpression(&call_expr.args()[1]); |
2040 | |
|
2041 | 0 | if (options_.max_recursion_depth != 0 && body_subexpression != nullptr && |
2042 | 0 | body_subexpression->IsRecursive() && |
2043 | 0 | (options_.max_recursion_depth < 0 || |
2044 | 0 | body_subexpression->recursive_program().depth < |
2045 | 0 | options_.max_recursion_depth)) { |
2046 | 0 | auto recursive_program = body_subexpression->ExtractRecursiveProgram(); |
2047 | 0 | SetRecursiveStep( |
2048 | 0 | CreateDirectBlockStep(block.index, block.slot_count, |
2049 | 0 | std::move(recursive_program.step), expr.id()), |
2050 | 0 | recursive_program.depth + 1); |
2051 | 0 | return CallHandlerResult::kIntercepted; |
2052 | 0 | } |
2053 | | |
2054 | | // Otherwise, iterative plan. |
2055 | 0 | AddStep(CreateClearSlotsStep(block.index, block.slot_count, expr.id())); |
2056 | |
|
2057 | 0 | return CallHandlerResult::kIntercepted; |
2058 | 0 | } |
2059 | | |
2060 | | FlatExprVisitor::CallHandlerResult FlatExprVisitor::HandleListAppend( |
2061 | 9.21k | const cel::Expr& expr, const cel::CallExpr& call_expr) { |
2062 | 9.21k | ABSL_DCHECK(call_expr.function() == cel::builtin::kAdd); |
2063 | | |
2064 | | // Check to see if this is a special case of add that should really be |
2065 | | // treated as a list append |
2066 | 9.21k | if (!comprehension_stack_.empty() && |
2067 | 0 | comprehension_stack_.back().is_optimizable_list_append) { |
2068 | | // Already checked that this is an optimizeable comprehension, |
2069 | | // check that this is the correct list append node. |
2070 | 0 | const cel::ComprehensionExpr* comprehension = |
2071 | 0 | comprehension_stack_.back().comprehension; |
2072 | 0 | const cel::Expr& loop_step = comprehension->loop_step(); |
2073 | | // Macro loop_step for a map() will contain a list concat operation: |
2074 | | // accu_var + [elem] |
2075 | 0 | if (&loop_step == &expr) { |
2076 | 0 | AddResolvedFunctionStep(&call_expr, &expr, |
2077 | 0 | cel::builtin::kRuntimeListAppend); |
2078 | 0 | return CallHandlerResult::kIntercepted; |
2079 | 0 | } |
2080 | | // Macro loop_step for a filter() will contain a ternary: |
2081 | | // filter ? accu_var + [elem] : accu_var |
2082 | 0 | if (loop_step.has_call_expr() && |
2083 | 0 | loop_step.call_expr().function() == cel::builtin::kTernary && |
2084 | 0 | loop_step.call_expr().args().size() == 3 && |
2085 | 0 | &(loop_step.call_expr().args()[1]) == &expr) { |
2086 | 0 | AddResolvedFunctionStep(&call_expr, &expr, |
2087 | 0 | cel::builtin::kRuntimeListAppend); |
2088 | 0 | return CallHandlerResult::kIntercepted; |
2089 | 0 | } |
2090 | 0 | } |
2091 | | |
2092 | 9.21k | return CallHandlerResult::kNotIntercepted; |
2093 | 9.21k | } |
2094 | | |
2095 | | FlatExprVisitor::CallHandlerResult FlatExprVisitor::HandleHeterogeneousEquality( |
2096 | 4.50k | const cel::Expr& expr, const cel::CallExpr& call, bool inequality) { |
2097 | 4.50k | if (!ValidateOrError( |
2098 | 4.50k | call.args().size() == 2 && !call.has_target(), |
2099 | 4.50k | "unexpected number of args for builtin equality operator")) { |
2100 | 0 | return CallHandlerResult::kIntercepted; |
2101 | 0 | } |
2102 | | |
2103 | 4.50k | if (auto depth = RecursionEligible(); depth.has_value()) { |
2104 | 0 | auto args = ExtractRecursiveDependencies(); |
2105 | 0 | if (args.size() != 2) { |
2106 | 0 | SetProgressStatusError(absl::InvalidArgumentError( |
2107 | 0 | "unexpected number of args for builtin equality operator")); |
2108 | 0 | return CallHandlerResult::kIntercepted; |
2109 | 0 | } |
2110 | 0 | SetRecursiveStep( |
2111 | 0 | CreateDirectEqualityStep(std::move(args[0]), std::move(args[1]), |
2112 | 0 | inequality, expr.id()), |
2113 | 0 | *depth + 1); |
2114 | 0 | return CallHandlerResult::kIntercepted; |
2115 | 0 | } |
2116 | 4.50k | AddStep(CreateEqualityStep(inequality, expr.id())); |
2117 | 4.50k | return CallHandlerResult::kIntercepted; |
2118 | 4.50k | } |
2119 | | |
2120 | | FlatExprVisitor::CallHandlerResult |
2121 | | FlatExprVisitor::HandleHeterogeneousEqualityIn(const cel::Expr& expr, |
2122 | 2.75k | const cel::CallExpr& call) { |
2123 | 2.75k | if (!ValidateOrError(call.args().size() == 2 && !call.has_target(), |
2124 | 2.75k | "unexpected number of args for builtin 'in' operator")) { |
2125 | 13 | return CallHandlerResult::kIntercepted; |
2126 | 13 | } |
2127 | | |
2128 | 2.74k | if (auto depth = RecursionEligible(); depth.has_value()) { |
2129 | 0 | auto args = ExtractRecursiveDependencies(); |
2130 | 0 | if (args.size() != 2) { |
2131 | 0 | SetProgressStatusError(absl::InvalidArgumentError( |
2132 | 0 | "unexpected number of args for builtin 'in' operator")); |
2133 | 0 | return CallHandlerResult::kIntercepted; |
2134 | 0 | } |
2135 | 0 | SetRecursiveStep( |
2136 | 0 | CreateDirectInStep(std::move(args[0]), std::move(args[1]), expr.id()), |
2137 | 0 | *depth + 1); |
2138 | 0 | return CallHandlerResult::kIntercepted; |
2139 | 0 | } |
2140 | | |
2141 | 2.74k | AddStep(CreateInStep(expr.id())); |
2142 | 2.74k | return CallHandlerResult::kIntercepted; |
2143 | 2.74k | } |
2144 | | |
2145 | 10.4k | void BinaryCondVisitor::PreVisit(const cel::Expr* expr) { |
2146 | 10.4k | switch (cond_) { |
2147 | 6.58k | case BinaryCond::kAnd: |
2148 | 6.58k | ABSL_FALLTHROUGH_INTENDED; |
2149 | 10.4k | case BinaryCond::kOr: |
2150 | 10.4k | visitor_->ValidateOrError( |
2151 | 10.4k | !expr->call_expr().has_target() && |
2152 | 10.4k | expr->call_expr().args().size() == 2, |
2153 | 10.4k | "Invalid argument count for a binary function call."); |
2154 | 10.4k | break; |
2155 | 0 | case BinaryCond::kOptionalOr: |
2156 | 0 | ABSL_FALLTHROUGH_INTENDED; |
2157 | 0 | case BinaryCond::kOptionalOrValue: |
2158 | 0 | visitor_->ValidateOrError(expr->call_expr().has_target() && |
2159 | 0 | expr->call_expr().args().size() == 1, |
2160 | 0 | "Invalid argument count for or/orValue call."); |
2161 | 0 | break; |
2162 | 10.4k | } |
2163 | 10.4k | } |
2164 | | |
2165 | 20.7k | void BinaryCondVisitor::PostVisitArg(int arg_num, const cel::Expr* expr) { |
2166 | 20.7k | if (visitor_->PlanRecursiveProgram()) { |
2167 | 0 | return; |
2168 | 0 | } |
2169 | 20.7k | if (short_circuiting_ && arg_num == 0 && |
2170 | 10.3k | (cond_ == BinaryCond::kAnd || cond_ == BinaryCond::kOr)) { |
2171 | | // If first branch evaluation result is enough to determine output, |
2172 | | // jump over the second branch and provide result of the first argument as |
2173 | | // final output. |
2174 | | // Retain a pointer to the jump step so we can update the target after |
2175 | | // planning the second argument. |
2176 | 10.3k | std::unique_ptr<JumpStepBase> jump_step; |
2177 | 10.3k | switch (cond_) { |
2178 | 6.57k | case BinaryCond::kAnd: |
2179 | 6.57k | jump_step = CreateCondJumpStep(false, true, {}, expr->id()); |
2180 | 6.57k | break; |
2181 | 3.81k | case BinaryCond::kOr: |
2182 | 3.81k | jump_step = CreateCondJumpStep(true, true, {}, expr->id()); |
2183 | 3.81k | break; |
2184 | 0 | default: |
2185 | 0 | ABSL_UNREACHABLE(); |
2186 | 10.3k | } |
2187 | 10.3k | ProgramStepIndex index = visitor_->GetCurrentIndex(); |
2188 | 10.3k | if (JumpStepBase* jump_step_ptr = visitor_->AddStep(std::move(jump_step)); |
2189 | 10.3k | jump_step_ptr) { |
2190 | 10.3k | jump_step_ = Jump(index, jump_step_ptr); |
2191 | 10.3k | } |
2192 | 10.3k | } |
2193 | 20.7k | } |
2194 | | |
2195 | 0 | void BinaryCondVisitor::PostVisitTarget(const cel::Expr* expr) { |
2196 | 0 | if (visitor_->PlanRecursiveProgram()) { |
2197 | 0 | return; |
2198 | 0 | } |
2199 | 0 | if (short_circuiting_ && (cond_ == BinaryCond::kOptionalOr || |
2200 | 0 | cond_ == BinaryCond::kOptionalOrValue)) { |
2201 | | // If first branch evaluation result is enough to determine output, |
2202 | | // jump over the second branch and provide result of the first argument as |
2203 | | // final output. |
2204 | | // Retain a pointer to the jump step so we can update the target after |
2205 | | // planning the second argument. |
2206 | 0 | std::unique_ptr<JumpStepBase> jump_step; |
2207 | 0 | switch (cond_) { |
2208 | 0 | case BinaryCond::kOptionalOr: |
2209 | 0 | jump_step = CreateOptionalHasValueJumpStep(false, expr->id()); |
2210 | 0 | break; |
2211 | 0 | case BinaryCond::kOptionalOrValue: |
2212 | 0 | jump_step = CreateOptionalHasValueJumpStep(true, expr->id()); |
2213 | 0 | break; |
2214 | 0 | default: |
2215 | 0 | ABSL_UNREACHABLE(); |
2216 | 0 | } |
2217 | 0 | ProgramStepIndex index = visitor_->GetCurrentIndex(); |
2218 | 0 | if (JumpStepBase* jump_step_ptr = visitor_->AddStep(std::move(jump_step)); |
2219 | 0 | jump_step_ptr) { |
2220 | 0 | jump_step_ = Jump(index, jump_step_ptr); |
2221 | 0 | } |
2222 | 0 | } |
2223 | 0 | } |
2224 | | |
2225 | 10.3k | void BinaryCondVisitor::PostVisit(const cel::Expr* expr) { |
2226 | 10.3k | if (visitor_->PlanRecursiveProgram()) { |
2227 | 0 | switch (cond_) { |
2228 | 0 | case BinaryCond::kAnd: |
2229 | 0 | visitor_->MakeShortcircuitRecursive(expr, /*is_or=*/false); |
2230 | 0 | break; |
2231 | 0 | case BinaryCond::kOr: |
2232 | 0 | visitor_->MakeShortcircuitRecursive(expr, /*is_or=*/true); |
2233 | 0 | break; |
2234 | 0 | case BinaryCond::kOptionalOr: |
2235 | 0 | visitor_->MakeOptionalShortcircuit(expr, |
2236 | 0 | /*is_or_value=*/false); |
2237 | 0 | break; |
2238 | 0 | case BinaryCond::kOptionalOrValue: |
2239 | 0 | visitor_->MakeOptionalShortcircuit(expr, |
2240 | 0 | /*is_or_value=*/true); |
2241 | 0 | break; |
2242 | 0 | default: |
2243 | 0 | ABSL_UNREACHABLE(); |
2244 | 0 | } |
2245 | 0 | return; |
2246 | 0 | } |
2247 | | |
2248 | 10.3k | switch (cond_) { |
2249 | 6.56k | case BinaryCond::kAnd: |
2250 | 6.56k | visitor_->AddStep(CreateAndStep(expr->id())); |
2251 | 6.56k | break; |
2252 | 3.81k | case BinaryCond::kOr: |
2253 | 3.81k | visitor_->AddStep(CreateOrStep(expr->id())); |
2254 | 3.81k | break; |
2255 | 0 | case BinaryCond::kOptionalOr: |
2256 | 0 | visitor_->AddStep( |
2257 | 0 | CreateOptionalOrStep(/*is_or_value=*/false, expr->id())); |
2258 | 0 | break; |
2259 | 0 | case BinaryCond::kOptionalOrValue: |
2260 | 0 | visitor_->AddStep(CreateOptionalOrStep(/*is_or_value=*/true, expr->id())); |
2261 | 0 | break; |
2262 | 0 | default: |
2263 | 0 | ABSL_UNREACHABLE(); |
2264 | 10.3k | } |
2265 | 10.3k | if (short_circuiting_) { |
2266 | | // If short-circuiting is enabled, point the conditional jump past the |
2267 | | // boolean operator step. |
2268 | 10.3k | visitor_->SetProgressStatusError( |
2269 | 10.3k | jump_step_.set_target(visitor_->GetCurrentIndex())); |
2270 | 10.3k | } |
2271 | 10.3k | } |
2272 | | |
2273 | 783 | void TernaryCondVisitor::PreVisit(const cel::Expr* expr) { |
2274 | 783 | visitor_->ValidateOrError( |
2275 | 783 | !expr->call_expr().has_target() && expr->call_expr().args().size() == 3, |
2276 | 783 | "Invalid argument count for a ternary function call."); |
2277 | 783 | } |
2278 | | |
2279 | 2.21k | void TernaryCondVisitor::PostVisitArg(int arg_num, const cel::Expr* expr) { |
2280 | 2.21k | if (visitor_->PlanRecursiveProgram()) { |
2281 | 0 | return; |
2282 | 0 | } |
2283 | | // Ternary operator "_?_:_" requires a special handing. |
2284 | | // In contrary to regular function call, its execution affects the control |
2285 | | // flow of the overall CEL expression. |
2286 | | // If condition value (argument 0) is True, then control flow is unaffected |
2287 | | // as it is passed to the first conditional branch. Then, at the end of this |
2288 | | // branch, the jump is performed over the second conditional branch. |
2289 | | // If condition value is False, then jump is performed and control is passed |
2290 | | // to the beginning of the second conditional branch. |
2291 | | // If condition value is Error, then jump is peformed to bypass both |
2292 | | // conditional branches and provide Error as result of ternary operation. |
2293 | | |
2294 | | // condition argument for ternary operator |
2295 | 2.21k | if (arg_num == 0) { |
2296 | | // Jump in case of error or non-bool |
2297 | 781 | ProgramStepIndex error_jump_pos = visitor_->GetCurrentIndex(); |
2298 | 781 | auto* error_jump = |
2299 | 781 | visitor_->AddStep(CreateBoolCheckJumpStep({}, expr->id())); |
2300 | 781 | if (error_jump) { |
2301 | 781 | error_jump_ = Jump(error_jump_pos, error_jump); |
2302 | 781 | } |
2303 | | |
2304 | | // Jump to the second branch of execution |
2305 | | // Value is to be removed from the stack. |
2306 | 781 | ProgramStepIndex cond_jump_pos = visitor_->GetCurrentIndex(); |
2307 | 781 | auto* jump_to_second = |
2308 | 781 | visitor_->AddStep(CreateCondJumpStep(false, false, {}, expr->id())); |
2309 | 781 | if (jump_to_second) { |
2310 | 781 | jump_to_second_ = |
2311 | 781 | Jump(cond_jump_pos, static_cast<JumpStepBase*>(jump_to_second)); |
2312 | 781 | } |
2313 | 1.43k | } else if (arg_num == 1) { |
2314 | | // Jump after the first and over the second branch of execution. |
2315 | | // Value is to be removed from the stack. |
2316 | 780 | ProgramStepIndex jump_pos = visitor_->GetCurrentIndex(); |
2317 | 780 | auto* jump_after_first = visitor_->AddStep(CreateJumpStep({}, expr->id())); |
2318 | 780 | if (!jump_after_first) { |
2319 | 0 | return; |
2320 | 0 | } |
2321 | 780 | jump_after_first_ = Jump(jump_pos, jump_after_first); |
2322 | | |
2323 | 780 | if (visitor_->ValidateOrError( |
2324 | 780 | jump_to_second_.exists(), |
2325 | 780 | "Error configuring ternary operator: jump_to_second_ is null")) { |
2326 | 780 | visitor_->SetProgressStatusError( |
2327 | 780 | jump_to_second_.set_target(visitor_->GetCurrentIndex())); |
2328 | 780 | } |
2329 | 780 | } |
2330 | | // Code executed after traversing the final branch of execution |
2331 | | // (arg_num == 2) is placed in PostVisitCall, to make this method less |
2332 | | // clattered. |
2333 | 2.21k | } |
2334 | | |
2335 | 652 | void TernaryCondVisitor::PostVisit(const cel::Expr* expr) { |
2336 | 652 | if (visitor_->PlanRecursiveProgram()) { |
2337 | 0 | visitor_->MakeTernaryRecursive(expr); |
2338 | 0 | return; |
2339 | 0 | } |
2340 | | // Determine and set jump offset in jump instruction. |
2341 | 652 | if (visitor_->ValidateOrError( |
2342 | 652 | error_jump_.exists(), |
2343 | 652 | "Error configuring ternary operator: error_jump_ is null")) { |
2344 | 652 | visitor_->SetProgressStatusError( |
2345 | 652 | error_jump_.set_target(visitor_->GetCurrentIndex())); |
2346 | 652 | } |
2347 | 652 | if (visitor_->ValidateOrError( |
2348 | 652 | jump_after_first_.exists(), |
2349 | 652 | "Error configuring ternary operator: jump_after_first_ is null")) { |
2350 | 652 | visitor_->SetProgressStatusError( |
2351 | 652 | jump_after_first_.set_target(visitor_->GetCurrentIndex())); |
2352 | 652 | } |
2353 | 652 | } |
2354 | | |
2355 | 0 | void ExhaustiveTernaryCondVisitor::PreVisit(const cel::Expr* expr) { |
2356 | 0 | visitor_->ValidateOrError( |
2357 | 0 | !expr->call_expr().has_target() && expr->call_expr().args().size() == 3, |
2358 | 0 | "Invalid argument count for a ternary function call."); |
2359 | 0 | } |
2360 | | |
2361 | 0 | void ExhaustiveTernaryCondVisitor::PostVisit(const cel::Expr* expr) { |
2362 | 0 | if (visitor_->PlanRecursiveProgram()) { |
2363 | 0 | visitor_->MakeTernaryRecursive(expr); |
2364 | 0 | return; |
2365 | 0 | } |
2366 | 0 | visitor_->AddStep(CreateTernaryStep(expr->id())); |
2367 | 0 | } |
2368 | | |
2369 | 0 | void ComprehensionVisitor::PreVisit(const cel::Expr* expr) { |
2370 | 0 | if (is_trivial_) { |
2371 | 0 | visitor_->SuppressBranch(&expr->comprehension_expr().iter_range()); |
2372 | 0 | visitor_->SuppressBranch(&expr->comprehension_expr().loop_condition()); |
2373 | 0 | visitor_->SuppressBranch(&expr->comprehension_expr().loop_step()); |
2374 | 0 | } |
2375 | 0 | } |
2376 | | |
2377 | | absl::Status ComprehensionVisitor::PostVisitArgDefault( |
2378 | 0 | cel::ComprehensionArg arg_num, const cel::Expr* expr) { |
2379 | 0 | if (visitor_->PlanRecursiveProgram()) { |
2380 | 0 | return absl::OkStatus(); |
2381 | 0 | } |
2382 | 0 | switch (arg_num) { |
2383 | 0 | case cel::ITER_RANGE: { |
2384 | 0 | init_step_pos_ = visitor_->GetCurrentIndex(); |
2385 | 0 | init_step_ = visitor_->AddStep( |
2386 | 0 | std::make_unique<ComprehensionInitStep>(expr->id())); |
2387 | 0 | break; |
2388 | 0 | } |
2389 | 0 | case cel::ACCU_INIT: { |
2390 | 0 | next_step_pos_ = visitor_->GetCurrentIndex(); |
2391 | 0 | next_step_ = visitor_->AddStep(std::make_unique<ComprehensionNextStep>( |
2392 | 0 | iter_slot_, iter2_slot_, accu_slot_, expr->id())); |
2393 | 0 | break; |
2394 | 0 | } |
2395 | 0 | case cel::LOOP_CONDITION: { |
2396 | 0 | cond_step_pos_ = visitor_->GetCurrentIndex(); |
2397 | 0 | cond_step_ = visitor_->AddStep(std::make_unique<ComprehensionCondStep>( |
2398 | 0 | iter_slot_, iter2_slot_, accu_slot_, short_circuiting_, expr->id())); |
2399 | 0 | break; |
2400 | 0 | } |
2401 | 0 | case cel::LOOP_STEP: { |
2402 | 0 | ProgramStepIndex index = visitor_->GetCurrentIndex(); |
2403 | 0 | auto* jump_to_next = visitor_->AddStep(CreateJumpStep({}, expr->id())); |
2404 | 0 | if (!jump_to_next) { |
2405 | 0 | break; |
2406 | 0 | } |
2407 | 0 | Jump jump_helper(index, jump_to_next); |
2408 | 0 | visitor_->SetProgressStatusError(jump_helper.set_target(next_step_pos_)); |
2409 | | |
2410 | | // Set offsets jumping to the result step. |
2411 | 0 | if (cond_step_) { |
2412 | 0 | CEL_ASSIGN_OR_RETURN( |
2413 | 0 | int jump_from_cond, |
2414 | 0 | Jump::CalculateOffset(cond_step_pos_, visitor_->GetCurrentIndex())); |
2415 | 0 | cond_step_->set_jump_offset(jump_from_cond); |
2416 | 0 | } |
2417 | | |
2418 | 0 | if (next_step_) { |
2419 | 0 | CEL_ASSIGN_OR_RETURN( |
2420 | 0 | int jump_from_next, |
2421 | 0 | Jump::CalculateOffset(next_step_pos_, visitor_->GetCurrentIndex())); |
2422 | |
|
2423 | 0 | next_step_->set_jump_offset(jump_from_next); |
2424 | 0 | } |
2425 | 0 | break; |
2426 | 0 | } |
2427 | 0 | case cel::RESULT: { |
2428 | 0 | if (!init_step_ || !next_step_ || !cond_step_) { |
2429 | | // Encountered an error earlier. Can't determine where to jump. |
2430 | 0 | break; |
2431 | 0 | } |
2432 | 0 | visitor_->AddStep(CreateComprehensionFinishStep(accu_slot_, expr->id())); |
2433 | | // Set offsets jumping past the result step in case of errors. |
2434 | 0 | CEL_ASSIGN_OR_RETURN( |
2435 | 0 | int jump_from_init, |
2436 | 0 | Jump::CalculateOffset(init_step_pos_, visitor_->GetCurrentIndex())); |
2437 | 0 | init_step_->set_error_jump_offset(jump_from_init); |
2438 | |
|
2439 | 0 | CEL_ASSIGN_OR_RETURN( |
2440 | 0 | int jump_from_next, |
2441 | 0 | Jump::CalculateOffset(next_step_pos_, visitor_->GetCurrentIndex())); |
2442 | 0 | next_step_->set_error_jump_offset(jump_from_next); |
2443 | |
|
2444 | 0 | CEL_ASSIGN_OR_RETURN( |
2445 | 0 | int jump_from_cond, |
2446 | 0 | Jump::CalculateOffset(cond_step_pos_, visitor_->GetCurrentIndex())); |
2447 | 0 | cond_step_->set_error_jump_offset(jump_from_cond); |
2448 | 0 | break; |
2449 | 0 | } |
2450 | 0 | } |
2451 | 0 | return absl::OkStatus(); |
2452 | 0 | } |
2453 | | |
2454 | | void ComprehensionVisitor::PostVisitArgTrivial(cel::ComprehensionArg arg_num, |
2455 | 0 | const cel::Expr* expr) { |
2456 | 0 | if (visitor_->PlanRecursiveProgram()) { |
2457 | 0 | return; |
2458 | 0 | } |
2459 | 0 | switch (arg_num) { |
2460 | 0 | case cel::ITER_RANGE: { |
2461 | 0 | break; |
2462 | 0 | } |
2463 | 0 | case cel::ACCU_INIT: { |
2464 | 0 | if (!accu_init_extracted_) { |
2465 | 0 | visitor_->AddStep(CreateAssignSlotAndPopStep(accu_slot_)); |
2466 | 0 | } |
2467 | 0 | break; |
2468 | 0 | } |
2469 | 0 | case cel::LOOP_CONDITION: { |
2470 | 0 | break; |
2471 | 0 | } |
2472 | 0 | case cel::LOOP_STEP: { |
2473 | 0 | break; |
2474 | 0 | } |
2475 | 0 | case cel::RESULT: { |
2476 | 0 | visitor_->AddStep(CreateClearSlotStep(accu_slot_, expr->id())); |
2477 | 0 | break; |
2478 | 0 | } |
2479 | 0 | } |
2480 | 0 | } |
2481 | | |
2482 | 0 | void ComprehensionVisitor::PostVisit(const cel::Expr* expr) { |
2483 | 0 | if (is_trivial_) { |
2484 | 0 | visitor_->MaybeMakeBindRecursive(expr, &expr->comprehension_expr(), |
2485 | 0 | accu_slot_); |
2486 | 0 | return; |
2487 | 0 | } |
2488 | 0 | visitor_->MaybeMakeComprehensionRecursive( |
2489 | 0 | expr, &expr->comprehension_expr(), iter_slot_, iter2_slot_, accu_slot_); |
2490 | 0 | } |
2491 | | |
2492 | | // Flattens the expression table into the end of the mainline expression vector |
2493 | | // and returns an index to the individual sub expressions. |
2494 | | std::vector<ExecutionPathView> FlattenExpressionTable( |
2495 | 9.92k | ProgramBuilder& program_builder, ExecutionPath& main) { |
2496 | 9.92k | std::vector<std::pair<size_t, size_t>> ranges; |
2497 | 9.92k | main = program_builder.FlattenMain(); |
2498 | 9.92k | ranges.push_back(std::make_pair(0, main.size())); |
2499 | | |
2500 | 9.92k | std::vector<ExecutionPath> subexpressions = |
2501 | 9.92k | program_builder.FlattenSubexpressions(); |
2502 | 9.92k | for (auto& subexpression : subexpressions) { |
2503 | 0 | ranges.push_back(std::make_pair(main.size(), subexpression.size())); |
2504 | 0 | absl::c_move(subexpression, std::back_inserter(main)); |
2505 | 0 | } |
2506 | | |
2507 | 9.92k | std::vector<ExecutionPathView> subexpression_indexes; |
2508 | 9.92k | subexpression_indexes.reserve(ranges.size()); |
2509 | 9.92k | for (const auto& range : ranges) { |
2510 | 9.92k | subexpression_indexes.push_back( |
2511 | 9.92k | absl::MakeSpan(main).subspan(range.first, range.second)); |
2512 | 9.92k | } |
2513 | 9.92k | return subexpression_indexes; |
2514 | 9.92k | } |
2515 | | |
2516 | | absl::Status CheckAstExtensions( |
2517 | 10.3k | const std::vector<cel::ExtensionSpec>& extensions) { |
2518 | 10.3k | for (const cel::ExtensionSpec& extension : extensions) { |
2519 | 0 | if (extension.id() == "cel_block" && extension.version().major() == 1) { |
2520 | | // cel_block v1 is always supported. |
2521 | 0 | continue; |
2522 | 0 | } |
2523 | | |
2524 | | // TODO(uncreated-issue/89): Add support for json field names. |
2525 | 0 | return absl::InvalidArgumentError(absl::StrCat( |
2526 | 0 | "unsupported CEL extension: ", extension.id(), "@", |
2527 | 0 | extension.version().major(), ".", extension.version().minor())); |
2528 | 0 | } |
2529 | 10.3k | return absl::OkStatus(); |
2530 | 10.3k | } |
2531 | | |
2532 | | } // namespace |
2533 | | |
2534 | | absl::StatusOr<FlatExpression> FlatExprBuilder::CreateExpressionImpl( |
2535 | 10.3k | std::unique_ptr<Ast> ast, std::vector<RuntimeIssue>* issues) const { |
2536 | 10.3k | if (absl::StartsWith(container_, ".") || absl::EndsWith(container_, ".")) { |
2537 | 0 | return absl::InvalidArgumentError( |
2538 | 0 | absl::StrCat("Invalid expression container: '", container_, "'")); |
2539 | 0 | } |
2540 | | |
2541 | 10.3k | RuntimeIssue::Severity max_severity = options_.fail_on_warnings |
2542 | 10.3k | ? RuntimeIssue::Severity::kWarning |
2543 | 10.3k | : RuntimeIssue::Severity::kError; |
2544 | 10.3k | IssueCollector issue_collector(max_severity); |
2545 | | |
2546 | 10.3k | absl::StatusOr<std::vector<cel::ExtensionSpec>> runtime_extensions = |
2547 | 10.3k | ExtractAndValidateRuntimeExtensions(*ast); |
2548 | | |
2549 | 10.3k | if (!runtime_extensions.ok()) { |
2550 | 0 | CEL_RETURN_IF_ERROR(issue_collector.AddIssue( |
2551 | 0 | RuntimeIssue::CreateError(runtime_extensions.status()))); |
2552 | 0 | } |
2553 | | |
2554 | 10.3k | auto status = CheckAstExtensions(*runtime_extensions); |
2555 | 10.3k | if (!status.ok()) { |
2556 | 0 | CEL_RETURN_IF_ERROR( |
2557 | 0 | issue_collector.AddIssue(RuntimeIssue::CreateError(status))); |
2558 | 0 | } |
2559 | | |
2560 | 10.3k | Resolver resolver(container_, function_registry_, type_registry_, |
2561 | 10.3k | GetTypeProvider(), |
2562 | 10.3k | options_.enable_qualified_type_identifiers); |
2563 | | |
2564 | 10.3k | std::shared_ptr<google::protobuf::Arena> arena; |
2565 | 10.3k | ProgramBuilder program_builder; |
2566 | 10.3k | PlannerContext extension_context(env_, resolver, options_, GetTypeProvider(), |
2567 | 10.3k | issue_collector, program_builder, arena); |
2568 | | |
2569 | 10.3k | for (const std::unique_ptr<AstTransform>& transform : ast_transforms_) { |
2570 | 10.3k | CEL_RETURN_IF_ERROR(transform->UpdateAst(extension_context, *ast)); |
2571 | 10.3k | } |
2572 | | |
2573 | 10.3k | std::vector<std::unique_ptr<ProgramOptimizer>> optimizers; |
2574 | 10.3k | for (const ProgramOptimizerFactory& optimizer_factory : program_optimizers_) { |
2575 | 0 | CEL_ASSIGN_OR_RETURN(auto optimizer, |
2576 | 0 | optimizer_factory(extension_context, *ast)); |
2577 | 0 | if (optimizer != nullptr) { |
2578 | 0 | optimizers.push_back(std::move(optimizer)); |
2579 | 0 | } |
2580 | 0 | } |
2581 | | |
2582 | | // These objects are expected to remain scoped to one build call -- references |
2583 | | // to them shouldn't be persisted in any part of the result expression. |
2584 | 10.3k | FlatExprVisitor visitor(resolver, options_, std::move(optimizers), |
2585 | 10.3k | ast->reference_map(), GetTypeProvider(), |
2586 | 10.3k | issue_collector, program_builder, extension_context, |
2587 | 10.3k | enable_optional_types_); |
2588 | | |
2589 | 10.3k | if (options_.max_recursion_depth == -1 || options_.max_recursion_depth > 0) { |
2590 | 0 | int depth_limit = options_.max_recursion_depth == -1 |
2591 | 0 | ? std::numeric_limits<int>::max() |
2592 | 0 | : options_.max_recursion_depth; |
2593 | 0 | visitor.SetMaxRecursionDepth(depth_limit); |
2594 | 0 | } |
2595 | | |
2596 | 10.3k | cel::TraversalOptions opts; |
2597 | 10.3k | opts.use_comprehension_callbacks = true; |
2598 | 10.3k | AstTraverse(ast->root_expr(), visitor, opts); |
2599 | | |
2600 | 10.3k | if (!visitor.progress_status().ok()) { |
2601 | 410 | return visitor.progress_status(); |
2602 | 410 | } |
2603 | | |
2604 | 9.92k | if (issues != nullptr) { |
2605 | 0 | (*issues) = issue_collector.ExtractIssues(); |
2606 | 0 | } |
2607 | | |
2608 | 9.92k | ExecutionPath execution_path; |
2609 | 9.92k | std::vector<ExecutionPathView> subexpressions = |
2610 | 9.92k | FlattenExpressionTable(program_builder, execution_path); |
2611 | | |
2612 | 9.92k | return FlatExpression(std::move(execution_path), std::move(subexpressions), |
2613 | 9.92k | visitor.slot_count(), GetTypeProvider(), options_, |
2614 | 9.92k | std::move(arena)); |
2615 | 10.3k | } |
2616 | 40.9k | const cel::TypeProvider& FlatExprBuilder::GetTypeProvider() const { |
2617 | 40.9k | return use_legacy_type_provider_ |
2618 | 40.9k | ? static_cast<const cel::TypeProvider&>( |
2619 | 40.9k | *GetLegacyRuntimeTypeProvider(type_registry_)) |
2620 | 40.9k | : GetRuntimeTypeProvider(type_registry_); |
2621 | 40.9k | } |
2622 | | |
2623 | | } // namespace google::api::expr::runtime |