Coverage Report

Created: 2026-05-27 07:00

next uncovered line (L), next uncovered region (R), next uncovered branch (B)
/proc/self/cwd/common/ast_rewrite.cc
Line
Count
Source
1
// Copyright 2021 Google LLC
2
//
3
// Licensed under the Apache License, Version 2.0 (the "License");
4
// you may not use this file except in compliance with the License.
5
// You may obtain a copy of the License at
6
//
7
//      https://www.apache.org/licenses/LICENSE-2.0
8
//
9
// Unless required by applicable law or agreed to in writing, software
10
// distributed under the License is distributed on an "AS IS" BASIS,
11
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
// See the License for the specific language governing permissions and
13
// limitations under the License.
14
15
#include "common/ast_rewrite.h"
16
17
#include <stack>
18
#include <vector>
19
20
#include "absl/log/absl_log.h"
21
#include "absl/types/span.h"
22
#include "absl/types/variant.h"
23
#include "common/ast_visitor.h"
24
#include "common/constant.h"
25
#include "common/expr.h"
26
27
namespace cel {
28
29
namespace {
30
31
struct ArgRecord {
32
  // Not null.
33
  Expr* expr;
34
35
  // For records that are direct arguments to call, we need to call
36
  // the CallArg visitor immediately after the argument is evaluated.
37
  const Expr* calling_expr;
38
  int call_arg;
39
};
40
41
struct ComprehensionRecord {
42
  // Not null.
43
  Expr* expr;
44
45
  const ComprehensionExpr* comprehension;
46
  const Expr* comprehension_expr;
47
  ComprehensionArg comprehension_arg;
48
  bool use_comprehension_callbacks;
49
};
50
51
struct ExprRecord {
52
  // Not null.
53
  Expr* expr;
54
};
55
56
using StackRecordKind =
57
    std::variant<ExprRecord, ArgRecord, ComprehensionRecord>;
58
59
struct StackRecord {
60
 public:
61
  static constexpr int kTarget = -2;
62
63
0
  explicit StackRecord(Expr* e) {
64
0
    ExprRecord record;
65
0
    record.expr = e;
66
0
    record_variant = record;
67
0
  }
68
69
  StackRecord(Expr* e, ComprehensionExpr* comprehension,
70
              Expr* comprehension_expr, ComprehensionArg comprehension_arg,
71
0
              bool use_comprehension_callbacks) {
72
0
    if (use_comprehension_callbacks) {
73
0
      ComprehensionRecord record;
74
0
      record.expr = e;
75
0
      record.comprehension = comprehension;
76
0
      record.comprehension_expr = comprehension_expr;
77
0
      record.comprehension_arg = comprehension_arg;
78
0
      record.use_comprehension_callbacks = use_comprehension_callbacks;
79
0
      record_variant = record;
80
0
      return;
81
0
    }
82
0
    ArgRecord record;
83
0
    record.expr = e;
84
0
    record.calling_expr = comprehension_expr;
85
0
    record.call_arg = comprehension_arg;
86
0
    record_variant = record;
87
0
  }
88
89
0
  StackRecord(Expr* e, const Expr* call, int argnum) {
90
0
    ArgRecord record;
91
0
    record.expr = e;
92
0
    record.calling_expr = call;
93
0
    record.call_arg = argnum;
94
0
    record_variant = record;
95
0
  }
96
97
0
  Expr* expr() const { return absl::get<ExprRecord>(record_variant).expr; }
98
99
0
  bool IsExprRecord() const {
100
0
    return absl::holds_alternative<ExprRecord>(record_variant);
101
0
  }
102
103
  StackRecordKind record_variant;
104
  bool visited = false;
105
};
106
107
struct PreVisitor {
108
0
  void operator()(const ExprRecord& record) {
109
0
    struct {
110
0
      AstVisitor* visitor;
111
0
      const Expr* expr;
112
0
      void operator()(const Constant&) {
113
        // No pre-visit action.
114
0
      }
115
0
      void operator()(const IdentExpr&) {
116
        // No pre-visit action.
117
0
      }
118
0
      void operator()(const SelectExpr& select) {
119
0
        visitor->PreVisitSelect(*expr, select);
120
0
      }
121
0
      void operator()(const CallExpr& call) {
122
0
        visitor->PreVisitCall(*expr, call);
123
0
      }
124
0
      void operator()(const ListExpr&) {
125
        // No pre-visit action.
126
0
      }
127
0
      void operator()(const StructExpr&) {
128
        // No pre-visit action.
129
0
      }
130
0
      void operator()(const MapExpr&) {
131
        // No pre-visit action.
132
0
      }
133
0
      void operator()(const ComprehensionExpr& comprehension) {
134
0
        visitor->PreVisitComprehension(*expr, comprehension);
135
0
      }
136
0
      void operator()(const UnspecifiedExpr&) {
137
        // No pre-visit action.
138
0
      }
139
0
    } handler{visitor, record.expr};
140
0
    visitor->PreVisitExpr(*record.expr);
141
0
    absl::visit(handler, record.expr->kind());
142
0
  }
143
144
  // Do nothing for Arg variant.
145
0
  void operator()(const ArgRecord&) {}
146
147
0
  void operator()(const ComprehensionRecord& record) {
148
0
    visitor->PreVisitComprehensionSubexpression(*record.comprehension_expr,
149
0
                                                *record.comprehension,
150
0
                                                record.comprehension_arg);
151
0
  }
152
153
  AstVisitor* visitor;
154
};
155
156
0
void PreVisit(const StackRecord& record, AstVisitor* visitor) {
157
0
  absl::visit(PreVisitor{visitor}, record.record_variant);
158
0
}
159
160
struct PostVisitor {
161
0
  void operator()(const ExprRecord& record) {
162
0
    struct {
163
0
      AstVisitor* visitor;
164
0
      const Expr* expr;
165
0
      void operator()(const Constant& constant) {
166
0
        visitor->PostVisitConst(*expr, constant);
167
0
      }
168
0
      void operator()(const IdentExpr& ident) {
169
0
        visitor->PostVisitIdent(*expr, ident);
170
0
      }
171
0
      void operator()(const SelectExpr& select) {
172
0
        visitor->PostVisitSelect(*expr, select);
173
0
      }
174
0
      void operator()(const CallExpr& call) {
175
0
        visitor->PostVisitCall(*expr, call);
176
0
      }
177
0
      void operator()(const ListExpr& create_list) {
178
0
        visitor->PostVisitList(*expr, create_list);
179
0
      }
180
0
      void operator()(const StructExpr& create_struct) {
181
0
        visitor->PostVisitStruct(*expr, create_struct);
182
0
      }
183
0
      void operator()(const MapExpr& map_expr) {
184
0
        visitor->PostVisitMap(*expr, map_expr);
185
0
      }
186
0
      void operator()(const ComprehensionExpr& comprehension) {
187
0
        visitor->PostVisitComprehension(*expr, comprehension);
188
0
      }
189
0
      void operator()(const UnspecifiedExpr&) {
190
0
        ABSL_LOG(ERROR) << "Unsupported Expr kind";
191
0
      }
192
0
    } handler{visitor, record.expr};
193
0
    absl::visit(handler, record.expr->kind());
194
195
0
    visitor->PostVisitExpr(*record.expr);
196
0
  }
197
198
0
  void operator()(const ArgRecord& record) {
199
0
    if (record.call_arg == StackRecord::kTarget) {
200
0
      visitor->PostVisitTarget(*record.calling_expr);
201
0
    } else {
202
0
      visitor->PostVisitArg(*record.calling_expr, record.call_arg);
203
0
    }
204
0
  }
205
206
0
  void operator()(const ComprehensionRecord& record) {
207
0
    visitor->PostVisitComprehensionSubexpression(*record.comprehension_expr,
208
0
                                                 *record.comprehension,
209
0
                                                 record.comprehension_arg);
210
0
  }
211
212
  AstVisitor* visitor;
213
};
214
215
0
void PostVisit(const StackRecord& record, AstVisitor* visitor) {
216
0
  absl::visit(PostVisitor{visitor}, record.record_variant);
217
0
}
218
219
0
void PushSelectDeps(SelectExpr* select_expr, std::stack<StackRecord>* stack) {
220
0
  if (select_expr->has_operand()) {
221
0
    stack->push(StackRecord(&select_expr->mutable_operand()));
222
0
  }
223
0
}
224
225
void PushCallDeps(CallExpr* call_expr, Expr* expr,
226
0
                  std::stack<StackRecord>* stack) {
227
0
  const int arg_size = call_expr->args().size();
228
  // Our contract is that we visit arguments in order.  To do that, we need
229
  // to push them onto the stack in reverse order.
230
0
  for (int i = arg_size - 1; i >= 0; --i) {
231
0
    stack->push(StackRecord(&call_expr->mutable_args()[i], expr, i));
232
0
  }
233
  // Are we receiver-style?
234
0
  if (call_expr->has_target()) {
235
0
    stack->push(
236
0
        StackRecord(&call_expr->mutable_target(), expr, StackRecord::kTarget));
237
0
  }
238
0
}
239
240
0
void PushListDeps(ListExpr* list_expr, std::stack<StackRecord>* stack) {
241
0
  auto& elements = list_expr->mutable_elements();
242
0
  for (auto it = elements.rbegin(); it != elements.rend(); ++it) {
243
0
    auto& element = *it;
244
0
    stack->push(StackRecord(&element.mutable_expr()));
245
0
  }
246
0
}
247
248
0
void PushStructDeps(StructExpr* struct_expr, std::stack<StackRecord>* stack) {
249
0
  auto& entries = struct_expr->mutable_fields();
250
0
  for (auto it = entries.rbegin(); it != entries.rend(); ++it) {
251
0
    auto& entry = *it;
252
    // The contract is to visit key, then value.  So put them on the stack
253
    // in the opposite order.
254
0
    if (entry.has_value()) {
255
0
      stack->push(StackRecord(&entry.mutable_value()));
256
0
    }
257
0
  }
258
0
}
259
260
0
void PushMapDeps(MapExpr* struct_expr, std::stack<StackRecord>* stack) {
261
0
  auto& entries = struct_expr->mutable_entries();
262
0
  for (auto it = entries.rbegin(); it != entries.rend(); ++it) {
263
0
    auto& entry = *it;
264
    // The contract is to visit key, then value.  So put them on the stack
265
    // in the opposite order.
266
0
    if (entry.has_value()) {
267
0
      stack->push(StackRecord(&entry.mutable_value()));
268
0
    }
269
    // The contract is to visit key, then value.  So put them on the stack
270
    // in the opposite order.
271
0
    if (entry.has_key()) {
272
0
      stack->push(StackRecord(&entry.mutable_key()));
273
0
    }
274
0
  }
275
0
}
276
277
void PushComprehensionDeps(ComprehensionExpr* c, Expr* expr,
278
                           std::stack<StackRecord>* stack,
279
0
                           bool use_comprehension_callbacks) {
280
0
  StackRecord iter_range(&c->mutable_iter_range(), c, expr, ITER_RANGE,
281
0
                         use_comprehension_callbacks);
282
0
  StackRecord accu_init(&c->mutable_accu_init(), c, expr, ACCU_INIT,
283
0
                        use_comprehension_callbacks);
284
0
  StackRecord loop_condition(&c->mutable_loop_condition(), c, expr,
285
0
                             LOOP_CONDITION, use_comprehension_callbacks);
286
0
  StackRecord loop_step(&c->mutable_loop_step(), c, expr, LOOP_STEP,
287
0
                        use_comprehension_callbacks);
288
0
  StackRecord result(&c->mutable_result(), c, expr, RESULT,
289
0
                     use_comprehension_callbacks);
290
  // Push them in reverse order.
291
0
  stack->push(result);
292
0
  stack->push(loop_step);
293
0
  stack->push(loop_condition);
294
0
  stack->push(accu_init);
295
0
  stack->push(iter_range);
296
0
}
297
298
struct PushDepsVisitor {
299
0
  void operator()(const ExprRecord& record) {
300
0
    struct {
301
0
      std::stack<StackRecord>& stack;
302
0
      const RewriteTraversalOptions& options;
303
0
      const ExprRecord& record;
304
0
      void operator()(const Constant&) {}
305
0
      void operator()(const IdentExpr&) {}
306
0
      void operator()(const SelectExpr&) {
307
0
        PushSelectDeps(&record.expr->mutable_select_expr(), &stack);
308
0
      }
309
0
      void operator()(const CallExpr&) {
310
0
        PushCallDeps(&record.expr->mutable_call_expr(), record.expr, &stack);
311
0
      }
312
0
      void operator()(const ListExpr&) {
313
0
        PushListDeps(&record.expr->mutable_list_expr(), &stack);
314
0
      }
315
0
      void operator()(const StructExpr&) {
316
0
        PushStructDeps(&record.expr->mutable_struct_expr(), &stack);
317
0
      }
318
0
      void operator()(const MapExpr&) {
319
0
        PushMapDeps(&record.expr->mutable_map_expr(), &stack);
320
0
      }
321
0
      void operator()(const ComprehensionExpr&) {
322
0
        PushComprehensionDeps(&record.expr->mutable_comprehension_expr(),
323
0
                              record.expr, &stack,
324
0
                              options.use_comprehension_callbacks);
325
0
      }
326
0
      void operator()(const UnspecifiedExpr&) {}
327
0
    } handler{stack, options, record};
328
0
    absl::visit(handler, record.expr->kind());
329
0
  }
330
331
0
  void operator()(const ArgRecord& record) {
332
0
    stack.push(StackRecord(record.expr));
333
0
  }
334
335
0
  void operator()(const ComprehensionRecord& record) {
336
0
    stack.push(StackRecord(record.expr));
337
0
  }
338
339
  std::stack<StackRecord>& stack;
340
  const RewriteTraversalOptions& options;
341
};
342
343
void PushDependencies(const StackRecord& record, std::stack<StackRecord>& stack,
344
0
                      const RewriteTraversalOptions& options) {
345
0
  absl::visit(PushDepsVisitor{stack, options}, record.record_variant);
346
0
}
347
348
}  // namespace
349
350
bool AstRewrite(Expr& expr, AstRewriter& visitor,
351
0
                RewriteTraversalOptions options) {
352
0
  std::stack<StackRecord> stack;
353
0
  std::vector<const Expr*> traversal_path;
354
355
0
  stack.push(StackRecord(&expr));
356
0
  bool rewritten = false;
357
358
0
  while (!stack.empty()) {
359
0
    StackRecord& record = stack.top();
360
0
    if (!record.visited) {
361
0
      if (record.IsExprRecord()) {
362
0
        traversal_path.push_back(record.expr());
363
0
        visitor.TraversalStackUpdate(absl::MakeSpan(traversal_path));
364
365
0
        if (visitor.PreVisitRewrite(*record.expr())) {
366
0
          rewritten = true;
367
0
        }
368
0
      }
369
0
      PreVisit(record, &visitor);
370
0
      PushDependencies(record, stack, options);
371
0
      record.visited = true;
372
0
    } else {
373
0
      PostVisit(record, &visitor);
374
0
      if (record.IsExprRecord()) {
375
0
        if (visitor.PostVisitRewrite(*record.expr())) {
376
0
          rewritten = true;
377
0
        }
378
379
0
        traversal_path.pop_back();
380
0
        visitor.TraversalStackUpdate(absl::MakeSpan(traversal_path));
381
0
      }
382
0
      stack.pop();
383
0
    }
384
0
  }
385
386
0
  return rewritten;
387
0
}
388
389
}  // namespace cel