Coverage Report

Created: 2024-09-14 07:19

/src/skia/src/sksl/SkSLConstantFolder.cpp
Line
Count
Source (jump to first uncovered line)
1
/*
2
 * Copyright 2020 Google LLC
3
 *
4
 * Use of this source code is governed by a BSD-style license that can be
5
 * found in the LICENSE file.
6
 */
7
8
#include "src/sksl/SkSLConstantFolder.h"
9
10
#include "include/core/SkTypes.h"
11
#include "include/private/base/SkFloatingPoint.h"
12
#include "include/private/base/SkTArray.h"
13
#include "src/sksl/SkSLAnalysis.h"
14
#include "src/sksl/SkSLContext.h"
15
#include "src/sksl/SkSLErrorReporter.h"
16
#include "src/sksl/SkSLPosition.h"
17
#include "src/sksl/SkSLProgramSettings.h"
18
#include "src/sksl/ir/SkSLBinaryExpression.h"
19
#include "src/sksl/ir/SkSLConstructorCompound.h"
20
#include "src/sksl/ir/SkSLConstructorDiagonalMatrix.h"
21
#include "src/sksl/ir/SkSLConstructorSplat.h"
22
#include "src/sksl/ir/SkSLExpression.h"
23
#include "src/sksl/ir/SkSLLiteral.h"
24
#include "src/sksl/ir/SkSLModifierFlags.h"
25
#include "src/sksl/ir/SkSLPrefixExpression.h"
26
#include "src/sksl/ir/SkSLType.h"
27
#include "src/sksl/ir/SkSLVariable.h"
28
#include "src/sksl/ir/SkSLVariableReference.h"
29
30
#include <cstdint>
31
#include <float.h>
32
#include <limits>
33
#include <optional>
34
#include <string>
35
#include <utility>
36
37
using namespace skia_private;
38
39
namespace SkSL {
40
41
23.0k
static bool is_vec_or_mat(const Type& type) {
42
23.0k
    switch (type.typeKind()) {
43
6.04k
        case Type::TypeKind::kMatrix:
44
18.1k
        case Type::TypeKind::kVector:
45
18.1k
            return true;
46
47
4.90k
        default:
48
4.90k
            return false;
49
23.0k
    }
50
23.0k
}
51
52
static std::unique_ptr<Expression> eliminate_no_op_boolean(Position pos,
53
                                                           const Expression& left,
54
                                                           Operator op,
55
2.81k
                                                           const Expression& right) {
56
2.81k
    bool rightVal = right.as<Literal>().boolValue();
57
58
    // Detect no-op Boolean expressions and optimize them away.
59
2.81k
    if ((op.kind() == Operator::Kind::LOGICALAND && rightVal)  ||  // (expr && true)  -> (expr)
60
2.81k
        (op.kind() == Operator::Kind::LOGICALOR  && !rightVal) ||  // (expr || false) -> (expr)
61
2.81k
        (op.kind() == Operator::Kind::LOGICALXOR && !rightVal) ||  // (expr ^^ false) -> (expr)
62
2.81k
        (op.kind() == Operator::Kind::EQEQ       && rightVal)  ||  // (expr == true)  -> (expr)
63
2.81k
        (op.kind() == Operator::Kind::NEQ        && !rightVal)) {  // (expr != false) -> (expr)
64
65
1.25k
        return left.clone(pos);
66
1.25k
    }
67
68
1.56k
    return nullptr;
69
2.81k
}
70
71
static std::unique_ptr<Expression> short_circuit_boolean(Position pos,
72
                                                         const Expression& left,
73
                                                         Operator op,
74
2.75k
                                                         const Expression& right) {
75
2.75k
    bool leftVal = left.as<Literal>().boolValue();
76
77
    // When the literal is on the left, we can sometimes eliminate the other expression entirely.
78
2.75k
    if ((op.kind() == Operator::Kind::LOGICALAND && !leftVal) ||  // (false && expr) -> (false)
79
2.75k
        (op.kind() == Operator::Kind::LOGICALOR  && leftVal)) {   // (true  || expr) -> (true)
80
81
329
        return left.clone(pos);
82
329
    }
83
84
    // We can't eliminate the right-side expression via short-circuit, but we might still be able to
85
    // simplify away a no-op expression.
86
2.42k
    return eliminate_no_op_boolean(pos, right, op, left);
87
2.75k
}
88
89
static std::unique_ptr<Expression> simplify_constant_equality(const Context& context,
90
                                                              Position pos,
91
                                                              const Expression& left,
92
                                                              Operator op,
93
11.0k
                                                              const Expression& right) {
94
11.0k
    if (op.kind() == Operator::Kind::EQEQ || op.kind() == Operator::Kind::NEQ) {
95
274
        bool equality = (op.kind() == Operator::Kind::EQEQ);
96
97
274
        switch (left.compareConstant(right)) {
98
173
            case Expression::ComparisonResult::kNotEqual:
99
173
                equality = !equality;
100
173
                [[fallthrough]];
101
102
274
            case Expression::ComparisonResult::kEqual:
103
274
                return Literal::MakeBool(context, pos, equality);
104
105
0
            case Expression::ComparisonResult::kUnknown:
106
0
                break;
107
274
        }
108
274
    }
109
10.7k
    return nullptr;
110
11.0k
}
111
112
static std::unique_ptr<Expression> simplify_matrix_multiplication(const Context& context,
113
                                                                  Position pos,
114
                                                                  const Expression& left,
115
                                                                  const Expression& right,
116
                                                                  int leftColumns,
117
                                                                  int leftRows,
118
                                                                  int rightColumns,
119
223
                                                                  int rightRows) {
120
223
    const Type& componentType = left.type().componentType();
121
223
    SkASSERT(componentType.matches(right.type().componentType()));
122
123
    // Fetch the left matrix.
124
223
    double leftVals[4][4];
125
975
    for (int c = 0; c < leftColumns; ++c) {
126
3.14k
        for (int r = 0; r < leftRows; ++r) {
127
2.38k
            leftVals[c][r] = *left.getConstantValue((c * leftRows) + r);
128
2.38k
        }
129
752
    }
130
    // Fetch the right matrix.
131
223
    double rightVals[4][4];
132
905
    for (int c = 0; c < rightColumns; ++c) {
133
3.10k
        for (int r = 0; r < rightRows; ++r) {
134
2.42k
            rightVals[c][r] = *right.getConstantValue((c * rightRows) + r);
135
2.42k
        }
136
682
    }
137
138
223
    SkASSERT(leftColumns == rightRows);
139
223
    int outColumns   = rightColumns,
140
223
        outRows      = leftRows;
141
142
223
    double args[16];
143
223
    int argIndex = 0;
144
863
    for (int c = 0; c < outColumns; ++c) {
145
2.69k
        for (int r = 0; r < outRows; ++r) {
146
            // Compute a dot product for this position.
147
2.05k
            double val = 0;
148
9.76k
            for (int dotIdx = 0; dotIdx < leftColumns; ++dotIdx) {
149
7.71k
                val += leftVals[dotIdx][r] * rightVals[c][dotIdx];
150
7.71k
            }
151
152
2.05k
            if (val >= -FLT_MAX && val <= FLT_MAX) {
153
2.03k
                args[argIndex++] = val;
154
2.03k
            } else {
155
                // The value is outside the 32-bit float range, or is NaN; do not optimize.
156
21
                return nullptr;
157
21
            }
158
2.05k
        }
159
661
    }
160
161
202
    if (outColumns == 1) {
162
        // Matrix-times-vector conceptually makes a 1-column N-row matrix, but we return vecN.
163
25
        std::swap(outColumns, outRows);
164
25
    }
165
166
202
    const Type& resultType = componentType.toCompound(context, outColumns, outRows);
167
202
    return ConstructorCompound::MakeFromConstants(context, pos, resultType, args);
168
223
}
169
170
static std::unique_ptr<Expression> simplify_matrix_times_matrix(const Context& context,
171
                                                                Position pos,
172
                                                                const Expression& left,
173
165
                                                                const Expression& right) {
174
165
    const Type& leftType = left.type();
175
165
    const Type& rightType = right.type();
176
177
165
    SkASSERT(leftType.isMatrix());
178
165
    SkASSERT(rightType.isMatrix());
179
180
165
    return simplify_matrix_multiplication(context, pos, left, right,
181
165
                                          leftType.columns(), leftType.rows(),
182
165
                                          rightType.columns(), rightType.rows());
183
165
}
184
185
static std::unique_ptr<Expression> simplify_vector_times_matrix(const Context& context,
186
                                                                Position pos,
187
                                                                const Expression& left,
188
33
                                                                const Expression& right) {
189
33
    const Type& leftType = left.type();
190
33
    const Type& rightType = right.type();
191
192
33
    SkASSERT(leftType.isVector());
193
33
    SkASSERT(rightType.isMatrix());
194
195
33
    return simplify_matrix_multiplication(context, pos, left, right,
196
33
                                          /*leftColumns=*/leftType.columns(), /*leftRows=*/1,
197
33
                                          rightType.columns(), rightType.rows());
198
33
}
199
200
static std::unique_ptr<Expression> simplify_matrix_times_vector(const Context& context,
201
                                                                Position pos,
202
                                                                const Expression& left,
203
25
                                                                const Expression& right) {
204
25
    const Type& leftType = left.type();
205
25
    const Type& rightType = right.type();
206
207
25
    SkASSERT(leftType.isMatrix());
208
25
    SkASSERT(rightType.isVector());
209
210
25
    return simplify_matrix_multiplication(context, pos, left, right,
211
25
                                          leftType.columns(), leftType.rows(),
212
25
                                          /*rightColumns=*/1, /*rightRows=*/rightType.columns());
213
25
}
214
215
static std::unique_ptr<Expression> simplify_componentwise(const Context& context,
216
                                                          Position pos,
217
                                                          const Expression& left,
218
                                                          Operator op,
219
10.9k
                                                          const Expression& right) {
220
10.9k
    SkASSERT(is_vec_or_mat(left.type()));
221
10.9k
    SkASSERT(left.type().matches(right.type()));
222
10.9k
    const Type& type = left.type();
223
224
    // Handle equality operations: == !=
225
10.9k
    if (std::unique_ptr<Expression> result = simplify_constant_equality(context, pos, left, op,
226
10.9k
            right)) {
227
274
        return result;
228
274
    }
229
230
    // Handle floating-point arithmetic: + - * /
231
10.7k
    using FoldFn = double (*)(double, double);
232
10.7k
    FoldFn foldFn;
233
10.7k
    switch (op.kind()) {
234
16.1k
        case Operator::Kind::PLUS:  foldFn = +[](double a, double b) { return a + b; }; break;
235
10.3k
        case Operator::Kind::MINUS: foldFn = +[](double a, double b) { return a - b; }; break;
236
26.1k
        case Operator::Kind::STAR:  foldFn = +[](double a, double b) { return a * b; }; break;
SkSLConstantFolder.cpp:SkSL::simplify_componentwise(SkSL::Context const&, SkSL::Position, SkSL::Expression const&, SkSL::Operator, SkSL::Expression const&)::$_2::operator()(double, double) const
Line
Count
Source
236
26.1k
        case Operator::Kind::STAR:  foldFn = +[](double a, double b) { return a * b; }; break;
Unexecuted instantiation: SkSLConstantFolder.cpp:SkSL::simplify_componentwise(SkSL::Context const&, SkSL::Position, SkSL::Expression const&, SkSL::Operator, SkSL::Expression const&)::$_4::operator()(double, double) const
237
2.88k
        case Operator::Kind::SLASH: foldFn = +[](double a, double b) { return a / b; }; break;
SkSLConstantFolder.cpp:SkSL::simplify_componentwise(SkSL::Context const&, SkSL::Position, SkSL::Expression const&, SkSL::Operator, SkSL::Expression const&)::$_3::operator()(double, double) const
Line
Count
Source
237
2.88k
        case Operator::Kind::SLASH: foldFn = +[](double a, double b) { return a / b; }; break;
Unexecuted instantiation: SkSLConstantFolder.cpp:SkSL::simplify_componentwise(SkSL::Context const&, SkSL::Position, SkSL::Expression const&, SkSL::Operator, SkSL::Expression const&)::$_5::operator()(double, double) const
238
54
        default:
239
54
            return nullptr;
240
10.7k
    }
241
242
10.6k
    const Type& componentType = type.componentType();
243
10.6k
    SkASSERT(componentType.isNumber());
244
245
10.6k
    double minimumValue = componentType.minimumValue();
246
10.6k
    double maximumValue = componentType.maximumValue();
247
248
10.6k
    double args[16];
249
10.6k
    int numSlots = type.slotCount();
250
66.1k
    for (int i = 0; i < numSlots; i++) {
251
55.5k
        double value = foldFn(*left.getConstantValue(i), *right.getConstantValue(i));
252
55.5k
        if (value < minimumValue || value > maximumValue) {
253
72
            return nullptr;
254
72
        }
255
55.4k
        args[i] = value;
256
55.4k
    }
257
10.5k
    return ConstructorCompound::MakeFromConstants(context, pos, type, args);
258
10.6k
}
SkSLConstantFolder.cpp:SkSL::simplify_componentwise(SkSL::Context const&, SkSL::Position, SkSL::Expression const&, SkSL::Operator, SkSL::Expression const&)
Line
Count
Source
219
10.9k
                                                          const Expression& right) {
220
10.9k
    SkASSERT(is_vec_or_mat(left.type()));
221
10.9k
    SkASSERT(left.type().matches(right.type()));
222
10.9k
    const Type& type = left.type();
223
224
    // Handle equality operations: == !=
225
10.9k
    if (std::unique_ptr<Expression> result = simplify_constant_equality(context, pos, left, op,
226
10.9k
            right)) {
227
274
        return result;
228
274
    }
229
230
    // Handle floating-point arithmetic: + - * /
231
10.7k
    using FoldFn = double (*)(double, double);
232
10.7k
    FoldFn foldFn;
233
10.7k
    switch (op.kind()) {
234
2.70k
        case Operator::Kind::PLUS:  foldFn = +[](double a, double b) { return a + b; }; break;
235
1.69k
        case Operator::Kind::MINUS: foldFn = +[](double a, double b) { return a - b; }; break;
236
5.63k
        case Operator::Kind::STAR:  foldFn = +[](double a, double b) { return a * b; }; break;
237
613
        case Operator::Kind::SLASH: foldFn = +[](double a, double b) { return a / b; }; break;
238
54
        default:
239
54
            return nullptr;
240
10.7k
    }
241
242
10.6k
    const Type& componentType = type.componentType();
243
10.6k
    SkASSERT(componentType.isNumber());
244
245
10.6k
    double minimumValue = componentType.minimumValue();
246
10.6k
    double maximumValue = componentType.maximumValue();
247
248
10.6k
    double args[16];
249
10.6k
    int numSlots = type.slotCount();
250
66.1k
    for (int i = 0; i < numSlots; i++) {
251
55.5k
        double value = foldFn(*left.getConstantValue(i), *right.getConstantValue(i));
252
55.5k
        if (value < minimumValue || value > maximumValue) {
253
72
            return nullptr;
254
72
        }
255
55.4k
        args[i] = value;
256
55.4k
    }
257
10.5k
    return ConstructorCompound::MakeFromConstants(context, pos, type, args);
258
10.6k
}
Unexecuted instantiation: SkSLConstantFolder.cpp:SkSL::simplify_componentwise(SkSL::Context const&, SkSL::Position, SkSL::Expression const&, SkSL::Operator, SkSL::Expression const&)
259
260
static std::unique_ptr<Expression> splat_scalar(const Context& context,
261
                                                const Expression& scalar,
262
10.2k
                                                const Type& type) {
263
10.2k
    if (type.isVector()) {
264
6.73k
        return ConstructorSplat::Make(context, scalar.fPosition, type, scalar.clone());
265
6.73k
    }
266
3.50k
    if (type.isMatrix()) {
267
3.50k
        int numSlots = type.slotCount();
268
3.50k
        ExpressionArray splatMatrix;
269
3.50k
        splatMatrix.reserve_exact(numSlots);
270
38.5k
        for (int index = 0; index < numSlots; ++index) {
271
35.0k
            splatMatrix.push_back(scalar.clone());
272
35.0k
        }
273
3.50k
        return ConstructorCompound::Make(context, scalar.fPosition, type, std::move(splatMatrix));
274
3.50k
    }
275
0
    SkDEBUGFAILF("unsupported type %s", type.description().c_str());
276
0
    return nullptr;
277
3.50k
}
278
279
static std::unique_ptr<Expression> cast_expression(const Context& context,
280
                                                   Position pos,
281
                                                   const Expression& expr,
282
6.47k
                                                   const Type& type) {
283
6.47k
    SkASSERT(type.componentType().matches(expr.type().componentType()));
284
6.47k
    if (expr.type().isScalar()) {
285
2.04k
        if (type.isMatrix()) {
286
167
            return ConstructorDiagonalMatrix::Make(context, pos, type, expr.clone());
287
167
        }
288
1.87k
        if (type.isVector()) {
289
104
            return ConstructorSplat::Make(context, pos, type, expr.clone());
290
104
        }
291
1.87k
    }
292
6.20k
    if (type.matches(expr.type())) {
293
6.10k
        return expr.clone(pos);
294
6.10k
    }
295
    // We can't cast matrices into vectors or vice-versa.
296
101
    return nullptr;
297
6.20k
}
298
299
static std::unique_ptr<Expression> zero_expression(const Context& context,
300
                                                   Position pos,
301
3.19k
                                                   const Type& type) {
302
3.19k
    std::unique_ptr<Expression> zero = Literal::Make(pos, 0.0, &type.componentType());
303
3.19k
    if (type.isScalar()) {
304
640
        return zero;
305
640
    }
306
2.55k
    if (type.isVector()) {
307
916
        return ConstructorSplat::Make(context, pos, type, std::move(zero));
308
916
    }
309
1.63k
    if (type.isMatrix()) {
310
1.63k
        return ConstructorDiagonalMatrix::Make(context, pos, type, std::move(zero));
311
1.63k
    }
312
0
    SkDEBUGFAILF("unsupported type %s", type.description().c_str());
313
0
    return nullptr;
314
1.63k
}
SkSLConstantFolder.cpp:SkSL::zero_expression(SkSL::Context const&, SkSL::Position, SkSL::Type const&)
Line
Count
Source
301
3.19k
                                                   const Type& type) {
302
3.19k
    std::unique_ptr<Expression> zero = Literal::Make(pos, 0.0, &type.componentType());
303
3.19k
    if (type.isScalar()) {
304
640
        return zero;
305
640
    }
306
2.55k
    if (type.isVector()) {
307
916
        return ConstructorSplat::Make(context, pos, type, std::move(zero));
308
916
    }
309
1.63k
    if (type.isMatrix()) {
310
1.63k
        return ConstructorDiagonalMatrix::Make(context, pos, type, std::move(zero));
311
1.63k
    }
312
0
    SkDEBUGFAILF("unsupported type %s", type.description().c_str());
313
0
    return nullptr;
314
1.63k
}
Unexecuted instantiation: SkSLConstantFolder.cpp:SkSL::zero_expression(SkSL::Context const&, SkSL::Position, SkSL::Type const&)
315
316
static std::unique_ptr<Expression> negate_expression(const Context& context,
317
                                                     Position pos,
318
                                                     const Expression& expr,
319
1.53k
                                                     const Type& type) {
320
1.53k
    std::unique_ptr<Expression> ctor = cast_expression(context, pos, expr, type);
321
1.53k
    return ctor ? PrefixExpression::Make(context, pos, Operator::Kind::MINUS, std::move(ctor))
322
1.53k
                : nullptr;
323
1.53k
}
324
325
4.39k
bool ConstantFolder::GetConstantInt(const Expression& value, SKSL_INT* out) {
326
4.39k
    const Expression* expr = GetConstantValueForVariable(value);
327
4.39k
    if (!expr->isIntLiteral()) {
328
1.68k
        return false;
329
1.68k
    }
330
2.70k
    *out = expr->as<Literal>().intValue();
331
2.70k
    return true;
332
4.39k
}
333
334
605
bool ConstantFolder::GetConstantValue(const Expression& value, double* out) {
335
605
    const Expression* expr = GetConstantValueForVariable(value);
336
605
    if (!expr->is<Literal>()) {
337
489
        return false;
338
489
    }
339
116
    *out = expr->as<Literal>().value();
340
116
    return true;
341
605
}
342
343
12.4k
static bool contains_constant_zero(const Expression& expr) {
344
12.4k
    int numSlots = expr.type().slotCount();
345
30.9k
    for (int index = 0; index < numSlots; ++index) {
346
22.4k
        std::optional<double> slotVal = expr.getConstantValue(index);
347
22.4k
        if (slotVal.has_value() && *slotVal == 0.0) {
348
3.94k
            return true;
349
3.94k
        }
350
22.4k
    }
351
8.48k
    return false;
352
12.4k
}
353
354
43.3k
bool ConstantFolder::IsConstantSplat(const Expression& expr, double value) {
355
43.3k
    int numSlots = expr.type().slotCount();
356
59.6k
    for (int index = 0; index < numSlots; ++index) {
357
50.1k
        std::optional<double> slotVal = expr.getConstantValue(index);
358
50.1k
        if (!slotVal.has_value() || *slotVal != value) {
359
33.8k
            return false;
360
33.8k
        }
361
50.1k
    }
362
9.47k
    return true;
363
43.3k
}
364
365
// Returns true if the expression is a square diagonal matrix containing `value`.
366
8.44k
static bool is_constant_diagonal(const Expression& expr, double value) {
367
8.44k
    SkASSERT(expr.type().isMatrix());
368
8.44k
    int columns = expr.type().columns();
369
8.44k
    int rows = expr.type().rows();
370
8.44k
    if (columns != rows) {
371
0
        return false;
372
0
    }
373
8.44k
    int slotIdx = 0;
374
9.72k
    for (int c = 0; c < columns; ++c) {
375
13.3k
        for (int r = 0; r < rows; ++r) {
376
12.1k
            double expectation = (c == r) ? value : 0;
377
12.1k
            std::optional<double> slotVal = expr.getConstantValue(slotIdx++);
378
12.1k
            if (!slotVal.has_value() || *slotVal != expectation) {
379
8.01k
                return false;
380
8.01k
            }
381
12.1k
        }
382
9.29k
    }
383
434
    return true;
384
8.44k
}
385
386
// Returns true if the expression is a scalar, vector, or diagonal matrix containing `value`.
387
36.5k
static bool is_constant_value(const Expression& expr, double value) {
388
36.5k
    return expr.type().isMatrix() ? is_constant_diagonal(expr, value)
389
36.5k
                                  : ConstantFolder::IsConstantSplat(expr, value);
390
36.5k
}
391
392
// The expression represents the right-hand side of a division op. If the division can be
393
// strength-reduced into multiplication by a reciprocal, returns that reciprocal as an expression.
394
// Note that this only supports literal values with safe-to-use reciprocals, and returns null if
395
// Expression contains anything else.
396
static std::unique_ptr<Expression> make_reciprocal_expression(const Context& context,
397
2.26k
                                                              const Expression& right) {
398
2.26k
    if (right.type().isMatrix() || !right.type().componentType().isFloat()) {
399
741
        return nullptr;
400
741
    }
401
    // Verify that each slot contains a finite, non-zero literal, take its reciprocal.
402
1.52k
    double values[4];
403
1.52k
    int nslots = right.type().slotCount();
404
3.12k
    for (int index = 0; index < nslots; ++index) {
405
2.07k
        std::optional<double> value = right.getConstantValue(index);
406
2.07k
        if (!value) {
407
462
            return nullptr;
408
462
        }
409
1.60k
        *value = sk_ieee_double_divide(1.0, *value);
410
1.60k
        if (*value >= -FLT_MAX && *value <= FLT_MAX && *value != 0.0) {
411
            // The reciprocal can be represented safely as a finite 32-bit float.
412
1.60k
            values[index] = *value;
413
1.60k
        } else {
414
            // The value is outside the 32-bit float range, or is NaN; do not optimize.
415
0
            return nullptr;
416
0
        }
417
1.60k
    }
418
    // Turn the expression array into a compound constructor. (If this is a single-slot expression,
419
    // this will return the literal as-is.)
420
1.05k
    return ConstructorCompound::MakeFromConstants(context, right.fPosition, right.type(), values);
421
1.52k
}
422
423
static bool error_on_divide_by_zero(const Context& context, Position pos, Operator op,
424
112k
                                    const Expression& right) {
425
112k
    switch (op.kind()) {
426
11.3k
        case Operator::Kind::SLASH:
427
12.4k
        case Operator::Kind::SLASHEQ:
428
12.4k
        case Operator::Kind::PERCENT:
429
12.4k
        case Operator::Kind::PERCENTEQ:
430
12.4k
            if (contains_constant_zero(right)) {
431
3.94k
                context.fErrors->error(pos, "division by zero");
432
3.94k
                return true;
433
3.94k
            }
434
8.48k
            return false;
435
100k
        default:
436
100k
            return false;
437
112k
    }
438
112k
}
439
440
612k
const Expression* ConstantFolder::GetConstantValueOrNull(const Expression& inExpr) {
441
612k
    const Expression* expr = &inExpr;
442
612k
    while (expr->is<VariableReference>()) {
443
24.9k
        const VariableReference& varRef = expr->as<VariableReference>();
444
24.9k
        if (varRef.refKind() != VariableRefKind::kRead) {
445
4.08k
            return nullptr;
446
4.08k
        }
447
20.8k
        const Variable& var = *varRef.variable();
448
20.8k
        if (!var.modifierFlags().isConst()) {
449
20.5k
            return nullptr;
450
20.5k
        }
451
330
        expr = var.initialValue();
452
330
        if (!expr) {
453
            // Generally, const variables must have initial values. However, function parameters are
454
            // an exception; they can be const but won't have an initial value.
455
18
            return nullptr;
456
18
        }
457
330
    }
458
587k
    return Analysis::IsCompileTimeConstant(*expr) ? expr : nullptr;
459
612k
}
460
461
338k
const Expression* ConstantFolder::GetConstantValueForVariable(const Expression& inExpr) {
462
338k
    const Expression* expr = GetConstantValueOrNull(inExpr);
463
338k
    return expr ? expr : &inExpr;
464
338k
}
465
466
std::unique_ptr<Expression> ConstantFolder::MakeConstantValueForVariable(
467
272k
        Position pos, std::unique_ptr<Expression> inExpr) {
468
272k
    const Expression* expr = GetConstantValueOrNull(*inExpr);
469
272k
    return expr ? expr->clone(pos) : std::move(inExpr);
470
272k
}
471
472
16.9k
static bool is_scalar_op_matrix(const Expression& left, const Expression& right) {
473
16.9k
    return left.type().isScalar() && right.type().isMatrix();
474
16.9k
}
475
476
6.74k
static bool is_matrix_op_scalar(const Expression& left, const Expression& right) {
477
6.74k
    return is_scalar_op_matrix(right, left);
478
6.74k
}
479
480
static std::unique_ptr<Expression> simplify_arithmetic(const Context& context,
481
                                                       Position pos,
482
                                                       const Expression& left,
483
                                                       Operator op,
484
                                                       const Expression& right,
485
25.4k
                                                       const Type& resultType) {
486
25.4k
    switch (op.kind()) {
487
4.65k
        case Operator::Kind::PLUS:
488
4.65k
            if (!is_scalar_op_matrix(left, right) &&
489
4.65k
                ConstantFolder::IsConstantSplat(right, 0.0)) {  // x + 0
490
467
                if (std::unique_ptr<Expression> expr = cast_expression(context, pos, left,
491
467
                                                                       resultType)) {
492
467
                    return expr;
493
467
                }
494
467
            }
495
4.19k
            if (!is_matrix_op_scalar(left, right) &&
496
4.19k
                ConstantFolder::IsConstantSplat(left, 0.0)) {  // 0 + x
497
2.17k
                if (std::unique_ptr<Expression> expr = cast_expression(context, pos, right,
498
2.17k
                                                                       resultType)) {
499
2.17k
                    return expr;
500
2.17k
                }
501
2.17k
            }
502
2.01k
            break;
503
504
7.85k
        case Operator::Kind::STAR:
505
7.85k
            if (is_constant_value(right, 1.0)) {  // x * 1
506
325
                if (std::unique_ptr<Expression> expr = cast_expression(context, pos, left,
507
325
                                                                       resultType)) {
508
295
                    return expr;
509
295
                }
510
325
            }
511
7.56k
            if (is_constant_value(left, 1.0)) {   // 1 * x
512
592
                if (std::unique_ptr<Expression> expr = cast_expression(context, pos, right,
513
592
                                                                       resultType)) {
514
565
                    return expr;
515
565
                }
516
592
            }
517
6.99k
            if (is_constant_value(right, 0.0) && !Analysis::HasSideEffects(left)) {  // x * 0
518
591
                return zero_expression(context, pos, resultType);
519
591
            }
520
6.40k
            if (is_constant_value(left, 0.0) && !Analysis::HasSideEffects(right)) {  // 0 * x
521
2.60k
                return zero_expression(context, pos, resultType);
522
2.60k
            }
523
3.80k
            if (is_constant_value(right, -1.0)) {  // x * -1 (to `-x`)
524
368
                if (std::unique_ptr<Expression> expr = negate_expression(context, pos, left,
525
368
                                                                         resultType)) {
526
337
                    return expr;
527
337
                }
528
368
            }
529
3.46k
            if (is_constant_value(left, -1.0)) {  // -1 * x (to `-x`)
530
121
                if (std::unique_ptr<Expression> expr = negate_expression(context, pos, right,
531
121
                                                                         resultType)) {
532
108
                    return expr;
533
108
                }
534
121
            }
535
3.35k
            break;
536
537
3.35k
        case Operator::Kind::MINUS:
538
3.03k
            if (!is_scalar_op_matrix(left, right) &&
539
3.03k
                ConstantFolder::IsConstantSplat(right, 0.0)) {  // x - 0
540
487
                if (std::unique_ptr<Expression> expr = cast_expression(context, pos, left,
541
487
                                                                       resultType)) {
542
487
                    return expr;
543
487
                }
544
487
            }
545
2.55k
            if (!is_matrix_op_scalar(left, right) &&
546
2.55k
                ConstantFolder::IsConstantSplat(left, 0.0)) {  // 0 - x
547
1.05k
                if (std::unique_ptr<Expression> expr = negate_expression(context, pos, right,
548
1.05k
                                                                         resultType)) {
549
1.05k
                    return expr;
550
1.05k
                }
551
1.05k
            }
552
1.50k
            break;
553
554
2.48k
        case Operator::Kind::SLASH:
555
2.48k
            if (!is_scalar_op_matrix(left, right) &&
556
2.48k
                ConstantFolder::IsConstantSplat(right, 1.0)) {  // x / 1
557
332
                if (std::unique_ptr<Expression> expr = cast_expression(context, pos, left,
558
332
                                                                       resultType)) {
559
332
                    return expr;
560
332
                }
561
332
            }
562
2.14k
            if (!left.type().isMatrix()) {  // convert `x / 2` into `x * 0.5`
563
1.85k
                if (std::unique_ptr<Expression> expr = make_reciprocal_expression(context, right)) {
564
694
                    return BinaryExpression::Make(context, pos, left.clone(), Operator::Kind::STAR,
565
694
                                                  std::move(expr));
566
694
                }
567
1.85k
            }
568
1.45k
            break;
569
570
1.45k
        case Operator::Kind::PLUSEQ:
571
283
        case Operator::Kind::MINUSEQ:
572
283
            if (ConstantFolder::IsConstantSplat(right, 0.0)) {  // x += 0, x -= 0
573
178
                if (std::unique_ptr<Expression> var = cast_expression(context, pos, left,
574
178
                                                                      resultType)) {
575
178
                    Analysis::UpdateVariableRefKind(var.get(), VariableRefKind::kRead);
576
178
                    return var;
577
178
                }
578
178
            }
579
105
            break;
580
581
425
        case Operator::Kind::STAREQ:
582
425
            if (is_constant_value(right, 1.0)) {  // x *= 1
583
30
                if (std::unique_ptr<Expression> var = cast_expression(context, pos, left,
584
30
                                                                      resultType)) {
585
30
                    Analysis::UpdateVariableRefKind(var.get(), VariableRefKind::kRead);
586
30
                    return var;
587
30
                }
588
30
            }
589
395
            break;
590
591
756
        case Operator::Kind::SLASHEQ:
592
756
            if (ConstantFolder::IsConstantSplat(right, 1.0)) {  // x /= 1
593
354
                if (std::unique_ptr<Expression> var = cast_expression(context, pos, left,
594
354
                                                                      resultType)) {
595
354
                    Analysis::UpdateVariableRefKind(var.get(), VariableRefKind::kRead);
596
354
                    return var;
597
354
                }
598
354
            }
599
402
            if (std::unique_ptr<Expression> expr = make_reciprocal_expression(context, right)) {
600
364
                return BinaryExpression::Make(context, pos, left.clone(), Operator::Kind::STAREQ,
601
364
                                              std::move(expr));
602
364
            }
603
38
            break;
604
605
5.97k
        default:
606
5.97k
            break;
607
25.4k
    }
608
609
14.8k
    return nullptr;
610
25.4k
}
611
612
// The expression must be scalar, and represents the right-hand side of a division op. It can
613
// contain anything, not just literal values. This returns the binary expression `1.0 / expr`. The
614
// expression might be further simplified by the constant folding, if possible.
615
static std::unique_ptr<Expression> one_over_scalar(const Context& context,
616
301
                                                   const Expression& right) {
617
301
    SkASSERT(right.type().isScalar());
618
301
    Position pos = right.fPosition;
619
301
    return BinaryExpression::Make(context, pos,
620
301
                                  Literal::Make(pos, 1.0, &right.type()),
621
301
                                  Operator::Kind::SLASH,
622
301
                                  right.clone());
623
301
}
624
625
static std::unique_ptr<Expression> simplify_matrix_division(const Context& context,
626
                                                            Position pos,
627
                                                            const Expression& left,
628
                                                            Operator op,
629
                                                            const Expression& right,
630
38.1k
                                                            const Type& resultType) {
631
    // Convert matrix-over-scalar `x /= y` into `x *= (1.0 / y)`. This generates better
632
    // code in SPIR-V and Metal, and should be roughly equivalent elsewhere.
633
38.1k
    switch (op.kind()) {
634
1.86k
        case OperatorKind::SLASH:
635
2.07k
        case OperatorKind::SLASHEQ:
636
2.07k
            if (left.type().isMatrix() && right.type().isScalar()) {
637
301
                Operator multiplyOp = op.isAssignment() ? OperatorKind::STAREQ
638
301
                                                        : OperatorKind::STAR;
639
301
                return BinaryExpression::Make(context, pos,
640
301
                                              left.clone(),
641
301
                                              multiplyOp,
642
301
                                              one_over_scalar(context, right));
643
301
            }
644
1.77k
            break;
645
646
36.0k
        default:
647
36.0k
            break;
648
38.1k
    }
649
650
37.8k
    return nullptr;
651
38.1k
}
652
653
static std::unique_ptr<Expression> fold_expression(Position pos,
654
                                                   double result,
655
46.9k
                                                   const Type* resultType) {
656
46.9k
    if (resultType->isNumber()) {
657
25.6k
        if (result >= resultType->minimumValue() && result <= resultType->maximumValue()) {
658
            // This result will fit inside its type.
659
25.1k
        } else {
660
            // The value is outside the range or is NaN (all if-checks fail); do not optimize.
661
512
            return nullptr;
662
512
        }
663
25.6k
    }
664
665
46.4k
    return Literal::Make(pos, result, resultType);
666
46.9k
}
667
668
static std::unique_ptr<Expression> fold_two_constants(const Context& context,
669
                                                      Position pos,
670
                                                      const Expression* left,
671
                                                      Operator op,
672
                                                      const Expression* right,
673
59.8k
                                                      const Type& resultType) {
674
59.8k
    SkASSERT(Analysis::IsCompileTimeConstant(*left));
675
59.8k
    SkASSERT(Analysis::IsCompileTimeConstant(*right));
676
59.8k
    const Type& leftType = left->type();
677
59.8k
    const Type& rightType = right->type();
678
679
    // Handle pairs of integer literals.
680
59.8k
    if (left->isIntLiteral() && right->isIntLiteral()) {
681
40.4k
        using SKSL_UINT = uint64_t;
682
40.4k
        SKSL_INT leftVal  = left->as<Literal>().intValue();
683
40.4k
        SKSL_INT rightVal = right->as<Literal>().intValue();
684
685
        // Note that fold_expression returns null if the result would overflow its type.
686
40.4k
        #define RESULT(Op)   fold_expression(pos, (SKSL_INT)(leftVal) Op \
687
21.8k
                                                  (SKSL_INT)(rightVal), &resultType)
688
40.4k
        #define URESULT(Op)  fold_expression(pos, (SKSL_INT)((SKSL_UINT)(leftVal) Op \
689
17.6k
                                                  (SKSL_UINT)(rightVal)), &resultType)
