/src/skia/src/sksl/analysis/SkSLGetLoopUnrollInfo.cpp
Line | Count | Source (jump to first uncovered line) |
1 | | /* |
2 | | * Copyright 2021 Google LLC |
3 | | * |
4 | | * Use of this source code is governed by a BSD-style license that can be |
5 | | * found in the LICENSE file. |
6 | | */ |
7 | | |
8 | | #include "include/core/SkTypes.h" |
9 | | #include "include/private/base/SkFloatingPoint.h" |
10 | | #include "src/sksl/SkSLAnalysis.h" |
11 | | #include "src/sksl/SkSLConstantFolder.h" |
12 | | #include "src/sksl/SkSLErrorReporter.h" |
13 | | #include "src/sksl/SkSLOperator.h" |
14 | | #include "src/sksl/SkSLPosition.h" |
15 | | #include "src/sksl/analysis/SkSLNoOpErrorReporter.h" |
16 | | #include "src/sksl/ir/SkSLBinaryExpression.h" |
17 | | #include "src/sksl/ir/SkSLExpression.h" |
18 | | #include "src/sksl/ir/SkSLForStatement.h" |
19 | | #include "src/sksl/ir/SkSLIRNode.h" |
20 | | #include "src/sksl/ir/SkSLPostfixExpression.h" |
21 | | #include "src/sksl/ir/SkSLPrefixExpression.h" |
22 | | #include "src/sksl/ir/SkSLStatement.h" |
23 | | #include "src/sksl/ir/SkSLType.h" |
24 | | #include "src/sksl/ir/SkSLVarDeclarations.h" |
25 | | #include "src/sksl/ir/SkSLVariable.h" |
26 | | #include "src/sksl/ir/SkSLVariableReference.h" |
27 | | |
28 | | #include <cmath> |
29 | | #include <memory> |
30 | | |
31 | | namespace SkSL { |
32 | | |
33 | | class Context; |
34 | | |
35 | | // Loops that run for 100000+ iterations will exceed our program size limit. |
36 | | static constexpr int kLoopTerminationLimit = 100000; |
37 | | |
38 | 44 | static int calculate_count(double start, double end, double delta, bool forwards, bool inclusive) { |
39 | 44 | if ((forwards && start > end) || (!forwards && start < end)) { |
40 | | // The loop starts in a completed state (the start has already advanced past the end). |
41 | 0 | return 0; |
42 | 0 | } |
43 | 44 | if ((delta == 0.0) || forwards != (delta > 0.0)) { |
44 | | // The loop does not progress toward a completed state, and will never terminate. |
45 | 0 | return kLoopTerminationLimit; |
46 | 0 | } |
47 | 44 | double iterations = sk_ieee_double_divide(end - start, delta); |
48 | 44 | double count = std::ceil(iterations); |
49 | 44 | if (inclusive && (count == iterations)) { |
50 | 8 | count += 1.0; |
51 | 8 | } |
52 | 44 | if (count > kLoopTerminationLimit || !std::isfinite(count)) { |
53 | | // The loop runs for more iterations than we can safely unroll. |
54 | 0 | return kLoopTerminationLimit; |
55 | 0 | } |
56 | 44 | return (int)count; |
57 | 44 | } |
58 | | |
59 | | std::unique_ptr<LoopUnrollInfo> Analysis::GetLoopUnrollInfo(const Context& context, |
60 | | Position loopPos, |
61 | | const ForLoopPositions& positions, |
62 | | const Statement* loopInitializer, |
63 | | std::unique_ptr<Expression>* loopTest, |
64 | | const Expression* loopNext, |
65 | | const Statement* loopStatement, |
66 | 318 | ErrorReporter* errorPtr) { |
67 | 318 | NoOpErrorReporter unused; |
68 | 318 | ErrorReporter& errors = errorPtr ? *errorPtr : unused; |
69 | | |
70 | 318 | auto loopInfo = std::make_unique<LoopUnrollInfo>(); |
71 | | |
72 | | // |
73 | | // init_declaration has the form: type_specifier identifier = constant_expression |
74 | | // |
75 | 318 | if (!loopInitializer) { |
76 | 145 | Position pos = positions.initPosition.valid() ? positions.initPosition : loopPos; |
77 | 145 | errors.error(pos, "missing init declaration"); |
78 | 145 | return nullptr; |
79 | 145 | } |
80 | 173 | if (!loopInitializer->is<VarDeclaration>()) { |
81 | 125 | errors.error(loopInitializer->fPosition, "invalid init declaration"); |
82 | 125 | return nullptr; |
83 | 125 | } |
84 | 48 | const VarDeclaration& initDecl = loopInitializer->as<VarDeclaration>(); |
85 | 48 | if (!initDecl.baseType().isNumber()) { |
86 | 2 | errors.error(loopInitializer->fPosition, "invalid type for loop index"); |
87 | 2 | return nullptr; |
88 | 2 | } |
89 | 46 | if (initDecl.arraySize() != 0) { |
90 | 2 | errors.error(loopInitializer->fPosition, "invalid type for loop index"); |
91 | 2 | return nullptr; |
92 | 2 | } |
93 | 44 | if (!initDecl.value()) { |
94 | 0 | errors.error(loopInitializer->fPosition, "missing loop index initializer"); |
95 | 0 | return nullptr; |
96 | 0 | } |
97 | 44 | if (!ConstantFolder::GetConstantValue(*initDecl.value(), &loopInfo->fStart)) { |
98 | 0 | errors.error(loopInitializer->fPosition, |
99 | 0 | "loop index initializer must be a constant expression"); |
100 | 0 | return nullptr; |
101 | 0 | } |
102 | | |
103 | 44 | loopInfo->fIndex = initDecl.var(); |
104 | | |
105 | 88 | auto is_loop_index = [&](const std::unique_ptr<Expression>& expr) { |
106 | 88 | return expr->is<VariableReference>() && |
107 | 88 | expr->as<VariableReference>().variable() == loopInfo->fIndex; |
108 | 88 | }; |
109 | | |
110 | | // |
111 | | // condition has the form: loop_index relational_operator constant_expression |
112 | | // |
113 | 44 | if (!loopTest || !*loopTest) { |
114 | 0 | Position pos = positions.conditionPosition.valid() ? positions.conditionPosition : loopPos; |
115 | 0 | errors.error(pos, "missing condition"); |
116 | 0 | return nullptr; |
117 | 0 | } |
118 | 44 | if (!loopTest->get()->is<BinaryExpression>()) { |
119 | 0 | errors.error(loopTest->get()->fPosition, "invalid condition"); |
120 | 0 | return nullptr; |
121 | 0 | } |
122 | 44 | const BinaryExpression* cond = &loopTest->get()->as<BinaryExpression>(); |
123 | 44 | if (!is_loop_index(cond->left())) { |
124 | 0 | errors.error(cond->fPosition, "expected loop index on left hand side of condition"); |
125 | 0 | return nullptr; |
126 | 0 | } |
127 | | // relational_operator is one of: > >= < <= == or != |
128 | 44 | switch (cond->getOperator().kind()) { |
129 | 0 | case Operator::Kind::GT: |
130 | 0 | case Operator::Kind::GTEQ: |
131 | 36 | case Operator::Kind::LT: |
132 | 44 | case Operator::Kind::LTEQ: |
133 | 44 | case Operator::Kind::EQEQ: |
134 | 44 | case Operator::Kind::NEQ: |
135 | 44 | break; |
136 | 0 | default: |
137 | 0 | errors.error(cond->fPosition, "invalid relational operator"); |
138 | 0 | return nullptr; |
139 | 44 | } |
140 | 44 | double loopEnd = 0; |
141 | 44 | if (!ConstantFolder::GetConstantValue(*cond->right(), &loopEnd)) { |
142 | 0 | errors.error(cond->fPosition, "loop index must be compared with a constant expression"); |
143 | 0 | return nullptr; |
144 | 0 | } |
145 | | |
146 | | // |
147 | | // expression has one of the following forms: |
148 | | // loop_index++ |
149 | | // loop_index-- |
150 | | // loop_index += constant_expression |
151 | | // loop_index -= constant_expression |
152 | | // The spec doesn't mention prefix increment and decrement, but there is some consensus that |
153 | | // it's an oversight, so we allow those as well. |
154 | | // |
155 | 44 | if (!loopNext) { |
156 | 0 | Position pos = positions.nextPosition.valid() ? positions.nextPosition : loopPos; |
157 | 0 | errors.error(pos, "missing loop expression"); |
158 | 0 | return nullptr; |
159 | 0 | } |
160 | 44 | switch (loopNext->kind()) { |
161 | 0 | case Expression::Kind::kBinary: { |
162 | 0 | const BinaryExpression& next = loopNext->as<BinaryExpression>(); |
163 | 0 | if (!is_loop_index(next.left())) { |
164 | 0 | errors.error(loopNext->fPosition, "expected loop index in loop expression"); |
165 | 0 | return nullptr; |
166 | 0 | } |
167 | 0 | if (!ConstantFolder::GetConstantValue(*next.right(), &loopInfo->fDelta)) { |
168 | 0 | errors.error(loopNext->fPosition, |
169 | 0 | "loop index must be modified by a constant expression"); |
170 | 0 | return nullptr; |
171 | 0 | } |
172 | 0 | switch (next.getOperator().kind()) { |
173 | 0 | case Operator::Kind::PLUSEQ: break; |
174 | 0 | case Operator::Kind::MINUSEQ: loopInfo->fDelta = -loopInfo->fDelta; break; |
175 | 0 | default: |
176 | 0 | errors.error(loopNext->fPosition, "invalid operator in loop expression"); |
177 | 0 | return nullptr; |
178 | 0 | } |
179 | 0 | break; |
180 | 0 | } |
181 | 44 | case Expression::Kind::kPrefix: { |
182 | 44 | const PrefixExpression& next = loopNext->as<PrefixExpression>(); |
183 | 44 | if (!is_loop_index(next.operand())) { |
184 | 0 | errors.error(loopNext->fPosition, "expected loop index in loop expression"); |
185 | 0 | return nullptr; |
186 | 0 | } |
187 | 44 | switch (next.getOperator().kind()) { |
188 | 44 | case Operator::Kind::PLUSPLUS: loopInfo->fDelta = 1; break; |
189 | 0 | case Operator::Kind::MINUSMINUS: loopInfo->fDelta = -1; break; |
190 | 0 | default: |
191 | 0 | errors.error(loopNext->fPosition, "invalid operator in loop expression"); |
192 | 0 | return nullptr; |
193 | 44 | } |
194 | 44 | break; |
195 | 44 | } |
196 | 44 | case Expression::Kind::kPostfix: { |
197 | 0 | const PostfixExpression& next = loopNext->as<PostfixExpression>(); |
198 | 0 | if (!is_loop_index(next.operand())) { |
199 | 0 | errors.error(loopNext->fPosition, "expected loop index in loop expression"); |
200 | 0 | return nullptr; |
201 | 0 | } |
202 | 0 | switch (next.getOperator().kind()) { |
203 | 0 | case Operator::Kind::PLUSPLUS: loopInfo->fDelta = 1; break; |
204 | 0 | case Operator::Kind::MINUSMINUS: loopInfo->fDelta = -1; break; |
205 | 0 | default: |
206 | 0 | errors.error(loopNext->fPosition, "invalid operator in loop expression"); |
207 | 0 | return nullptr; |
208 | 0 | } |
209 | 0 | break; |
210 | 0 | } |
211 | 0 | default: |
212 | 0 | errors.error(loopNext->fPosition, "invalid loop expression"); |
213 | 0 | return nullptr; |
214 | 44 | } |
215 | | |
216 | | // |
217 | | // Within the body of the loop, the loop index is not statically assigned to, nor is it used as |
218 | | // argument to a function 'out' or 'inout' parameter. |
219 | | // |
220 | 44 | if (Analysis::StatementWritesToVariable(*loopStatement, *initDecl.var())) { |
221 | 0 | errors.error(loopStatement->fPosition, |
222 | 0 | "loop index must not be modified within body of the loop"); |
223 | 0 | return nullptr; |
224 | 0 | } |
225 | | |
226 | | // Finally, compute the iteration count, based on the bounds, and the termination operator. |
227 | 44 | loopInfo->fCount = 0; |
228 | | |
229 | 44 | switch (cond->getOperator().kind()) { |
230 | 36 | case Operator::Kind::LT: |
231 | 36 | loopInfo->fCount = calculate_count(loopInfo->fStart, loopEnd, loopInfo->fDelta, |
232 | 36 | /*forwards=*/true, /*inclusive=*/false); |
233 | 36 | break; |
234 | | |
235 | 0 | case Operator::Kind::GT: |
236 | 0 | loopInfo->fCount = calculate_count(loopInfo->fStart, loopEnd, loopInfo->fDelta, |
237 | 0 | /*forwards=*/false, /*inclusive=*/false); |
238 | 0 | break; |
239 | | |
240 | 8 | case Operator::Kind::LTEQ: |
241 | 8 | loopInfo->fCount = calculate_count(loopInfo->fStart, loopEnd, loopInfo->fDelta, |
242 | 8 | /*forwards=*/true, /*inclusive=*/true); |
243 | 8 | break; |
244 | | |
245 | 0 | case Operator::Kind::GTEQ: |
246 | 0 | loopInfo->fCount = calculate_count(loopInfo->fStart, loopEnd, loopInfo->fDelta, |
247 | 0 | /*forwards=*/false, /*inclusive=*/true); |
248 | 0 | break; |
249 | | |
250 | 0 | case Operator::Kind::NEQ: { |
251 | 0 | float iterations = sk_ieee_double_divide(loopEnd - loopInfo->fStart, loopInfo->fDelta); |
252 | 0 | loopInfo->fCount = std::ceil(iterations); |
253 | 0 | if (loopInfo->fCount < 0 || loopInfo->fCount != iterations || |
254 | 0 | !std::isfinite(iterations)) { |
255 | | // The loop doesn't reach the exact endpoint and so will never terminate. |
256 | 0 | loopInfo->fCount = kLoopTerminationLimit; |
257 | 0 | } |
258 | 0 | if (loopInfo->fIndex->type().componentType().isFloat()) { |
259 | | // Rewrite `x != n` tests as `x < n` or `x > n` depending on the loop direction. |
260 | | // Less-than and greater-than tests avoid infinite loops caused by rounding error. |
261 | 0 | Operator::Kind op = (loopInfo->fDelta > 0) ? Operator::Kind::LT |
262 | 0 | : Operator::Kind::GT; |
263 | 0 | *loopTest = BinaryExpression::Make(context, |
264 | 0 | cond->fPosition, |
265 | 0 | cond->left()->clone(), |
266 | 0 | op, |
267 | 0 | cond->right()->clone()); |
268 | 0 | cond = &loopTest->get()->as<BinaryExpression>(); |
269 | 0 | } |
270 | 0 | break; |
271 | 0 | } |
272 | 0 | case Operator::Kind::EQEQ: { |
273 | 0 | if (loopInfo->fStart == loopEnd) { |
274 | | // Start and end begin in the same place, so we can run one iteration... |
275 | 0 | if (loopInfo->fDelta) { |
276 | | // ... and then they diverge, so the loop terminates. |
277 | 0 | loopInfo->fCount = 1; |
278 | 0 | } else { |
279 | | // ... but they never diverge, so the loop runs forever. |
280 | 0 | loopInfo->fCount = kLoopTerminationLimit; |
281 | 0 | } |
282 | 0 | } else { |
283 | | // Start never equals end, so the loop will not run a single iteration. |
284 | 0 | loopInfo->fCount = 0; |
285 | 0 | } |
286 | 0 | break; |
287 | 0 | } |
288 | 0 | default: SkUNREACHABLE; |
289 | 44 | } |
290 | | |
291 | 44 | SkASSERT(loopInfo->fCount >= 0); |
292 | 44 | if (loopInfo->fCount >= kLoopTerminationLimit) { |
293 | 0 | errors.error(loopPos, "loop must guarantee termination in fewer iterations"); |
294 | 0 | return nullptr; |
295 | 0 | } |
296 | | |
297 | 44 | return loopInfo; |
298 | 44 | } |
299 | | |
300 | | } // namespace SkSL |