/proc/self/cwd/eval/compiler/comprehension_vulnerability_check.cc
Line | Count | Source |
1 | | // |
2 | | // Copyright 2023 Google LLC |
3 | | // |
4 | | // Licensed under the Apache License, Version 2.0 (the "License"); |
5 | | // you may not use this file except in compliance with the License. |
6 | | // You may obtain a copy of the License at |
7 | | // |
8 | | // https://www.apache.org/licenses/LICENSE-2.0 |
9 | | // |
10 | | // Unless required by applicable law or agreed to in writing, software |
11 | | // distributed under the License is distributed on an "AS IS" BASIS, |
12 | | // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. |
13 | | // See the License for the specific language governing permissions and |
14 | | // limitations under the License. |
15 | | #include "eval/compiler/comprehension_vulnerability_check.h" |
16 | | |
17 | | #include <algorithm> |
18 | | #include <memory> |
19 | | #include <vector> |
20 | | |
21 | | #include "absl/status/status.h" |
22 | | #include "absl/strings/string_view.h" |
23 | | #include "absl/types/variant.h" |
24 | | #include "base/builtins.h" |
25 | | #include "common/ast.h" |
26 | | #include "common/constant.h" |
27 | | #include "common/expr.h" |
28 | | #include "eval/compiler/flat_expr_builder_extensions.h" |
29 | | |
30 | | namespace google::api::expr::runtime { |
31 | | |
32 | | namespace { |
33 | | |
34 | | using ::cel::CallExpr; |
35 | | using ::cel::ComprehensionExpr; |
36 | | using ::cel::Constant; |
37 | | using ::cel::Expr; |
38 | | using ::cel::IdentExpr; |
39 | | using ::cel::ListExpr; |
40 | | using ::cel::MapExpr; |
41 | | using ::cel::SelectExpr; |
42 | | using ::cel::StructExpr; |
43 | | using ::cel::UnspecifiedExpr; |
44 | | |
45 | | // ComprehensionAccumulationReferences recursively walks an expression to count |
46 | | // the locations where the given accumulation var_name is referenced. |
47 | | // |
48 | | // The purpose of this function is to detect cases where the accumulation |
49 | | // variable might be used in hand-rolled ASTs that cause exponential memory |
50 | | // consumption. The var_name is generally not accessible by CEL expression |
51 | | // writers, only by macro authors. However, a hand-rolled AST makes it possible |
52 | | // to misuse the accumulation variable. |
53 | | // |
54 | | // Limitations: |
55 | | // - This check only covers standard operators and functions. |
56 | | // Extension functions may cause the same issue if they allocate an amount of |
57 | | // memory that is dependent on the size of the inputs. |
58 | | // |
59 | | // - This check is not exhaustive. There may be ways to construct an AST to |
60 | | // trigger exponential memory growth not captured by this check. |
61 | | // |
62 | | // The algorithm for reference counting is as follows: |
63 | | // |
64 | | // * Calls - If the call is a concatenation operator, sum the number of places |
65 | | // where the variable appears within the call, as this could result |
66 | | // in memory explosion if the accumulation variable type is a list |
67 | | // or string. Otherwise, return 0. |
68 | | // |
69 | | // accu: ["hello"] |
70 | | // expr: accu + accu // memory grows exponentionally |
71 | | // |
72 | | // * CreateList - If the accumulation var_name appears within multiple elements |
73 | | // of a CreateList call, this means that the accumulation is |
74 | | // generating an ever-expanding tree of values that will likely |
75 | | // exhaust memory. |
76 | | // |
77 | | // accu: ["hello"] |
78 | | // expr: [accu, accu] // memory grows exponentially |
79 | | // |
80 | | // * CreateStruct - If the accumulation var_name as an entry within the |
81 | | // creation of a map or message value, then it's possible that the |
82 | | // comprehension is accumulating an ever-expanding tree of values. |
83 | | // |
84 | | // accu: {"key": "val"} |
85 | | // expr: {1: accu, 2: accu} |
86 | | // |
87 | | // * Comprehension - If the accumulation var_name is not shadowed by a nested |
88 | | // iter_var or accu_var, then it may be accmulating memory within a |
89 | | // nested context. The accumulation may occur on either the |
90 | | // comprehension loop_step or result step. |
91 | | // |
92 | | // Since this behavior generally only occurs within hand-rolled ASTs, it is |
93 | | // very reasonable to opt-in to this check only when using human authored ASTs. |
94 | | int ComprehensionAccumulationReferences(const cel::Expr& expr, |
95 | 0 | absl::string_view var_name) { |
96 | 0 | struct Handler { |
97 | 0 | const Expr& expr; |
98 | 0 | absl::string_view var_name; |
99 | |
|
100 | 0 | int operator()(const CallExpr& call) { |
101 | 0 | int references = 0; |
102 | 0 | absl::string_view function = call.function(); |
103 | | // Return the maximum reference count of each side of the ternary branch. |
104 | 0 | if (function == cel::builtin::kTernary && call.args().size() == 3) { |
105 | 0 | return std::max( |
106 | 0 | ComprehensionAccumulationReferences(call.args()[1], var_name), |
107 | 0 | ComprehensionAccumulationReferences(call.args()[2], var_name)); |
108 | 0 | } |
109 | | // Return the number of times the accumulator var_name appears in the add |
110 | | // expression. There's no arg size check on the add as it may become a |
111 | | // variadic add at a future date. |
112 | 0 | if (function == cel::builtin::kAdd) { |
113 | 0 | for (int i = 0; i < call.args().size(); i++) { |
114 | 0 | references += |
115 | 0 | ComprehensionAccumulationReferences(call.args()[i], var_name); |
116 | 0 | } |
117 | |
|
118 | 0 | return references; |
119 | 0 | } |
120 | | // Return whether the accumulator var_name is used as the operand in an |
121 | | // index expression or in the identity `dyn` function. |
122 | 0 | if ((function == cel::builtin::kIndex && call.args().size() == 2) || |
123 | 0 | (function == cel::builtin::kDyn && call.args().size() == 1)) { |
124 | 0 | return ComprehensionAccumulationReferences(call.args()[0], var_name); |
125 | 0 | } |
126 | 0 | return 0; |
127 | 0 | } |
128 | 0 | int operator()(const ComprehensionExpr& comprehension) { |
129 | 0 | absl::string_view accu_var = comprehension.accu_var(); |
130 | 0 | absl::string_view iter_var = comprehension.iter_var(); |
131 | |
|
132 | 0 | int result_references = 0; |
133 | 0 | int loop_step_references = 0; |
134 | 0 | int sum_of_accumulator_references = 0; |
135 | | |
136 | | // The accumulation or iteration variable shadows the var_name and so will |
137 | | // not manipulate the target var_name in a nested comprehension scope. |
138 | 0 | if (accu_var != var_name && iter_var != var_name) { |
139 | 0 | loop_step_references = ComprehensionAccumulationReferences( |
140 | 0 | comprehension.loop_step(), var_name); |
141 | 0 | } |
142 | | |
143 | | // Accumulator variable (but not necessarily iter var) can shadow an |
144 | | // outer accumulator variable in the result sub-expression. |
145 | 0 | if (accu_var != var_name) { |
146 | 0 | result_references = ComprehensionAccumulationReferences( |
147 | 0 | comprehension.result(), var_name); |
148 | 0 | } |
149 | | |
150 | | // Count the raw number of times the accumulator variable was referenced. |
151 | | // This is to account for cases where the outer accumulator is shadowed by |
152 | | // the inner accumulator, while the inner accumulator is being used as the |
153 | | // iterable range. |
154 | | // |
155 | | // An equivalent expression to this problem: |
156 | | // |
157 | | // outer_accu := outer_accu |
158 | | // for y in outer_accu: |
159 | | // outer_accu += input |
160 | | // return outer_accu |
161 | | |
162 | | // If this is overly restrictive (Ex: when generalized reducers is |
163 | | // implemented), we may need to revisit this solution |
164 | |
|
165 | 0 | sum_of_accumulator_references = ComprehensionAccumulationReferences( |
166 | 0 | comprehension.accu_init(), var_name); |
167 | |
|
168 | 0 | sum_of_accumulator_references += ComprehensionAccumulationReferences( |
169 | 0 | comprehension.iter_range(), var_name); |
170 | | |
171 | | // Count the number of times the accumulator var_name within the loop_step |
172 | | // or the nested comprehension result. |
173 | | // |
174 | | // This doesn't cover cases where the inner accumulator accumulates the |
175 | | // outer accumulator then is returned in the inner comprehension result. |
176 | 0 | return std::max({loop_step_references, result_references, |
177 | 0 | sum_of_accumulator_references}); |
178 | 0 | } |
179 | |
|
180 | 0 | int operator()(const ListExpr& list) { |
181 | | // Count the number of times the accumulator var_name appears within a |
182 | | // create list expression's elements. |
183 | 0 | int references = 0; |
184 | 0 | for (int i = 0; i < list.elements().size(); i++) { |
185 | 0 | references += ComprehensionAccumulationReferences( |
186 | 0 | list.elements()[i].expr(), var_name); |
187 | 0 | } |
188 | 0 | return references; |
189 | 0 | } |
190 | |
|
191 | 0 | int operator()(const StructExpr& map) { |
192 | | // Count the number of times the accumulation variable occurs within |
193 | | // entry values. |
194 | 0 | int references = 0; |
195 | 0 | for (int i = 0; i < map.fields().size(); i++) { |
196 | 0 | const auto& entry = map.fields()[i]; |
197 | 0 | if (entry.has_value()) { |
198 | 0 | references += |
199 | 0 | ComprehensionAccumulationReferences(entry.value(), var_name); |
200 | 0 | } |
201 | 0 | } |
202 | 0 | return references; |
203 | 0 | } |
204 | |
|
205 | 0 | int operator()(const MapExpr& map) { |
206 | | // Count the number of times the accumulation variable occurs within |
207 | | // entry values. |
208 | 0 | int references = 0; |
209 | 0 | for (int i = 0; i < map.entries().size(); i++) { |
210 | 0 | const auto& entry = map.entries()[i]; |
211 | 0 | if (entry.has_value()) { |
212 | 0 | references += |
213 | 0 | ComprehensionAccumulationReferences(entry.value(), var_name); |
214 | 0 | } |
215 | 0 | } |
216 | 0 | return references; |
217 | 0 | } |
218 | |
|
219 | 0 | int operator()(const SelectExpr& select) { |
220 | | // Test only expressions have a boolean return and thus cannot easily |
221 | | // allocate large amounts of memory. |
222 | 0 | if (select.test_only()) { |
223 | 0 | return 0; |
224 | 0 | } |
225 | | // Return whether the accumulator var_name appears within a non-test |
226 | | // select operand. |
227 | 0 | return ComprehensionAccumulationReferences(select.operand(), var_name); |
228 | 0 | } |
229 | |
|
230 | 0 | int operator()(const IdentExpr& ident) { |
231 | | // Return whether the identifier name equals the accumulator var_name. |
232 | 0 | return ident.name() == var_name ? 1 : 0; |
233 | 0 | } |
234 | |
|
235 | 0 | int operator()(const Constant& constant) { return 0; } |
236 | |
|
237 | 0 | int operator()(const UnspecifiedExpr&) { return 0; } |
238 | 0 | } handler{expr, var_name}; |
239 | 0 | return absl::visit(handler, expr.kind()); |
240 | 0 | } |
241 | | |
242 | | bool ComprehensionHasMemoryExhaustionVulnerability( |
243 | 0 | const ComprehensionExpr& comprehension) { |
244 | 0 | absl::string_view accu_var = comprehension.accu_var(); |
245 | 0 | const auto& loop_step = comprehension.loop_step(); |
246 | 0 | return ComprehensionAccumulationReferences(loop_step, accu_var) >= 2; |
247 | 0 | } |
248 | | |
249 | | class ComprehensionVulnerabilityCheck : public ProgramOptimizer { |
250 | | public: |
251 | 0 | absl::Status OnPreVisit(PlannerContext& context, const Expr& node) override { |
252 | 0 | if (node.has_comprehension_expr() && |
253 | 0 | ComprehensionHasMemoryExhaustionVulnerability( |
254 | 0 | node.comprehension_expr())) { |
255 | 0 | return absl::InvalidArgumentError( |
256 | 0 | "Comprehension contains memory exhaustion vulnerability"); |
257 | 0 | } |
258 | 0 | return absl::OkStatus(); |
259 | 0 | } |
260 | | |
261 | | absl::Status OnPostVisit(PlannerContext& context, |
262 | 0 | const cel::Expr& node) override { |
263 | 0 | return absl::OkStatus(); |
264 | 0 | } |
265 | | }; |
266 | | |
267 | | } // namespace |
268 | | |
269 | 0 | ProgramOptimizerFactory CreateComprehensionVulnerabilityCheck() { |
270 | 0 | return [](PlannerContext&, const cel::Ast& ast) { |
271 | 0 | return std::make_unique<ComprehensionVulnerabilityCheck>(); |
272 | 0 | }; |
273 | 0 | } |
274 | | |
275 | | } // namespace google::api::expr::runtime |