690
40.4k
        switch (op.kind()) {
691
7.56k
            case Operator::Kind::PLUS:       return URESULT(+);
692
4.19k
            case Operator::Kind::MINUS:      return URESULT(-);
693
5.93k
            case Operator::Kind::STAR:       return URESULT(*);
694
3.21k
            case Operator::Kind::SLASH:
695
3.21k
                if (leftVal == std::numeric_limits<SKSL_INT>::min() && rightVal == -1) {
696
0
                    context.fErrors->error(pos, "arithmetic overflow");
697
0
                    return nullptr;
698
0
                }
699
3.21k
                return RESULT(/);
700
701
21
            case Operator::Kind::PERCENT:
702
21
                if (leftVal == std::numeric_limits<SKSL_INT>::min() && rightVal == -1) {
703
0
                    context.fErrors->error(pos, "arithmetic overflow");
704
0
                    return nullptr;
705
0
                }
706
21
                return RESULT(%);
707
708
4
            case Operator::Kind::BITWISEAND: return RESULT(&);
709
1
            case Operator::Kind::BITWISEOR:  return RESULT(|);
710
0
            case Operator::Kind::BITWISEXOR: return RESULT(^);
711
1.32k
            case Operator::Kind::EQEQ:       return RESULT(==);
712
1.15k
            case Operator::Kind::NEQ:        return RESULT(!=);
713
8.17k
            case Operator::Kind::GT:         return RESULT(>);
714
374
            case Operator::Kind::GTEQ:       return RESULT(>=);
715
7.24k
            case Operator::Kind::LT:         return RESULT(<);
716
357
            case Operator::Kind::LTEQ:       return RESULT(<=);
717
0
            case Operator::Kind::SHL:
718
0
                if (rightVal >= 0 && rightVal <= 31) {
719
                    // Left-shifting a negative (or really, any signed) value is undefined behavior
720
                    // in C++, but not in GLSL. Do the shift on unsigned values to avoid triggering
721
                    // an UBSAN error.
722
0
                    return URESULT(<<);
723
0
                }
724
0
                context.fErrors->error(pos, "shift value out of range");
725
0
                return nullptr;
726
727
0
            case Operator::Kind::SHR:
728
0
                if (rightVal >= 0 && rightVal <= 31) {
729
0
                    return RESULT(>>);
730
0
                }
731
0
                context.fErrors->error(pos, "shift value out of range");
732
0
                return nullptr;
733
734
864
            default:
735
864
                break;
736
40.4k
        }
