Coverage Report

Created: 2024-09-14 07:19

/src/skia/src/sksl/analysis/SkSLGetLoopUnrollInfo.cpp
Line
Count
Source (jump to first uncovered line)
1
/*
2
 * Copyright 2021 Google LLC
3
 *
4
 * Use of this source code is governed by a BSD-style license that can be
5
 * found in the LICENSE file.
6
 */
7
8
#include "include/core/SkTypes.h"
9
#include "include/private/base/SkFloatingPoint.h"
10
#include "src/sksl/SkSLAnalysis.h"
11
#include "src/sksl/SkSLConstantFolder.h"
12
#include "src/sksl/SkSLErrorReporter.h"
13
#include "src/sksl/SkSLOperator.h"
14
#include "src/sksl/SkSLPosition.h"
15
#include "src/sksl/analysis/SkSLNoOpErrorReporter.h"
16
#include "src/sksl/ir/SkSLBinaryExpression.h"
17
#include "src/sksl/ir/SkSLExpression.h"
18
#include "src/sksl/ir/SkSLForStatement.h"
19
#include "src/sksl/ir/SkSLIRNode.h"
20
#include "src/sksl/ir/SkSLPostfixExpression.h"
21
#include "src/sksl/ir/SkSLPrefixExpression.h"
22
#include "src/sksl/ir/SkSLStatement.h"
23
#include "src/sksl/ir/SkSLType.h"
24
#include "src/sksl/ir/SkSLVarDeclarations.h"
25
#include "src/sksl/ir/SkSLVariable.h"
26
#include "src/sksl/ir/SkSLVariableReference.h"
27
28
#include <cmath>
29
#include <memory>
30
31
namespace SkSL {
32
33
class Context;
34
35
// Loops that run for 100000+ iterations will exceed our program size limit.
36
static constexpr int kLoopTerminationLimit = 100000;
37
38
44
static int calculate_count(double start, double end, double delta, bool forwards, bool inclusive) {
39
44
    if ((forwards && start > end) || (!forwards && start < end)) {
40
        // The loop starts in a completed state (the start has already advanced past the end).
41
0
        return 0;
42
0
    }
43
44
    if ((delta == 0.0) || forwards != (delta > 0.0)) {
44
        // The loop does not progress toward a completed state, and will never terminate.
45
0
        return kLoopTerminationLimit;
46
0
    }
47
44
    double iterations = sk_ieee_double_divide(end - start, delta);
48
44
    double count = std::ceil(iterations);
49
44
    if (inclusive && (count == iterations)) {
50
8
        count += 1.0;
51
8
    }
52
44
    if (count > kLoopTerminationLimit || !std::isfinite(count)) {
53
        // The loop runs for more iterations than we can safely unroll.
54
0
        return kLoopTerminationLimit;
55
0
    }
56
44
    return (int)count;
57
44
}
58
59
std::unique_ptr<LoopUnrollInfo> Analysis::GetLoopUnrollInfo(const Context& context,
60
                                                            Position loopPos,
61
                                                            const ForLoopPositions& positions,
62
                                                            const Statement* loopInitializer,
63
                                                            std::unique_ptr<Expression>* loopTest,
64
                                                            const Expression* loopNext,
65
                                                            const Statement* loopStatement,
66
318
                                                            ErrorReporter* errorPtr) {
67
318
    NoOpErrorReporter unused;
68
318
    ErrorReporter& errors = errorPtr ? *errorPtr : unused;
69
70
318
    auto loopInfo = std::make_unique<LoopUnrollInfo>();
71
72
    //
73
    // init_declaration has the form: type_specifier identifier = constant_expression
74
    //
75
318
    if (!loopInitializer) {
76
145
        Position pos = positions.initPosition.valid() ? positions.initPosition : loopPos;
77
145
        errors.error(pos, "missing init declaration");
78
145
        return nullptr;
79
145
    }
80
173
    if (!loopInitializer->is<VarDeclaration>()) {
81
125
        errors.error(loopInitializer->fPosition, "invalid init declaration");
82
125
        return nullptr;
83
125
    }
84
48
    const VarDeclaration& initDecl = loopInitializer->as<VarDeclaration>();
85
48
    if (!initDecl.baseType().isNumber()) {
86
2
        errors.error(loopInitializer->fPosition, "invalid type for loop index");
87
2
        return nullptr;
88
2
    }
89
46
    if (initDecl.arraySize() != 0) {
90
2
        errors.error(loopInitializer->fPosition, "invalid type for loop index");
91
2
        return nullptr;
92
2
    }
93
44
    if (!initDecl.value()) {
94
0
        errors.error(loopInitializer->fPosition, "missing loop index initializer");
95
0
        return nullptr;
96
0
    }
97
44
    if (!ConstantFolder::GetConstantValue(*initDecl.value(), &loopInfo->fStart)) {
98
0
        errors.error(loopInitializer->fPosition,
99
0
                     "loop index initializer must be a constant expression");
100
0
        return nullptr;
101
0
    }
102
103
44
    loopInfo->fIndex = initDecl.var();
104
105
88
    auto is_loop_index = [&](const std::unique_ptr<Expression>& expr) {
106
88
        return expr->is<VariableReference>() &&
107
88
               expr->as<VariableReference>().variable() == loopInfo->fIndex;
108
88
    };
109
110
    //
111
    // condition has the form: loop_index relational_operator constant_expression
112
    //
113
44
    if (!loopTest || !*loopTest) {
114
0
        Position pos = positions.conditionPosition.valid() ? positions.conditionPosition : loopPos;
115
0
        errors.error(pos, "missing condition");
116
0
        return nullptr;
117
0
    }
118
44
    if (!loopTest->get()->is<BinaryExpression>()) {
119
0
        errors.error(loopTest->get()->fPosition, "invalid condition");
120
0
        return nullptr;
121
0
    }
122
44
    const BinaryExpression* cond = &loopTest->get()->as<BinaryExpression>();
123
44
    if (!is_loop_index(cond->left())) {
124
0
        errors.error(cond->fPosition, "expected loop index on left hand side of condition");
125
0
        return nullptr;
126
0
    }
127
    // relational_operator is one of: > >= < <= == or !=
128
44
    switch (cond->getOperator().kind()) {
129
0
        case Operator::Kind::GT:
130
0
        case Operator::Kind::GTEQ:
131
36
        case Operator::Kind::LT:
132
44
        case Operator::Kind::LTEQ:
133
44
        case Operator::Kind::EQEQ:
134
44
        case Operator::Kind::NEQ:
135
44
            break;
136
0
        default:
137
0
            errors.error(cond->fPosition, "invalid relational operator");
138
0
            return nullptr;
139
44
    }
140
44
    double loopEnd = 0;
141
44
    if (!ConstantFolder::GetConstantValue(*cond->right(), &loopEnd)) {
142
0
        errors.error(cond->fPosition, "loop index must be compared with a constant expression");
143
0
        return nullptr;
144
0
    }
145
146
    //
147
    // expression has one of the following forms:
148
    //   loop_index++
149
    //   loop_index--
150
    //   loop_index += constant_expression
151
    //   loop_index -= constant_expression
152
    // The spec doesn't mention prefix increment and decrement, but there is some consensus that
153
    // it's an oversight, so we allow those as well.
154
    //
155
44
    if (!loopNext) {
156
0
        Position pos = positions.nextPosition.valid() ? positions.nextPosition : loopPos;
157
0
        errors.error(pos, "missing loop expression");
158
0
        return nullptr;
159
0
    }
160
44
    switch (loopNext->kind()) {
161
0
        case Expression::Kind::kBinary: {
162
0
            const BinaryExpression& next = loopNext->as<BinaryExpression>();
163
0
            if (!is_loop_index(next.left())) {
164
0
                errors.error(loopNext->fPosition, "expected loop index in loop expression");
165
0
                return nullptr;
166
0
            }
167
0
            if (!ConstantFolder::GetConstantValue(*next.right(), &loopInfo->fDelta)) {
168
0
                errors.error(loopNext->fPosition,
169
0
                             "loop index must be modified by a constant expression");
170
0
                return nullptr;
171
0
            }
172
0
            switch (next.getOperator().kind()) {
173
0
                case Operator::Kind::PLUSEQ:                                        break;
174
0
                case Operator::Kind::MINUSEQ: loopInfo->fDelta = -loopInfo->fDelta; break;
175
0
                default:
176
0
                    errors.error(loopNext->fPosition, "invalid operator in loop expression");
177
0
                    return nullptr;
178
0
            }
179
0
            break;
180
0
        }
181
44
        case Expression::Kind::kPrefix: {
182
44
            const PrefixExpression& next = loopNext->as<PrefixExpression>();
183
44
            if (!is_loop_index(next.operand())) {
184
0
                errors.error(loopNext->fPosition, "expected loop index in loop expression");
185
0
                return nullptr;
186
0
            }
187
44
            switch (next.getOperator().kind()) {
188
44
                case Operator::Kind::PLUSPLUS:   loopInfo->fDelta =  1; break;
189
0
                case Operator::Kind::MINUSMINUS: loopInfo->fDelta = -1; break;
190
0
                default:
191
0
                    errors.error(loopNext->fPosition, "invalid operator in loop expression");
192
0
                    return nullptr;
193
44
            }
194
44
            break;
195
44
        }
196
44
        case Expression::Kind::kPostfix: {
197
0
            const PostfixExpression& next = loopNext->as<PostfixExpression>();
198
0
            if (!is_loop_index(next.operand())) {
199
0
                errors.error(loopNext->fPosition, "expected loop index in loop expression");
200
0
                return nullptr;
201
0
            }
202
0
            switch (next.getOperator().kind()) {
203
0
                case Operator::Kind::PLUSPLUS:   loopInfo->fDelta =  1; break;
204
0
                case Operator::Kind::MINUSMINUS: loopInfo->fDelta = -1; break;
205
0
                default:
206
0
                    errors.error(loopNext->fPosition, "invalid operator in loop expression");
207
0
                    return nullptr;
208
0
            }
209
0
            break;
210
0
        }
211
0
        default:
212
0
            errors.error(loopNext->fPosition, "invalid loop expression");
213
0
            return nullptr;
214
44
    }
215
216
    //
217
    // Within the body of the loop, the loop index is not statically assigned to, nor is it used as
218
    // argument to a function 'out' or 'inout' parameter.
219
    //
220
44
    if (Analysis::StatementWritesToVariable(*loopStatement, *initDecl.var())) {
221
0
        errors.error(loopStatement->fPosition,
222
0
                     "loop index must not be modified within body of the loop");
223
0
        return nullptr;
224
0
    }
225
226
    // Finally, compute the iteration count, based on the bounds, and the termination operator.
227
44
    loopInfo->fCount = 0;
228
229
44
    switch (cond->getOperator().kind()) {
230
36
        case Operator::Kind::LT:
231
36
            loopInfo->fCount = calculate_count(loopInfo->fStart, loopEnd, loopInfo->fDelta,
232
36
                                              /*forwards=*/true, /*inclusive=*/false);
233
36
            break;
234
235
0
        case Operator::Kind::GT:
236
0
            loopInfo->fCount = calculate_count(loopInfo->fStart, loopEnd, loopInfo->fDelta,
237
0
                                              /*forwards=*/false, /*inclusive=*/false);
238
0
            break;
239
240
8
        case Operator::Kind::LTEQ:
241
8
            loopInfo->fCount = calculate_count(loopInfo->fStart, loopEnd, loopInfo->fDelta,
242
8
                                              /*forwards=*/true, /*inclusive=*/true);
243
8
            break;
244
245
0
        case Operator::Kind::GTEQ:
246
0
            loopInfo->fCount = calculate_count(loopInfo->fStart, loopEnd, loopInfo->fDelta,
247
0
                                              /*forwards=*/false, /*inclusive=*/true);
248
0
            break;
249
250
0
        case Operator::Kind::NEQ: {
251
0
            float iterations = sk_ieee_double_divide(loopEnd - loopInfo->fStart, loopInfo->fDelta);
252
0
            loopInfo->fCount = std::ceil(iterations);
253
0
            if (loopInfo->fCount < 0 || loopInfo->fCount != iterations ||
254
0
                !std::isfinite(iterations)) {
255
                // The loop doesn't reach the exact endpoint and so will never terminate.
256
0
                loopInfo->fCount = kLoopTerminationLimit;
257
0
            }
258
0
            if (loopInfo->fIndex->type().componentType().isFloat()) {
259
                // Rewrite `x != n` tests as `x < n` or `x > n` depending on the loop direction.
260
                // Less-than and greater-than tests avoid infinite loops caused by rounding error.
261
0
                Operator::Kind op = (loopInfo->fDelta > 0) ? Operator::Kind::LT
262
0
                                                           : Operator::Kind::GT;
263
0
                *loopTest = BinaryExpression::Make(context,
264
0
                                                   cond->fPosition,
265
0
                                                   cond->left()->clone(),
266
0
                                                   op,
267
0
                                                   cond->right()->clone());
268
0
                cond = &loopTest->get()->as<BinaryExpression>();
269
0
            }
270
0
            break;
271
0
        }
272
0
        case Operator::Kind::EQEQ: {
273
0
            if (loopInfo->fStart == loopEnd) {
274
                // Start and end begin in the same place, so we can run one iteration...
275
0
                if (loopInfo->fDelta) {
276
                    // ... and then they diverge, so the loop terminates.
277
0
                    loopInfo->fCount = 1;
278
0
                } else {
279
                    // ... but they never diverge, so the loop runs forever.
280
0
                    loopInfo->fCount = kLoopTerminationLimit;
281
0
                }
282
0
            } else {
283
                // Start never equals end, so the loop will not run a single iteration.
284
0
                loopInfo->fCount = 0;
285
0
            }
286
0
            break;
287
0
        }
288
0
        default: SkUNREACHABLE;
289
44
    }
290
291
44
    SkASSERT(loopInfo->fCount >= 0);
292
44
    if (loopInfo->fCount >= kLoopTerminationLimit) {
293
0
        errors.error(loopPos, "loop must guarantee termination in fewer iterations");
294
0
        return nullptr;
295
0
    }
296
297
44
    return loopInfo;
298
44
}
299
300
}  // namespace SkSL