Coverage Report

Created: 2026-05-27 07:00

next uncovered line (L), next uncovered region (R), next uncovered branch (B)
/proc/self/cwd/common/ast_traverse.cc
Line
Count
Source
1
// Copyright 2018 Google LLC
2
//
3
// Licensed under the Apache License, Version 2.0 (the "License");
4
// you may not use this file except in compliance with the License.
5
// You may obtain a copy of the License at
6
//
7
//      https://www.apache.org/licenses/LICENSE-2.0
8
//
9
// Unless required by applicable law or agreed to in writing, software
10
// distributed under the License is distributed on an "AS IS" BASIS,
11
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
// See the License for the specific language governing permissions and
13
// limitations under the License.
14
15
#include "common/ast_traverse.h"
16
17
#include <memory>
18
#include <stack>
19
20
#include "absl/log/absl_log.h"
21
#include "absl/types/variant.h"
22
#include "common/ast_visitor.h"
23
#include "common/constant.h"
24
#include "common/expr.h"
25
26
namespace cel {
27
28
namespace {
29
30
struct ArgRecord {
31
  // Not null.
32
  const Expr* expr;
33
34
  // For records that are direct arguments to call, we need to call
35
  // the CallArg visitor immediately after the argument is evaluated.
36
  const Expr* calling_expr;
37
  int call_arg;
38
};
39
40
struct ComprehensionRecord {
41
  // Not null.
42
  const Expr* expr;
43
44
  const ComprehensionExpr* comprehension;
45
  const Expr* comprehension_expr;
46
  ComprehensionArg comprehension_arg;
47
  bool use_comprehension_callbacks;
48
};
49
50
struct ExprRecord {
51
  // Not null.
52
  const Expr* expr;
53
};
54
55
using StackRecordKind =
56
    std::variant<ExprRecord, ArgRecord, ComprehensionRecord>;
57
58
struct StackRecord {
59
 public:
60
  static constexpr int kTarget = -2;
61
62
385k
  explicit StackRecord(const Expr* e) {
63
385k
    ExprRecord record;
64
385k
    record.expr = e;
65
385k
    record_variant = record;
66
385k
  }
67
68
  StackRecord(const Expr* e, const ComprehensionExpr* comprehension,
69
              const Expr* comprehension_expr,
70
              ComprehensionArg comprehension_arg,
71
0
              bool use_comprehension_callbacks) {
72
0
    if (use_comprehension_callbacks) {
73
0
      ComprehensionRecord record;
74
0
      record.expr = e;
75
0
      record.comprehension = comprehension;
76
0
      record.comprehension_expr = comprehension_expr;
77
0
      record.comprehension_arg = comprehension_arg;
78
0
      record.use_comprehension_callbacks = use_comprehension_callbacks;
79
0
      record_variant = record;
80
0
      return;
81
0
    }
82
0
    ArgRecord record;
83
0
    record.expr = e;
84
0
    record.calling_expr = comprehension_expr;
85
0
    record.call_arg = comprehension_arg;
86
0
    record_variant = record;
87
0
  }
88
89
318k
  StackRecord(const Expr* e, const Expr* call, int argnum) {
90
318k
    ArgRecord record;
91
318k
    record.expr = e;
92
318k
    record.calling_expr = call;
93
318k
    record.call_arg = argnum;
94
318k
    record_variant = record;
95
318k
  }
96
  StackRecordKind record_variant;
97
  bool visited = false;
98
};
99
100
struct PreVisitor {
101
385k
  void operator()(const ExprRecord& record) {
102
385k
    const Expr* expr = record.expr;
103
385k
    visitor->PreVisitExpr(*expr);
104
385k
    if (expr->has_select_expr()) {
105
17.2k
      visitor->PreVisitSelect(*expr, expr->select_expr());
106
368k
    } else if (expr->has_call_expr()) {
107
167k
      visitor->PreVisitCall(*expr, expr->call_expr());
108
200k
    } else if (expr->has_comprehension_expr()) {
109
0
      visitor->PreVisitComprehension(*expr, expr->comprehension_expr());
110
200k
    } else {
111
      // No pre-visit action.
112
200k
    }
113
385k
  }
114
115
  // Do nothing for Arg variant.
116
318k
  void operator()(const ArgRecord&) {}
117
118
0
  void operator()(const ComprehensionRecord& record) {
119
0
    visitor->PreVisitComprehensionSubexpression(*record.comprehension_expr,
120
0
                                                *record.comprehension,
121
0
                                                record.comprehension_arg);
122
0
  }
123
124
  AstVisitor* visitor;
125
};
126
127
703k
void PreVisit(const StackRecord& record, AstVisitor* visitor) {
128
703k
  absl::visit(PreVisitor{visitor}, record.record_variant);
129
703k
}
130
131
struct PostVisitor {
132
385k
  void operator()(const ExprRecord& record) {
133
385k
    const Expr* expr = record.expr;
134
385k
    struct {
135
385k
      AstVisitor* visitor;
136
385k
      const Expr* expr;
137
385k
      void operator()(const Constant& constant) {
138
118k
        visitor->PostVisitConst(*expr, expr->const_expr());
139
118k
      }
140
385k
      void operator()(const IdentExpr& ident) {
141
61.6k
        visitor->PostVisitIdent(*expr, expr->ident_expr());
142
61.6k
      }
143
385k
      void operator()(const SelectExpr& select) {
144
17.2k
        visitor->PostVisitSelect(*expr, expr->select_expr());
145
17.2k
      }
146
385k
      void operator()(const CallExpr& call) {
147
167k
        visitor->PostVisitCall(*expr, expr->call_expr());
148
167k
      }
149
385k
      void operator()(const ListExpr& create_list) {
150
9.20k
        visitor->PostVisitList(*expr, expr->list_expr());
151
9.20k
      }
152
385k
      void operator()(const StructExpr& create_struct) {
153
2.45k
        visitor->PostVisitStruct(*expr, expr->struct_expr());
154
2.45k
      }
155
385k
      void operator()(const MapExpr& map_expr) {
156
8.70k
        visitor->PostVisitMap(*expr, expr->map_expr());
157
8.70k
      }
158
385k
      void operator()(const ComprehensionExpr& comprehension) {
159
0
        visitor->PostVisitComprehension(*expr, expr->comprehension_expr());
160
0
      }
161
385k
      void operator()(const UnspecifiedExpr&) {
162
0
        ABSL_LOG(ERROR) << "Unsupported Expr kind";
163
0
      }
164
385k
    } handler{visitor, record.expr};
165
385k
    absl::visit(handler, record.expr->kind());
166
167
385k
    visitor->PostVisitExpr(*expr);
168
385k
  }
169
170
318k
  void operator()(const ArgRecord& record) {
171
318k
    if (record.call_arg == StackRecord::kTarget) {
172
782
      visitor->PostVisitTarget(*record.calling_expr);
173
317k
    } else {
174
317k
      visitor->PostVisitArg(*record.calling_expr, record.call_arg);
175
317k
    }
176
318k
  }
177
178
0
  void operator()(const ComprehensionRecord& record) {
179
0
    visitor->PostVisitComprehensionSubexpression(*record.comprehension_expr,
180
0
                                                 *record.comprehension,
181
0
                                                 record.comprehension_arg);
182
0
  }
183
184
  AstVisitor* visitor;
185
};
186
187
703k
void PostVisit(const StackRecord& record, AstVisitor* visitor) {
188
703k
  absl::visit(PostVisitor{visitor}, record.record_variant);
189
703k
}
190
191
void PushSelectDeps(const SelectExpr* select_expr,
192
17.2k
                    std::stack<StackRecord>* stack) {
193
17.2k
  if (select_expr->has_operand()) {
194
17.2k
    stack->push(StackRecord(&select_expr->operand()));
195
17.2k
  }
196
17.2k
}
197
198
void PushCallDeps(const CallExpr* call_expr, const Expr* expr,
199
167k
                  std::stack<StackRecord>* stack) {
200
167k
  const int arg_size = call_expr->args().size();
201
  // Our contract is that we visit arguments in order.  To do that, we need
202
  // to push them onto the stack in reverse order.
203
485k
  for (int i = arg_size - 1; i >= 0; --i) {
204
317k
    stack->push(StackRecord(&call_expr->args()[i], expr, i));
205
317k
  }
206
  // Are we receiver-style?
207
167k
  if (call_expr->has_target()) {
208
782
    stack->push(StackRecord(&call_expr->target(), expr, StackRecord::kTarget));
209
782
  }
210
167k
}
211
212
9.20k
void PushListDeps(const ListExpr* list_expr, std::stack<StackRecord>* stack) {
213
9.20k
  const auto& elements = list_expr->elements();
214
40.2k
  for (auto it = elements.rbegin(); it != elements.rend(); ++it) {
215
31.0k
    const auto& element = *it;
216
31.0k
    stack->push(StackRecord(&element.expr()));
217
31.0k
  }
218
9.20k
}
219
220
void PushStructDeps(const StructExpr* struct_expr,
221
2.45k
                    std::stack<StackRecord>* stack) {
222
2.45k
  const auto& entries = struct_expr->fields();
223
3.31k
  for (auto it = entries.rbegin(); it != entries.rend(); ++it) {
224
858
    const auto& entry = *it;
225
    // The contract is to visit key, then value.  So put them on the stack
226
    // in the opposite order.
227
858
    if (entry.has_value()) {
228
858
      stack->push(StackRecord(&entry.value()));
229
858
    }
230
858
  }
231
2.45k
}
232
233
8.70k
void PushMapDeps(const MapExpr* map_expr, std::stack<StackRecord>* stack) {
234
8.70k
  const auto& entries = map_expr->entries();
235
12.5k
  for (auto it = entries.rbegin(); it != entries.rend(); ++it) {
236
3.86k
    const auto& entry = *it;
237
    // The contract is to visit key, then value.  So put them on the stack
238
    // in the opposite order.
239
3.86k
    if (entry.has_value()) {
240
3.86k
      stack->push(StackRecord(&entry.value()));
241
3.86k
    }
242
    // The contract is to visit key, then value.  So put them on the stack
243
    // in the opposite order.
244
3.86k
    if (entry.has_key()) {
245
3.86k
      stack->push(StackRecord(&entry.key()));
246
3.86k
    }
247
3.86k
  }
248
8.70k
}
249
250
void PushComprehensionDeps(const ComprehensionExpr* c, const Expr* expr,
251
                           std::stack<StackRecord>* stack,
252
0
                           bool use_comprehension_callbacks) {
253
0
  StackRecord iter_range(&c->iter_range(), c, expr, ITER_RANGE,
254
0
                         use_comprehension_callbacks);
255
0
  StackRecord accu_init(&c->accu_init(), c, expr, ACCU_INIT,
256
0
                        use_comprehension_callbacks);
257
0
  StackRecord loop_condition(&c->loop_condition(), c, expr, LOOP_CONDITION,
258
0
                             use_comprehension_callbacks);
259
0
  StackRecord loop_step(&c->loop_step(), c, expr, LOOP_STEP,
260
0
                        use_comprehension_callbacks);
261
0
  StackRecord result(&c->result(), c, expr, RESULT,
262
0
                     use_comprehension_callbacks);
263
  // Push them in reverse order.
264
0
  stack->push(result);
265
0
  stack->push(loop_step);
266
0
  stack->push(loop_condition);
267
0
  stack->push(accu_init);
268
0
  stack->push(iter_range);
269
0
}
270
271
struct PushDepsVisitor {
272
385k
  void operator()(const ExprRecord& record) {
273
385k
    struct {
274
385k
      std::stack<StackRecord>& stack;
275
385k
      const TraversalOptions& options;
276
385k
      const ExprRecord& record;
277
385k
      void operator()(const Constant& constant) {}
278
385k
      void operator()(const IdentExpr& ident) {}
279
385k
      void operator()(const SelectExpr& select) {
280
17.2k
        PushSelectDeps(&record.expr->select_expr(), &stack);
281
17.2k
      }
282
385k
      void operator()(const CallExpr& call) {
283
167k
        PushCallDeps(&record.expr->call_expr(), record.expr, &stack);
284
167k
      }
285
385k
      void operator()(const ListExpr& create_list) {
286
9.20k
        PushListDeps(&record.expr->list_expr(), &stack);
287
9.20k
      }
288
385k
      void operator()(const StructExpr& create_struct) {
289
2.45k
        PushStructDeps(&record.expr->struct_expr(), &stack);
290
2.45k
      }
291
385k
      void operator()(const MapExpr& map_expr) {
292
8.70k
        PushMapDeps(&record.expr->map_expr(), &stack);
293
8.70k
      }
294
385k
      void operator()(const ComprehensionExpr& comprehension) {
295
0
        PushComprehensionDeps(&record.expr->comprehension_expr(), record.expr,
296
0
                              &stack, options.use_comprehension_callbacks);
297
0
      }
298
385k
      void operator()(const UnspecifiedExpr&) {}
299
385k
    } handler{stack, options, record};
300
385k
    absl::visit(handler, record.expr->kind());
301
385k
  }
302
303
318k
  void operator()(const ArgRecord& record) {
304
318k
    stack.push(StackRecord(record.expr));
305
318k
  }
306
307
0
  void operator()(const ComprehensionRecord& record) {
308
0
    stack.push(StackRecord(record.expr));
309
0
  }
310
311
  std::stack<StackRecord>& stack;
312
  const TraversalOptions& options;
313
};
314
315
void PushDependencies(const StackRecord& record, std::stack<StackRecord>& stack,
316
703k
                      const TraversalOptions& options) {
317
703k
  absl::visit(PushDepsVisitor{stack, options}, record.record_variant);
318
703k
}
319
320
}  // namespace
321
322
namespace common_internal {
323
struct AstTraversalState {
324
  std::stack<StackRecord> stack;
325
};
326
}  // namespace common_internal
327
328
AstTraversal AstTraversal::Create(const cel::Expr& ast,
329
0
                                  const TraversalOptions& options) {
330
0
  AstTraversal instance(options);
331
0
  instance.state_ = std::make_unique<common_internal::AstTraversalState>();
332
0
  instance.state_->stack.push(StackRecord(&ast));
333
0
  return instance;
334
0
}
335
336
0
AstTraversal::AstTraversal(TraversalOptions options) : options_(options) {}
337
338
0
AstTraversal::~AstTraversal() = default;
339
340
0
bool AstTraversal::Step(AstVisitor& visitor) {
341
0
  if (IsDone()) {
342
0
    return false;
343
0
  }
344
0
  auto& stack = state_->stack;
345
0
  StackRecord& record = stack.top();
346
0
  if (!record.visited) {
347
0
    PreVisit(record, &visitor);
348
0
    PushDependencies(record, stack, options_);
349
0
    record.visited = true;
350
0
  } else {
351
0
    PostVisit(record, &visitor);
352
0
    stack.pop();
353
0
  }
354
355
0
  return !stack.empty();
356
0
}
357
358
0
bool AstTraversal::IsDone() {
359
0
  return state_ == nullptr || state_->stack.empty();
360
0
}
361
362
void AstTraverse(const Expr& expr, AstVisitor& visitor,
363
10.3k
                 TraversalOptions options) {
364
10.3k
  std::stack<StackRecord> stack;
365
10.3k
  stack.push(StackRecord(&expr));
366
367
1.41M
  while (!stack.empty()) {
368
1.40M
    StackRecord& record = stack.top();
369
1.40M
    if (!record.visited) {
370
703k
      PreVisit(record, &visitor);
371
703k
      PushDependencies(record, stack, options);
372
703k
      record.visited = true;
373
703k
    } else {
374
703k
      PostVisit(record, &visitor);
375
703k
      stack.pop();
376
703k
    }
377
1.40M
  }
378
10.3k
}
379
380
}  // namespace cel