737
864
        #undef RESULT
738
864
        #undef URESULT
739
740
864
        return nullptr;
741
40.4k
    }
742
743
    // Handle pairs of floating-point literals.
744
19.4k
    if (left->isFloatLiteral() && right->isFloatLiteral()) {
745
7.50k
        SKSL_FLOAT leftVal  = left->as<Literal>().floatValue();
746
7.50k
        SKSL_FLOAT rightVal = right->as<Literal>().floatValue();
747
748
7.50k
        #define RESULT(Op) fold_expression(pos, leftVal Op rightVal, &resultType)
749
7.50k
        switch (op.kind()) {
750
1.21k
            case Operator::Kind::PLUS:  return RESULT(+);
751
1.49k
            case Operator::Kind::MINUS: return RESULT(-);
752
1.17k
            case Operator::Kind::STAR:  return RESULT(*);
753
817
            case Operator::Kind::SLASH: return RESULT(/);
754
423
            case Operator::Kind::EQEQ:  return RESULT(==);
755
467
            case Operator::Kind::NEQ:   return RESULT(!=);
756
736
            case Operator::Kind::GT:    return RESULT(>);
757
87
            case Operator::Kind::GTEQ:  return RESULT(>=);
758
871
            case Operator::Kind::LT:    return RESULT(<);
759
114
            case Operator::Kind::LTEQ:  return RESULT(<=);
760
95
            default:                    break;
761
7.50k
        }
762
95
        #undef RESULT
763
764
95
        return nullptr;
765
7.50k
    }
