/proc/self/cwd/common/ast_traverse.cc
Line | Count | Source |
1 | | // Copyright 2018 Google LLC |
2 | | // |
3 | | // Licensed under the Apache License, Version 2.0 (the "License"); |
4 | | // you may not use this file except in compliance with the License. |
5 | | // You may obtain a copy of the License at |
6 | | // |
7 | | // https://www.apache.org/licenses/LICENSE-2.0 |
8 | | // |
9 | | // Unless required by applicable law or agreed to in writing, software |
10 | | // distributed under the License is distributed on an "AS IS" BASIS, |
11 | | // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. |
12 | | // See the License for the specific language governing permissions and |
13 | | // limitations under the License. |
14 | | |
15 | | #include "common/ast_traverse.h" |
16 | | |
17 | | #include <memory> |
18 | | #include <stack> |
19 | | |
20 | | #include "absl/log/absl_log.h" |
21 | | #include "absl/types/variant.h" |
22 | | #include "common/ast_visitor.h" |
23 | | #include "common/constant.h" |
24 | | #include "common/expr.h" |
25 | | |
26 | | namespace cel { |
27 | | |
28 | | namespace { |
29 | | |
30 | | struct ArgRecord { |
31 | | // Not null. |
32 | | const Expr* expr; |
33 | | |
34 | | // For records that are direct arguments to call, we need to call |
35 | | // the CallArg visitor immediately after the argument is evaluated. |
36 | | const Expr* calling_expr; |
37 | | int call_arg; |
38 | | }; |
39 | | |
40 | | struct ComprehensionRecord { |
41 | | // Not null. |
42 | | const Expr* expr; |
43 | | |
44 | | const ComprehensionExpr* comprehension; |
45 | | const Expr* comprehension_expr; |
46 | | ComprehensionArg comprehension_arg; |
47 | | bool use_comprehension_callbacks; |
48 | | }; |
49 | | |
50 | | struct ExprRecord { |
51 | | // Not null. |
52 | | const Expr* expr; |
53 | | }; |
54 | | |
55 | | using StackRecordKind = |
56 | | std::variant<ExprRecord, ArgRecord, ComprehensionRecord>; |
57 | | |
58 | | struct StackRecord { |
59 | | public: |
60 | | static constexpr int kTarget = -2; |
61 | | |
62 | 385k | explicit StackRecord(const Expr* e) { |
63 | 385k | ExprRecord record; |
64 | 385k | record.expr = e; |
65 | 385k | record_variant = record; |
66 | 385k | } |
67 | | |
68 | | StackRecord(const Expr* e, const ComprehensionExpr* comprehension, |
69 | | const Expr* comprehension_expr, |
70 | | ComprehensionArg comprehension_arg, |
71 | 0 | bool use_comprehension_callbacks) { |
72 | 0 | if (use_comprehension_callbacks) { |
73 | 0 | ComprehensionRecord record; |
74 | 0 | record.expr = e; |
75 | 0 | record.comprehension = comprehension; |
76 | 0 | record.comprehension_expr = comprehension_expr; |
77 | 0 | record.comprehension_arg = comprehension_arg; |
78 | 0 | record.use_comprehension_callbacks = use_comprehension_callbacks; |
79 | 0 | record_variant = record; |
80 | 0 | return; |
81 | 0 | } |
82 | 0 | ArgRecord record; |
83 | 0 | record.expr = e; |
84 | 0 | record.calling_expr = comprehension_expr; |
85 | 0 | record.call_arg = comprehension_arg; |
86 | 0 | record_variant = record; |
87 | 0 | } |
88 | | |
89 | 318k | StackRecord(const Expr* e, const Expr* call, int argnum) { |
90 | 318k | ArgRecord record; |
91 | 318k | record.expr = e; |
92 | 318k | record.calling_expr = call; |
93 | 318k | record.call_arg = argnum; |
94 | 318k | record_variant = record; |
95 | 318k | } |
96 | | StackRecordKind record_variant; |
97 | | bool visited = false; |
98 | | }; |
99 | | |
100 | | struct PreVisitor { |
101 | 385k | void operator()(const ExprRecord& record) { |
102 | 385k | const Expr* expr = record.expr; |
103 | 385k | visitor->PreVisitExpr(*expr); |
104 | 385k | if (expr->has_select_expr()) { |
105 | 17.2k | visitor->PreVisitSelect(*expr, expr->select_expr()); |
106 | 368k | } else if (expr->has_call_expr()) { |
107 | 167k | visitor->PreVisitCall(*expr, expr->call_expr()); |
108 | 200k | } else if (expr->has_comprehension_expr()) { |
109 | 0 | visitor->PreVisitComprehension(*expr, expr->comprehension_expr()); |
110 | 200k | } else { |
111 | | // No pre-visit action. |
112 | 200k | } |
113 | 385k | } |
114 | | |
115 | | // Do nothing for Arg variant. |
116 | 318k | void operator()(const ArgRecord&) {} |
117 | | |
118 | 0 | void operator()(const ComprehensionRecord& record) { |
119 | 0 | visitor->PreVisitComprehensionSubexpression(*record.comprehension_expr, |
120 | 0 | *record.comprehension, |
121 | 0 | record.comprehension_arg); |
122 | 0 | } |
123 | | |
124 | | AstVisitor* visitor; |
125 | | }; |
126 | | |
127 | 703k | void PreVisit(const StackRecord& record, AstVisitor* visitor) { |
128 | 703k | absl::visit(PreVisitor{visitor}, record.record_variant); |
129 | 703k | } |
130 | | |
131 | | struct PostVisitor { |
132 | 385k | void operator()(const ExprRecord& record) { |
133 | 385k | const Expr* expr = record.expr; |
134 | 385k | struct { |
135 | 385k | AstVisitor* visitor; |
136 | 385k | const Expr* expr; |
137 | 385k | void operator()(const Constant& constant) { |
138 | 118k | visitor->PostVisitConst(*expr, expr->const_expr()); |
139 | 118k | } |
140 | 385k | void operator()(const IdentExpr& ident) { |
141 | 61.6k | visitor->PostVisitIdent(*expr, expr->ident_expr()); |
142 | 61.6k | } |
143 | 385k | void operator()(const SelectExpr& select) { |
144 | 17.2k | visitor->PostVisitSelect(*expr, expr->select_expr()); |
145 | 17.2k | } |
146 | 385k | void operator()(const CallExpr& call) { |
147 | 167k | visitor->PostVisitCall(*expr, expr->call_expr()); |
148 | 167k | } |
149 | 385k | void operator()(const ListExpr& create_list) { |
150 | 9.20k | visitor->PostVisitList(*expr, expr->list_expr()); |
151 | 9.20k | } |
152 | 385k | void operator()(const StructExpr& create_struct) { |
153 | 2.45k | visitor->PostVisitStruct(*expr, expr->struct_expr()); |
154 | 2.45k | } |
155 | 385k | void operator()(const MapExpr& map_expr) { |
156 | 8.70k | visitor->PostVisitMap(*expr, expr->map_expr()); |
157 | 8.70k | } |
158 | 385k | void operator()(const ComprehensionExpr& comprehension) { |
159 | 0 | visitor->PostVisitComprehension(*expr, expr->comprehension_expr()); |
160 | 0 | } |
161 | 385k | void operator()(const UnspecifiedExpr&) { |
162 | 0 | ABSL_LOG(ERROR) << "Unsupported Expr kind"; |
163 | 0 | } |
164 | 385k | } handler{visitor, record.expr}; |
165 | 385k | absl::visit(handler, record.expr->kind()); |
166 | | |
167 | 385k | visitor->PostVisitExpr(*expr); |
168 | 385k | } |
169 | | |
170 | 318k | void operator()(const ArgRecord& record) { |
171 | 318k | if (record.call_arg == StackRecord::kTarget) { |
172 | 782 | visitor->PostVisitTarget(*record.calling_expr); |
173 | 317k | } else { |
174 | 317k | visitor->PostVisitArg(*record.calling_expr, record.call_arg); |
175 | 317k | } |
176 | 318k | } |
177 | | |
178 | 0 | void operator()(const ComprehensionRecord& record) { |
179 | 0 | visitor->PostVisitComprehensionSubexpression(*record.comprehension_expr, |
180 | 0 | *record.comprehension, |
181 | 0 | record.comprehension_arg); |
182 | 0 | } |
183 | | |
184 | | AstVisitor* visitor; |
185 | | }; |
186 | | |
187 | 703k | void PostVisit(const StackRecord& record, AstVisitor* visitor) { |
188 | 703k | absl::visit(PostVisitor{visitor}, record.record_variant); |
189 | 703k | } |
190 | | |
191 | | void PushSelectDeps(const SelectExpr* select_expr, |
192 | 17.2k | std::stack<StackRecord>* stack) { |
193 | 17.2k | if (select_expr->has_operand()) { |
194 | 17.2k | stack->push(StackRecord(&select_expr->operand())); |
195 | 17.2k | } |
196 | 17.2k | } |
197 | | |
198 | | void PushCallDeps(const CallExpr* call_expr, const Expr* expr, |
199 | 167k | std::stack<StackRecord>* stack) { |
200 | 167k | const int arg_size = call_expr->args().size(); |
201 | | // Our contract is that we visit arguments in order. To do that, we need |
202 | | // to push them onto the stack in reverse order. |
203 | 485k | for (int i = arg_size - 1; i >= 0; --i) { |
204 | 317k | stack->push(StackRecord(&call_expr->args()[i], expr, i)); |
205 | 317k | } |
206 | | // Are we receiver-style? |
207 | 167k | if (call_expr->has_target()) { |
208 | 782 | stack->push(StackRecord(&call_expr->target(), expr, StackRecord::kTarget)); |
209 | 782 | } |
210 | 167k | } |
211 | | |
212 | 9.20k | void PushListDeps(const ListExpr* list_expr, std::stack<StackRecord>* stack) { |
213 | 9.20k | const auto& elements = list_expr->elements(); |
214 | 40.2k | for (auto it = elements.rbegin(); it != elements.rend(); ++it) { |
215 | 31.0k | const auto& element = *it; |
216 | 31.0k | stack->push(StackRecord(&element.expr())); |
217 | 31.0k | } |
218 | 9.20k | } |
219 | | |
220 | | void PushStructDeps(const StructExpr* struct_expr, |
221 | 2.45k | std::stack<StackRecord>* stack) { |
222 | 2.45k | const auto& entries = struct_expr->fields(); |
223 | 3.31k | for (auto it = entries.rbegin(); it != entries.rend(); ++it) { |
224 | 858 | const auto& entry = *it; |
225 | | // The contract is to visit key, then value. So put them on the stack |
226 | | // in the opposite order. |
227 | 858 | if (entry.has_value()) { |
228 | 858 | stack->push(StackRecord(&entry.value())); |
229 | 858 | } |
230 | 858 | } |
231 | 2.45k | } |
232 | | |
233 | 8.70k | void PushMapDeps(const MapExpr* map_expr, std::stack<StackRecord>* stack) { |
234 | 8.70k | const auto& entries = map_expr->entries(); |
235 | 12.5k | for (auto it = entries.rbegin(); it != entries.rend(); ++it) { |
236 | 3.86k | const auto& entry = *it; |
237 | | // The contract is to visit key, then value. So put them on the stack |
238 | | // in the opposite order. |
239 | 3.86k | if (entry.has_value()) { |
240 | 3.86k | stack->push(StackRecord(&entry.value())); |
241 | 3.86k | } |
242 | | // The contract is to visit key, then value. So put them on the stack |
243 | | // in the opposite order. |
244 | 3.86k | if (entry.has_key()) { |
245 | 3.86k | stack->push(StackRecord(&entry.key())); |
246 | 3.86k | } |
247 | 3.86k | } |
248 | 8.70k | } |
249 | | |
250 | | void PushComprehensionDeps(const ComprehensionExpr* c, const Expr* expr, |
251 | | std::stack<StackRecord>* stack, |
252 | 0 | bool use_comprehension_callbacks) { |
253 | 0 | StackRecord iter_range(&c->iter_range(), c, expr, ITER_RANGE, |
254 | 0 | use_comprehension_callbacks); |
255 | 0 | StackRecord accu_init(&c->accu_init(), c, expr, ACCU_INIT, |
256 | 0 | use_comprehension_callbacks); |
257 | 0 | StackRecord loop_condition(&c->loop_condition(), c, expr, LOOP_CONDITION, |
258 | 0 | use_comprehension_callbacks); |
259 | 0 | StackRecord loop_step(&c->loop_step(), c, expr, LOOP_STEP, |
260 | 0 | use_comprehension_callbacks); |
261 | 0 | StackRecord result(&c->result(), c, expr, RESULT, |
262 | 0 | use_comprehension_callbacks); |
263 | | // Push them in reverse order. |
264 | 0 | stack->push(result); |
265 | 0 | stack->push(loop_step); |
266 | 0 | stack->push(loop_condition); |
267 | 0 | stack->push(accu_init); |
268 | 0 | stack->push(iter_range); |
269 | 0 | } |
270 | | |
271 | | struct PushDepsVisitor { |
272 | 385k | void operator()(const ExprRecord& record) { |
273 | 385k | struct { |
274 | 385k | std::stack<StackRecord>& stack; |
275 | 385k | const TraversalOptions& options; |
276 | 385k | const ExprRecord& record; |
277 | 385k | void operator()(const Constant& constant) {} |
278 | 385k | void operator()(const IdentExpr& ident) {} |
279 | 385k | void operator()(const SelectExpr& select) { |
280 | 17.2k | PushSelectDeps(&record.expr->select_expr(), &stack); |
281 | 17.2k | } |
282 | 385k | void operator()(const CallExpr& call) { |
283 | 167k | PushCallDeps(&record.expr->call_expr(), record.expr, &stack); |
284 | 167k | } |
285 | 385k | void operator()(const ListExpr& create_list) { |
286 | 9.20k | PushListDeps(&record.expr->list_expr(), &stack); |
287 | 9.20k | } |
288 | 385k | void operator()(const StructExpr& create_struct) { |
289 | 2.45k | PushStructDeps(&record.expr->struct_expr(), &stack); |
290 | 2.45k | } |
291 | 385k | void operator()(const MapExpr& map_expr) { |
292 | 8.70k | PushMapDeps(&record.expr->map_expr(), &stack); |
293 | 8.70k | } |
294 | 385k | void operator()(const ComprehensionExpr& comprehension) { |
295 | 0 | PushComprehensionDeps(&record.expr->comprehension_expr(), record.expr, |
296 | 0 | &stack, options.use_comprehension_callbacks); |
297 | 0 | } |
298 | 385k | void operator()(const UnspecifiedExpr&) {} |
299 | 385k | } handler{stack, options, record}; |
300 | 385k | absl::visit(handler, record.expr->kind()); |
301 | 385k | } |
302 | | |
303 | 318k | void operator()(const ArgRecord& record) { |
304 | 318k | stack.push(StackRecord(record.expr)); |
305 | 318k | } |
306 | | |
307 | 0 | void operator()(const ComprehensionRecord& record) { |
308 | 0 | stack.push(StackRecord(record.expr)); |
309 | 0 | } |
310 | | |
311 | | std::stack<StackRecord>& stack; |
312 | | const TraversalOptions& options; |
313 | | }; |
314 | | |
315 | | void PushDependencies(const StackRecord& record, std::stack<StackRecord>& stack, |
316 | 703k | const TraversalOptions& options) { |
317 | 703k | absl::visit(PushDepsVisitor{stack, options}, record.record_variant); |
318 | 703k | } |
319 | | |
320 | | } // namespace |
321 | | |
322 | | namespace common_internal { |
323 | | struct AstTraversalState { |
324 | | std::stack<StackRecord> stack; |
325 | | }; |
326 | | } // namespace common_internal |
327 | | |
328 | | AstTraversal AstTraversal::Create(const cel::Expr& ast, |
329 | 0 | const TraversalOptions& options) { |
330 | 0 | AstTraversal instance(options); |
331 | 0 | instance.state_ = std::make_unique<common_internal::AstTraversalState>(); |
332 | 0 | instance.state_->stack.push(StackRecord(&ast)); |
333 | 0 | return instance; |
334 | 0 | } |
335 | | |
336 | 0 | AstTraversal::AstTraversal(TraversalOptions options) : options_(options) {} |
337 | | |
338 | 0 | AstTraversal::~AstTraversal() = default; |
339 | | |
340 | 0 | bool AstTraversal::Step(AstVisitor& visitor) { |
341 | 0 | if (IsDone()) { |
342 | 0 | return false; |
343 | 0 | } |
344 | 0 | auto& stack = state_->stack; |
345 | 0 | StackRecord& record = stack.top(); |
346 | 0 | if (!record.visited) { |
347 | 0 | PreVisit(record, &visitor); |
348 | 0 | PushDependencies(record, stack, options_); |
349 | 0 | record.visited = true; |
350 | 0 | } else { |
351 | 0 | PostVisit(record, &visitor); |
352 | 0 | stack.pop(); |
353 | 0 | } |
354 | |
|
355 | 0 | return !stack.empty(); |
356 | 0 | } |
357 | | |
358 | 0 | bool AstTraversal::IsDone() { |
359 | 0 | return state_ == nullptr || state_->stack.empty(); |
360 | 0 | } |
361 | | |
362 | | void AstTraverse(const Expr& expr, AstVisitor& visitor, |
363 | 10.3k | TraversalOptions options) { |
364 | 10.3k | std::stack<StackRecord> stack; |
365 | 10.3k | stack.push(StackRecord(&expr)); |
366 | | |
367 | 1.41M | while (!stack.empty()) { |
368 | 1.40M | StackRecord& record = stack.top(); |
369 | 1.40M | if (!record.visited) { |
370 | 703k | PreVisit(record, &visitor); |
371 | 703k | PushDependencies(record, stack, options); |
372 | 703k | record.visited = true; |
373 | 703k | } else { |
374 | 703k | PostVisit(record, &visitor); |
375 | 703k | stack.pop(); |
376 | 703k | } |
377 | 1.40M | } |
378 | 10.3k | } |
379 | | |
380 | | } // namespace cel |