Coverage Report

Created: 2025-11-16 06:14

next uncovered line (L), next uncovered region (R), next uncovered branch (B)
/src/spirv-tools/source/opt/const_folding_rules.cpp
Line
Count
Source
1
// Copyright (c) 2018 Google LLC
2
//
3
// Licensed under the Apache License, Version 2.0 (the "License");
4
// you may not use this file except in compliance with the License.
5
// You may obtain a copy of the License at
6
//
7
//     http://www.apache.org/licenses/LICENSE-2.0
8
//
9
// Unless required by applicable law or agreed to in writing, software
10
// distributed under the License is distributed on an "AS IS" BASIS,
11
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
// See the License for the specific language governing permissions and
13
// limitations under the License.
14
15
#include "source/opt/const_folding_rules.h"
16
17
#include "source/opt/ir_context.h"
18
19
namespace spvtools {
20
namespace opt {
21
namespace {
22
constexpr uint32_t kExtractCompositeIdInIdx = 0;
23
24
// Returns a constants with the value NaN of the given type.  Only works for
25
// 32-bit and 64-bit float point types.  Returns |nullptr| if an error occurs.
26
const analysis::Constant* GetNan(const analysis::Type* type,
27
1.12k
                                 analysis::ConstantManager* const_mgr) {
28
1.12k
  const analysis::Float* float_type = type->AsFloat();
29
1.12k
  if (float_type == nullptr) {
30
0
    return nullptr;
31
0
  }
32
33
1.12k
  switch (float_type->width()) {
34
1.12k
    case 32:
35
1.12k
      return const_mgr->GetFloatConst(std::numeric_limits<float>::quiet_NaN());
36
0
    case 64:
37
0
      return const_mgr->GetDoubleConst(
38
0
          std::numeric_limits<double>::quiet_NaN());
39
0
    default:
40
0
      return nullptr;
41
1.12k
  }
42
1.12k
}
43
44
// Returns a constants with the value INF of the given type.  Only works for
45
// 32-bit and 64-bit float point types.  Returns |nullptr| if an error occurs.
46
const analysis::Constant* GetInf(const analysis::Type* type,
47
965
                                 analysis::ConstantManager* const_mgr) {
48
965
  const analysis::Float* float_type = type->AsFloat();
49
965
  if (float_type == nullptr) {
50
0
    return nullptr;
51
0
  }
52
53
965
  switch (float_type->width()) {
54
965
    case 32:
55
965
      return const_mgr->GetFloatConst(std::numeric_limits<float>::infinity());
56
0
    case 64:
57
0
      return const_mgr->GetDoubleConst(std::numeric_limits<double>::infinity());
58
0
    default:
59
0
      return nullptr;
60
965
  }
61
965
}
62
63
// Returns true if |type| is Float or a vector of Float.
64
80
bool HasFloatingPoint(const analysis::Type* type) {
65
80
  if (type->AsFloat()) {
66
0
    return true;
67
80
  } else if (const analysis::Vector* vec_type = type->AsVector()) {
68
80
    return vec_type->element_type()->AsFloat() != nullptr;
69
80
  }
70
71
0
  return false;
72
80
}
73
74
// Returns a constants with the value |-val| of the given type.  Only works for
75
// 32-bit and 64-bit float point types.  Returns |nullptr| if an error occurs.
76
const analysis::Constant* NegateFPConst(const analysis::Type* result_type,
77
                                        const analysis::Constant* val,
78
712
                                        analysis::ConstantManager* const_mgr) {
79
712
  const analysis::Float* float_type = result_type->AsFloat();
80
712
  assert(float_type != nullptr);
81
712
  if (float_type->width() == 32) {
82
712
    float fa = val->GetFloat();
83
712
    return const_mgr->GetFloatConst(-fa);
84
712
  } else if (float_type->width() == 64) {
85
0
    double da = val->GetDouble();
86
0
    return const_mgr->GetDoubleConst(-da);
87
0
  }
88
0
  return nullptr;
89
712
}
90
91
// Returns a constants with the value |-val| of the given type.
92
const analysis::Constant* NegateIntConst(const analysis::Type* result_type,
93
                                         const analysis::Constant* val,
94
1.24k
                                         analysis::ConstantManager* const_mgr) {
95
1.24k
  const analysis::Integer* int_type = result_type->AsInteger();
96
1.24k
  assert(int_type != nullptr);
97
98
1.24k
  if (val->AsNullConstant()) {
99
2
    return val;
100
2
  }
101
102
1.24k
  uint64_t new_value = static_cast<uint64_t>(-val->GetSignExtendedValue());
103
1.24k
  return const_mgr->GetIntConst(new_value, int_type->width(),
104
1.24k
                                int_type->IsSigned());
105
1.24k
}
106
107
// Folds an OpcompositeExtract where input is a composite constant.
108
11.6k
ConstantFoldingRule FoldExtractWithConstants() {
109
11.6k
  return [](IRContext* context, Instruction* inst,
110
11.6k
            const std::vector<const analysis::Constant*>& constants)
111
336k
             -> const analysis::Constant* {
112
336k
    const analysis::Constant* c = constants[kExtractCompositeIdInIdx];
113
336k
    if (c == nullptr) {
114
301k
      return nullptr;
115
301k
    }
116
117
70.7k
    for (uint32_t i = 1; i < inst->NumInOperands(); ++i) {
118
35.4k
      uint32_t element_index = inst->GetSingleWordInOperand(i);
119
35.4k
      if (c->AsNullConstant()) {
120
        // Return Null for the return type.
121
141
        analysis::ConstantManager* const_mgr = context->get_constant_mgr();
122
141
        analysis::TypeManager* type_mgr = context->get_type_mgr();
123
141
        return const_mgr->GetConstant(type_mgr->GetType(inst->type_id()), {});
124
141
      }
125
126
35.2k
      auto cc = c->AsCompositeConstant();
127
35.2k
      assert(cc != nullptr);
128
35.2k
      auto components = cc->GetComponents();
129
      // Protect against invalid IR.  Refuse to fold if the index is out
130
      // of bounds.
131
35.2k
      if (element_index >= components.size()) return nullptr;
132
35.2k
      c = components[element_index];
133
35.2k
    }
134
35.2k
    return c;
135
35.4k
  };
136
11.6k
}
137
138
// Folds an OpcompositeInsert where input is a composite constant.
139
11.6k
ConstantFoldingRule FoldInsertWithConstants() {
140
11.6k
  return [](IRContext* context, Instruction* inst,
141
11.6k
            const std::vector<const analysis::Constant*>& constants)
142
73.8k
             -> const analysis::Constant* {
143
73.8k
    analysis::ConstantManager* const_mgr = context->get_constant_mgr();
144
73.8k
    const analysis::Constant* object = constants[0];
145
73.8k
    const analysis::Constant* composite = constants[1];
146
73.8k
    if (object == nullptr || composite == nullptr) {
147
69.1k
      return nullptr;
148
69.1k
    }
149
150
    // If there is more than 1 index, then each additional constant used by the
151
    // index will need to be recreated to use the inserted object.
152
4.71k
    std::vector<const analysis::Constant*> chain;
153
4.71k
    std::vector<const analysis::Constant*> components;
154
4.71k
    const analysis::Type* type = nullptr;
155
4.71k
    const uint32_t final_index = (inst->NumInOperands() - 1);
156
157
    // Work down hierarchy of all indexes
158
9.42k
    for (uint32_t i = 2; i < inst->NumInOperands(); ++i) {
159
4.71k
      type = composite->type();
160
161
4.71k
      if (composite->AsNullConstant()) {
162
        // Make new composite so it can be inserted in the index with the
163
        // non-null value
164
77
        if (const auto new_composite =
165
77
                const_mgr->GetNullCompositeConstant(type)) {
166
          // Keep track of any indexes along the way to last index
167
77
          if (i != final_index) {
168
0
            chain.push_back(new_composite);
169
0
          }
170
77
          components = new_composite->AsCompositeConstant()->GetComponents();
171
77
        } else {
172
          // Unsupported input type (such as structs)
173
0
          return nullptr;
174
0
        }
175
4.63k
      } else {
176
        // Keep track of any indexes along the way to last index
177
4.63k
        if (i != final_index) {
178
0
          chain.push_back(composite);
179
0
        }
180
4.63k
        components = composite->AsCompositeConstant()->GetComponents();
181
4.63k
      }
182
4.71k
      const uint32_t index = inst->GetSingleWordInOperand(i);
183
4.71k
      composite = components[index];
184
4.71k
    }
185
186
    // Final index in hierarchy is inserted with new object.
187
4.71k
    const uint32_t final_operand = inst->GetSingleWordInOperand(final_index);
188
4.71k
    std::vector<uint32_t> ids;
189
14.2k
    for (size_t i = 0; i < components.size(); i++) {
190
9.51k
      const analysis::Constant* constant =
191
9.51k
          (i == final_operand) ? object : components[i];
192
9.51k
      Instruction* member_inst = const_mgr->GetDefiningInstruction(constant);
193
9.51k
      ids.push_back(member_inst->result_id());
194
9.51k
    }
195
4.71k
    const analysis::Constant* new_constant = const_mgr->GetConstant(type, ids);
196
197
    // Work backwards up the chain and replace each index with new constant.
198
4.71k
    for (size_t i = chain.size(); i > 0; i--) {
199
      // Need to insert any previous instruction into the module first.
200
      // Can't just insert in types_values_begin() because it will move above
201
      // where the types are declared.
202
      // Can't compare with location of inst because not all new added
203
      // instructions are added to types_values_
204
0
      auto iter = context->types_values_end();
205
0
      Module::inst_iterator* pos = &iter;
206
0
      const_mgr->BuildInstructionAndAddToModule(new_constant, pos);
207
208
0
      composite = chain[i - 1];
209
0
      components = composite->AsCompositeConstant()->GetComponents();
210
0
      type = composite->type();
211
0
      ids.clear();
212
0
      for (size_t k = 0; k < components.size(); k++) {
213
0
        const uint32_t index =
214
0
            inst->GetSingleWordInOperand(1 + static_cast<uint32_t>(i));
215
0
        const analysis::Constant* constant =
216
0
            (k == index) ? new_constant : components[k];
217
0
        const uint32_t constant_id =
218
0
            const_mgr->FindDeclaredConstant(constant, 0);
219
0
        ids.push_back(constant_id);
220
0
      }
221
0
      new_constant = const_mgr->GetConstant(type, ids);
222
0
    }
223
224
    // If multiple constants were created, only need to return the top index.
225
4.71k
    return new_constant;
226
4.71k
  };
227
11.6k
}
228
229
11.6k
ConstantFoldingRule FoldVectorShuffleWithConstants() {
230
11.6k
  return [](IRContext* context, Instruction* inst,
231
11.6k
            const std::vector<const analysis::Constant*>& constants)
232
23.9k
             -> const analysis::Constant* {
233
23.9k
    assert(inst->opcode() == spv::Op::OpVectorShuffle);
234
23.9k
    const analysis::Constant* c1 = constants[0];
235
23.9k
    const analysis::Constant* c2 = constants[1];
236
23.9k
    if (c1 == nullptr || c2 == nullptr) {
237
23.6k
      return nullptr;
238
23.6k
    }
239
240
295
    analysis::ConstantManager* const_mgr = context->get_constant_mgr();
241
295
    const analysis::Type* element_type = c1->type()->AsVector()->element_type();
242
243
295
    std::vector<const analysis::Constant*> c1_components;
244
295
    if (const analysis::VectorConstant* vec_const = c1->AsVectorConstant()) {
245
243
      c1_components = vec_const->GetComponents();
246
243
    } else {
247
52
      assert(c1->AsNullConstant());
248
52
      const analysis::Constant* element =
249
52
          const_mgr->GetConstant(element_type, {});
250
52
      c1_components.resize(c1->type()->AsVector()->element_count(), element);
251
52
    }
252
295
    std::vector<const analysis::Constant*> c2_components;
253
295
    if (const analysis::VectorConstant* vec_const = c2->AsVectorConstant()) {
254
205
      c2_components = vec_const->GetComponents();
255
205
    } else {
256
90
      assert(c2->AsNullConstant());
257
90
      const analysis::Constant* element =
258
90
          const_mgr->GetConstant(element_type, {});
259
90
      c2_components.resize(c2->type()->AsVector()->element_count(), element);
260
90
    }
261
262
295
    std::vector<uint32_t> ids;
263
295
    const uint32_t undef_literal_value = 0xffffffff;
264
607
    for (uint32_t i = 2; i < inst->NumInOperands(); ++i) {
265
474
      uint32_t index = inst->GetSingleWordInOperand(i);
266
474
      if (index == undef_literal_value) {
267
        // Don't fold shuffle with undef literal value.
268
162
        return nullptr;
269
312
      } else if (index < c1_components.size()) {
270
223
        Instruction* member_inst =
271
223
            const_mgr->GetDefiningInstruction(c1_components[index]);
272
223
        ids.push_back(member_inst->result_id());
273
223
      } else {
274
89
        Instruction* member_inst = const_mgr->GetDefiningInstruction(
275
89
            c2_components[index - c1_components.size()]);
276
89
        ids.push_back(member_inst->result_id());
277
89
      }
278
474
    }
279
280
133
    analysis::TypeManager* type_mgr = context->get_type_mgr();
281
133
    return const_mgr->GetConstant(type_mgr->GetType(inst->type_id()), ids);
282
295
  };
283
11.6k
}
284
285
11.6k
ConstantFoldingRule FoldVectorTimesScalar() {
286
11.6k
  return [](IRContext* context, Instruction* inst,
287
11.6k
            const std::vector<const analysis::Constant*>& constants)
288
414k
             -> const analysis::Constant* {
289
414k
    assert(inst->opcode() == spv::Op::OpVectorTimesScalar);
290
414k
    analysis::ConstantManager* const_mgr = context->get_constant_mgr();
291
414k
    analysis::TypeManager* type_mgr = context->get_type_mgr();
292
293
414k
    if (!inst->IsFloatingPointFoldingAllowed()) {
294
80
      if (HasFloatingPoint(type_mgr->GetType(inst->type_id()))) {
295
80
        return nullptr;
296
80
      }
297
80
    }
298
299
414k
    const analysis::Constant* c1 = constants[0];
300
414k
    const analysis::Constant* c2 = constants[1];
301
302
414k
    if (c1 && c1->IsZero()) {
303
1.40k
      return c1;
304
1.40k
    }
305
306
412k
    if (c2 && c2->IsZero()) {
307
      // Get or create the NullConstant for this type.
308
5.65k
      std::vector<uint32_t> ids;
309
5.65k
      return const_mgr->GetConstant(type_mgr->GetType(inst->type_id()), ids);
310
5.65k
    }
311
312
407k
    if (c1 == nullptr || c2 == nullptr) {
313
377k
      return nullptr;
314
377k
    }
315
316
    // Check result type.
317
29.7k
    const analysis::Type* result_type = type_mgr->GetType(inst->type_id());
318
29.7k
    const analysis::Vector* vector_type = result_type->AsVector();
319
29.7k
    assert(vector_type != nullptr);
320
29.7k
    const analysis::Type* element_type = vector_type->element_type();
321
29.7k
    assert(element_type != nullptr);
322
29.7k
    const analysis::Float* float_type = element_type->AsFloat();
323
29.7k
    assert(float_type != nullptr);
324
325
    // Check types of c1 and c2.
326
29.7k
    assert(c1->type()->AsVector() == vector_type);
327
29.7k
    assert(c1->type()->AsVector()->element_type() == element_type &&
328
29.7k
           c2->type() == element_type);
329
330
    // Get a float vector that is the result of vector-times-scalar.
331
29.7k
    std::vector<const analysis::Constant*> c1_components =
332
29.7k
        c1->GetVectorComponents(const_mgr);
333
29.7k
    std::vector<uint32_t> ids;
334
29.7k
    if (float_type->width() == 32) {
335
29.7k
      float scalar = c2->GetFloat();
336
147k
      for (uint32_t i = 0; i < c1_components.size(); ++i) {
337
118k
        utils::FloatProxy<float> result(c1_components[i]->GetFloat() * scalar);
338
118k
        std::vector<uint32_t> words = result.GetWords();
339
118k
        const analysis::Constant* new_elem =
340
118k
            const_mgr->GetConstant(float_type, words);
341
118k
        ids.push_back(const_mgr->GetDefiningInstruction(new_elem)->result_id());
342
118k
      }
343
29.7k
      return const_mgr->GetConstant(vector_type, ids);
344
29.7k
    } else if (float_type->width() == 64) {
345
0
      double scalar = c2->GetDouble();
346
0
      for (uint32_t i = 0; i < c1_components.size(); ++i) {
347
0
        utils::FloatProxy<double> result(c1_components[i]->GetDouble() *
348
0
                                         scalar);
349
0
        std::vector<uint32_t> words = result.GetWords();
350
0
        const analysis::Constant* new_elem =
351
0
            const_mgr->GetConstant(float_type, words);
352
0
        ids.push_back(const_mgr->GetDefiningInstruction(new_elem)->result_id());
353
0
      }
354
0
      return const_mgr->GetConstant(vector_type, ids);
355
0
    }
356
0
    return nullptr;
357
29.7k
  };
358
11.6k
}
359
360
// Returns to the constant that results from tranposing |matrix|. The result
361
// will have type |result_type|, and |matrix| must exist in |context|. The
362
// result constant will also exist in |context|.
363
const analysis::Constant* TransposeMatrix(const analysis::Constant* matrix,
364
                                          analysis::Matrix* result_type,
365
0
                                          IRContext* context) {
366
0
  analysis::ConstantManager* const_mgr = context->get_constant_mgr();
367
0
  if (matrix->AsNullConstant() != nullptr) {
368
0
    return const_mgr->GetNullCompositeConstant(result_type);
369
0
  }
370
371
0
  const auto& columns = matrix->AsMatrixConstant()->GetComponents();
372
0
  uint32_t number_of_rows = columns[0]->type()->AsVector()->element_count();
373
374
  // Collect the ids of the elements in their new positions.
375
0
  std::vector<std::vector<uint32_t>> result_elements(number_of_rows);
376
0
  for (const analysis::Constant* column : columns) {
377
0
    if (column->AsNullConstant()) {
378
0
      column = const_mgr->GetNullCompositeConstant(column->type());
379
0
    }
380
0
    const auto& column_components = column->AsVectorConstant()->GetComponents();
381
382
0
    for (uint32_t row = 0; row < number_of_rows; ++row) {
383
0
      result_elements[row].push_back(
384
0
          const_mgr->GetDefiningInstruction(column_components[row])
385
0
              ->result_id());
386
0
    }
387
0
  }
388
389
  // Create the constant for each row in the result, and collect the ids.
390
0
  std::vector<uint32_t> result_columns(number_of_rows);
391
0
  for (uint32_t col = 0; col < number_of_rows; ++col) {
392
0
    auto* element = const_mgr->GetConstant(result_type->element_type(),
393
0
                                           result_elements[col]);
394
0
    result_columns[col] =
395
0
        const_mgr->GetDefiningInstruction(element)->result_id();
396
0
  }
397
398
  // Create the matrix constant from the row ids, and return it.
399
0
  return const_mgr->GetConstant(result_type, result_columns);
400
0
}
401
402
const analysis::Constant* FoldTranspose(
403
    IRContext* context, Instruction* inst,
404
0
    const std::vector<const analysis::Constant*>& constants) {
405
0
  assert(inst->opcode() == spv::Op::OpTranspose);
406
407
0
  analysis::TypeManager* type_mgr = context->get_type_mgr();
408
0
  if (!inst->IsFloatingPointFoldingAllowed()) {
409
0
    if (HasFloatingPoint(type_mgr->GetType(inst->type_id()))) {
410
0
      return nullptr;
411
0
    }
412
0
  }
413
414
0
  const analysis::Constant* matrix = constants[0];
415
0
  if (matrix == nullptr) {
416
0
    return nullptr;
417
0
  }
418
419
0
  auto* result_type = type_mgr->GetType(inst->type_id());
420
0
  return TransposeMatrix(matrix, result_type->AsMatrix(), context);
421
0
}
422
423
11.6k
ConstantFoldingRule FoldVectorTimesMatrix() {
424
11.6k
  return [](IRContext* context, Instruction* inst,
425
11.6k
            const std::vector<const analysis::Constant*>& constants)
426
11.6k
             -> const analysis::Constant* {
427
2.84k
    assert(inst->opcode() == spv::Op::OpVectorTimesMatrix);
428
2.84k
    analysis::ConstantManager* const_mgr = context->get_constant_mgr();
429
2.84k
    analysis::TypeManager* type_mgr = context->get_type_mgr();
430
431
2.84k
    if (!inst->IsFloatingPointFoldingAllowed()) {
432
0
      if (HasFloatingPoint(type_mgr->GetType(inst->type_id()))) {
433
0
        return nullptr;
434
0
      }
435
0
    }
436
437
2.84k
    const analysis::Constant* c1 = constants[0];
438
2.84k
    const analysis::Constant* c2 = constants[1];
439
440
2.84k
    if (c1 == nullptr || c2 == nullptr) {
441
2.81k
      return nullptr;
442
2.81k
    }
443
444
    // Check result type.
445
24
    const analysis::Type* result_type = type_mgr->GetType(inst->type_id());
446
24
    const analysis::Vector* vector_type = result_type->AsVector();
447
24
    assert(vector_type != nullptr);
448
24
    const analysis::Type* element_type = vector_type->element_type();
449
24
    assert(element_type != nullptr);
450
24
    const analysis::Float* float_type = element_type->AsFloat();
451
24
    assert(float_type != nullptr);
452
453
    // Check types of c1 and c2.
454
24
    assert(c1->type()->AsVector() == vector_type);
455
24
    assert(c1->type()->AsVector()->element_type() == element_type &&
456
24
           c2->type()->AsMatrix()->element_type() == vector_type);
457
458
24
    uint32_t resultVectorSize = result_type->AsVector()->element_count();
459
24
    std::vector<uint32_t> ids;
460
461
24
    if ((c1 && c1->IsZero()) || (c2 && c2->IsZero())) {
462
11
      std::vector<uint32_t> words(float_type->width() / 32, 0);
463
33
      for (uint32_t i = 0; i < resultVectorSize; ++i) {
464
22
        const analysis::Constant* new_elem =
465
22
            const_mgr->GetConstant(float_type, words);
466
22
        ids.push_back(const_mgr->GetDefiningInstruction(new_elem)->result_id());
467
22
      }
468
11
      return const_mgr->GetConstant(vector_type, ids);
469
11
    }
470
471
    // Get a float vector that is the result of vector-times-matrix.
472
13
    std::vector<const analysis::Constant*> c1_components =
473
13
        c1->GetVectorComponents(const_mgr);
474
13
    std::vector<const analysis::Constant*> c2_components =
475
13
        c2->AsMatrixConstant()->GetComponents();
476
477
13
    if (float_type->width() == 32) {
478
39
      for (uint32_t i = 0; i < resultVectorSize; ++i) {
479
26
        float result_scalar = 0.0f;
480
26
        if (!c2_components[i]->AsNullConstant()) {
481
18
          const analysis::VectorConstant* c2_vec =
482
18
              c2_components[i]->AsVectorConstant();
483
54
          for (uint32_t j = 0; j < c2_vec->GetComponents().size(); ++j) {
484
36
            float c1_scalar = c1_components[j]->GetFloat();
485
36
            float c2_scalar = c2_vec->GetComponents()[j]->GetFloat();
486
36
            result_scalar += c1_scalar * c2_scalar;
487
36
          }
488
18
        }
489
26
        utils::FloatProxy<float> result(result_scalar);
490
26
        std::vector<uint32_t> words = result.GetWords();
491
26
        const analysis::Constant* new_elem =
492
26
            const_mgr->GetConstant(float_type, words);
493
26
        ids.push_back(const_mgr->GetDefiningInstruction(new_elem)->result_id());
494
26
      }
495
13
      return const_mgr->GetConstant(vector_type, ids);
496
13
    } else if (float_type->width() == 64) {
497
0
      for (uint32_t i = 0; i < c2_components.size(); ++i) {
498
0
        double result_scalar = 0.0;
499
0
        if (!c2_components[i]->AsNullConstant()) {
500
0
          const analysis::VectorConstant* c2_vec =
501
0
              c2_components[i]->AsVectorConstant();
502
0
          for (uint32_t j = 0; j < c2_vec->GetComponents().size(); ++j) {
503
0
            double c1_scalar = c1_components[j]->GetDouble();
504
0
            double c2_scalar = c2_vec->GetComponents()[j]->GetDouble();
505
0
            result_scalar += c1_scalar * c2_scalar;
506
0
          }
507
0
        }
508
0
        utils::FloatProxy<double> result(result_scalar);
509
0
        std::vector<uint32_t> words = result.GetWords();
510
0
        const analysis::Constant* new_elem =
511
0
            const_mgr->GetConstant(float_type, words);
512
0
        ids.push_back(const_mgr->GetDefiningInstruction(new_elem)->result_id());
513
0
      }
514
0
      return const_mgr->GetConstant(vector_type, ids);
515
0
    }
516
0
    return nullptr;
517
13
  };
518
11.6k
}
519
520
11.6k
ConstantFoldingRule FoldMatrixTimesVector() {
521
11.6k
  return [](IRContext* context, Instruction* inst,
522
11.6k
            const std::vector<const analysis::Constant*>& constants)
523
11.6k
             -> const analysis::Constant* {
524
139
    assert(inst->opcode() == spv::Op::OpMatrixTimesVector);
525
139
    analysis::ConstantManager* const_mgr = context->get_constant_mgr();
526
139
    analysis::TypeManager* type_mgr = context->get_type_mgr();
527
528
139
    if (!inst->IsFloatingPointFoldingAllowed()) {
529
0
      if (HasFloatingPoint(type_mgr->GetType(inst->type_id()))) {
530
0
        return nullptr;
531
0
      }
532
0
    }
533
534
139
    const analysis::Constant* c1 = constants[0];
535
139
    const analysis::Constant* c2 = constants[1];
536
537
139
    if (c1 == nullptr || c2 == nullptr) {
538
137
      return nullptr;
539
137
    }
540
541
    // Check result type.
542
2
    const analysis::Type* result_type = type_mgr->GetType(inst->type_id());
543
2
    const analysis::Vector* vector_type = result_type->AsVector();
544
2
    assert(vector_type != nullptr);
545
2
    const analysis::Type* element_type = vector_type->element_type();
546
2
    assert(element_type != nullptr);
547
2
    const analysis::Float* float_type = element_type->AsFloat();
548
2
    assert(float_type != nullptr);
549
550
    // Check types of c1 and c2.
551
2
    assert(c1->type()->AsMatrix()->element_type() == vector_type);
552
2
    assert(c2->type()->AsVector()->element_type() == element_type);
553
554
2
    uint32_t resultVectorSize = result_type->AsVector()->element_count();
555
2
    std::vector<uint32_t> ids;
556
557
2
    if ((c1 && c1->IsZero()) || (c2 && c2->IsZero())) {
558
2
      std::vector<uint32_t> words(float_type->width() / 32, 0);
559
6
      for (uint32_t i = 0; i < resultVectorSize; ++i) {
560
4
        const analysis::Constant* new_elem =
561
4
            const_mgr->GetConstant(float_type, words);
562
4
        ids.push_back(const_mgr->GetDefiningInstruction(new_elem)->result_id());
563
4
      }
564
2
      return const_mgr->GetConstant(vector_type, ids);
565
2
    }
566
567
    // Get a float vector that is the result of matrix-times-vector.
568
0
    std::vector<const analysis::Constant*> c1_components =
569
0
        c1->AsMatrixConstant()->GetComponents();
570
0
    std::vector<const analysis::Constant*> c2_components =
571
0
        c2->GetVectorComponents(const_mgr);
572
573
0
    if (float_type->width() == 32) {
574
0
      for (uint32_t i = 0; i < resultVectorSize; ++i) {
575
0
        float result_scalar = 0.0f;
576
0
        for (uint32_t j = 0; j < c1_components.size(); ++j) {
577
0
          if (!c1_components[j]->AsNullConstant()) {
578
0
            float c1_scalar = c1_components[j]
579
0
                                  ->AsVectorConstant()
580
0
                                  ->GetComponents()[i]
581
0
                                  ->GetFloat();
582
0
            float c2_scalar = c2_components[j]->GetFloat();
583
0
            result_scalar += c1_scalar * c2_scalar;
584
0
          }
585
0
        }
586
0
        utils::FloatProxy<float> result(result_scalar);
587
0
        std::vector<uint32_t> words = result.GetWords();
588
0
        const analysis::Constant* new_elem =
589
0
            const_mgr->GetConstant(float_type, words);
590
0
        ids.push_back(const_mgr->GetDefiningInstruction(new_elem)->result_id());
591
0
      }
592
0
      return const_mgr->GetConstant(vector_type, ids);
593
0
    } else if (float_type->width() == 64) {
594
0
      for (uint32_t i = 0; i < resultVectorSize; ++i) {
595
0
        double result_scalar = 0.0;
596
0
        for (uint32_t j = 0; j < c1_components.size(); ++j) {
597
0
          if (!c1_components[j]->AsNullConstant()) {
598
0
            double c1_scalar = c1_components[j]
599
0
                                   ->AsVectorConstant()
600
0
                                   ->GetComponents()[i]
601
0
                                   ->GetDouble();
602
0
            double c2_scalar = c2_components[j]->GetDouble();
603
0
            result_scalar += c1_scalar * c2_scalar;
604
0
          }
605
0
        }
606
0
        utils::FloatProxy<double> result(result_scalar);
607
0
        std::vector<uint32_t> words = result.GetWords();
608
0
        const analysis::Constant* new_elem =
609
0
            const_mgr->GetConstant(float_type, words);
610
0
        ids.push_back(const_mgr->GetDefiningInstruction(new_elem)->result_id());
611
0
      }
612
0
      return const_mgr->GetConstant(vector_type, ids);
613
0
    }
614
0
    return nullptr;
615
0
  };
616
11.6k
}
617
618
11.6k
ConstantFoldingRule FoldCompositeWithConstants() {
619
  // Folds an OpCompositeConstruct where all of the inputs are constants to a
620
  // constant.  A new constant is created if necessary.
621
11.6k
  return [](IRContext* context, Instruction* inst,
622
11.6k
            const std::vector<const analysis::Constant*>& constants)
623
223k
             -> const analysis::Constant* {
624
223k
    analysis::ConstantManager* const_mgr = context->get_constant_mgr();
625
223k
    analysis::TypeManager* type_mgr = context->get_type_mgr();
626
223k
    const analysis::Type* new_type = type_mgr->GetType(inst->type_id());
627
223k
    Instruction* type_inst =
628
223k
        context->get_def_use_mgr()->GetDef(inst->type_id());
629
630
223k
    std::vector<uint32_t> ids;
631
536k
    for (uint32_t i = 0; i < constants.size(); ++i) {
632
396k
      const analysis::Constant* element_const = constants[i];
633
396k
      if (element_const == nullptr) {
634
83.4k
        return nullptr;
635
83.4k
      }
636
637
313k
      uint32_t component_type_id = 0;
638
313k
      if (type_inst->opcode() == spv::Op::OpTypeStruct) {
639
20.7k
        component_type_id = type_inst->GetSingleWordInOperand(i);
640
292k
      } else if (type_inst->opcode() == spv::Op::OpTypeArray) {
641
67
        component_type_id = type_inst->GetSingleWordInOperand(0);
642
67
      }
643
644
313k
      uint32_t element_id =
645
313k
          const_mgr->FindDeclaredConstant(element_const, component_type_id);
646
313k
      if (element_id == 0) {
647
118
        return nullptr;
648
118
      }
649
313k
      ids.push_back(element_id);
650
313k
    }
651
139k
    return const_mgr->GetConstant(new_type, ids);
652
223k
  };
653
11.6k
}
654
655
// The interface for a function that returns the result of applying a scalar
656
// floating-point binary operation on |a| and |b|.  The type of the return value
657
// will be |type|.  The input constants must also be of type |type|.
658
using UnaryScalarFoldingRule = std::function<const analysis::Constant*(
659
    const analysis::Type* result_type, const analysis::Constant* a,
660
    analysis::ConstantManager*)>;
661
662
// The interface for a function that returns the result of applying a scalar
663
// floating-point binary operation on |a| and |b|.  The type of the return value
664
// will be |type|.  The input constants must also be of type |type|.
665
using BinaryScalarFoldingRule = std::function<const analysis::Constant*(
666
    const analysis::Type* result_type, const analysis::Constant* a,
667
    const analysis::Constant* b, analysis::ConstantManager*)>;
668
669
// Returns a |ConstantFoldingRule| that folds unary scalar ops
670
// using |scalar_rule| and unary vectors ops by applying
671
// |scalar_rule| to the elements of the vector.  The |ConstantFoldingRule|
672
// that is returned assumes that |constants| contains 1 entry.  If they are
673
// not |nullptr|, then their type is either |Float| or |Integer| or a |Vector|
674
// whose element type is |Float| or |Integer|.
675
180k
ConstantFoldingRule FoldUnaryOp(UnaryScalarFoldingRule scalar_rule) {
676
180k
  return [scalar_rule](IRContext* context, Instruction* inst,
677
180k
                       const std::vector<const analysis::Constant*>& constants)
678
421k
             -> const analysis::Constant* {
679
421k
    analysis::ConstantManager* const_mgr = context->get_constant_mgr();
680
421k
    analysis::TypeManager* type_mgr = context->get_type_mgr();
681
421k
    const analysis::Type* result_type = type_mgr->GetType(inst->type_id());
682
421k
    const analysis::Vector* vector_type = result_type->AsVector();
683
684
421k
    const analysis::Constant* arg =
685
421k
        (inst->opcode() == spv::Op::OpExtInst) ? constants[1] : constants[0];
686
687
421k
    if (arg == nullptr) {
688
88.6k
      return nullptr;
689
88.6k
    }
690
691
333k
    if (vector_type != nullptr) {
692
887
      std::vector<const analysis::Constant*> a_components;
693
887
      std::vector<const analysis::Constant*> results_components;
694
695
887
      a_components = arg->GetVectorComponents(const_mgr);
696
697
      // Fold each component of the vector.
698
3.89k
      for (uint32_t i = 0; i < a_components.size(); ++i) {
699
3.01k
        results_components.push_back(scalar_rule(vector_type->element_type(),
700
3.01k
                                                 a_components[i], const_mgr));
701
3.01k
        if (results_components[i] == nullptr) {
702
0
          return nullptr;
703
0
        }
704
3.01k
      }
705
706
      // Build the constant object and return it.
707
887
      std::vector<uint32_t> ids;
708
3.01k
      for (const analysis::Constant* member : results_components) {
709
3.01k
        ids.push_back(const_mgr->GetDefiningInstruction(member)->result_id());
710
3.01k
      }
711
887
      return const_mgr->GetConstant(vector_type, ids);
712
332k
    } else {
713
332k
      return scalar_rule(result_type, arg, const_mgr);
714
332k
    }
715
333k
  };
716
180k
}
717
718
// Returns a |ConstantFoldingRule| that folds binary scalar ops
719
// using |scalar_rule| and binary vectors ops by applying
720
// |scalar_rule| to the elements of the vector. The folding rule assumes that op
721
// has two inputs. For regular instruction, those are in operands 0 and 1. For
722
// extended instruction, they are in operands 1 and 2. If an element in
723
// |constants| is not nullprt, then the constant's type is |Float|, |Integer|,
724
// or |Vector| whose element type is |Float| or |Integer|.
725
92.9k
ConstantFoldingRule FoldBinaryOp(BinaryScalarFoldingRule scalar_rule) {
726
92.9k
  return [scalar_rule](IRContext* context, Instruction* inst,
727
92.9k
                       const std::vector<const analysis::Constant*>& constants)
728
1.03M
             -> const analysis::Constant* {
729
1.03M
    assert(constants.size() == inst->NumInOperands());
730
1.03M
    assert(constants.size() == (inst->opcode() == spv::Op::OpExtInst ? 3 : 2));
731
1.03M
    analysis::ConstantManager* const_mgr = context->get_constant_mgr();
732
1.03M
    analysis::TypeManager* type_mgr = context->get_type_mgr();
733
1.03M
    const analysis::Type* result_type = type_mgr->GetType(inst->type_id());
734
1.03M
    const analysis::Vector* vector_type = result_type->AsVector();
735
736
1.03M
    const analysis::Constant* arg1 =
737
1.03M
        (inst->opcode() == spv::Op::OpExtInst) ? constants[1] : constants[0];
738
1.03M
    const analysis::Constant* arg2 =
739
1.03M
        (inst->opcode() == spv::Op::OpExtInst) ? constants[2] : constants[1];
740
741
1.03M
    if (arg1 == nullptr || arg2 == nullptr) {
742
159k
      return nullptr;
743
159k
    }
744
745
877k
    if (vector_type == nullptr) {
746
877k
      return scalar_rule(result_type, arg1, arg2, const_mgr);
747
877k
    }
748
749
255
    std::vector<const analysis::Constant*> a_components;
750
255
    std::vector<const analysis::Constant*> b_components;
751
255
    std::vector<const analysis::Constant*> results_components;
752
753
255
    a_components = arg1->GetVectorComponents(const_mgr);
754
255
    b_components = arg2->GetVectorComponents(const_mgr);
755
255
    assert(a_components.size() == b_components.size());
756
757
    // Fold each component of the vector.
758
1.25k
    for (uint32_t i = 0; i < a_components.size(); ++i) {
759
1.00k
      results_components.push_back(scalar_rule(vector_type->element_type(),
760
1.00k
                                               a_components[i], b_components[i],
761
1.00k
                                               const_mgr));
762
1.00k
      if (results_components[i] == nullptr) {
763
0
        return nullptr;
764
0
      }
765
1.00k
    }
766
767
    // Build the constant object and return it.
768
255
    std::vector<uint32_t> ids;
769
1.00k
    for (const analysis::Constant* member : results_components) {
770
1.00k
      ids.push_back(const_mgr->GetDefiningInstruction(member)->result_id());
771
1.00k
    }
772
255
    return const_mgr->GetConstant(vector_type, ids);
773
255
  };
774
92.9k
}
775
776
// Returns a |ConstantFoldingRule| that folds unary floating point scalar ops
777
// using |scalar_rule| and unary float point vectors ops by applying
778
// |scalar_rule| to the elements of the vector.  The |ConstantFoldingRule|
779
// that is returned assumes that |constants| contains 1 entry.  If they are
780
// not |nullptr|, then their type is either |Float| or |Integer| or a |Vector|
781
// whose element type is |Float| or |Integer|.
782
145k
ConstantFoldingRule FoldFPUnaryOp(UnaryScalarFoldingRule scalar_rule) {
783
145k
  auto folding_rule = FoldUnaryOp(scalar_rule);
784
145k
  return [folding_rule](IRContext* context, Instruction* inst,
785
145k
                        const std::vector<const analysis::Constant*>& constants)
786
416k
             -> const analysis::Constant* {
787
416k
    if (!inst->IsFloatingPointFoldingAllowed()) {
788
23
      return nullptr;
789
23
    }
790
791
416k
    return folding_rule(context, inst, constants);
792
416k
  };
793
145k
}
794
795
// Returns the result of folding the constants in |constants| according the
796
// |scalar_rule|.  If |result_type| is a vector, then |scalar_rule| is applied
797
// per component.
798
const analysis::Constant* FoldFPBinaryOp(
799
    BinaryScalarFoldingRule scalar_rule, uint32_t result_type_id,
800
    const std::vector<const analysis::Constant*>& constants,
801
1.25M
    IRContext* context) {
802
1.25M
  analysis::ConstantManager* const_mgr = context->get_constant_mgr();
803
1.25M
  analysis::TypeManager* type_mgr = context->get_type_mgr();
804
1.25M
  const analysis::Type* result_type = type_mgr->GetType(result_type_id);
805
1.25M
  const analysis::Vector* vector_type = result_type->AsVector();
806
807
1.25M
  if (constants[0] == nullptr || constants[1] == nullptr) {
808
965k
    return nullptr;
809
965k
  }
810
811
291k
  if (vector_type != nullptr) {
812
36.9k
    std::vector<const analysis::Constant*> a_components;
813
36.9k
    std::vector<const analysis::Constant*> b_components;
814
36.9k
    std::vector<const analysis::Constant*> results_components;
815
816
36.9k
    a_components = constants[0]->GetVectorComponents(const_mgr);
817
36.9k
    b_components = constants[1]->GetVectorComponents(const_mgr);
818
819
    // Fold each component of the vector.
820
177k
    for (uint32_t i = 0; i < a_components.size(); ++i) {
821
140k
      results_components.push_back(scalar_rule(vector_type->element_type(),
822
140k
                                               a_components[i], b_components[i],
823
140k
                                               const_mgr));
824
140k
      if (results_components[i] == nullptr) {
825
0
        return nullptr;
826
0
      }
827
140k
    }
828
829
    // Build the constant object and return it.
830
36.9k
    std::vector<uint32_t> ids;
831
140k
    for (const analysis::Constant* member : results_components) {
832
140k
      Instruction* def = const_mgr->GetDefiningInstruction(member);
833
140k
      if (!def) return nullptr;
834
140k
      ids.push_back(def->result_id());
835
140k
    }
836
36.9k
    return const_mgr->GetConstant(vector_type, ids);
837
254k
  } else {
838
254k
    return scalar_rule(result_type, constants[0], constants[1], const_mgr);
839
254k
  }
840
291k
}
841
842
// Returns a |ConstantFoldingRule| that folds floating point scalars using
843
// |scalar_rule| and vectors of floating point by applying |scalar_rule| to the
844
// elements of the vector.  The |ConstantFoldingRule| that is returned assumes
845
// that |constants| contains 2 entries.  If they are not |nullptr|, then their
846
// type is either |Float| or a |Vector| whose element type is |Float|.
847
241k
ConstantFoldingRule FoldFPBinaryOp(BinaryScalarFoldingRule scalar_rule) {
848
241k
  return [scalar_rule](IRContext* context, Instruction* inst,
849
241k
                       const std::vector<const analysis::Constant*>& constants)
850
1.25M
             -> const analysis::Constant* {
851
1.25M
    if (!inst->IsFloatingPointFoldingAllowed()) {
852
3.10k
      return nullptr;
853
3.10k
    }
854
1.25M
    if (inst->opcode() == spv::Op::OpExtInst) {
855
10.7k
      return FoldFPBinaryOp(scalar_rule, inst->type_id(),
856
10.7k
                            {constants[1], constants[2]}, context);
857
10.7k
    }
858
1.24M
    return FoldFPBinaryOp(scalar_rule, inst->type_id(), constants, context);
859
1.25M
  };
860
241k
}
861
862
// This macro defines a |UnaryScalarFoldingRule| that performs float to
863
// integer conversion.
864
// TODO(greg-lunarg): Support for 64-bit integer types.
865
23.2k
UnaryScalarFoldingRule FoldFToIOp() {
866
23.2k
  return [](const analysis::Type* result_type, const analysis::Constant* a,
867
23.2k
            analysis::ConstantManager* const_mgr) -> const analysis::Constant* {
868
981
    assert(result_type != nullptr && a != nullptr);
869
981
    const analysis::Integer* integer_type = result_type->AsInteger();
870
981
    const analysis::Float* float_type = a->type()->AsFloat();
871
981
    assert(float_type != nullptr);
872
981
    assert(integer_type != nullptr);
873
981
    if (integer_type->width() != 32) return nullptr;
874
981
    if (float_type->width() == 32) {
875
981
      float fa = a->GetFloat();
876
981
      uint32_t result = integer_type->IsSigned()
877
981
                            ? static_cast<uint32_t>(static_cast<int32_t>(fa))
878
981
                            : static_cast<uint32_t>(fa);
879
981
      std::vector<uint32_t> words = {result};
880
981
      return const_mgr->GetConstant(result_type, words);
881
981
    } else if (float_type->width() == 64) {
882
0
      double fa = a->GetDouble();
883
0
      uint32_t result = integer_type->IsSigned()
884
0
                            ? static_cast<uint32_t>(static_cast<int32_t>(fa))
885
0
                            : static_cast<uint32_t>(fa);
886
0
      std::vector<uint32_t> words = {result};
887
0
      return const_mgr->GetConstant(result_type, words);
888
0
    }
889
0
    return nullptr;
890
981
  };
891
23.2k
}
892
893
// This function defines a |UnaryScalarFoldingRule| that performs integer to
894
// float conversion.
895
// TODO(greg-lunarg): Support for 64-bit integer types.
896
23.2k
UnaryScalarFoldingRule FoldIToFOp() {
897
23.2k
  return [](const analysis::Type* result_type, const analysis::Constant* a,
898
330k
            analysis::ConstantManager* const_mgr) -> const analysis::Constant* {
899
330k
    assert(result_type != nullptr && a != nullptr);
900
330k
    const analysis::Integer* integer_type = a->type()->AsInteger();
901
330k
    const analysis::Float* float_type = result_type->AsFloat();
902
330k
    assert(float_type != nullptr);
903
330k
    assert(integer_type != nullptr);
904
330k
    if (integer_type->width() != 32) return nullptr;
905
330k
    uint32_t ua = a->GetU32();
906
330k
    if (float_type->width() == 32) {
907
330k
      float result_val = integer_type->IsSigned()
908
330k
                             ? static_cast<float>(static_cast<int32_t>(ua))
909
330k
                             : static_cast<float>(ua);
910
330k
      utils::FloatProxy<float> result(result_val);
911
330k
      std::vector<uint32_t> words = {result.data()};
912
330k
      return const_mgr->GetConstant(result_type, words);
913
330k
    } else if (float_type->width() == 64) {
914
0
      double result_val = integer_type->IsSigned()
915
0
                              ? static_cast<double>(static_cast<int32_t>(ua))
916
0
                              : static_cast<double>(ua);
917
0
      utils::FloatProxy<double> result(result_val);
918
0
      std::vector<uint32_t> words = result.GetWords();
919
0
      return const_mgr->GetConstant(result_type, words);
920
0
    }
921
0
    return nullptr;
922
330k
  };
923
23.2k
}
924
925
// This defines a |UnaryScalarFoldingRule| that performs |OpQuantizeToF16|.
926
11.6k
UnaryScalarFoldingRule FoldQuantizeToF16Scalar() {
927
11.6k
  return [](const analysis::Type* result_type, const analysis::Constant* a,
928
11.6k
            analysis::ConstantManager* const_mgr) -> const analysis::Constant* {
929
1.73k
    assert(result_type != nullptr && a != nullptr);
930
1.73k
    const analysis::Float* float_type = a->type()->AsFloat();
931
1.73k
    assert(float_type != nullptr);
932
1.73k
    if (float_type->width() != 32) {
933
0
      return nullptr;
934
0
    }
935
936
1.73k
    float fa = a->GetFloat();
937
1.73k
    utils::HexFloat<utils::FloatProxy<float>> orignal(fa);
938
1.73k
    utils::HexFloat<utils::FloatProxy<utils::Float16>> quantized(0);
939
1.73k
    utils::HexFloat<utils::FloatProxy<float>> result(0.0f);
940
1.73k
    orignal.castTo(quantized, utils::round_direction::kToZero);
941
1.73k
    quantized.castTo(result, utils::round_direction::kToZero);
942
1.73k
    std::vector<uint32_t> words = {result.getBits()};
943
1.73k
    return const_mgr->GetConstant(result_type, words);
944
1.73k
  };
945
11.6k
}
946
947
// This macro defines a |BinaryScalarFoldingRule| that applies |op|.  The
948
// operator |op| must work for both float and double, and use syntax "f1 op f2".
949
#define FOLD_FPARITH_OP(op)                                                   \
950
54.2k
  [](const analysis::Type* result_type_in_macro, const analysis::Constant* a, \
951
54.2k
     const analysis::Constant* b,                                             \
952
54.2k
     analysis::ConstantManager* const_mgr_in_macro)                           \
953
377k
      -> const analysis::Constant* {                                          \
954
377k
    assert(result_type_in_macro != nullptr && a != nullptr && b != nullptr);  \
955
377k
    assert(result_type_in_macro == a->type() &&                               \
956
377k
           result_type_in_macro == b->type());                                \
957
377k
    const analysis::Float* float_type_in_macro =                              \
958
377k
        result_type_in_macro->AsFloat();                                      \
959
377k
    assert(float_type_in_macro != nullptr);                                   \
960
377k
    if (float_type_in_macro->width() == 32) {                                 \
961
377k
      float fa = a->GetFloat();                                               \
962
377k
      float fb = b->GetFloat();                                               \
963
377k
      utils::FloatProxy<float> result_in_macro(fa op fb);                     \
964
377k
      std::vector<uint32_t> words_in_macro = result_in_macro.GetWords();      \
965
377k
      return const_mgr_in_macro->GetConstant(result_type_in_macro,            \
966
377k
                                             words_in_macro);                 \
967
377k
    } else if (float_type_in_macro->width() == 64) {                          \
968
0
      double fa = a->GetDouble();                                             \
969
0
      double fb = b->GetDouble();                                             \
970
0
      utils::FloatProxy<double> result_in_macro(fa op fb);                    \
971
0
      std::vector<uint32_t> words_in_macro = result_in_macro.GetWords();      \
972
0
      return const_mgr_in_macro->GetConstant(result_type_in_macro,            \
973
0
                                             words_in_macro);                 \
974
0
    }                                                                         \
975
377k
    return nullptr;                                                           \
976
377k
  }
977
978
// Define the folding rule for conversion between floating point and integer
979
23.2k
ConstantFoldingRule FoldFToI() { return FoldFPUnaryOp(FoldFToIOp()); }
980
23.2k
ConstantFoldingRule FoldIToF() { return FoldFPUnaryOp(FoldIToFOp()); }
981
11.6k
ConstantFoldingRule FoldQuantizeToF16() {
982
11.6k
  return FoldFPUnaryOp(FoldQuantizeToF16Scalar());
983
11.6k
}
984
985
// Define the folding rules for subtraction, addition, multiplication, and
986
// division for floating point values.
987
26.6k
ConstantFoldingRule FoldFSub() { return FoldFPBinaryOp(FOLD_FPARITH_OP(-)); }
988
143k
ConstantFoldingRule FoldFAdd() { return FoldFPBinaryOp(FOLD_FPARITH_OP(+)); }
989
231k
ConstantFoldingRule FoldFMul() { return FoldFPBinaryOp(FOLD_FPARITH_OP(*)); }
990
991
// Returns the constant that results from evaluating |numerator| / 0.0.  Returns
992
// |nullptr| if the result could not be evaluated.
993
const analysis::Constant* FoldFPScalarDivideByZero(
994
    const analysis::Type* result_type, const analysis::Constant* numerator,
995
2.08k
    analysis::ConstantManager* const_mgr) {
996
2.08k
  if (numerator == nullptr) {
997
0
    return nullptr;
998
0
  }
999
1000
2.08k
  if (numerator->IsZero()) {
1001
1.12k
    return GetNan(result_type, const_mgr);
1002
1.12k
  }
1003
1004
965
  const analysis::Constant* result = GetInf(result_type, const_mgr);
1005
965
  if (result == nullptr) {
1006
0
    return nullptr;
1007
0
  }
1008
1009
965
  if (numerator->AsFloatConstant()->GetValueAsDouble() < 0.0) {
1010
227
    result = NegateFPConst(result_type, result, const_mgr);
1011
227
  }
1012
965
  return result;
1013
965
}
1014
1015
// Returns the result of folding |numerator| / |denominator|.  Returns |nullptr|
1016
// if it cannot be folded.
1017
const analysis::Constant* FoldScalarFPDivide(
1018
    const analysis::Type* result_type, const analysis::Constant* numerator,
1019
    const analysis::Constant* denominator,
1020
11.5k
    analysis::ConstantManager* const_mgr) {
1021
11.5k
  if (denominator == nullptr) {
1022
0
    return nullptr;
1023
0
  }
1024
1025
11.5k
  if (denominator->IsZero()) {
1026
1.76k
    return FoldFPScalarDivideByZero(result_type, numerator, const_mgr);
1027
1.76k
  }
1028
1029
9.81k
  uint32_t width = denominator->type()->AsFloat()->width();
1030
9.81k
  if (width != 32 && width != 64) {
1031
0
    return nullptr;
1032
0
  }
1033
1034
9.81k
  const analysis::FloatConstant* denominator_float =
1035
9.81k
      denominator->AsFloatConstant();
1036
9.81k
  if (denominator_float && denominator->GetValueAsDouble() == -0.0) {
1037
323
    const analysis::Constant* result =
1038
323
        FoldFPScalarDivideByZero(result_type, numerator, const_mgr);
1039
323
    if (result != nullptr)
1040
323
      result = NegateFPConst(result_type, result, const_mgr);
1041
323
    return result;
1042
9.49k
  } else {
1043
18.9k
    return FOLD_FPARITH_OP(/)(result_type, numerator, denominator, const_mgr);
1044
9.49k
  }
1045
9.81k
}
1046
1047
// Returns the constant folding rule to fold |OpFDiv| with two constants.
1048
11.6k
ConstantFoldingRule FoldFDiv() { return FoldFPBinaryOp(FoldScalarFPDivide); }
1049
1050
bool CompareFloatingPoint(bool op_result, bool op_unordered,
1051
10.9k
                          bool need_ordered) {
1052
10.9k
  if (need_ordered) {
1053
    // operands are ordered and Operand 1 is |op| Operand 2
1054
4.63k
    return !op_unordered && op_result;
1055
6.28k
  } else {
1056
    // operands are unordered or Operand 1 is |op| Operand 2
1057
6.28k
    return op_unordered || op_result;
1058
6.28k
  }
1059
10.9k
}
1060
1061
// This macro defines a |BinaryScalarFoldingRule| that applies |op|.  The
1062
// operator |op| must work for both float and double, and use syntax "f1 op f2".
1063
#define FOLD_FPCMP_OP(op, ord)                                            \
1064
139k
  [](const analysis::Type* result_type, const analysis::Constant* a,      \
1065
139k
     const analysis::Constant* b,                                         \
1066
139k
     analysis::ConstantManager* const_mgr) -> const analysis::Constant* { \
1067
10.9k
    assert(result_type != nullptr && a != nullptr && b != nullptr);       \
1068
10.9k
    assert(result_type->AsBool());                                        \
1069
10.9k
    assert(a->type() == b->type());                                       \
1070
10.9k
    const analysis::Float* float_type = a->type()->AsFloat();             \
1071
10.9k
    assert(float_type != nullptr);                                        \
1072
10.9k
    if (float_type->width() == 32) {                                      \
1073
10.9k
      float fa = a->GetFloat();                                           \
1074
10.9k
      float fb = b->GetFloat();                                           \
1075
10.9k
      bool result = CompareFloatingPoint(                                 \
1076
10.9k
          fa op fb, std::isnan(fa) || std::isnan(fb), ord);               \
1077
10.9k
      std::vector<uint32_t> words = {uint32_t(result)};                   \
1078
10.9k
      return const_mgr->GetConstant(result_type, words);                  \
1079
10.9k
    } else if (float_type->width() == 64) {                               \
1080
0
      double fa = a->GetDouble();                                         \
1081
0
      double fb = b->GetDouble();                                         \
1082
0
      bool result = CompareFloatingPoint(                                 \
1083
0
          fa op fb, std::isnan(fa) || std::isnan(fb), ord);               \
1084
0
      std::vector<uint32_t> words = {uint32_t(result)};                   \
1085
0
      return const_mgr->GetConstant(result_type, words);                  \
1086
0
    }                                                                     \
1087
10.9k
    return nullptr;                                                       \
1088
10.9k
  }
1089
1090
// Define the folding rules for ordered and unordered comparison for floating
1091
// point values.
1092
11.6k
ConstantFoldingRule FoldFOrdEqual() {
1093
12.1k
  return FoldFPBinaryOp(FOLD_FPCMP_OP(==, true));
1094
11.6k
}
1095
11.6k
ConstantFoldingRule FoldFUnordEqual() {
1096
11.8k
  return FoldFPBinaryOp(FOLD_FPCMP_OP(==, false));
1097
11.6k
}
1098
11.6k
ConstantFoldingRule FoldFOrdNotEqual() {
1099
12.1k
  return FoldFPBinaryOp(FOLD_FPCMP_OP(!=, true));
1100
11.6k
}
1101
11.6k
ConstantFoldingRule FoldFUnordNotEqual() {
1102
12.7k
  return FoldFPBinaryOp(FOLD_FPCMP_OP(!=, false));
1103
11.6k
}
1104
11.6k
ConstantFoldingRule FoldFOrdLessThan() {
1105
12.8k
  return FoldFPBinaryOp(FOLD_FPCMP_OP(<, true));
1106
11.6k
}
1107
11.6k
ConstantFoldingRule FoldFUnordLessThan() {
1108
12.4k
  return FoldFPBinaryOp(FOLD_FPCMP_OP(<, false));
1109
11.6k
}
1110
11.6k
ConstantFoldingRule FoldFOrdGreaterThan() {
1111
12.4k
  return FoldFPBinaryOp(FOLD_FPCMP_OP(>, true));
1112
11.6k
}
1113
11.6k
ConstantFoldingRule FoldFUnordGreaterThan() {
1114
13.3k
  return FoldFPBinaryOp(FOLD_FPCMP_OP(>, false));
1115
11.6k
}
1116
11.6k
ConstantFoldingRule FoldFOrdLessThanEqual() {
1117
12.3k
  return FoldFPBinaryOp(FOLD_FPCMP_OP(<=, true));
1118
11.6k
}
1119
11.6k
ConstantFoldingRule FoldFUnordLessThanEqual() {
1120
12.9k
  return FoldFPBinaryOp(FOLD_FPCMP_OP(<=, false));
1121
11.6k
}
1122
11.6k
ConstantFoldingRule FoldFOrdGreaterThanEqual() {
1123
12.3k
  return FoldFPBinaryOp(FOLD_FPCMP_OP(>=, true));
1124
11.6k
}
1125
11.6k
ConstantFoldingRule FoldFUnordGreaterThanEqual() {
1126
12.6k
  return FoldFPBinaryOp(FOLD_FPCMP_OP(>=, false));
1127
11.6k
}
1128
1129
// Folds an OpDot where all of the inputs are constants to a
1130
// constant.  A new constant is created if necessary.
1131
11.6k
ConstantFoldingRule FoldOpDotWithConstants() {
1132
11.6k
  return [](IRContext* context, Instruction* inst,
1133
11.6k
            const std::vector<const analysis::Constant*>& constants)
1134
11.6k
             -> const analysis::Constant* {
1135
172
    analysis::ConstantManager* const_mgr = context->get_constant_mgr();
1136
172
    analysis::TypeManager* type_mgr = context->get_type_mgr();
1137
172
    const analysis::Type* new_type = type_mgr->GetType(inst->type_id());
1138
172
    assert(new_type->AsFloat() && "OpDot should have a float return type.");
1139
172
    const analysis::Float* float_type = new_type->AsFloat();
1140
1141
172
    if (!inst->IsFloatingPointFoldingAllowed()) {
1142
0
      return nullptr;
1143
0
    }
1144
1145
    // If one of the operands is 0, then the result is 0.
1146
172
    bool has_zero_operand = false;
1147
1148
516
    for (int i = 0; i < 2; ++i) {
1149
344
      if (constants[i]) {
1150
78
        if (constants[i]->AsNullConstant() ||
1151
78
            constants[i]->AsVectorConstant()->IsZero()) {
1152
0
          has_zero_operand = true;
1153
0
          break;
1154
0
        }
1155
78
      }
1156
344
    }
1157
1158
172
    if (has_zero_operand) {
1159
0
      if (float_type->width() == 32) {
1160
0
        utils::FloatProxy<float> result(0.0f);
1161
0
        std::vector<uint32_t> words = result.GetWords();
1162
0
        return const_mgr->GetConstant(float_type, words);
1163
0
      }
1164
0
      if (float_type->width() == 64) {
1165
0
        utils::FloatProxy<double> result(0.0);
1166
0
        std::vector<uint32_t> words = result.GetWords();
1167
0
        return const_mgr->GetConstant(float_type, words);
1168
0
      }
1169
0
      return nullptr;
1170
0
    }
1171
1172
172
    if (constants[0] == nullptr || constants[1] == nullptr) {
1173
172
      return nullptr;
1174
172
    }
1175
1176
0
    std::vector<const analysis::Constant*> a_components;
1177
0
    std::vector<const analysis::Constant*> b_components;
1178
1179
0
    a_components = constants[0]->GetVectorComponents(const_mgr);
1180
0
    b_components = constants[1]->GetVectorComponents(const_mgr);
1181
1182
0
    utils::FloatProxy<double> result(0.0);
1183
0
    std::vector<uint32_t> words = result.GetWords();
1184
0
    const analysis::Constant* result_const =
1185
0
        const_mgr->GetConstant(float_type, words);
1186
0
    for (uint32_t i = 0; i < a_components.size() && result_const != nullptr;
1187
0
         ++i) {
1188
0
      if (a_components[i] == nullptr || b_components[i] == nullptr) {
1189
0
        return nullptr;
1190
0
      }
1191
1192
0
      const analysis::Constant* component = FOLD_FPARITH_OP(*)(
1193
0
          new_type, a_components[i], b_components[i], const_mgr);
1194
0
      if (component == nullptr) {
1195
0
        return nullptr;
1196
0
      }
1197
0
      result_const =
1198
0
          FOLD_FPARITH_OP(+)(new_type, result_const, component, const_mgr);
1199
0
    }
1200
0
    return result_const;
1201
0
  };
1202
11.6k
}
1203
1204
11.6k
ConstantFoldingRule FoldFNegate() { return FoldFPUnaryOp(NegateFPConst); }
1205
11.6k
ConstantFoldingRule FoldSNegate() { return FoldUnaryOp(NegateIntConst); }
1206
1207
92.9k
ConstantFoldingRule FoldFClampFeedingCompare(spv::Op cmp_opcode) {
1208
92.9k
  return [cmp_opcode](IRContext* context, Instruction* inst,
1209
92.9k
                      const std::vector<const analysis::Constant*>& constants)
1210
92.9k
             -> const analysis::Constant* {
1211
66.6k
    analysis::ConstantManager* const_mgr = context->get_constant_mgr();
1212
66.6k
    analysis::DefUseManager* def_use_mgr = context->get_def_use_mgr();
1213
1214
66.6k
    if (!inst->IsFloatingPointFoldingAllowed()) {
1215
19
      return nullptr;
1216
19
    }
1217
1218
66.6k
    uint32_t non_const_idx = (constants[0] ? 1 : 0);
1219
66.6k
    uint32_t operand_id = inst->GetSingleWordInOperand(non_const_idx);
1220
66.6k
    Instruction* operand_inst = def_use_mgr->GetDef(operand_id);
1221
1222
66.6k
    analysis::TypeManager* type_mgr = context->get_type_mgr();
1223
66.6k
    const analysis::Type* operand_type =
1224
66.6k
        type_mgr->GetType(operand_inst->type_id());
1225
1226
66.6k
    if (!operand_type->AsFloat()) {
1227
0
      return nullptr;
1228
0
    }
1229
1230
66.6k
    if (operand_type->AsFloat()->width() != 32 &&
1231
0
        operand_type->AsFloat()->width() != 64) {
1232
0
      return nullptr;
1233
0
    }
1234
1235
66.6k
    if (operand_inst->opcode() != spv::Op::OpExtInst) {
1236
60.2k
      return nullptr;
1237
60.2k
    }
1238
1239
6.39k
    if (operand_inst->GetSingleWordInOperand(1) != GLSLstd450FClamp) {
1240
6.39k
      return nullptr;
1241
6.39k
    }
1242
1243
0
    if (constants[1] == nullptr && constants[0] == nullptr) {
1244
0
      return nullptr;
1245
0
    }
1246
1247
0
    uint32_t max_id = operand_inst->GetSingleWordInOperand(4);
1248
0
    const analysis::Constant* max_const =
1249
0
        const_mgr->FindDeclaredConstant(max_id);
1250
1251
0
    uint32_t min_id = operand_inst->GetSingleWordInOperand(3);
1252
0
    const analysis::Constant* min_const =
1253
0
        const_mgr->FindDeclaredConstant(min_id);
1254
1255
0
    bool found_result = false;
1256
0
    bool result = false;
1257
1258
0
    switch (cmp_opcode) {
1259
0
      case spv::Op::OpFOrdLessThan:
1260
0
      case spv::Op::OpFUnordLessThan:
1261
0
      case spv::Op::OpFOrdGreaterThanEqual:
1262
0
      case spv::Op::OpFUnordGreaterThanEqual:
1263
0
        if (constants[0]) {
1264
0
          if (min_const) {
1265
0
            if (constants[0]->GetValueAsDouble() <
1266
0
                min_const->GetValueAsDouble()) {
1267
0
              found_result = true;
1268
0
              result = (cmp_opcode == spv::Op::OpFOrdLessThan ||
1269
0
                        cmp_opcode == spv::Op::OpFUnordLessThan);
1270
0
            }
1271
0
          }
1272
0
          if (max_const) {
1273
0
            if (constants[0]->GetValueAsDouble() >=
1274
0
                max_const->GetValueAsDouble()) {
1275
0
              found_result = true;
1276
0
              result = !(cmp_opcode == spv::Op::OpFOrdLessThan ||
1277
0
                         cmp_opcode == spv::Op::OpFUnordLessThan);
1278
0
            }
1279
0
          }
1280
0
        }
1281
1282
0
        if (constants[1]) {
1283
0
          if (max_const) {
1284
0
            if (max_const->GetValueAsDouble() <
1285
0
                constants[1]->GetValueAsDouble()) {
1286
0
              found_result = true;
1287
0
              result = (cmp_opcode == spv::Op::OpFOrdLessThan ||
1288
0
                        cmp_opcode == spv::Op::OpFUnordLessThan);
1289
0
            }
1290
0
          }
1291
1292
0
          if (min_const) {
1293
0
            if (min_const->GetValueAsDouble() >=
1294
0
                constants[1]->GetValueAsDouble()) {
1295
0
              found_result = true;
1296
0
              result = !(cmp_opcode == spv::Op::OpFOrdLessThan ||
1297
0
                         cmp_opcode == spv::Op::OpFUnordLessThan);
1298
0
            }
1299
0
          }
1300
0
        }
1301
0
        break;
1302
0
      case spv::Op::OpFOrdGreaterThan:
1303
0
      case spv::Op::OpFUnordGreaterThan:
1304
0
      case spv::Op::OpFOrdLessThanEqual:
1305
0
      case spv::Op::OpFUnordLessThanEqual:
1306
0
        if (constants[0]) {
1307
0
          if (min_const) {
1308
0
            if (constants[0]->GetValueAsDouble() <=
1309
0
                min_const->GetValueAsDouble()) {
1310
0
              found_result = true;
1311
0
              result = (cmp_opcode == spv::Op::OpFOrdLessThanEqual ||
1312
0
                        cmp_opcode == spv::Op::OpFUnordLessThanEqual);
1313
0
            }
1314
0
          }
1315
0
          if (max_const) {
1316
0
            if (constants[0]->GetValueAsDouble() >
1317
0
                max_const->GetValueAsDouble()) {
1318
0
              found_result = true;
1319
0
              result = !(cmp_opcode == spv::Op::OpFOrdLessThanEqual ||
1320
0
                         cmp_opcode == spv::Op::OpFUnordLessThanEqual);
1321
0
            }
1322
0
          }
1323
0
        }
1324
1325
0
        if (constants[1]) {
1326
0
          if (max_const) {
1327
0
            if (max_const->GetValueAsDouble() <=
1328
0
                constants[1]->GetValueAsDouble()) {
1329
0
              found_result = true;
1330
0
              result = (cmp_opcode == spv::Op::OpFOrdLessThanEqual ||
1331
0
                        cmp_opcode == spv::Op::OpFUnordLessThanEqual);
1332
0
            }
1333
0
          }
1334
1335
0
          if (min_const) {
1336
0
            if (min_const->GetValueAsDouble() >
1337
0
                constants[1]->GetValueAsDouble()) {
1338
0
              found_result = true;
1339
0
              result = !(cmp_opcode == spv::Op::OpFOrdLessThanEqual ||
1340
0
                         cmp_opcode == spv::Op::OpFUnordLessThanEqual);
1341
0
            }
1342
0
          }
1343
0
        }
1344
0
        break;
1345
0
      default:
1346
0
        return nullptr;
1347
0
    }
1348
1349
0
    if (!found_result) {
1350
0
      return nullptr;
1351
0
    }
1352
1353
0
    const analysis::Type* bool_type =
1354
0
        context->get_type_mgr()->GetType(inst->type_id());
1355
0
    const analysis::Constant* result_const =
1356
0
        const_mgr->GetConstant(bool_type, {static_cast<uint32_t>(result)});
1357
0
    assert(result_const);
1358
0
    return result_const;
1359
0
  };
1360
92.9k
}
1361
1362
6.92k
ConstantFoldingRule FoldFMix() {
1363
6.92k
  return [](IRContext* context, Instruction* inst,
1364
6.92k
            const std::vector<const analysis::Constant*>& constants)
1365
9.04k
             -> const analysis::Constant* {
1366
9.04k
    analysis::ConstantManager* const_mgr = context->get_constant_mgr();
1367
9.04k
    assert(inst->opcode() == spv::Op::OpExtInst &&
1368
9.04k
           "Expecting an extended instruction.");
1369
9.04k
    assert(inst->GetSingleWordInOperand(0) ==
1370
9.04k
               context->get_feature_mgr()->GetExtInstImportId_GLSLstd450() &&
1371
9.04k
           "Expecting a GLSLstd450 extended instruction.");
1372
9.04k
    assert(inst->GetSingleWordInOperand(1) == GLSLstd450FMix &&
1373
9.04k
           "Expecting and FMix instruction.");
1374
1375
9.04k
    if (!inst->IsFloatingPointFoldingAllowed()) {
1376
0
      return nullptr;
1377
0
    }
1378
1379
    // Make sure all FMix operands are constants.
1380
9.74k
    for (uint32_t i = 1; i < 4; i++) {
1381
9.64k
      if (constants[i] == nullptr) {
1382
8.94k
        return nullptr;
1383
8.94k
      }
1384
9.64k
    }
1385
1386
103
    const analysis::Constant* one;
1387
103
    bool is_vector = false;
1388
103
    const analysis::Type* result_type = constants[1]->type();
1389
103
    const analysis::Type* base_type = result_type;
1390
103
    if (base_type->AsVector()) {
1391
85
      is_vector = true;
1392
85
      base_type = base_type->AsVector()->element_type();
1393
85
    }
1394
103
    assert(base_type->AsFloat() != nullptr &&
1395
103
           "FMix is suppose to act on floats or vectors of floats.");
1396
1397
103
    if (base_type->AsFloat()->width() == 32) {
1398
103
      one = const_mgr->GetConstant(base_type,
1399
103
                                   utils::FloatProxy<float>(1.0f).GetWords());
1400
103
    } else if (base_type->AsFloat()->width() == 64) {
1401
0
      one = const_mgr->GetConstant(base_type,
1402
0
                                   utils::FloatProxy<double>(1.0).GetWords());
1403
0
    } else {
1404
      // We won't support folding half types.
1405
0
      return nullptr;
1406
0
    }
1407
1408
103
    if (is_vector) {
1409
85
      Instruction* one_inst = const_mgr->GetDefiningInstruction(one);
1410
85
      if (one_inst == nullptr) return nullptr;
1411
85
      uint32_t one_id = one_inst->result_id();
1412
85
      one =
1413
85
          const_mgr->GetConstant(result_type, std::vector<uint32_t>(4, one_id));
1414
85
    }
1415
1416
103
    const analysis::Constant* temp1 = FoldFPBinaryOp(
1417
374
        FOLD_FPARITH_OP(-), inst->type_id(), {one, constants[3]}, context);
1418
103
    if (temp1 == nullptr) {
1419
0
      return nullptr;
1420
0
    }
1421
1422
103
    const analysis::Constant* temp2 = FoldFPBinaryOp(
1423
374
        FOLD_FPARITH_OP(*), inst->type_id(), {constants[1], temp1}, context);
1424
103
    if (temp2 == nullptr) {
1425
0
      return nullptr;
1426
0
    }
1427
103
    const analysis::Constant* temp3 =
1428
374
        FoldFPBinaryOp(FOLD_FPARITH_OP(*), inst->type_id(),
1429
103
                       {constants[2], constants[3]}, context);
1430
103
    if (temp3 == nullptr) {
1431
0
      return nullptr;
1432
0
    }
1433
374
    return FoldFPBinaryOp(FOLD_FPARITH_OP(+), inst->type_id(), {temp2, temp3},
1434
103
                          context);
1435
103
  };
1436
6.92k
}
1437
1438
const analysis::Constant* FoldMin(const analysis::Type* result_type,
1439
                                  const analysis::Constant* a,
1440
                                  const analysis::Constant* b,
1441
1.97k
                                  analysis::ConstantManager*) {
1442
1.97k
  if (const analysis::Integer* int_type = result_type->AsInteger()) {
1443
102
    if (int_type->width() <= 32) {
1444
102
      assert(
1445
102
          (a->AsIntConstant() != nullptr || a->AsNullConstant() != nullptr) &&
1446
102
          "Must be an integer or null constant.");
1447
102
      assert(
1448
102
          (b->AsIntConstant() != nullptr || b->AsNullConstant() != nullptr) &&
1449
102
          "Must be an integer or null constant.");
1450
1451
102
      if (int_type->IsSigned()) {
1452
93
        int32_t va = (a->AsIntConstant() != nullptr)
1453
93
                         ? a->AsIntConstant()->GetS32BitValue()
1454
93
                         : 0;
1455
93
        int32_t vb = (b->AsIntConstant() != nullptr)
1456
93
                         ? b->AsIntConstant()->GetS32BitValue()
1457
93
                         : 0;
1458
93
        return (va < vb ? a : b);
1459
93
      } else {
1460
9
        uint32_t va = (a->AsIntConstant() != nullptr)
1461
9
                          ? a->AsIntConstant()->GetU32BitValue()
1462
9
                          : 0;
1463
9
        uint32_t vb = (b->AsIntConstant() != nullptr)
1464
9
                          ? b->AsIntConstant()->GetU32BitValue()
1465
9
                          : 0;
1466
9
        return (va < vb ? a : b);
1467
9
      }
1468
102
    } else if (int_type->width() == 64) {
1469
0
      if (int_type->IsSigned()) {
1470
0
        int64_t va = a->GetS64();
1471
0
        int64_t vb = b->GetS64();
1472
0
        return (va < vb ? a : b);
1473
0
      } else {
1474
0
        uint64_t va = a->GetU64();
1475
0
        uint64_t vb = b->GetU64();
1476
0
        return (va < vb ? a : b);
1477
0
      }
1478
0
    }
1479
1.86k
  } else if (const analysis::Float* float_type = result_type->AsFloat()) {
1480
1.86k
    if (float_type->width() == 32) {
1481
1.86k
      float va = a->GetFloat();
1482
1.86k
      float vb = b->GetFloat();
1483
1.86k
      return (va < vb ? a : b);
1484
1.86k
    } else if (float_type->width() == 64) {
1485
0
      double va = a->GetDouble();
1486
0
      double vb = b->GetDouble();
1487
0
      return (va < vb ? a : b);
1488
0
    }
1489
1.86k
  }
1490
0
  return nullptr;
1491
1.97k
}
1492
1493
const analysis::Constant* FoldMax(const analysis::Type* result_type,
1494
                                  const analysis::Constant* a,
1495
                                  const analysis::Constant* b,
1496
2.86k
                                  analysis::ConstantManager*) {
1497
2.86k
  if (const analysis::Integer* int_type = result_type->AsInteger()) {
1498
11
    if (int_type->width() <= 32) {
1499
11
      assert(
1500
11
          (a->AsIntConstant() != nullptr || a->AsNullConstant() != nullptr) &&
1501
11
          "Must be an integer or null constant.");
1502
11
      assert(
1503
11
          (b->AsIntConstant() != nullptr || b->AsNullConstant() != nullptr) &&
1504
11
          "Must be an integer or null constant.");
1505
1506
11
      if (int_type->IsSigned()) {
1507
11
        int32_t va = (a->AsIntConstant() != nullptr)
1508
11
                         ? a->AsIntConstant()->GetS32BitValue()
1509
11
                         : 0;
1510
11
        int32_t vb = (b->AsIntConstant() != nullptr)
1511
11
                         ? b->AsIntConstant()->GetS32BitValue()
1512
11
                         : 0;
1513
11
        return (va > vb ? a : b);
1514
11
      } else {
1515
0
        uint32_t va = (a->AsIntConstant() != nullptr)
1516
0
                          ? a->AsIntConstant()->GetU32BitValue()
1517
0
                          : 0;
1518
0
        uint32_t vb = (b->AsIntConstant() != nullptr)
1519
0
                          ? b->AsIntConstant()->GetU32BitValue()
1520
0
                          : 0;
1521
0
        return (va > vb ? a : b);
1522
0
      }
1523
11
    } else if (int_type->width() == 64) {
1524
0
      if (int_type->IsSigned()) {
1525
0
        int64_t va = a->GetS64();
1526
0
        int64_t vb = b->GetS64();
1527
0
        return (va > vb ? a : b);
1528
0
      } else {
1529
0
        uint64_t va = a->GetU64();
1530
0
        uint64_t vb = b->GetU64();
1531
0
        return (va > vb ? a : b);
1532
0
      }
1533
0
    }
1534
2.85k
  } else if (const analysis::Float* float_type = result_type->AsFloat()) {
1535
2.85k
    if (float_type->width() == 32) {
1536
2.85k
      float va = a->GetFloat();
1537
2.85k
      float vb = b->GetFloat();
1538
2.85k
      return (va > vb ? a : b);
1539
2.85k
    } else if (float_type->width() == 64) {
1540
0
      double va = a->GetDouble();
1541
0
      double vb = b->GetDouble();
1542
0
      return (va > vb ? a : b);
1543
0
    }
1544
2.85k
  }
1545
0
  return nullptr;
1546
2.86k
}
1547
1548
// Fold an clamp instruction when all three operands are constant.
1549
const analysis::Constant* FoldClamp1(
1550
    IRContext* context, Instruction* inst,
1551
8.03k
    const std::vector<const analysis::Constant*>& constants) {
1552
8.03k
  assert(inst->opcode() == spv::Op::OpExtInst &&
1553
8.03k
         "Expecting an extended instruction.");
1554
8.03k
  assert(inst->GetSingleWordInOperand(0) ==
1555
8.03k
             context->get_feature_mgr()->GetExtInstImportId_GLSLstd450() &&
1556
8.03k
         "Expecting a GLSLstd450 extended instruction.");
1557
1558
  // Make sure all Clamp operands are constants.
1559
12.7k
  for (uint32_t i = 1; i < 4; i++) {
1560
12.1k
    if (constants[i] == nullptr) {
1561
7.49k
      return nullptr;
1562
7.49k
    }
1563
12.1k
  }
1564
1565
534
  const analysis::Constant* temp = FoldFPBinaryOp(
1566
534
      FoldMax, inst->type_id(), {constants[1], constants[2]}, context);
1567
534
  if (temp == nullptr) {
1568
0
    return nullptr;
1569
0
  }
1570
534
  return FoldFPBinaryOp(FoldMin, inst->type_id(), {temp, constants[3]},
1571
534
                        context);
1572
534
}
1573
1574
// Fold a clamp instruction when |x <= min_val|.
1575
const analysis::Constant* FoldClamp2(
1576
    IRContext* context, Instruction* inst,
1577
7.49k
    const std::vector<const analysis::Constant*>& constants) {
1578
7.49k
  assert(inst->opcode() == spv::Op::OpExtInst &&
1579
7.49k
         "Expecting an extended instruction.");
1580
7.49k
  assert(inst->GetSingleWordInOperand(0) ==
1581
7.49k
             context->get_feature_mgr()->GetExtInstImportId_GLSLstd450() &&
1582
7.49k
         "Expecting a GLSLstd450 extended instruction.");
1583
1584
7.49k
  const analysis::Constant* x = constants[1];
1585
7.49k
  const analysis::Constant* min_val = constants[2];
1586
1587
7.49k
  if (x == nullptr || min_val == nullptr) {
1588
6.07k
    return nullptr;
1589
6.07k
  }
1590
1591
1.42k
  const analysis::Constant* temp =
1592
1.42k
      FoldFPBinaryOp(FoldMax, inst->type_id(), {x, min_val}, context);
1593
1.42k
  if (temp == min_val) {
1594
    // We can assume that |min_val| is less than |max_val|.  Therefore, if the
1595
    // result of the max operation is |min_val|, we know the result of the min
1596
    // operation, even if |max_val| is not a constant.
1597
1.18k
    return min_val;
1598
1.18k
  }
1599
237
  return nullptr;
1600
1.42k
}
1601
1602
// Fold a clamp instruction when |x >= max_val|.
1603
const analysis::Constant* FoldClamp3(
1604
    IRContext* context, Instruction* inst,
1605
6.31k
    const std::vector<const analysis::Constant*>& constants) {
1606
6.31k
  assert(inst->opcode() == spv::Op::OpExtInst &&
1607
6.31k
         "Expecting an extended instruction.");
1608
6.31k
  assert(inst->GetSingleWordInOperand(0) ==
1609
6.31k
             context->get_feature_mgr()->GetExtInstImportId_GLSLstd450() &&
1610
6.31k
         "Expecting a GLSLstd450 extended instruction.");
1611
1612
6.31k
  const analysis::Constant* x = constants[1];
1613
6.31k
  const analysis::Constant* max_val = constants[3];
1614
1615
6.31k
  if (x == nullptr || max_val == nullptr) {
1616
6.14k
    return nullptr;
1617
6.14k
  }
1618
1619
166
  const analysis::Constant* temp =
1620
166
      FoldFPBinaryOp(FoldMin, inst->type_id(), {x, max_val}, context);
1621
166
  if (temp == max_val) {
1622
    // We can assume that |min_val| is less than |max_val|.  Therefore, if the
1623
    // result of the max operation is |min_val|, we know the result of the min
1624
    // operation, even if |max_val| is not a constant.
1625
63
    return max_val;
1626
63
  }
1627
103
  return nullptr;
1628
166
}
1629
1630
76.1k
UnaryScalarFoldingRule FoldFTranscendentalUnary(double (*fp)(double)) {
1631
76.1k
  return
1632
76.1k
      [fp](const analysis::Type* result_type, const analysis::Constant* a,
1633
76.1k
           analysis::ConstantManager* const_mgr) -> const analysis::Constant* {
1634
1.15k
        assert(result_type != nullptr && a != nullptr);
1635
1.15k
        const analysis::Float* float_type = a->type()->AsFloat();
1636
1.15k
        assert(float_type != nullptr);
1637
1.15k
        assert(float_type == result_type->AsFloat());
1638
1.15k
        if (float_type->width() == 32) {
1639
1.15k
          float fa = a->GetFloat();
1640
1.15k
          float res = static_cast<float>(fp(fa));
1641
1.15k
          utils::FloatProxy<float> result(res);
1642
1.15k
          std::vector<uint32_t> words = result.GetWords();
1643
1.15k
          return const_mgr->GetConstant(result_type, words);
1644
1.15k
        } else if (float_type->width() == 64) {
1645
0
          double fa = a->GetDouble();
1646
0
          double res = fp(fa);
1647
0
          utils::FloatProxy<double> result(res);
1648
0
          std::vector<uint32_t> words = result.GetWords();
1649
0
          return const_mgr->GetConstant(result_type, words);
1650
0
        }
1651
0
        return nullptr;
1652
1.15k
      };
1653
76.1k
}
1654
1655
BinaryScalarFoldingRule FoldFTranscendentalBinary(double (*fp)(double,
1656
13.8k
                                                               double)) {
1657
13.8k
  return
1658
13.8k
      [fp](const analysis::Type* result_type, const analysis::Constant* a,
1659
13.8k
           const analysis::Constant* b,
1660
13.8k
           analysis::ConstantManager* const_mgr) -> const analysis::Constant* {
1661
135
        assert(result_type != nullptr && a != nullptr);
1662
135
        const analysis::Float* float_type = a->type()->AsFloat();
1663
135
        assert(float_type != nullptr);
1664
135
        assert(float_type == result_type->AsFloat());
1665
135
        assert(float_type == b->type()->AsFloat());
1666
135
        if (float_type->width() == 32) {
1667
135
          float fa = a->GetFloat();
1668
135
          float fb = b->GetFloat();
1669
135
          float res = static_cast<float>(fp(fa, fb));
1670
135
          utils::FloatProxy<float> result(res);
1671
135
          std::vector<uint32_t> words = result.GetWords();
1672
135
          return const_mgr->GetConstant(result_type, words);
1673
135
        } else if (float_type->width() == 64) {
1674
0
          double fa = a->GetDouble();
1675
0
          double fb = b->GetDouble();
1676
0
          double res = fp(fa, fb);
1677
0
          utils::FloatProxy<double> result(res);
1678
0
          std::vector<uint32_t> words = result.GetWords();
1679
0
          return const_mgr->GetConstant(result_type, words);
1680
0
        }
1681
0
        return nullptr;
1682
135
      };
1683
13.8k
}
1684
1685
enum Sign { Signed, Unsigned };
1686
1687
// Returns a BinaryScalarFoldingRule that applies `op` to the scalars.
1688
// The `signedness` is used to determine if the operands should be interpreted
1689
// as signed or unsigned. If the operands are signed, the value will be sign
1690
// extended before the value is passed to `op`. Otherwise the values will be
1691
// zero extended.
1692
template <Sign signedness>
1693
BinaryScalarFoldingRule FoldBinaryIntegerOperation(uint64_t (*op)(uint64_t,
1694
92.9k
                                                                  uint64_t)) {
1695
92.9k
  return
1696
92.9k
      [op](const analysis::Type* result_type, const analysis::Constant* a,
1697
92.9k
           const analysis::Constant* b,
1698
878k
           analysis::ConstantManager* const_mgr) -> const analysis::Constant* {
1699
878k
        assert(result_type != nullptr && a != nullptr && b != nullptr);
1700
878k
        const analysis::Integer* integer_type = result_type->AsInteger();
1701
878k
        assert(integer_type != nullptr);
1702
878k
        assert(a->type()->kind() == analysis::Type::kInteger);
1703
878k
        assert(b->type()->kind() == analysis::Type::kInteger);
1704
878k
        assert(integer_type->width() == a->type()->AsInteger()->width());
1705
878k
        assert(integer_type->width() == b->type()->AsInteger()->width());
1706
1707
        // In SPIR-V, all operations support unsigned types, but the way they
1708
        // are interpreted depends on the opcode. This is why we use the
1709
        // template argument to determine how to interpret the operands.
1710
878k
        uint64_t ia = (signedness == Signed ? a->GetSignExtendedValue()
1711
878k
                                            : a->GetZeroExtendedValue());
1712
878k
        uint64_t ib = (signedness == Signed ? b->GetSignExtendedValue()
1713
878k
                                            : b->GetZeroExtendedValue());
1714
878k
        uint64_t result = op(ia, ib);
1715
1716
878k
        const analysis::Constant* result_constant =
1717
878k
            const_mgr->GenerateIntegerConstant(integer_type, result);
1718
878k
        return result_constant;
1719
878k
      };
const_folding_rules.cpp:spvtools::opt::(anonymous namespace)::FoldBinaryIntegerOperation<(spvtools::opt::(anonymous namespace)::Sign)1>(unsigned long (*)(unsigned long, unsigned long))::{lambda(spvtools::opt::analysis::Type const*, spvtools::opt::analysis::Constant const*, spvtools::opt::analysis::Constant const*, spvtools::opt::analysis::ConstantManager*)#1}::operator()(spvtools::opt::analysis::Type const*, spvtools::opt::analysis::Constant const*, spvtools::opt::analysis::Constant const*, spvtools::opt::analysis::ConstantManager*) const
Line
Count
Source
1698
777k
           analysis::ConstantManager* const_mgr) -> const analysis::Constant* {
1699
777k
        assert(result_type != nullptr && a != nullptr && b != nullptr);
1700
777k
        const analysis::Integer* integer_type = result_type->AsInteger();
1701
777k
        assert(integer_type != nullptr);
1702
777k
        assert(a->type()->kind() == analysis::Type::kInteger);
1703
777k
        assert(b->type()->kind() == analysis::Type::kInteger);
1704
777k
        assert(integer_type->width() == a->type()->AsInteger()->width());
1705
777k
        assert(integer_type->width() == b->type()->AsInteger()->width());
1706
1707
        // In SPIR-V, all operations support unsigned types, but the way they
1708
        // are interpreted depends on the opcode. This is why we use the
1709
        // template argument to determine how to interpret the operands.
1710
777k
        uint64_t ia = (signedness == Signed ? a->GetSignExtendedValue()
1711
777k
                                            : a->GetZeroExtendedValue());
1712
777k
        uint64_t ib = (signedness == Signed ? b->GetSignExtendedValue()
1713
777k
                                            : b->GetZeroExtendedValue());
1714
777k
        uint64_t result = op(ia, ib);
1715
1716
777k
        const analysis::Constant* result_constant =
1717
777k
            const_mgr->GenerateIntegerConstant(integer_type, result);
1718
777k
        return result_constant;
1719
777k
      };
const_folding_rules.cpp:spvtools::opt::(anonymous namespace)::FoldBinaryIntegerOperation<(spvtools::opt::(anonymous namespace)::Sign)0>(unsigned long (*)(unsigned long, unsigned long))::{lambda(spvtools::opt::analysis::Type const*, spvtools::opt::analysis::Constant const*, spvtools::opt::analysis::Constant const*, spvtools::opt::analysis::ConstantManager*)#1}::operator()(spvtools::opt::analysis::Type const*, spvtools::opt::analysis::Constant const*, spvtools::opt::analysis::Constant const*, spvtools::opt::analysis::ConstantManager*) const
Line
Count
Source
1698
100k
           analysis::ConstantManager* const_mgr) -> const analysis::Constant* {
1699
100k
        assert(result_type != nullptr && a != nullptr && b != nullptr);
1700
100k
        const analysis::Integer* integer_type = result_type->AsInteger();
1701
100k
        assert(integer_type != nullptr);
1702
100k
        assert(a->type()->kind() == analysis::Type::kInteger);
1703
100k
        assert(b->type()->kind() == analysis::Type::kInteger);
1704
100k
        assert(integer_type->width() == a->type()->AsInteger()->width());
1705
100k
        assert(integer_type->width() == b->type()->AsInteger()->width());
1706
1707
        // In SPIR-V, all operations support unsigned types, but the way they
1708
        // are interpreted depends on the opcode. This is why we use the
1709
        // template argument to determine how to interpret the operands.
1710
100k
        uint64_t ia = (signedness == Signed ? a->GetSignExtendedValue()
1711
100k
                                            : a->GetZeroExtendedValue());
1712
100k
        uint64_t ib = (signedness == Signed ? b->GetSignExtendedValue()
1713
100k
                                            : b->GetZeroExtendedValue());
1714
100k
        uint64_t result = op(ia, ib);
1715
1716
100k
        const analysis::Constant* result_constant =
1717
100k
            const_mgr->GenerateIntegerConstant(integer_type, result);
1718
100k
        return result_constant;
1719
100k
      };
1720
92.9k
}
const_folding_rules.cpp:std::__1::function<spvtools::opt::analysis::Constant const* (spvtools::opt::analysis::Type const*, spvtools::opt::analysis::Constant const*, spvtools::opt::analysis::Constant const*, spvtools::opt::analysis::ConstantManager*)> spvtools::opt::(anonymous namespace)::FoldBinaryIntegerOperation<(spvtools::opt::(anonymous namespace)::Sign)1>(unsigned long (*)(unsigned long, unsigned long))
Line
Count
Source
1694
58.0k
                                                                  uint64_t)) {
1695
58.0k
  return
1696
58.0k
      [op](const analysis::Type* result_type, const analysis::Constant* a,
1697
58.0k
           const analysis::Constant* b,
1698
58.0k
           analysis::ConstantManager* const_mgr) -> const analysis::Constant* {
1699
58.0k
        assert(result_type != nullptr && a != nullptr && b != nullptr);
1700
58.0k
        const analysis::Integer* integer_type = result_type->AsInteger();
1701
58.0k
        assert(integer_type != nullptr);
1702
58.0k
        assert(a->type()->kind() == analysis::Type::kInteger);
1703
58.0k
        assert(b->type()->kind() == analysis::Type::kInteger);
1704
58.0k
        assert(integer_type->width() == a->type()->AsInteger()->width());
1705
58.0k
        assert(integer_type->width() == b->type()->AsInteger()->width());
1706
1707
        // In SPIR-V, all operations support unsigned types, but the way they
1708
        // are interpreted depends on the opcode. This is why we use the
1709
        // template argument to determine how to interpret the operands.
1710
58.0k
        uint64_t ia = (signedness == Signed ? a->GetSignExtendedValue()
1711
58.0k
                                            : a->GetZeroExtendedValue());
1712
58.0k
        uint64_t ib = (signedness == Signed ? b->GetSignExtendedValue()
1713
58.0k
                                            : b->GetZeroExtendedValue());
1714
58.0k
        uint64_t result = op(ia, ib);
1715
1716
58.0k
        const analysis::Constant* result_constant =
1717
58.0k
            const_mgr->GenerateIntegerConstant(integer_type, result);
1718
58.0k
        return result_constant;
1719
58.0k
      };
1720
58.0k
}
const_folding_rules.cpp:std::__1::function<spvtools::opt::analysis::Constant const* (spvtools::opt::analysis::Type const*, spvtools::opt::analysis::Constant const*, spvtools::opt::analysis::Constant const*, spvtools::opt::analysis::ConstantManager*)> spvtools::opt::(anonymous namespace)::FoldBinaryIntegerOperation<(spvtools::opt::(anonymous namespace)::Sign)0>(unsigned long (*)(unsigned long, unsigned long))
Line
Count
Source
1694
34.8k
                                                                  uint64_t)) {
1695
34.8k
  return
1696
34.8k
      [op](const analysis::Type* result_type, const analysis::Constant* a,
1697
34.8k
           const analysis::Constant* b,
1698
34.8k
           analysis::ConstantManager* const_mgr) -> const analysis::Constant* {
1699
34.8k
        assert(result_type != nullptr && a != nullptr && b != nullptr);
1700
34.8k
        const analysis::Integer* integer_type = result_type->AsInteger();
1701
34.8k
        assert(integer_type != nullptr);
1702
34.8k
        assert(a->type()->kind() == analysis::Type::kInteger);
1703
34.8k
        assert(b->type()->kind() == analysis::Type::kInteger);
1704
34.8k
        assert(integer_type->width() == a->type()->AsInteger()->width());
1705
34.8k
        assert(integer_type->width() == b->type()->AsInteger()->width());
1706
1707
        // In SPIR-V, all operations support unsigned types, but the way they
1708
        // are interpreted depends on the opcode. This is why we use the
1709
        // template argument to determine how to interpret the operands.
1710
34.8k
        uint64_t ia = (signedness == Signed ? a->GetSignExtendedValue()
1711
34.8k
                                            : a->GetZeroExtendedValue());
1712
34.8k
        uint64_t ib = (signedness == Signed ? b->GetSignExtendedValue()
1713
34.8k
                                            : b->GetZeroExtendedValue());
1714
34.8k
        uint64_t result = op(ia, ib);
1715
1716
34.8k
        const analysis::Constant* result_constant =
1717
34.8k
            const_mgr->GenerateIntegerConstant(integer_type, result);
1718
34.8k
        return result_constant;
1719
34.8k
      };
1720
34.8k
}
1721
1722
// A scalar folding rule that folds OpSConvert.
1723
const analysis::Constant* FoldScalarSConvert(
1724
    const analysis::Type* result_type, const analysis::Constant* a,
1725
0
    analysis::ConstantManager* const_mgr) {
1726
0
  assert(result_type != nullptr);
1727
0
  assert(a != nullptr);
1728
0
  assert(const_mgr != nullptr);
1729
0
  const analysis::Integer* integer_type = result_type->AsInteger();
1730
0
  assert(integer_type && "The result type of an SConvert");
1731
0
  int64_t value = a->GetSignExtendedValue();
1732
0
  return const_mgr->GenerateIntegerConstant(integer_type, value);
1733
0
}
1734
1735
// A scalar folding rule that folds OpUConvert.
1736
const analysis::Constant* FoldScalarUConvert(
1737
    const analysis::Type* result_type, const analysis::Constant* a,
1738
0
    analysis::ConstantManager* const_mgr) {
1739
0
  assert(result_type != nullptr);
1740
0
  assert(a != nullptr);
1741
0
  assert(const_mgr != nullptr);
1742
0
  const analysis::Integer* integer_type = result_type->AsInteger();
1743
0
  assert(integer_type && "The result type of an UConvert");
1744
0
  uint64_t value = a->GetZeroExtendedValue();
1745
1746
  // If the operand was an unsigned value with less than 32-bit, it would have
1747
  // been sign extended earlier, and we need to clear those bits.
1748
0
  auto* operand_type = a->type()->AsInteger();
1749
0
  value = utils::ClearHighBits(value, 64 - operand_type->width());
1750
0
  return const_mgr->GenerateIntegerConstant(integer_type, value);
1751
0
}
1752
}  // namespace
1753
1754
11.6k
void ConstantFoldingRules::AddFoldingRules() {
1755
  // Add all folding rules to the list for the opcodes to which they apply.
1756
  // Note that the order in which rules are added to the list matters. If a rule
1757
  // applies to the instruction, the rest of the rules will not be attempted.
1758
  // Take that into consideration.
1759
1760
11.6k
  rules_[spv::Op::OpCompositeConstruct].push_back(FoldCompositeWithConstants());
1761
1762
11.6k
  rules_[spv::Op::OpCompositeExtract].push_back(FoldExtractWithConstants());
1763
11.6k
  rules_[spv::Op::OpCompositeInsert].push_back(FoldInsertWithConstants());
1764
1765
11.6k
  rules_[spv::Op::OpConvertFToS].push_back(FoldFToI());
1766
11.6k
  rules_[spv::Op::OpConvertFToU].push_back(FoldFToI());
1767
11.6k
  rules_[spv::Op::OpConvertSToF].push_back(FoldIToF());
1768
11.6k
  rules_[spv::Op::OpConvertUToF].push_back(FoldIToF());
1769
11.6k
  rules_[spv::Op::OpSConvert].push_back(FoldUnaryOp(FoldScalarSConvert));
1770
11.6k
  rules_[spv::Op::OpUConvert].push_back(FoldUnaryOp(FoldScalarUConvert));
1771
1772
11.6k
  rules_[spv::Op::OpDot].push_back(FoldOpDotWithConstants());
1773
11.6k
  rules_[spv::Op::OpFAdd].push_back(FoldFAdd());
1774
11.6k
  rules_[spv::Op::OpFDiv].push_back(FoldFDiv());
1775
11.6k
  rules_[spv::Op::OpFMul].push_back(FoldFMul());
1776
11.6k
  rules_[spv::Op::OpFSub].push_back(FoldFSub());
1777
1778
11.6k
  rules_[spv::Op::OpFOrdEqual].push_back(FoldFOrdEqual());
1779
1780
11.6k
  rules_[spv::Op::OpFUnordEqual].push_back(FoldFUnordEqual());
1781
1782
11.6k
  rules_[spv::Op::OpFOrdNotEqual].push_back(FoldFOrdNotEqual());
1783
1784
11.6k
  rules_[spv::Op::OpFUnordNotEqual].push_back(FoldFUnordNotEqual());
1785
1786
11.6k
  rules_[spv::Op::OpFOrdLessThan].push_back(FoldFOrdLessThan());
1787
11.6k
  rules_[spv::Op::OpFOrdLessThan].push_back(
1788
11.6k
      FoldFClampFeedingCompare(spv::Op::OpFOrdLessThan));
1789
1790
11.6k
  rules_[spv::Op::OpFUnordLessThan].push_back(FoldFUnordLessThan());
1791
11.6k
  rules_[spv::Op::OpFUnordLessThan].push_back(
1792
11.6k
      FoldFClampFeedingCompare(spv::Op::OpFUnordLessThan));
1793
1794
11.6k
  rules_[spv::Op::OpFOrdGreaterThan].push_back(FoldFOrdGreaterThan());
1795
11.6k
  rules_[spv::Op::OpFOrdGreaterThan].push_back(
1796
11.6k
      FoldFClampFeedingCompare(spv::Op::OpFOrdGreaterThan));
1797
1798
11.6k
  rules_[spv::Op::OpFUnordGreaterThan].push_back(FoldFUnordGreaterThan());
1799
11.6k
  rules_[spv::Op::OpFUnordGreaterThan].push_back(
1800
11.6k
      FoldFClampFeedingCompare(spv::Op::OpFUnordGreaterThan));
1801
1802
11.6k
  rules_[spv::Op::OpFOrdLessThanEqual].push_back(FoldFOrdLessThanEqual());
1803
11.6k
  rules_[spv::Op::OpFOrdLessThanEqual].push_back(
1804
11.6k
      FoldFClampFeedingCompare(spv::Op::OpFOrdLessThanEqual));
1805
1806
11.6k
  rules_[spv::Op::OpFUnordLessThanEqual].push_back(FoldFUnordLessThanEqual());
1807
11.6k
  rules_[spv::Op::OpFUnordLessThanEqual].push_back(
1808
11.6k
      FoldFClampFeedingCompare(spv::Op::OpFUnordLessThanEqual));
1809
1810
11.6k
  rules_[spv::Op::OpFOrdGreaterThanEqual].push_back(FoldFOrdGreaterThanEqual());
1811
11.6k
  rules_[spv::Op::OpFOrdGreaterThanEqual].push_back(
1812
11.6k
      FoldFClampFeedingCompare(spv::Op::OpFOrdGreaterThanEqual));
1813
1814
11.6k
  rules_[spv::Op::OpFUnordGreaterThanEqual].push_back(
1815
11.6k
      FoldFUnordGreaterThanEqual());
1816
11.6k
  rules_[spv::Op::OpFUnordGreaterThanEqual].push_back(
1817
11.6k
      FoldFClampFeedingCompare(spv::Op::OpFUnordGreaterThanEqual));
1818
1819
11.6k
  rules_[spv::Op::OpVectorShuffle].push_back(FoldVectorShuffleWithConstants());
1820
11.6k
  rules_[spv::Op::OpVectorTimesScalar].push_back(FoldVectorTimesScalar());
1821
11.6k
  rules_[spv::Op::OpVectorTimesMatrix].push_back(FoldVectorTimesMatrix());
1822
11.6k
  rules_[spv::Op::OpMatrixTimesVector].push_back(FoldMatrixTimesVector());
1823
11.6k
  rules_[spv::Op::OpTranspose].push_back(FoldTranspose);
1824
1825
11.6k
  rules_[spv::Op::OpFNegate].push_back(FoldFNegate());
1826
11.6k
  rules_[spv::Op::OpSNegate].push_back(FoldSNegate());
1827
11.6k
  rules_[spv::Op::OpQuantizeToF16].push_back(FoldQuantizeToF16());
1828
1829
11.6k
  rules_[spv::Op::OpIAdd].push_back(
1830
11.6k
      FoldBinaryOp(FoldBinaryIntegerOperation<Unsigned>(
1831
458k
          [](uint64_t a, uint64_t b) { return a + b; })));
1832
11.6k
  rules_[spv::Op::OpISub].push_back(
1833
11.6k
      FoldBinaryOp(FoldBinaryIntegerOperation<Unsigned>(
1834
226k
          [](uint64_t a, uint64_t b) { return a - b; })));
1835
11.6k
  rules_[spv::Op::OpIMul].push_back(
1836
11.6k
      FoldBinaryOp(FoldBinaryIntegerOperation<Unsigned>(
1837
92.7k
          [](uint64_t a, uint64_t b) { return a * b; })));
1838
11.6k
  rules_[spv::Op::OpUDiv].push_back(
1839
11.6k
      FoldBinaryOp(FoldBinaryIntegerOperation<Unsigned>(
1840
11.6k
          [](uint64_t a, uint64_t b) { return (b != 0 ? a / b : 0); })));
1841
11.6k
  rules_[spv::Op::OpSDiv].push_back(FoldBinaryOp(
1842
37.4k
      FoldBinaryIntegerOperation<Signed>([](uint64_t a, uint64_t b) {
1843
37.4k
        return (b != 0 ? static_cast<uint64_t>(static_cast<int64_t>(a) /
1844
36.5k
                                               static_cast<int64_t>(b))
1845
37.4k
                       : 0);
1846
37.4k
      })));
1847
11.6k
  rules_[spv::Op::OpUMod].push_back(
1848
11.6k
      FoldBinaryOp(FoldBinaryIntegerOperation<Unsigned>(
1849
11.6k
          [](uint64_t a, uint64_t b) { return (b != 0 ? a % b : 0); })));
1850
1851
11.6k
  rules_[spv::Op::OpSRem].push_back(FoldBinaryOp(
1852
61.3k
      FoldBinaryIntegerOperation<Signed>([](uint64_t a, uint64_t b) {
1853
61.3k
        return (b != 0 ? static_cast<uint64_t>(static_cast<int64_t>(a) %
1854
60.7k
                                               static_cast<int64_t>(b))
1855
61.3k
                       : 0);
1856
61.3k
      })));
1857
1858
11.6k
  rules_[spv::Op::OpSMod].push_back(FoldBinaryOp(
1859
11.6k
      FoldBinaryIntegerOperation<Signed>([](uint64_t a, uint64_t b) {
1860
1.82k
        if (b == 0) return static_cast<uint64_t>(0ull);
1861
1862
1.73k
        int64_t signed_a = static_cast<int64_t>(a);
1863
1.73k
        int64_t signed_b = static_cast<int64_t>(b);
1864
1.73k
        int64_t result = signed_a % signed_b;
1865
1.73k
        if ((signed_b < 0) != (result < 0)) result += signed_b;
1866
1.73k
        return static_cast<uint64_t>(result);
1867
1.82k
      })));
1868
1869
  // Add rules for GLSLstd450
1870
11.6k
  FeatureManager* feature_manager = context_->get_feature_mgr();
1871
11.6k
  uint32_t ext_inst_glslstd450_id =
1872
11.6k
      feature_manager->GetExtInstImportId_GLSLstd450();
1873
11.6k
  if (ext_inst_glslstd450_id != 0) {
1874
6.92k
    ext_rules_[{ext_inst_glslstd450_id, GLSLstd450FMix}].push_back(FoldFMix());
1875
6.92k
    ext_rules_[{ext_inst_glslstd450_id, GLSLstd450SMin}].push_back(
1876
6.92k
        FoldFPBinaryOp(FoldMin));
1877
6.92k
    ext_rules_[{ext_inst_glslstd450_id, GLSLstd450UMin}].push_back(
1878
6.92k
        FoldFPBinaryOp(FoldMin));
1879
6.92k
    ext_rules_[{ext_inst_glslstd450_id, GLSLstd450FMin}].push_back(
1880
6.92k
        FoldFPBinaryOp(FoldMin));
1881
6.92k
    ext_rules_[{ext_inst_glslstd450_id, GLSLstd450SMax}].push_back(
1882
6.92k
        FoldFPBinaryOp(FoldMax));
1883
6.92k
    ext_rules_[{ext_inst_glslstd450_id, GLSLstd450UMax}].push_back(
1884
6.92k
        FoldFPBinaryOp(FoldMax));
1885
6.92k
    ext_rules_[{ext_inst_glslstd450_id, GLSLstd450FMax}].push_back(
1886
6.92k
        FoldFPBinaryOp(FoldMax));
1887
6.92k
    ext_rules_[{ext_inst_glslstd450_id, GLSLstd450UClamp}].push_back(
1888
6.92k
        FoldClamp1);
1889
6.92k
    ext_rules_[{ext_inst_glslstd450_id, GLSLstd450UClamp}].push_back(
1890
6.92k
        FoldClamp2);
1891
6.92k
    ext_rules_[{ext_inst_glslstd450_id, GLSLstd450UClamp}].push_back(
1892
6.92k
        FoldClamp3);
1893
6.92k
    ext_rules_[{ext_inst_glslstd450_id, GLSLstd450SClamp}].push_back(
1894
6.92k
        FoldClamp1);
1895
6.92k
    ext_rules_[{ext_inst_glslstd450_id, GLSLstd450SClamp}].push_back(
1896
6.92k
        FoldClamp2);
1897
6.92k
    ext_rules_[{ext_inst_glslstd450_id, GLSLstd450SClamp}].push_back(
1898
6.92k
        FoldClamp3);
1899
6.92k
    ext_rules_[{ext_inst_glslstd450_id, GLSLstd450FClamp}].push_back(
1900
6.92k
        FoldClamp1);
1901
6.92k
    ext_rules_[{ext_inst_glslstd450_id, GLSLstd450FClamp}].push_back(
1902
6.92k
        FoldClamp2);
1903
6.92k
    ext_rules_[{ext_inst_glslstd450_id, GLSLstd450FClamp}].push_back(
1904
6.92k
        FoldClamp3);
1905
6.92k
    ext_rules_[{ext_inst_glslstd450_id, GLSLstd450Sin}].push_back(
1906
6.92k
        FoldFPUnaryOp(FoldFTranscendentalUnary(std::sin)));
1907
6.92k
    ext_rules_[{ext_inst_glslstd450_id, GLSLstd450Cos}].push_back(
1908
6.92k
        FoldFPUnaryOp(FoldFTranscendentalUnary(std::cos)));
1909
6.92k
    ext_rules_[{ext_inst_glslstd450_id, GLSLstd450Tan}].push_back(
1910
6.92k
        FoldFPUnaryOp(FoldFTranscendentalUnary(std::tan)));
1911
6.92k
    ext_rules_[{ext_inst_glslstd450_id, GLSLstd450Asin}].push_back(
1912
6.92k
        FoldFPUnaryOp(FoldFTranscendentalUnary(std::asin)));
1913
6.92k
    ext_rules_[{ext_inst_glslstd450_id, GLSLstd450Acos}].push_back(
1914
6.92k
        FoldFPUnaryOp(FoldFTranscendentalUnary(std::acos)));
1915
6.92k
    ext_rules_[{ext_inst_glslstd450_id, GLSLstd450Atan}].push_back(
1916
6.92k
        FoldFPUnaryOp(FoldFTranscendentalUnary(std::atan)));
1917
6.92k
    ext_rules_[{ext_inst_glslstd450_id, GLSLstd450Exp}].push_back(
1918
6.92k
        FoldFPUnaryOp(FoldFTranscendentalUnary(std::exp)));
1919
6.92k
    ext_rules_[{ext_inst_glslstd450_id, GLSLstd450Log}].push_back(
1920
6.92k
        FoldFPUnaryOp(FoldFTranscendentalUnary(std::log)));
1921
1922
#ifdef __ANDROID__
1923
    // Android NDK r15c targeting ABI 15 doesn't have full support for C++11
1924
    // (no std::exp2/log2). ::exp2 is available from C99 but ::log2 isn't
1925
    // available up until ABI 18 so we use a shim
1926
    auto log2_shim = [](double v) -> double { return log(v) / log(2.0); };
1927
    ext_rules_[{ext_inst_glslstd450_id, GLSLstd450Exp2}].push_back(
1928
        FoldFPUnaryOp(FoldFTranscendentalUnary(::exp2)));
1929
    ext_rules_[{ext_inst_glslstd450_id, GLSLstd450Log2}].push_back(
1930
        FoldFPUnaryOp(FoldFTranscendentalUnary(log2_shim)));
1931
#else
1932
6.92k
    ext_rules_[{ext_inst_glslstd450_id, GLSLstd450Exp2}].push_back(
1933
6.92k
        FoldFPUnaryOp(FoldFTranscendentalUnary(std::exp2)));
1934
6.92k
    ext_rules_[{ext_inst_glslstd450_id, GLSLstd450Log2}].push_back(
1935
6.92k
        FoldFPUnaryOp(FoldFTranscendentalUnary(std::log2)));
1936
6.92k
#endif
1937
1938
6.92k
    ext_rules_[{ext_inst_glslstd450_id, GLSLstd450Sqrt}].push_back(
1939
6.92k
        FoldFPUnaryOp(FoldFTranscendentalUnary(std::sqrt)));
1940
6.92k
    ext_rules_[{ext_inst_glslstd450_id, GLSLstd450Atan2}].push_back(
1941
6.92k
        FoldFPBinaryOp(FoldFTranscendentalBinary(std::atan2)));
1942
6.92k
    ext_rules_[{ext_inst_glslstd450_id, GLSLstd450Pow}].push_back(
1943
6.92k
        FoldFPBinaryOp(FoldFTranscendentalBinary(std::pow)));
1944
6.92k
  }
1945
11.6k
}
1946
}  // namespace opt
1947
}  // namespace spvtools