766
767
    // Perform matrix multiplication.
768
11.9k
    if (op.kind() == Operator::Kind::STAR) {
769
5.85k
        if (leftType.isMatrix() && rightType.isMatrix()) {
770
165
            return simplify_matrix_times_matrix(context, pos, *left, *right);
771
165
        }
772
5.69k
        if (leftType.isVector() && rightType.isMatrix()) {
773
33
            return simplify_vector_times_matrix(context, pos, *left, *right);
774
33
        }
775
5.66k
        if (leftType.isMatrix() && rightType.isVector()) {
776
25
            return simplify_matrix_times_vector(context, pos, *left, *right);
777
25
        }
778
5.66k
    }
779
780
    // Perform constant folding on pairs of vectors/matrices.
781
11.7k
    if (is_vec_or_mat(leftType) && leftType.matches(rightType)) {
782
743
        return simplify_componentwise(context, pos, *left, op, *right);
783
743
    }
784
785
    // Perform constant folding on vectors/matrices against scalars, e.g.: half4(2) + 2
786
10.9k
    if (rightType.isScalar() && is_vec_or_mat(leftType) &&
787
10.9k
        leftType.componentType().matches(rightType)) {
788
6.71k
        return simplify_componentwise(context, pos,
789
6.71k
                                      *left, op, *splat_scalar(context, *right, left->type()));
790
6.71k
    }
791
792
    // Perform constant folding on scalars against vectors/matrices, e.g.: 2 + half4(2)
793
4.27k
    if (leftType.isScalar() && is_vec_or_mat(rightType) &&
794
4.27k
        rightType.componentType().matches(leftType)) {
795
3.52k
        return simplify_componentwise(context, pos,
796
3.52k
                                      *splat_scalar(context, *left, right->type()), op, *right);
797
3.52k
    }
798
799
    // Perform constant folding on pairs of matrices, arrays or structs.
800
757
    if ((leftType.isMatrix() && rightType.isMatrix()) ||
801
757
        (leftType.isArray() && rightType.isArray()) ||
802
757
        (leftType.isStruct() && rightType.isStruct())) {
803
64
        return simplify_constant_equality(context, pos, *left, op, *right);
804
64
    }
805
806
    // We aren't able to constant-fold these expressions.
807
693
    return nullptr;
808
757
}
SkSLConstantFolder.cpp:SkSL::fold_two_constants(SkSL::Context const&, SkSL::Position, SkSL::Expression const*, SkSL::Operator, SkSL::Expression const*, SkSL::Type const&)
Line
Count
Source
673
59.8k
                                                      const Type& resultType) {
674
59.8k
    SkASSERT(Analysis::IsCompileTimeConstant(*left));
675
59.8k
    SkASSERT(Analysis::IsCompileTimeConstant(*right));
676
59.8k
    const Type& leftType = left->type();
677
59.8k
    const Type& rightType = right->type();
678
679
    // Handle pairs of integer literals.
680
59.8k
    if (left->isIntLiteral() && right->isIntLiteral()) {
681
40.4k
        using SKSL_UINT = uint64_t;
682
40.4k
        SKSL_INT leftVal  = left->as<Literal>().intValue();
683
40.4k
        SKSL_INT rightVal = right->as<Literal>().intValue();
684
685
        // Note that fold_expression returns null if the result would overflow its type.
686
40.4k
        #define RESULT(Op)   fold_expression(pos, (SKSL_INT)(leftVal) Op \
687
40.4k
                                                  (SKSL_INT)(rightVal), &resultType)
688
40.4k
        #define URESULT(Op)  fold_expression(pos, (SKSL_INT)((SKSL_UINT)(leftVal) Op \
689
40.4k
                                                  (SKSL_UINT)(rightVal)), &resultType)
690
40.4k
        switch (op.kind()) {
691
7.56k
            case Operator::Kind::PLUS:       return URESULT(+);
692
4.19k
            case Operator::Kind::MINUS:      return URESULT(-);
693
5.93k
            case Operator::Kind::STAR:       return URESULT(*);
694
3.21k
            case Operator::Kind::SLASH:
695
3.21k
                if (leftVal == std::numeric_limits<SKSL_INT>::min() && rightVal == -1) {
696
0
                    context.fErrors->error(pos, "arithmetic overflow");
697
0
                    return nullptr;
698
0
                }
699
3.21k
                return RESULT(/);
700
701
21
            case Operator::Kind::PERCENT:
702
21
                if (leftVal == std::numeric_limits<SKSL_INT>::min() && rightVal == -1) {
703
0
                    context.fErrors->error(pos, "arithmetic overflow");
704
0
                    return nullptr;
705
0
                }
706
21
                return RESULT(%);
707
708
4
            case Operator::Kind::BITWISEAND: return RESULT(&);
709
1
            case Operator::Kind::BITWISEOR:  return RESULT(|);
710
0
            case Operator::Kind::BITWISEXOR: return RESULT(^);
711
1.32k
            case Operator::Kind::EQEQ:       return RESULT(==);
712
1.15k
            case Operator::Kind::NEQ:        return RESULT(!=);
713
8.17k
            case Operator::Kind::GT:         return RESULT(>);
714
374
            case Operator::Kind::GTEQ:       return RESULT(>=);
715
7.24k
            case Operator::Kind::LT:         return RESULT(<);
716
357
            case Operator::Kind::LTEQ:       return RESULT(<=);
717
0
            case Operator::Kind::SHL:
718
0
                if (rightVal >= 0 && rightVal <= 31) {
719
                    // Left-shifting a negative (or really, any signed) value is undefined behavior
720
                    // in C++, but not in GLSL. Do the shift on unsigned values to avoid triggering
721
                    // an UBSAN error.
722
0
                    return URESULT(<<);
723
0
                }
724
0
                context.fErrors->error(pos, "shift value out of range");
725
0
                return nullptr;
726
727
0
            case Operator::Kind::SHR:
728
0
                if (rightVal >= 0 && rightVal <= 31) {
729
0
                    return RESULT(>>);
730
0
                }
731
0
                context.fErrors->error(pos, "shift value out of range");
732
0
                return nullptr;
733
734
864
            default:
735
864
                break;
736
40.4k
        }
737
864
        #undef RESULT
738
864
        #undef URESULT
739
740
864
        return nullptr;
741
40.4k
    }
742
743
    // Handle pairs of floating-point literals.
744
19.4k
    if (left->isFloatLiteral() && right->isFloatLiteral()) {
745
7.50k
        SKSL_FLOAT leftVal  = left->as<Literal>().floatValue();
746
7.50k
        SKSL_FLOAT rightVal = right->as<Literal>().floatValue();
747
748
7.50k
        #define RESULT(Op) fold_expression(pos, leftVal Op rightVal, &resultType)
749
7.50k
        switch (op.kind()) {
750
1.21k
            case Operator::Kind::PLUS:  return RESULT(+);
751
1.49k
            case Operator::Kind::MINUS: return RESULT(-);
752
1.17k
            case Operator::Kind::STAR:  return RESULT(*);
753
817
            case Operator::Kind::SLASH: return RESULT(/);
754
423
            case Operator::Kind::EQEQ:  return RESULT(==);
755
467
            case Operator::Kind::NEQ:   return RESULT(!=);
756
736
            case Operator::Kind::GT:    return RESULT(>);
757
87
            case Operator::Kind::GTEQ:  return RESULT(>=);
758
871
            case Operator::Kind::LT:    return RESULT(<);
759
114
            case Operator::Kind::LTEQ:  return RESULT(<=);
760
95
            default:                    break;
761
7.50k
        }
762
95
        #undef RESULT
763
764
95
        return nullptr;
765
7.50k
    }
766
767
    // Perform matrix multiplication.
768
11.9k
    if (op.kind() == Operator::Kind::STAR) {
769
5.85k
        if (leftType.isMatrix() && rightType.isMatrix()) {
770
165
            return simplify_matrix_times_matrix(context, pos, *left, *right);
771
165
        }
772
5.69k
        if (leftType.isVector() && rightType.isMatrix()) {
773
33
            return simplify_vector_times_matrix(context, pos, *left, *right);
774
33
        }
775
5.66k
        if (leftType.isMatrix() && rightType.isVector()) {
776
25
            return simplify_matrix_times_vector(context, pos, *left, *right);
777
25
        }
778
5.66k
    }
779
780
    // Perform constant folding on pairs of vectors/matrices.
781
11.7k
    if (is_vec_or_mat(leftType) && leftType.matches(rightType)) {
782
743
        return simplify_componentwise(context, pos, *left, op, *right);
783
743
    }
784
785
    // Perform constant folding on vectors/matrices against scalars, e.g.: half4(2) + 2
786
10.9k
    if (rightType.isScalar() && is_vec_or_mat(leftType) &&
787
10.9k
        leftType.componentType().matches(rightType)) {
788
6.71k
        return simplify_componentwise(context, pos,
789
6.71k
                                      *left, op, *splat_scalar(context, *right, left->type()));
790
6.71k
    }
791
792
    // Perform constant folding on scalars against vectors/matrices, e.g.: 2 + half4(2)
793
4.27k
    if (leftType.isScalar() && is_vec_or_mat(rightType) &&
794
4.27k
        rightType.componentType().matches(leftType)) {
795
3.52k
        return simplify_componentwise(context, pos,
796
3.52k
                                      *splat_scalar(context, *left, right->type()), op, *right);
797
3.52k
    }
798
799
    // Perform constant folding on pairs of matrices, arrays or structs.
800
757
    if ((leftType.isMatrix() && rightType.isMatrix()) ||
801
757
        (leftType.isArray() && rightType.isArray()) ||
802
757
        (leftType.isStruct() && rightType.isStruct())) {
803
64
        return simplify_constant_equality(context, pos, *left, op, *right);
804
64
    }
805
806
    // We aren't able to constant-fold these expressions.
807
693
    return nullptr;
808
757
}
Unexecuted instantiation: SkSLConstantFolder.cpp:SkSL::fold_two_constants(SkSL::Context const&, SkSL::Position, SkSL::Expression const*, SkSL::Operator, SkSL::Expression const*, SkSL::Type const&)
809
810
std::unique_ptr<Expression> ConstantFolder::Simplify(const Context& context,
811
                                                     Position pos,
812
                                                     const Expression& leftExpr,
813
                                                     Operator op,
814
                                                     const Expression& rightExpr,
815
118k
                                                     const Type& resultType) {
816
    // Replace constant variables with their literal values.
817
118k
    const Expression* left = GetConstantValueForVariable(leftExpr);
818
118k
    const Expression* right = GetConstantValueForVariable(rightExpr);
819
820
    // If this is the assignment operator, and both sides are the same trivial expression, this is
821
    // self-assignment (i.e., `var = var`) and can be reduced to just a variable reference (`var`).
822
    // This can happen when other parts of the assignment are optimized away.
823
118k
    if (op.kind() == Operator::Kind::EQ && Analysis::IsSameExpressionTree(*left, *right)) {
824
434
        return right->clone(pos);
825
434
    }
826
827
    // Simplify the expression when both sides are constant Boolean literals.
828
117k
    if (left->isBoolLiteral() && right->isBoolLiteral()) {
829
1.37k
        bool leftVal  = left->as<Literal>().boolValue();
830
1.37k
        bool rightVal = right->as<Literal>().boolValue();
831
1.37k
        bool result;
832
1.37k
        switch (op.kind()) {
833
462
            case Operator::Kind::LOGICALAND: result = leftVal && rightVal; break;
834
319
            case Operator::Kind::LOGICALOR:  result = leftVal || rightVal; break;
835
151
            case Operator::Kind::LOGICALXOR: result = leftVal ^  rightVal; break;
836
229
            case Operator::Kind::EQEQ:       result = leftVal == rightVal; break;
837
110
            case Operator::Kind::NEQ:        result = leftVal != rightVal; break;
838
105
            default: return nullptr;
839
1.37k
        }
840
1.27k
        return Literal::MakeBool(context, pos, result);
841
1.37k
    }
842
843
    // If the left side is a Boolean literal, apply short-circuit optimizations.
844
116k
    if (left->isBoolLiteral()) {
845
1.06k
        return short_circuit_boolean(pos, *left, op, *right);
846
1.06k
    }
847
848
    // If the right side is a Boolean literal...
849
115k
    if (right->isBoolLiteral()) {
850
        // ... and the left side has no side effects...
851
2.07k
        if (!Analysis::HasSideEffects(*left)) {
852
            // We can reverse the expressions and short-circuit optimizations are still valid.
853
1.68k
            return short_circuit_boolean(pos, *right, op, *left);
854
1.68k
        }
855
856
        // We can't use short-circuiting, but we can still optimize away no-op Boolean expressions.
857
390
        return eliminate_no_op_boolean(pos, *left, op, *right);
858
2.07k
    }
859
860
113k
    if (op.kind() == Operator::Kind::EQEQ && Analysis::IsSameExpressionTree(*left, *right)) {
861
        // With == comparison, if both sides are the same trivial expression, this is self-
862
        // comparison and is always true. (We are not concerned with NaN.)
863
208
        return Literal::MakeBool(context, pos, /*value=*/true);
864
208
    }
865
866
112k
    if (op.kind() == Operator::Kind::NEQ && Analysis::IsSameExpressionTree(*left, *right)) {
867
        // With != comparison, if both sides are the same trivial expression, this is self-
868
        // comparison and is always false. (We are not concerned with NaN.)
869
219
        return Literal::MakeBool(context, pos, /*value=*/false);
870
219
    }
871
872
112k
    if (error_on_divide_by_zero(context, pos, op, *right)) {
873
3.94k
        return nullptr;
874
3.94k
    }
875
876
    // Perform full constant folding when both sides are compile-time constants.
877
108k
    bool leftSideIsConstant = Analysis::IsCompileTimeConstant(*left);
878
108k
    bool rightSideIsConstant = Analysis::IsCompileTimeConstant(*right);
879
108k
    if (leftSideIsConstant && rightSideIsConstant) {
880
59.8k
        return fold_two_constants(context, pos, left, op, right, resultType);
881
59.8k
    }
882
883
48.7k
    if (context.fConfig->fSettings.fOptimize) {
884
        // If just one side is constant, we might still be able to simplify arithmetic expressions
885
        // like `x * 1`, `x *= 1`, `x + 0`, `x * 0`, `0 / x`, etc.
886
48.7k
        if (leftSideIsConstant || rightSideIsConstant) {
887
25.4k
            if (std::unique_ptr<Expression> expr = simplify_arithmetic(context, pos, *left, op,
888
25.4k
                                                                       *right, resultType)) {
889
10.6k
                return expr;
890
10.6k
            }
891
25.4k
        }
892
893
        // We can simplify some forms of matrix division even when neither side is constant.
894
38.1k
        if (std::unique_ptr<Expression> expr = simplify_matrix_division(context, pos, *left, op,
895
38.1k
                                                                        *right, resultType)) {
896
301
            return expr;
897
301
        }
898
38.1k
    }
899
900
    // We aren't able to constant-fold.
901
37.8k
    return nullptr;
902
48.7k
}
903
904
}  // namespace SkSL