Coverage Report

Created: 2026-06-08 06:54

next uncovered line (L), next uncovered region (R), next uncovered branch (B)
/src/spirv-tools/source/opt/const_folding_rules.cpp
Line
Count
Source
1
// Copyright (c) 2018 Google LLC
2
//
3
// Licensed under the Apache License, Version 2.0 (the "License");
4
// you may not use this file except in compliance with the License.
5
// You may obtain a copy of the License at
6
//
7
//     http://www.apache.org/licenses/LICENSE-2.0
8
//
9
// Unless required by applicable law or agreed to in writing, software
10
// distributed under the License is distributed on an "AS IS" BASIS,
11
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
// See the License for the specific language governing permissions and
13
// limitations under the License.
14
15
#include "source/opt/const_folding_rules.h"
16
17
#include <optional>
18
19
#include "source/opt/ir_context.h"
20
21
namespace spvtools {
22
namespace opt {
23
namespace {
24
constexpr uint32_t kExtractCompositeIdInIdx = 0;
25
26
// Returns a constants with the value NaN of the given type.  Only works for
27
// 32-bit and 64-bit float point types.  Returns |nullptr| if an error occurs.
28
const analysis::Constant* GetNan(const analysis::Type* type,
29
1.41k
                                 analysis::ConstantManager* const_mgr) {
30
1.41k
  const analysis::Float* float_type = type->AsFloat();
31
1.41k
  if (float_type == nullptr) {
32
0
    return nullptr;
33
0
  }
34
35
1.41k
  switch (float_type->width()) {
36
1.41k
    case 32:
37
1.41k
      return const_mgr->GetFloatConst(std::numeric_limits<float>::quiet_NaN());
38
0
    case 64:
39
0
      return const_mgr->GetDoubleConst(
40
0
          std::numeric_limits<double>::quiet_NaN());
41
0
    default:
42
0
      return nullptr;
43
1.41k
  }
44
1.41k
}
45
46
// Returns a constants with the value INF of the given type.  Only works for
47
// 32-bit and 64-bit float point types.  Returns |nullptr| if an error occurs.
48
const analysis::Constant* GetInf(const analysis::Type* type,
49
1.06k
                                 analysis::ConstantManager* const_mgr) {
50
1.06k
  const analysis::Float* float_type = type->AsFloat();
51
1.06k
  if (float_type == nullptr) {
52
0
    return nullptr;
53
0
  }
54
55
1.06k
  switch (float_type->width()) {
56
1.06k
    case 32:
57
1.06k
      return const_mgr->GetFloatConst(std::numeric_limits<float>::infinity());
58
0
    case 64:
59
0
      return const_mgr->GetDoubleConst(std::numeric_limits<double>::infinity());
60
0
    default:
61
0
      return nullptr;
62
1.06k
  }
63
1.06k
}
64
65
// Returns true if |type| is Float or a vector of Float.
66
522
bool HasFloatingPoint(const analysis::Type* type) {
67
522
  if (type->AsFloat()) {
68
0
    return true;
69
522
  } else if (const analysis::Vector* vec_type = type->AsVector()) {
70
522
    return vec_type->element_type()->AsFloat() != nullptr;
71
522
  }
72
73
0
  return false;
74
522
}
75
76
// Returns a constants with the value |-val| of the given type.  Only works for
77
// 32-bit and 64-bit float point types.  Returns |nullptr| if an error occurs.
78
const analysis::Constant* NegateFPConst(const analysis::Type* result_type,
79
                                        const analysis::Constant* val,
80
911
                                        analysis::ConstantManager* const_mgr) {
81
911
  const analysis::Float* float_type = result_type->AsFloat();
82
911
  assert(float_type != nullptr);
83
911
  if (float_type->width() == 32) {
84
911
    float fa = val->GetFloat();
85
911
    return const_mgr->GetFloatConst(-fa);
86
911
  } else if (float_type->width() == 64) {
87
0
    double da = val->GetDouble();
88
0
    return const_mgr->GetDoubleConst(-da);
89
0
  }
90
0
  return nullptr;
91
911
}
92
93
// Returns a constants with the value |-val| of the given type.
94
const analysis::Constant* NegateIntConst(const analysis::Type* result_type,
95
                                         const analysis::Constant* val,
96
2.81k
                                         analysis::ConstantManager* const_mgr) {
97
2.81k
  const analysis::Integer* int_type = result_type->AsInteger();
98
2.81k
  assert(int_type != nullptr);
99
100
2.81k
  if (val->AsNullConstant()) {
101
7
    return val;
102
7
  }
103
104
2.80k
  uint64_t new_value = static_cast<uint64_t>(-val->GetSignExtendedValue());
105
2.80k
  return const_mgr->GetIntConst(new_value, int_type->width(),
106
2.80k
                                int_type->IsSigned());
107
2.81k
}
108
109
// Folds an OpcompositeExtract where input is a composite constant.
110
16.0k
ConstantFoldingRule FoldExtractWithConstants() {
111
16.0k
  return [](IRContext* context, Instruction* inst,
112
16.0k
            const std::vector<const analysis::Constant*>& constants)
113
472k
             -> const analysis::Constant* {
114
472k
    const analysis::Constant* c = constants[kExtractCompositeIdInIdx];
115
472k
    if (c == nullptr) {
116
427k
      return nullptr;
117
427k
    }
118
119
88.1k
    for (uint32_t i = 1; i < inst->NumInOperands(); ++i) {
120
44.1k
      uint32_t element_index = inst->GetSingleWordInOperand(i);
121
44.1k
      if (c->AsNullConstant()) {
122
        // Return Null for the return type.
123
181
        analysis::ConstantManager* const_mgr = context->get_constant_mgr();
124
181
        analysis::TypeManager* type_mgr = context->get_type_mgr();
125
181
        return const_mgr->GetConstant(type_mgr->GetType(inst->type_id()), {});
126
181
      }
127
128
43.9k
      auto cc = c->AsCompositeConstant();
129
43.9k
      assert(cc != nullptr);
130
43.9k
      auto components = cc->GetComponents();
131
      // Protect against invalid IR.  Refuse to fold if the index is out
132
      // of bounds.
133
43.9k
      if (element_index >= components.size()) return nullptr;
134
43.9k
      c = components[element_index];
135
43.9k
    }
136
43.9k
    return c;
137
44.1k
  };
138
16.0k
}
139
140
// Folds an OpcompositeInsert where input is a composite constant.
141
16.0k
ConstantFoldingRule FoldInsertWithConstants() {
142
16.0k
  return [](IRContext* context, Instruction* inst,
143
16.0k
            const std::vector<const analysis::Constant*>& constants)
144
102k
             -> const analysis::Constant* {
145
102k
    analysis::ConstantManager* const_mgr = context->get_constant_mgr();
146
102k
    const analysis::Constant* object = constants[0];
147
102k
    const analysis::Constant* composite = constants[1];
148
102k
    if (object == nullptr || composite == nullptr) {
149
97.5k
      return nullptr;
150
97.5k
    }
151
152
    // If there is more than 1 index, then each additional constant used by the
153
    // index will need to be recreated to use the inserted object.
154
4.86k
    std::vector<const analysis::Constant*> chain;
155
4.86k
    std::vector<const analysis::Constant*> components;
156
4.86k
    const analysis::Type* type = nullptr;
157
4.86k
    const uint32_t final_index = (inst->NumInOperands() - 1);
158
159
    // Work down hierarchy of all indexes
160
9.74k
    for (uint32_t i = 2; i < inst->NumInOperands(); ++i) {
161
4.87k
      type = composite->type();
162
163
4.87k
      if (composite->AsNullConstant()) {
164
        // Make new composite so it can be inserted in the index with the
165
        // non-null value
166
22
        if (const auto new_composite =
167
22
                const_mgr->GetNullCompositeConstant(type)) {
168
          // Keep track of any indexes along the way to last index
169
22
          if (i != final_index) {
170
0
            chain.push_back(new_composite);
171
0
          }
172
22
          components = new_composite->AsCompositeConstant()->GetComponents();
173
22
        } else {
174
          // Unsupported input type (such as structs)
175
0
          return nullptr;
176
0
        }
177
4.85k
      } else {
178
        // Keep track of any indexes along the way to last index
179
4.85k
        if (i != final_index) {
180
4
          chain.push_back(composite);
181
4
        }
182
4.85k
        components = composite->AsCompositeConstant()->GetComponents();
183
4.85k
      }
184
4.87k
      const uint32_t index = inst->GetSingleWordInOperand(i);
185
4.87k
      composite = components[index];
186
4.87k
    }
187
188
    // Final index in hierarchy is inserted with new object.
189
4.86k
    const uint32_t final_operand = inst->GetSingleWordInOperand(final_index);
190
4.86k
    std::vector<uint32_t> ids;
191
14.7k
    for (size_t i = 0; i < components.size(); i++) {
192
9.90k
      const analysis::Constant* constant =
193
9.90k
          (i == final_operand) ? object : components[i];
194
9.90k
      Instruction* member_inst = const_mgr->GetDefiningInstruction(constant);
195
9.90k
      ids.push_back(member_inst->result_id());
196
9.90k
    }
197
4.86k
    const analysis::Constant* new_constant = const_mgr->GetConstant(type, ids);
198
199
    // Work backwards up the chain and replace each index with new constant.
200
4.87k
    for (size_t i = chain.size(); i > 0; i--) {
201
      // Need to insert any previous instruction into the module first.
202
      // Can't just insert in types_values_begin() because it will move above
203
      // where the types are declared.
204
      // Can't compare with location of inst because not all new added
205
      // instructions are added to types_values_
206
4
      auto iter = context->types_values_end();
207
4
      Module::inst_iterator* pos = &iter;
208
4
      const_mgr->BuildInstructionAndAddToModule(new_constant, pos);
209
210
4
      composite = chain[i - 1];
211
4
      components = composite->AsCompositeConstant()->GetComponents();
212
4
      type = composite->type();
213
4
      ids.clear();
214
8
      for (size_t k = 0; k < components.size(); k++) {
215
4
        const uint32_t index =
216
4
            inst->GetSingleWordInOperand(1 + static_cast<uint32_t>(i));
217
4
        const analysis::Constant* constant =
218
4
            (k == index) ? new_constant : components[k];
219
4
        const uint32_t constant_id =
220
4
            const_mgr->FindDeclaredConstant(constant, 0);
221
4
        ids.push_back(constant_id);
222
4
      }
223
4
      new_constant = const_mgr->GetConstant(type, ids);
224
4
    }
225
226
    // If multiple constants were created, only need to return the top index.
227
4.86k
    return new_constant;
228
4.86k
  };
229
16.0k
}
230
231
16.0k
ConstantFoldingRule FoldVectorShuffleWithConstants() {
232
16.0k
  return [](IRContext* context, Instruction* inst,
233
16.0k
            const std::vector<const analysis::Constant*>& constants)
234
34.9k
             -> const analysis::Constant* {
235
34.9k
    assert(inst->opcode() == spv::Op::OpVectorShuffle);
236
34.9k
    const analysis::Constant* c1 = constants[0];
237
34.9k
    const analysis::Constant* c2 = constants[1];
238
34.9k
    if (c1 == nullptr || c2 == nullptr) {
239
34.6k
      return nullptr;
240
34.6k
    }
241
242
293
    analysis::ConstantManager* const_mgr = context->get_constant_mgr();
243
293
    const analysis::Type* element_type = c1->type()->AsVector()->element_type();
244
245
293
    std::vector<const analysis::Constant*> c1_components;
246
293
    if (const analysis::VectorConstant* vec_const = c1->AsVectorConstant()) {
247
253
      c1_components = vec_const->GetComponents();
248
253
    } else {
249
40
      assert(c1->AsNullConstant());
250
40
      const analysis::Constant* element =
251
40
          const_mgr->GetConstant(element_type, {});
252
40
      c1_components.resize(c1->type()->AsVector()->element_count(), element);
253
40
    }
254
293
    std::vector<const analysis::Constant*> c2_components;
255
293
    if (const analysis::VectorConstant* vec_const = c2->AsVectorConstant()) {
256
217
      c2_components = vec_const->GetComponents();
257
217
    } else {
258
76
      assert(c2->AsNullConstant());
259
76
      const analysis::Constant* element =
260
76
          const_mgr->GetConstant(element_type, {});
261
76
      c2_components.resize(c2->type()->AsVector()->element_count(), element);
262
76
    }
263
264
293
    std::vector<uint32_t> ids;
265
293
    const uint32_t undef_literal_value = 0xffffffff;
266
656
    for (uint32_t i = 2; i < inst->NumInOperands(); ++i) {
267
508
      uint32_t index = inst->GetSingleWordInOperand(i);
268
508
      if (index == undef_literal_value) {
269
        // Don't fold shuffle with undef literal value.
270
145
        return nullptr;
271
363
      } else if (index < c1_components.size()) {
272
274
        Instruction* member_inst =
273
274
            const_mgr->GetDefiningInstruction(c1_components[index]);
274
274
        ids.push_back(member_inst->result_id());
275
274
      } else {
276
89
        Instruction* member_inst = const_mgr->GetDefiningInstruction(
277
89
            c2_components[index - c1_components.size()]);
278
89
        ids.push_back(member_inst->result_id());
279
89
      }
280
508
    }
281
282
148
    analysis::TypeManager* type_mgr = context->get_type_mgr();
283
148
    return const_mgr->GetConstant(type_mgr->GetType(inst->type_id()), ids);
284
293
  };
285
16.0k
}
286
287
16.0k
ConstantFoldingRule FoldVectorTimesScalar() {
288
16.0k
  return [](IRContext* context, Instruction* inst,
289
16.0k
            const std::vector<const analysis::Constant*>& constants)
290
398k
             -> const analysis::Constant* {
291
398k
    assert(inst->opcode() == spv::Op::OpVectorTimesScalar);
292
398k
    analysis::ConstantManager* const_mgr = context->get_constant_mgr();
293
398k
    analysis::TypeManager* type_mgr = context->get_type_mgr();
294
295
398k
    if (!inst->IsFloatingPointFoldingAllowed()) {
296
522
      if (HasFloatingPoint(type_mgr->GetType(inst->type_id()))) {
297
522
        return nullptr;
298
522
      }
299
522
    }
300
301
397k
    const analysis::Constant* c1 = constants[0];
302
397k
    const analysis::Constant* c2 = constants[1];
303
304
397k
    if (c1 && c1->IsZero()) {
305
1.76k
      return c1;
306
1.76k
    }
307
308
395k
    if (c2 && c2->IsZero()) {
309
      // Get or create the NullConstant for this type.
310
4.64k
      std::vector<uint32_t> ids;
311
4.64k
      return const_mgr->GetConstant(type_mgr->GetType(inst->type_id()), ids);
312
4.64k
    }
313
314
391k
    if (c1 == nullptr || c2 == nullptr) {
315
345k
      return nullptr;
316
345k
    }
317
318
    // Check result type.
319
45.9k
    const analysis::Type* result_type = type_mgr->GetType(inst->type_id());
320
45.9k
    const analysis::Vector* vector_type = result_type->AsVector();
321
45.9k
    assert(vector_type != nullptr);
322
45.9k
    const analysis::Type* element_type = vector_type->element_type();
323
45.9k
    assert(element_type != nullptr);
324
45.9k
    const analysis::Float* float_type = element_type->AsFloat();
325
45.9k
    assert(float_type != nullptr);
326
327
    // Check types of c1 and c2.
328
45.9k
    assert(c1->type()->AsVector() == vector_type);
329
45.9k
    assert(c1->type()->AsVector()->element_type() == element_type &&
330
45.9k
           c2->type() == element_type);
331
332
    // Get a float vector that is the result of vector-times-scalar.
333
45.9k
    std::vector<const analysis::Constant*> c1_components =
334
45.9k
        c1->GetVectorComponents(const_mgr);
335
45.9k
    std::vector<uint32_t> ids;
336
45.9k
    if (float_type->width() == 32) {
337
45.9k
      float scalar = c2->GetFloat();
338
227k
      for (uint32_t i = 0; i < c1_components.size(); ++i) {
339
181k
        utils::FloatProxy<float> result(c1_components[i]->GetFloat() * scalar);
340
181k
        std::vector<uint32_t> words = result.GetWords();
341
181k
        const analysis::Constant* new_elem =
342
181k
            const_mgr->GetConstant(float_type, words);
343
181k
        ids.push_back(const_mgr->GetDefiningInstruction(new_elem)->result_id());
344
181k
      }
345
45.9k
      return const_mgr->GetConstant(vector_type, ids);
346
45.9k
    } else if (float_type->width() == 64) {
347
0
      double scalar = c2->GetDouble();
348
0
      for (uint32_t i = 0; i < c1_components.size(); ++i) {
349
0
        utils::FloatProxy<double> result(c1_components[i]->GetDouble() *
350
0
                                         scalar);
351
0
        std::vector<uint32_t> words = result.GetWords();
352
0
        const analysis::Constant* new_elem =
353
0
            const_mgr->GetConstant(float_type, words);
354
0
        ids.push_back(const_mgr->GetDefiningInstruction(new_elem)->result_id());
355
0
      }
356
0
      return const_mgr->GetConstant(vector_type, ids);
357
0
    }
358
0
    return nullptr;
359
45.9k
  };
360
16.0k
}
361
362
// Returns to the constant that results from tranposing |matrix|. The result
363
// will have type |result_type|, and |matrix| must exist in |context|. The
364
// result constant will also exist in |context|.
365
const analysis::Constant* TransposeMatrix(const analysis::Constant* matrix,
366
                                          analysis::Matrix* result_type,
367
0
                                          IRContext* context) {
368
0
  analysis::ConstantManager* const_mgr = context->get_constant_mgr();
369
0
  if (matrix->AsNullConstant() != nullptr) {
370
0
    return const_mgr->GetNullCompositeConstant(result_type);
371
0
  }
372
373
0
  const auto& columns = matrix->AsMatrixConstant()->GetComponents();
374
0
  uint32_t number_of_rows = columns[0]->type()->AsVector()->element_count();
375
376
  // Collect the ids of the elements in their new positions.
377
0
  std::vector<std::vector<uint32_t>> result_elements(number_of_rows);
378
0
  for (const analysis::Constant* column : columns) {
379
0
    if (column->AsNullConstant()) {
380
0
      column = const_mgr->GetNullCompositeConstant(column->type());
381
0
    }
382
0
    const auto& column_components = column->AsVectorConstant()->GetComponents();
383
384
0
    for (uint32_t row = 0; row < number_of_rows; ++row) {
385
0
      result_elements[row].push_back(
386
0
          const_mgr->GetDefiningInstruction(column_components[row])
387
0
              ->result_id());
388
0
    }
389
0
  }
390
391
  // Create the constant for each row in the result, and collect the ids.
392
0
  std::vector<uint32_t> result_columns(number_of_rows);
393
0
  for (uint32_t col = 0; col < number_of_rows; ++col) {
394
0
    auto* element = const_mgr->GetConstant(result_type->element_type(),
395
0
                                           result_elements[col]);
396
0
    result_columns[col] =
397
0
        const_mgr->GetDefiningInstruction(element)->result_id();
398
0
  }
399
400
  // Create the matrix constant from the row ids, and return it.
401
0
  return const_mgr->GetConstant(result_type, result_columns);
402
0
}
403
404
const analysis::Constant* FoldTranspose(
405
    IRContext* context, Instruction* inst,
406
0
    const std::vector<const analysis::Constant*>& constants) {
407
0
  assert(inst->opcode() == spv::Op::OpTranspose);
408
409
0
  analysis::TypeManager* type_mgr = context->get_type_mgr();
410
0
  if (!inst->IsFloatingPointFoldingAllowed()) {
411
0
    if (HasFloatingPoint(type_mgr->GetType(inst->type_id()))) {
412
0
      return nullptr;
413
0
    }
414
0
  }
415
416
0
  const analysis::Constant* matrix = constants[0];
417
0
  if (matrix == nullptr) {
418
0
    return nullptr;
419
0
  }
420
421
0
  auto* result_type = type_mgr->GetType(inst->type_id());
422
0
  return TransposeMatrix(matrix, result_type->AsMatrix(), context);
423
0
}
424
425
16.0k
ConstantFoldingRule FoldVectorTimesMatrix() {
426
16.0k
  return [](IRContext* context, Instruction* inst,
427
16.0k
            const std::vector<const analysis::Constant*>& constants)
428
16.0k
             -> const analysis::Constant* {
429
6.24k
    assert(inst->opcode() == spv::Op::OpVectorTimesMatrix);
430
6.24k
    analysis::ConstantManager* const_mgr = context->get_constant_mgr();
431
6.24k
    analysis::TypeManager* type_mgr = context->get_type_mgr();
432
433
6.24k
    if (!inst->IsFloatingPointFoldingAllowed()) {
434
0
      if (HasFloatingPoint(type_mgr->GetType(inst->type_id()))) {
435
0
        return nullptr;
436
0
      }
437
0
    }
438
439
6.24k
    const analysis::Constant* c1 = constants[0];
440
6.24k
    const analysis::Constant* c2 = constants[1];
441
442
6.24k
    if (c1 == nullptr || c2 == nullptr) {
443
6.18k
      return nullptr;
444
6.18k
    }
445
446
    // Check result type.
447
61
    const analysis::Type* result_type = type_mgr->GetType(inst->type_id());
448
61
    const analysis::Vector* vector_type = result_type->AsVector();
449
61
    assert(vector_type != nullptr);
450
61
    const analysis::Type* element_type = vector_type->element_type();
451
61
    assert(element_type != nullptr);
452
61
    const analysis::Float* float_type = element_type->AsFloat();
453
61
    assert(float_type != nullptr);
454
455
    // Check types of c1 and c2.
456
61
    assert(c1->type()->AsVector() == vector_type);
457
61
    assert(c1->type()->AsVector()->element_type() == element_type &&
458
61
           c2->type()->AsMatrix()->element_type() == vector_type);
459
460
61
    uint32_t resultVectorSize = result_type->AsVector()->element_count();
461
61
    std::vector<uint32_t> ids;
462
463
61
    if ((c1 && c1->IsZero()) || (c2 && c2->IsZero())) {
464
43
      std::vector<uint32_t> words(float_type->width() / 32, 0);
465
129
      for (uint32_t i = 0; i < resultVectorSize; ++i) {
466
86
        const analysis::Constant* new_elem =
467
86
            const_mgr->GetConstant(float_type, words);
468
86
        ids.push_back(const_mgr->GetDefiningInstruction(new_elem)->result_id());
469
86
      }
470
43
      return const_mgr->GetConstant(vector_type, ids);
471
43
    }
472
473
    // Get a float vector that is the result of vector-times-matrix.
474
18
    std::vector<const analysis::Constant*> c1_components =
475
18
        c1->GetVectorComponents(const_mgr);
476
18
    std::vector<const analysis::Constant*> c2_components =
477
18
        c2->AsMatrixConstant()->GetComponents();
478
479
18
    if (float_type->width() == 32) {
480
54
      for (uint32_t i = 0; i < resultVectorSize; ++i) {
481
36
        float result_scalar = 0.0f;
482
36
        if (!c2_components[i]->AsNullConstant()) {
483
28
          const analysis::VectorConstant* c2_vec =
484
28
              c2_components[i]->AsVectorConstant();
485
84
          for (uint32_t j = 0; j < c2_vec->GetComponents().size(); ++j) {
486
56
            float c1_scalar = c1_components[j]->GetFloat();
487
56
            float c2_scalar = c2_vec->GetComponents()[j]->GetFloat();
488
56
            result_scalar += c1_scalar * c2_scalar;
489
56
          }
490
28
        }
491
36
        utils::FloatProxy<float> result(result_scalar);
492
36
        std::vector<uint32_t> words = result.GetWords();
493
36
        const analysis::Constant* new_elem =
494
36
            const_mgr->GetConstant(float_type, words);
495
36
        ids.push_back(const_mgr->GetDefiningInstruction(new_elem)->result_id());
496
36
      }
497
18
      return const_mgr->GetConstant(vector_type, ids);
498
18
    } else if (float_type->width() == 64) {
499
0
      for (uint32_t i = 0; i < c2_components.size(); ++i) {
500
0
        double result_scalar = 0.0;
501
0
        if (!c2_components[i]->AsNullConstant()) {
502
0
          const analysis::VectorConstant* c2_vec =
503
0
              c2_components[i]->AsVectorConstant();
504
0
          for (uint32_t j = 0; j < c2_vec->GetComponents().size(); ++j) {
505
0
            double c1_scalar = c1_components[j]->GetDouble();
506
0
            double c2_scalar = c2_vec->GetComponents()[j]->GetDouble();
507
0
            result_scalar += c1_scalar * c2_scalar;
508
0
          }
509
0
        }
510
0
        utils::FloatProxy<double> result(result_scalar);
511
0
        std::vector<uint32_t> words = result.GetWords();
512
0
        const analysis::Constant* new_elem =
513
0
            const_mgr->GetConstant(float_type, words);
514
0
        ids.push_back(const_mgr->GetDefiningInstruction(new_elem)->result_id());
515
0
      }
516
0
      return const_mgr->GetConstant(vector_type, ids);
517
0
    }
518
0
    return nullptr;
519
18
  };
520
16.0k
}
521
522
16.0k
ConstantFoldingRule FoldMatrixTimesVector() {
523
16.0k
  return [](IRContext* context, Instruction* inst,
524
16.0k
            const std::vector<const analysis::Constant*>& constants)
525
16.0k
             -> const analysis::Constant* {
526
185
    assert(inst->opcode() == spv::Op::OpMatrixTimesVector);
527
185
    analysis::ConstantManager* const_mgr = context->get_constant_mgr();
528
185
    analysis::TypeManager* type_mgr = context->get_type_mgr();
529
530
185
    if (!inst->IsFloatingPointFoldingAllowed()) {
531
0
      if (HasFloatingPoint(type_mgr->GetType(inst->type_id()))) {
532
0
        return nullptr;
533
0
      }
534
0
    }
535
536
185
    const analysis::Constant* c1 = constants[0];
537
185
    const analysis::Constant* c2 = constants[1];
538
539
185
    if (c1 == nullptr || c2 == nullptr) {
540
183
      return nullptr;
541
183
    }
542
543
    // Check result type.
544
2
    const analysis::Type* result_type = type_mgr->GetType(inst->type_id());
545
2
    const analysis::Vector* vector_type = result_type->AsVector();
546
2
    assert(vector_type != nullptr);
547
2
    const analysis::Type* element_type = vector_type->element_type();
548
2
    assert(element_type != nullptr);
549
2
    const analysis::Float* float_type = element_type->AsFloat();
550
2
    assert(float_type != nullptr);
551
552
    // Check types of c1 and c2.
553
2
    assert(c1->type()->AsMatrix()->element_type() == vector_type);
554
2
    assert(c2->type()->AsVector()->element_type() == element_type);
555
556
2
    uint32_t resultVectorSize = result_type->AsVector()->element_count();
557
2
    std::vector<uint32_t> ids;
558
559
2
    if ((c1 && c1->IsZero()) || (c2 && c2->IsZero())) {
560
2
      std::vector<uint32_t> words(float_type->width() / 32, 0);
561
6
      for (uint32_t i = 0; i < resultVectorSize; ++i) {
562
4
        const analysis::Constant* new_elem =
563
4
            const_mgr->GetConstant(float_type, words);
564
4
        ids.push_back(const_mgr->GetDefiningInstruction(new_elem)->result_id());
565
4
      }
566
2
      return const_mgr->GetConstant(vector_type, ids);
567
2
    }
568
569
    // Get a float vector that is the result of matrix-times-vector.
570
0
    std::vector<const analysis::Constant*> c1_components =
571
0
        c1->AsMatrixConstant()->GetComponents();
572
0
    std::vector<const analysis::Constant*> c2_components =
573
0
        c2->GetVectorComponents(const_mgr);
574
575
0
    if (float_type->width() == 32) {
576
0
      for (uint32_t i = 0; i < resultVectorSize; ++i) {
577
0
        float result_scalar = 0.0f;
578
0
        for (uint32_t j = 0; j < c1_components.size(); ++j) {
579
0
          if (!c1_components[j]->AsNullConstant()) {
580
0
            float c1_scalar = c1_components[j]
581
0
                                  ->AsVectorConstant()
582
0
                                  ->GetComponents()[i]
583
0
                                  ->GetFloat();
584
0
            float c2_scalar = c2_components[j]->GetFloat();
585
0
            result_scalar += c1_scalar * c2_scalar;
586
0
          }
587
0
        }
588
0
        utils::FloatProxy<float> result(result_scalar);
589
0
        std::vector<uint32_t> words = result.GetWords();
590
0
        const analysis::Constant* new_elem =
591
0
            const_mgr->GetConstant(float_type, words);
592
0
        ids.push_back(const_mgr->GetDefiningInstruction(new_elem)->result_id());
593
0
      }
594
0
      return const_mgr->GetConstant(vector_type, ids);
595
0
    } else if (float_type->width() == 64) {
596
0
      for (uint32_t i = 0; i < resultVectorSize; ++i) {
597
0
        double result_scalar = 0.0;
598
0
        for (uint32_t j = 0; j < c1_components.size(); ++j) {
599
0
          if (!c1_components[j]->AsNullConstant()) {
600
0
            double c1_scalar = c1_components[j]
601
0
                                   ->AsVectorConstant()
602
0
                                   ->GetComponents()[i]
603
0
                                   ->GetDouble();
604
0
            double c2_scalar = c2_components[j]->GetDouble();
605
0
            result_scalar += c1_scalar * c2_scalar;
606
0
          }
607
0
        }
608
0
        utils::FloatProxy<double> result(result_scalar);
609
0
        std::vector<uint32_t> words = result.GetWords();
610
0
        const analysis::Constant* new_elem =
611
0
            const_mgr->GetConstant(float_type, words);
612
0
        ids.push_back(const_mgr->GetDefiningInstruction(new_elem)->result_id());
613
0
      }
614
0
      return const_mgr->GetConstant(vector_type, ids);
615
0
    }
616
0
    return nullptr;
617
0
  };
618
16.0k
}
619
620
16.0k
ConstantFoldingRule FoldCompositeWithConstants() {
621
  // Folds an OpCompositeConstruct where all of the inputs are constants to a
622
  // constant.  A new constant is created if necessary.
623
16.0k
  return [](IRContext* context, Instruction* inst,
624
16.0k
            const std::vector<const analysis::Constant*>& constants)
625
282k
             -> const analysis::Constant* {
626
282k
    analysis::ConstantManager* const_mgr = context->get_constant_mgr();
627
282k
    analysis::TypeManager* type_mgr = context->get_type_mgr();
628
282k
    const analysis::Type* new_type = type_mgr->GetType(inst->type_id());
629
282k
    Instruction* type_inst =
630
282k
        context->get_def_use_mgr()->GetDef(inst->type_id());
631
632
282k
    std::vector<uint32_t> ids;
633
655k
    for (uint32_t i = 0; i < constants.size(); ++i) {
634
497k
      const analysis::Constant* element_const = constants[i];
635
497k
      if (element_const == nullptr) {
636
124k
        return nullptr;
637
124k
      }
638
639
373k
      uint32_t component_type_id = 0;
640
373k
      if (type_inst->opcode() == spv::Op::OpTypeStruct) {
641
23.6k
        component_type_id = type_inst->GetSingleWordInOperand(i);
642
349k
      } else if (type_inst->opcode() == spv::Op::OpTypeArray) {
643
552
        component_type_id = type_inst->GetSingleWordInOperand(0);
644
552
      }
645
646
373k
      uint32_t element_id =
647
373k
          const_mgr->FindDeclaredConstant(element_const, component_type_id);
648
373k
      if (element_id == 0) {
649
178
        return nullptr;
650
178
      }
651
373k
      ids.push_back(element_id);
652
373k
    }
653
157k
    return const_mgr->GetConstant(new_type, ids);
654
282k
  };
655
16.0k
}
656
657
// The interface for a function that returns the result of applying a scalar
658
// floating-point binary operation on |a| and |b|.  The type of the return value
659
// will be |type|.  The input constants must also be of type |type|.
660
using UnaryScalarFoldingRule = std::function<const analysis::Constant*(
661
    const analysis::Type* result_type, const analysis::Constant* a,
662
    analysis::ConstantManager*)>;
663
664
// The interface for a function that returns the result of applying a scalar
665
// floating-point binary operation on |a| and |b|.  The type of the return value
666
// will be |type|.  The input constants must also be of type |type|.
667
using BinaryScalarFoldingRule = std::function<const analysis::Constant*(
668
    const analysis::Type* result_type, const analysis::Constant* a,
669
    const analysis::Constant* b, analysis::ConstantManager*)>;
670
671
// Returns a |ConstantFoldingRule| that folds unary scalar ops
672
// using |scalar_rule| and unary vectors ops by applying
673
// |scalar_rule| to the elements of the vector.  The |ConstantFoldingRule|
674
// that is returned assumes that |constants| contains 1 entry.  If they are
675
// not |nullptr|, then their type is either |Float| or |Integer| or a |Vector|
676
// whose element type is |Float| or |Integer|.
677
250k
ConstantFoldingRule FoldUnaryOp(UnaryScalarFoldingRule scalar_rule) {
678
250k
  return [scalar_rule](IRContext* context, Instruction* inst,
679
250k
                       const std::vector<const analysis::Constant*>& constants)
680
455k
             -> const analysis::Constant* {
681
455k
    analysis::ConstantManager* const_mgr = context->get_constant_mgr();
682
455k
    analysis::TypeManager* type_mgr = context->get_type_mgr();
683
455k
    const analysis::Type* result_type = type_mgr->GetType(inst->type_id());
684
455k
    const analysis::Vector* vector_type = result_type->AsVector();
685
686
455k
    const analysis::Constant* arg =
687
455k
        (inst->opcode() == spv::Op::OpExtInst) ? constants[1] : constants[0];
688
689
455k
    if (arg == nullptr) {
690
119k
      return nullptr;
691
119k
    }
692
693
335k
    if (vector_type != nullptr) {
694
1.50k
      std::vector<const analysis::Constant*> a_components;
695
1.50k
      std::vector<const analysis::Constant*> results_components;
696
697
1.50k
      a_components = arg->GetVectorComponents(const_mgr);
698
699
      // Fold each component of the vector.
700
6.78k
      for (uint32_t i = 0; i < a_components.size(); ++i) {
701
5.28k
        results_components.push_back(scalar_rule(vector_type->element_type(),
702
5.28k
                                                 a_components[i], const_mgr));
703
5.28k
        if (results_components[i] == nullptr) {
704
0
          return nullptr;
705
0
        }
706
5.28k
      }
707
708
      // Build the constant object and return it.
709
1.50k
      std::vector<uint32_t> ids;
710
5.28k
      for (const analysis::Constant* member : results_components) {
711
5.28k
        ids.push_back(const_mgr->GetDefiningInstruction(member)->result_id());
712
5.28k
      }
713
1.50k
      return const_mgr->GetConstant(vector_type, ids);
714
334k
    } else {
715
334k
      return scalar_rule(result_type, arg, const_mgr);
716
334k
    }
717
335k
  };
718
250k
}
719
720
// Returns a |ConstantFoldingRule| that folds binary scalar ops
721
// using |scalar_rule| and binary vectors ops by applying
722
// |scalar_rule| to the elements of the vector. The folding rule assumes that op
723
// has two inputs. For regular instruction, those are in operands 0 and 1. For
724
// extended instruction, they are in operands 1 and 2. If an element in
725
// |constants| is not nullprt, then the constant's type is |Float|, |Integer|,
726
// or |Vector| whose element type is |Float| or |Integer|.
727
128k
ConstantFoldingRule FoldBinaryOp(BinaryScalarFoldingRule scalar_rule) {
728
128k
  return [scalar_rule](IRContext* context, Instruction* inst,
729
128k
                       const std::vector<const analysis::Constant*>& constants)
730
1.08M
             -> const analysis::Constant* {
731
1.08M
    assert(constants.size() == inst->NumInOperands());
732
1.08M
    assert(constants.size() == (inst->opcode() == spv::Op::OpExtInst ? 3 : 2));
733
1.08M
    analysis::ConstantManager* const_mgr = context->get_constant_mgr();
734
1.08M
    analysis::TypeManager* type_mgr = context->get_type_mgr();
735
1.08M
    const analysis::Type* result_type = type_mgr->GetType(inst->type_id());
736
1.08M
    const analysis::Vector* vector_type = result_type->AsVector();
737
738
1.08M
    const analysis::Constant* arg1 =
739
1.08M
        (inst->opcode() == spv::Op::OpExtInst) ? constants[1] : constants[0];
740
1.08M
    const analysis::Constant* arg2 =
741
1.08M
        (inst->opcode() == spv::Op::OpExtInst) ? constants[2] : constants[1];
742
743
1.08M
    if (arg1 == nullptr || arg2 == nullptr) {
744
222k
      return nullptr;
745
222k
    }
746
747
857k
    if (vector_type == nullptr) {
748
857k
      return scalar_rule(result_type, arg1, arg2, const_mgr);
749
857k
    }
750
751
811
    std::vector<const analysis::Constant*> a_components;
752
811
    std::vector<const analysis::Constant*> b_components;
753
811
    std::vector<const analysis::Constant*> results_components;
754
755
811
    a_components = arg1->GetVectorComponents(const_mgr);
756
811
    b_components = arg2->GetVectorComponents(const_mgr);
757
811
    assert(a_components.size() == b_components.size());
758
759
    // Fold each component of the vector.
760
4.03k
    for (uint32_t i = 0; i < a_components.size(); ++i) {
761
3.22k
      results_components.push_back(scalar_rule(vector_type->element_type(),
762
3.22k
                                               a_components[i], b_components[i],
763
3.22k
                                               const_mgr));
764
3.22k
      if (results_components[i] == nullptr) {
765
0
        return nullptr;
766
0
      }
767
3.22k
    }
768
769
    // Build the constant object and return it.
770
811
    std::vector<uint32_t> ids;
771
3.22k
    for (const analysis::Constant* member : results_components) {
772
3.22k
      ids.push_back(const_mgr->GetDefiningInstruction(member)->result_id());
773
3.22k
    }
774
811
    return const_mgr->GetConstant(vector_type, ids);
775
811
  };
776
128k
}
777
778
// Returns a |ConstantFoldingRule| that folds unary floating point scalar ops
779
// using |scalar_rule| and unary float point vectors ops by applying
780
// |scalar_rule| to the elements of the vector.  The |ConstantFoldingRule|
781
// that is returned assumes that |constants| contains 1 entry.  If they are
782
// not |nullptr|, then their type is either |Float| or |Integer| or a |Vector|
783
// whose element type is |Float| or |Integer|.
784
202k
ConstantFoldingRule FoldFPUnaryOp(UnaryScalarFoldingRule scalar_rule) {
785
202k
  auto folding_rule = FoldUnaryOp(scalar_rule);
786
202k
  return [folding_rule](IRContext* context, Instruction* inst,
787
202k
                        const std::vector<const analysis::Constant*>& constants)
788
448k
             -> const analysis::Constant* {
789
448k
    if (!inst->IsFloatingPointFoldingAllowed()) {
790
72
      return nullptr;
791
72
    }
792
793
447k
    return folding_rule(context, inst, constants);
794
448k
  };
795
202k
}
796
797
// Returns the result of folding the constants in |constants| according the
798
// |scalar_rule|.  If |result_type| is a vector, then |scalar_rule| is applied
799
// per component.
800
const analysis::Constant* FoldFPBinaryOp(
801
    BinaryScalarFoldingRule scalar_rule, uint32_t result_type_id,
802
    const std::vector<const analysis::Constant*>& constants,
803
1.47M
    IRContext* context) {
804
1.47M
  analysis::ConstantManager* const_mgr = context->get_constant_mgr();
805
1.47M
  analysis::TypeManager* type_mgr = context->get_type_mgr();
806
1.47M
  const analysis::Type* result_type = type_mgr->GetType(result_type_id);
807
1.47M
  const analysis::Vector* vector_type = result_type->AsVector();
808
809
1.47M
  if (constants[0] == nullptr || constants[1] == nullptr) {
810
1.14M
    return nullptr;
811
1.14M
  }
812
813
338k
  if (vector_type != nullptr) {
814
60.4k
    std::vector<const analysis::Constant*> a_components;
815
60.4k
    std::vector<const analysis::Constant*> b_components;
816
60.4k
    std::vector<const analysis::Constant*> results_components;
817
818
60.4k
    a_components = constants[0]->GetVectorComponents(const_mgr);
819
60.4k
    b_components = constants[1]->GetVectorComponents(const_mgr);
820
821
    // Fold each component of the vector.
822
280k
    for (uint32_t i = 0; i < a_components.size(); ++i) {
823
220k
      results_components.push_back(scalar_rule(vector_type->element_type(),
824
220k
                                               a_components[i], b_components[i],
825
220k
                                               const_mgr));
826
220k
      if (results_components[i] == nullptr) {
827
0
        return nullptr;
828
0
      }
829
220k
    }
830
831
    // Build the constant object and return it.
832
60.4k
    std::vector<uint32_t> ids;
833
220k
    for (const analysis::Constant* member : results_components) {
834
220k
      Instruction* def = const_mgr->GetDefiningInstruction(member);
835
220k
      if (!def) return nullptr;
836
220k
      ids.push_back(def->result_id());
837
220k
    }
838
60.4k
    return const_mgr->GetConstant(vector_type, ids);
839
277k
  } else {
840
277k
    return scalar_rule(result_type, constants[0], constants[1], const_mgr);
841
277k
  }
842
338k
}
843
844
// Returns a |ConstantFoldingRule| that folds floating point scalars using
845
// |scalar_rule| and vectors of floating point by applying |scalar_rule| to the
846
// elements of the vector.  The |ConstantFoldingRule| that is returned assumes
847
// that |constants| contains 2 entries.  If they are not |nullptr|, then their
848
// type is either |Float| or a |Vector| whose element type is |Float|.
849
353k
ConstantFoldingRule FoldFPBinaryOp(BinaryScalarFoldingRule scalar_rule) {
850
353k
  return [scalar_rule](IRContext* context, Instruction* inst,
851
353k
                       const std::vector<const analysis::Constant*>& constants)
852
1.47M
             -> const analysis::Constant* {
853
1.47M
    if (!inst->IsFloatingPointFoldingAllowed()) {
854
4.58k
      return nullptr;
855
4.58k
    }
856
1.47M
    if (inst->opcode() == spv::Op::OpExtInst) {
857
16.0k
      return FoldFPBinaryOp(scalar_rule, inst->type_id(),
858
16.0k
                            {constants[1], constants[2]}, context);
859
16.0k
    }
860
1.45M
    return FoldFPBinaryOp(scalar_rule, inst->type_id(), constants, context);
861
1.47M
  };
862
353k
}
863
864
// This macro defines a |UnaryScalarFoldingRule| that performs float to
865
// integer conversion.
866
// TODO(greg-lunarg): Support for 64-bit integer types.
867
32.1k
UnaryScalarFoldingRule FoldFToIOp() {
868
32.1k
  return [](const analysis::Type* result_type, const analysis::Constant* a,
869
32.1k
            analysis::ConstantManager* const_mgr) -> const analysis::Constant* {
870
1.70k
    assert(result_type != nullptr && a != nullptr);
871
1.70k
    const analysis::Integer* integer_type = result_type->AsInteger();
872
1.70k
    const analysis::Float* float_type = a->type()->AsFloat();
873
1.70k
    assert(float_type != nullptr);
874
1.70k
    assert(integer_type != nullptr);
875
1.70k
    if (integer_type->width() != 32) return nullptr;
876
1.70k
    if (float_type->width() == 32) {
877
1.70k
      float fa = a->GetFloat();
878
1.70k
      uint32_t result = integer_type->IsSigned()
879
1.70k
                            ? static_cast<uint32_t>(static_cast<int32_t>(fa))
880
1.70k
                            : static_cast<uint32_t>(fa);
881
1.70k
      std::vector<uint32_t> words = {result};
882
1.70k
      return const_mgr->GetConstant(result_type, words);
883
1.70k
    } else if (float_type->width() == 64) {
884
0
      double fa = a->GetDouble();
885
0
      uint32_t result = integer_type->IsSigned()
886
0
                            ? static_cast<uint32_t>(static_cast<int32_t>(fa))
887
0
                            : static_cast<uint32_t>(fa);
888
0
      std::vector<uint32_t> words = {result};
889
0
      return const_mgr->GetConstant(result_type, words);
890
0
    }
891
0
    return nullptr;
892
1.70k
  };
893
32.1k
}
894
895
// This function defines a |UnaryScalarFoldingRule| that performs integer to
896
// float conversion.
897
// TODO(greg-lunarg): Support for 64-bit integer types.
898
32.1k
UnaryScalarFoldingRule FoldIToFOp() {
899
32.1k
  return [](const analysis::Type* result_type, const analysis::Constant* a,
900
329k
            analysis::ConstantManager* const_mgr) -> const analysis::Constant* {
901
329k
    assert(result_type != nullptr && a != nullptr);
902
329k
    const analysis::Integer* integer_type = a->type()->AsInteger();
903
329k
    const analysis::Float* float_type = result_type->AsFloat();
904
329k
    assert(float_type != nullptr);
905
329k
    assert(integer_type != nullptr);
906
329k
    if (integer_type->width() != 32) return nullptr;
907
329k
    uint32_t ua = a->GetU32();
908
329k
    if (float_type->width() == 32) {
909
329k
      float result_val = integer_type->IsSigned()
910
329k
                             ? static_cast<float>(static_cast<int32_t>(ua))
911
329k
                             : static_cast<float>(ua);
912
329k
      utils::FloatProxy<float> result(result_val);
913
329k
      std::vector<uint32_t> words = {result.data()};
914
329k
      return const_mgr->GetConstant(result_type, words);
915
329k
    } else if (float_type->width() == 64) {
916
0
      double result_val = integer_type->IsSigned()
917
0
                              ? static_cast<double>(static_cast<int32_t>(ua))
918
0
                              : static_cast<double>(ua);
919
0
      utils::FloatProxy<double> result(result_val);
920
0
      std::vector<uint32_t> words = result.GetWords();
921
0
      return const_mgr->GetConstant(result_type, words);
922
0
    }
923
0
    return nullptr;
924
329k
  };
925
32.1k
}
926
927
// This defines a |UnaryScalarFoldingRule| that performs |OpQuantizeToF16|.
928
16.0k
UnaryScalarFoldingRule FoldQuantizeToF16Scalar() {
929
16.0k
  return [](const analysis::Type* result_type, const analysis::Constant* a,
930
16.0k
            analysis::ConstantManager* const_mgr) -> const analysis::Constant* {
931
1.85k
    assert(result_type != nullptr && a != nullptr);
932
1.85k
    const analysis::Float* float_type = a->type()->AsFloat();
933
1.85k
    assert(float_type != nullptr);
934
1.85k
    if (float_type->width() != 32) {
935
0
      return nullptr;
936
0
    }
937
938
1.85k
    float fa = a->GetFloat();
939
1.85k
    utils::HexFloat<utils::FloatProxy<float>> orignal(fa);
940
1.85k
    utils::HexFloat<utils::FloatProxy<utils::Float16>> quantized(0);
941
1.85k
    utils::HexFloat<utils::FloatProxy<float>> result(0.0f);
942
1.85k
    orignal.castTo(quantized, utils::round_direction::kToZero);
943
1.85k
    quantized.castTo(result, utils::round_direction::kToZero);
944
1.85k
    std::vector<uint32_t> words = {result.getBits()};
945
1.85k
    return const_mgr->GetConstant(result_type, words);
946
1.85k
  };
947
16.0k
}
948
949
// This macro defines a |BinaryScalarFoldingRule| that applies |op|.  The
950
// operator |op| must work for both float and double, and use syntax "f1 op f2".
951
#define FOLD_FPARITH_OP(op)                                                   \
952
68.7k
  [](const analysis::Type* result_type_in_macro, const analysis::Constant* a, \
953
68.7k
     const analysis::Constant* b,                                             \
954
68.7k
     analysis::ConstantManager* const_mgr_in_macro)                           \
955
465k
      -> const analysis::Constant* {                                          \
956
465k
    assert(result_type_in_macro != nullptr && a != nullptr && b != nullptr);  \
957
465k
    assert(result_type_in_macro == a->type() &&                               \
958
465k
           result_type_in_macro == b->type());                                \
959
465k
    const analysis::Float* float_type_in_macro =                              \
960
465k
        result_type_in_macro->AsFloat();                                      \
961
465k
    assert(float_type_in_macro != nullptr);                                   \
962
465k
    if (float_type_in_macro->width() == 32) {                                 \
963
465k
      float fa = a->GetFloat();                                               \
964
465k
      float fb = b->GetFloat();                                               \
965
465k
      utils::FloatProxy<float> result_in_macro(fa op fb);                     \
966
465k
      std::vector<uint32_t> words_in_macro = result_in_macro.GetWords();      \
967
465k
      return const_mgr_in_macro->GetConstant(result_type_in_macro,            \
968
465k
                                             words_in_macro);                 \
969
465k
    } else if (float_type_in_macro->width() == 64) {                          \
970
0
      double fa = a->GetDouble();                                             \
971
0
      double fb = b->GetDouble();                                             \
972
0
      utils::FloatProxy<double> result_in_macro(fa op fb);                    \
973
0
      std::vector<uint32_t> words_in_macro = result_in_macro.GetWords();      \
974
0
      return const_mgr_in_macro->GetConstant(result_type_in_macro,            \
975
0
                                             words_in_macro);                 \
976
0
    }                                                                         \
977
465k
    return nullptr;                                                           \
978
465k
  }
979
980
// Define the folding rule for conversion between floating point and integer
981
32.1k
ConstantFoldingRule FoldFToI() { return FoldFPUnaryOp(FoldFToIOp()); }
982
32.1k
ConstantFoldingRule FoldIToF() { return FoldFPUnaryOp(FoldIToFOp()); }
983
16.0k
ConstantFoldingRule FoldQuantizeToF16() {
984
16.0k
  return FoldFPUnaryOp(FoldQuantizeToF16Scalar());
985
16.0k
}
986
987
// Define the folding rules for subtraction, addition, multiplication, and
988
// division for floating point values.
989
42.0k
ConstantFoldingRule FoldFSub() { return FoldFPBinaryOp(FOLD_FPARITH_OP(-)); }
990
207k
ConstantFoldingRule FoldFAdd() { return FoldFPBinaryOp(FOLD_FPARITH_OP(+)); }
991
251k
ConstantFoldingRule FoldFMul() { return FoldFPBinaryOp(FOLD_FPARITH_OP(*)); }
992
993
// x - x = 0
994
32.1k
ConstantFoldingRule FoldRedundantSub() {
995
32.1k
  return [](IRContext* context, Instruction* inst,
996
32.1k
            const std::vector<const analysis::Constant*>&)
997
187k
             -> const analysis::Constant* {
998
187k
    assert(inst->opcode() == spv::Op::OpFSub ||
999
187k
           inst->opcode() == spv::Op::OpISub);
1000
1001
187k
    if (inst->GetSingleWordInOperand(0) == inst->GetSingleWordInOperand(1)) {
1002
2.43k
      bool use_float = inst->opcode() == spv::Op::OpFSub;
1003
2.43k
      if (use_float && !inst->IsFloatingPointFoldingAllowed()) {
1004
2
        return nullptr;
1005
2
      }
1006
2.43k
      analysis::TypeManager* type_mgr = context->get_type_mgr();
1007
2.43k
      const analysis::Type* type = type_mgr->GetType(inst->type_id());
1008
2.43k
      if (type->IsCooperativeMatrix()) {
1009
0
        return nullptr;
1010
0
      }
1011
2.43k
      analysis::ConstantManager* const_mgr = context->get_constant_mgr();
1012
2.43k
      uint32_t null_id = const_mgr->GetNullConstId(type);
1013
2.43k
      return const_mgr->FindDeclaredConstant(null_id);
1014
2.43k
    }
1015
185k
    return nullptr;
1016
187k
  };
1017
32.1k
}
1018
1019
// Returns the constant that results from evaluating |numerator| / 0.0.  Returns
1020
// |nullptr| if the result could not be evaluated.
1021
const analysis::Constant* FoldFPScalarDivideByZero(
1022
    const analysis::Type* result_type, const analysis::Constant* numerator,
1023
2.48k
    analysis::ConstantManager* const_mgr) {
1024
2.48k
  if (numerator == nullptr) {
1025
0
    return nullptr;
1026
0
  }
1027
1028
2.48k
  if (numerator->IsZero()) {
1029
1.41k
    return GetNan(result_type, const_mgr);
1030
1.41k
  }
1031
1032
1.06k
  const analysis::Constant* result = GetInf(result_type, const_mgr);
1033
1.06k
  if (result == nullptr) {
1034
0
    return nullptr;
1035
0
  }
1036
1037
1.06k
  if (numerator->AsFloatConstant()->GetValueAsDouble() < 0.0) {
1038
233
    result = NegateFPConst(result_type, result, const_mgr);
1039
233
  }
1040
1.06k
  return result;
1041
1.06k
}
1042
1043
// Returns the result of folding |numerator| / |denominator|.  Returns |nullptr|
1044
// if it cannot be folded.
1045
const analysis::Constant* FoldScalarFPDivide(
1046
    const analysis::Type* result_type, const analysis::Constant* numerator,
1047
    const analysis::Constant* denominator,
1048
12.1k
    analysis::ConstantManager* const_mgr) {
1049
12.1k
  if (denominator == nullptr) {
1050
0
    return nullptr;
1051
0
  }
1052
1053
12.1k
  if (denominator->IsZero()) {
1054
2.13k
    return FoldFPScalarDivideByZero(result_type, numerator, const_mgr);
1055
2.13k
  }
1056
1057
10.0k
  uint32_t width = denominator->type()->AsFloat()->width();
1058
10.0k
  if (width != 32 && width != 64) {
1059
0
    return nullptr;
1060
0
  }
1061
1062
10.0k
  const analysis::FloatConstant* denominator_float =
1063
10.0k
      denominator->AsFloatConstant();
1064
10.0k
  if (denominator_float && denominator->GetValueAsDouble() == -0.0) {
1065
342
    const analysis::Constant* result =
1066
342
        FoldFPScalarDivideByZero(result_type, numerator, const_mgr);
1067
342
    if (result != nullptr)
1068
342
      result = NegateFPConst(result_type, result, const_mgr);
1069
342
    return result;
1070
9.70k
  } else {
1071
19.4k
    return FOLD_FPARITH_OP(/)(result_type, numerator, denominator, const_mgr);
1072
9.70k
  }
1073
10.0k
}
1074
1075
// Returns the constant folding rule to fold |OpFDiv| with two constants.
1076
16.0k
ConstantFoldingRule FoldFDiv() { return FoldFPBinaryOp(FoldScalarFPDivide); }
1077
1078
// Get a singular uniform value, which is repeated when the |type| is a vector.
1079
const analysis::Constant* GetConstantUniformValue(
1080
    analysis::ConstantManager* const_mgr, const analysis::Type* type,
1081
1.75k
    std::optional<double> f = {}, std::optional<uint64_t> i = {}) {
1082
1.75k
  const analysis::Constant* uniform = nullptr;
1083
1.75k
  bool is_vector = false;
1084
1.75k
  const analysis::Type* base_type = type;
1085
1086
1.75k
  if (base_type->AsVector()) {
1087
1.57k
    is_vector = true;
1088
1.57k
    base_type = base_type->AsVector()->element_type();
1089
1.57k
  }
1090
1091
1.75k
  if (f && base_type->AsFloat()) {
1092
1.68k
    if (base_type->AsFloat()->width() == 32) {
1093
1.68k
      uniform = const_mgr->GetConstant(
1094
1.68k
          base_type, utils::FloatProxy<float>((float)f.value()).GetWords());
1095
1.68k
    } else if (base_type->AsFloat()->width() == 64) {
1096
0
      uniform = const_mgr->GetConstant(
1097
0
          base_type, utils::FloatProxy<double>(f.value()).GetWords());
1098
0
    }
1099
1.68k
  } else if (i && base_type->AsInteger()) {
1100
73
    uniform =
1101
73
        const_mgr->GenerateIntegerConstant(base_type->AsInteger(), i.value());
1102
73
  }
1103
1104
1.75k
  if (!uniform) {
1105
0
    return nullptr;
1106
0
  }
1107
1108
1.75k
  if (is_vector) {
1109
1.57k
    Instruction* uniform_inst = const_mgr->GetDefiningInstruction(uniform);
1110
1.57k
    if (!uniform_inst) return nullptr;
1111
1112
1.57k
    uint32_t uniform_id = uniform_inst->result_id();
1113
1.57k
    uniform =
1114
1.57k
        const_mgr->GetConstant(type, std::vector<uint32_t>(4, uniform_id));
1115
1.57k
  }
1116
1117
1.75k
  return uniform;
1118
1.75k
}
1119
1120
//  x /  x =  1
1121
// -x /  x = -1
1122
//  x / -x = -1
1123
48.2k
ConstantFoldingRule FoldRedundantDiv() {
1124
48.2k
  return [](IRContext* context, Instruction* inst,
1125
48.2k
            const std::vector<const analysis::Constant*>& constants)
1126
177k
             -> const analysis::Constant* {
1127
177k
    assert(inst->opcode() == spv::Op::OpFDiv ||
1128
177k
           inst->opcode() == spv::Op::OpSDiv ||
1129
177k
           inst->opcode() == spv::Op::OpUDiv);
1130
1131
177k
    if (constants[0] || constants[1]) {
1132
138k
      return nullptr;
1133
138k
    }
1134
1135
38.6k
    analysis::TypeManager* type_mgr = context->get_type_mgr();
1136
38.6k
    const analysis::Type* type = type_mgr->GetType(inst->type_id());
1137
1138
38.6k
    if (type->IsCooperativeMatrix()) {
1139
0
      return nullptr;
1140
0
    }
1141
1142
38.6k
    bool use_float = inst->opcode() == spv::Op::OpFDiv;
1143
38.6k
    if (use_float && !inst->IsFloatingPointFoldingAllowed()) {
1144
30
      return nullptr;
1145
30
    }
1146
1147
38.6k
    analysis::ConstantManager* const_mgr = context->get_constant_mgr();
1148
1149
38.6k
    if (inst->GetSingleWordInOperand(0) == inst->GetSingleWordInOperand(1)) {
1150
1.75k
      return GetConstantUniformValue(const_mgr, type, 1.0, 1);
1151
1.75k
    }
1152
1153
36.8k
    if (inst->opcode() == spv::Op::OpUDiv) {
1154
40
      return nullptr;
1155
40
    }
1156
1157
36.8k
    analysis::DefUseManager* def_use_mgr = context->get_def_use_mgr();
1158
1159
36.8k
    Instruction* lhs = def_use_mgr->GetDef(inst->GetSingleWordInOperand(0));
1160
36.8k
    if ((lhs->opcode() == spv::Op::OpSNegate ||
1161
36.8k
         lhs->opcode() == spv::Op::OpFNegate) &&
1162
37
        lhs->GetSingleWordInOperand(0) == inst->GetSingleWordInOperand(1) &&
1163
1
        (!use_float || lhs->IsFloatingPointFoldingAllowed())) {
1164
1
      return GetConstantUniformValue(const_mgr, type, -1.0, UINT64_MAX);
1165
1
    }
1166
1167
36.8k
    Instruction* rhs = def_use_mgr->GetDef(inst->GetSingleWordInOperand(1));
1168
36.8k
    if ((rhs->opcode() == spv::Op::OpSNegate ||
1169
36.8k
         rhs->opcode() == spv::Op::OpFNegate) &&
1170
32
        rhs->GetSingleWordInOperand(0) == inst->GetSingleWordInOperand(0) &&
1171
2
        (!use_float || rhs->IsFloatingPointFoldingAllowed())) {
1172
2
      return GetConstantUniformValue(const_mgr, type, -1.0, UINT64_MAX);
1173
2
    }
1174
1175
36.8k
    return nullptr;
1176
36.8k
  };
1177
48.2k
}
1178
1179
bool CompareFloatingPoint(bool op_result, bool op_unordered,
1180
14.8k
                          bool need_ordered) {
1181
14.8k
  if (need_ordered) {
1182
    // operands are ordered and Operand 1 is |op| Operand 2
1183
6.02k
    return !op_unordered && op_result;
1184
8.83k
  } else {
1185
    // operands are unordered or Operand 1 is |op| Operand 2
1186
8.83k
    return op_unordered || op_result;
1187
8.83k
  }
1188
14.8k
}
1189
1190
// This macro defines a |BinaryScalarFoldingRule| that applies |op|.  The
1191
// operator |op| must work for both float and double, and use syntax "f1 op f2".
1192
#define FOLD_FPCMP_OP(op, ord)                                            \
1193
193k
  [](const analysis::Type* result_type, const analysis::Constant* a,      \
1194
193k
     const analysis::Constant* b,                                         \
1195
193k
     analysis::ConstantManager* const_mgr) -> const analysis::Constant* { \
1196
14.8k
    assert(result_type != nullptr && a != nullptr && b != nullptr);       \
1197
14.8k
    assert(result_type->AsBool());                                        \
1198
14.8k
    assert(a->type() == b->type());                                       \
1199
14.8k
    const analysis::Float* float_type = a->type()->AsFloat();             \
1200
14.8k
    assert(float_type != nullptr);                                        \
1201
14.8k
    if (float_type->width() == 32) {                                      \
1202
14.8k
      float fa = a->GetFloat();                                           \
1203
14.8k
      float fb = b->GetFloat();                                           \
1204
14.8k
      bool result = CompareFloatingPoint(                                 \
1205
14.8k
          fa op fb, std::isnan(fa) || std::isnan(fb), ord);               \
1206
14.8k
      std::vector<uint32_t> words = {uint32_t(result)};                   \
1207
14.8k
      return const_mgr->GetConstant(result_type, words);                  \
1208
14.8k
    } else if (float_type->width() == 64) {                               \
1209
0
      double fa = a->GetDouble();                                         \
1210
0
      double fb = b->GetDouble();                                         \
1211
0
      bool result = CompareFloatingPoint(                                 \
1212
0
          fa op fb, std::isnan(fa) || std::isnan(fb), ord);               \
1213
0
      std::vector<uint32_t> words = {uint32_t(result)};                   \
1214
0
      return const_mgr->GetConstant(result_type, words);                  \
1215
0
    }                                                                     \
1216
14.8k
    return nullptr;                                                       \
1217
14.8k
  }
1218
1219
// Define the folding rules for ordered and unordered comparison for floating
1220
// point values.
1221
16.0k
ConstantFoldingRule FoldFOrdEqual() {
1222
16.8k
  return FoldFPBinaryOp(FOLD_FPCMP_OP(==, true));
1223
16.0k
}
1224
16.0k
ConstantFoldingRule FoldFUnordEqual() {
1225
16.5k
  return FoldFPBinaryOp(FOLD_FPCMP_OP(==, false));
1226
16.0k
}
1227
16.0k
ConstantFoldingRule FoldFOrdNotEqual() {
1228
16.8k
  return FoldFPBinaryOp(FOLD_FPCMP_OP(!=, true));
1229
16.0k
}
1230
16.0k
ConstantFoldingRule FoldFUnordNotEqual() {
1231
17.6k
  return FoldFPBinaryOp(FOLD_FPCMP_OP(!=, false));
1232
16.0k
}
1233
16.0k
ConstantFoldingRule FoldFOrdLessThan() {
1234
18.2k
  return FoldFPBinaryOp(FOLD_FPCMP_OP(<, true));
1235
16.0k
}
1236
16.0k
ConstantFoldingRule FoldFUnordLessThan() {
1237
17.7k
  return FoldFPBinaryOp(FOLD_FPCMP_OP(<, false));
1238
16.0k
}
1239
16.0k
ConstantFoldingRule FoldFOrdGreaterThan() {
1240
16.6k
  return FoldFPBinaryOp(FOLD_FPCMP_OP(>, true));
1241
16.0k
}
1242
16.0k
ConstantFoldingRule FoldFUnordGreaterThan() {
1243
17.0k
  return FoldFPBinaryOp(FOLD_FPCMP_OP(>, false));
1244
16.0k
}
1245
16.0k
ConstantFoldingRule FoldFOrdLessThanEqual() {
1246
16.9k
  return FoldFPBinaryOp(FOLD_FPCMP_OP(<=, true));
1247
16.0k
}
1248
16.0k
ConstantFoldingRule FoldFUnordLessThanEqual() {
1249
17.3k
  return FoldFPBinaryOp(FOLD_FPCMP_OP(<=, false));
1250
16.0k
}
1251
16.0k
ConstantFoldingRule FoldFOrdGreaterThanEqual() {
1252
17.0k
  return FoldFPBinaryOp(FOLD_FPCMP_OP(>=, true));
1253
16.0k
}
1254
16.0k
ConstantFoldingRule FoldFUnordGreaterThanEqual() {
1255
18.9k
  return FoldFPBinaryOp(FOLD_FPCMP_OP(>=, false));
1256
16.0k
}
1257
1258
16.0k
ConstantFoldingRule FoldInvariantSelect() {
1259
16.0k
  return [](IRContext*, Instruction* inst,
1260
16.0k
            const std::vector<const analysis::Constant*>& constants)
1261
23.2k
             -> const analysis::Constant* {
1262
23.2k
    assert(inst->opcode() == spv::Op::OpSelect);
1263
23.2k
    (void)inst;
1264
1265
23.2k
    if (!constants[1] || !constants[2]) {
1266
11.8k
      return nullptr;
1267
11.8k
    }
1268
11.4k
    if (constants[1] == constants[2]) {
1269
1.45k
      return constants[1];
1270
1.45k
    }
1271
9.96k
    if (constants[1]->IsZero() && constants[2]->IsZero()) {
1272
4
      return constants[1];
1273
4
    }
1274
9.96k
    return nullptr;
1275
9.96k
  };
1276
16.0k
}
1277
1278
// Folds an OpDot where all of the inputs are constants to a
1279
// constant.  A new constant is created if necessary.
1280
16.0k
ConstantFoldingRule FoldOpDotWithConstants() {
1281
16.0k
  return [](IRContext* context, Instruction* inst,
1282
16.0k
            const std::vector<const analysis::Constant*>& constants)
1283
16.0k
             -> const analysis::Constant* {
1284
77
    analysis::ConstantManager* const_mgr = context->get_constant_mgr();
1285
77
    analysis::TypeManager* type_mgr = context->get_type_mgr();
1286
77
    const analysis::Type* new_type = type_mgr->GetType(inst->type_id());
1287
77
    assert(new_type->AsFloat() && "OpDot should have a float return type.");
1288
77
    const analysis::Float* float_type = new_type->AsFloat();
1289
1290
77
    if (!inst->IsFloatingPointFoldingAllowed()) {
1291
0
      return nullptr;
1292
0
    }
1293
1294
    // If one of the operands is 0, then the result is 0.
1295
77
    bool has_zero_operand = false;
1296
1297
227
    for (int i = 0; i < 2; ++i) {
1298
152
      if (constants[i]) {
1299
67
        if (constants[i]->AsNullConstant() ||
1300
67
            constants[i]->AsVectorConstant()->IsZero()) {
1301
2
          has_zero_operand = true;
1302
2
          break;
1303
2
        }
1304
67
      }
1305
152
    }
1306
1307
77
    if (has_zero_operand) {
1308
2
      if (float_type->width() == 32) {
1309
2
        utils::FloatProxy<float> result(0.0f);
1310
2
        std::vector<uint32_t> words = result.GetWords();
1311
2
        return const_mgr->GetConstant(float_type, words);
1312
2
      }
1313
0
      if (float_type->width() == 64) {
1314
0
        utils::FloatProxy<double> result(0.0);
1315
0
        std::vector<uint32_t> words = result.GetWords();
1316
0
        return const_mgr->GetConstant(float_type, words);
1317
0
      }
1318
0
      return nullptr;
1319
0
    }
1320
1321
75
    if (constants[0] == nullptr || constants[1] == nullptr) {
1322
75
      return nullptr;
1323
75
    }
1324
1325
0
    std::vector<const analysis::Constant*> a_components;
1326
0
    std::vector<const analysis::Constant*> b_components;
1327
1328
0
    a_components = constants[0]->GetVectorComponents(const_mgr);
1329
0
    b_components = constants[1]->GetVectorComponents(const_mgr);
1330
1331
0
    utils::FloatProxy<double> result(0.0);
1332
0
    std::vector<uint32_t> words = result.GetWords();
1333
0
    const analysis::Constant* result_const =
1334
0
        const_mgr->GetConstant(float_type, words);
1335
0
    for (uint32_t i = 0; i < a_components.size() && result_const != nullptr;
1336
0
         ++i) {
1337
0
      if (a_components[i] == nullptr || b_components[i] == nullptr) {
1338
0
        return nullptr;
1339
0
      }
1340
1341
0
      const analysis::Constant* component = FOLD_FPARITH_OP(*)(
1342
0
          new_type, a_components[i], b_components[i], const_mgr);
1343
0
      if (component == nullptr) {
1344
0
        return nullptr;
1345
0
      }
1346
0
      result_const =
1347
0
          FOLD_FPARITH_OP(+)(new_type, result_const, component, const_mgr);
1348
0
    }
1349
0
    return result_const;
1350
0
  };
1351
16.0k
}
1352
1353
16.0k
ConstantFoldingRule FoldFNegate() { return FoldFPUnaryOp(NegateFPConst); }
1354
16.0k
ConstantFoldingRule FoldSNegate() { return FoldUnaryOp(NegateIntConst); }
1355
1356
128k
ConstantFoldingRule FoldFClampFeedingCompare(spv::Op cmp_opcode) {
1357
128k
  return [cmp_opcode](IRContext* context, Instruction* inst,
1358
128k
                      const std::vector<const analysis::Constant*>& constants)
1359
128k
             -> const analysis::Constant* {
1360
102k
    analysis::ConstantManager* const_mgr = context->get_constant_mgr();
1361
102k
    analysis::DefUseManager* def_use_mgr = context->get_def_use_mgr();
1362
1363
102k
    if (!inst->IsFloatingPointFoldingAllowed()) {
1364
36
      return nullptr;
1365
36
    }
1366
1367
102k
    uint32_t non_const_idx = (constants[0] ? 1 : 0);
1368
102k
    uint32_t operand_id = inst->GetSingleWordInOperand(non_const_idx);
1369
102k
    Instruction* operand_inst = def_use_mgr->GetDef(operand_id);
1370
1371
102k
    analysis::TypeManager* type_mgr = context->get_type_mgr();
1372
102k
    const analysis::Type* operand_type =
1373
102k
        type_mgr->GetType(operand_inst->type_id());
1374
1375
102k
    if (!operand_type->AsFloat()) {
1376
0
      return nullptr;
1377
0
    }
1378
1379
102k
    if (operand_type->AsFloat()->width() != 32 &&
1380
0
        operand_type->AsFloat()->width() != 64) {
1381
0
      return nullptr;
1382
0
    }
1383
1384
102k
    if (operand_inst->opcode() != spv::Op::OpExtInst) {
1385
92.3k
      return nullptr;
1386
92.3k
    }
1387
1388
10.0k
    if (operand_inst->GetSingleWordInOperand(1) != GLSLstd450FClamp) {
1389
10.0k
      return nullptr;
1390
10.0k
    }
1391
1392
0
    if (constants[1] == nullptr && constants[0] == nullptr) {
1393
0
      return nullptr;
1394
0
    }
1395
1396
0
    uint32_t max_id = operand_inst->GetSingleWordInOperand(4);
1397
0
    const analysis::Constant* max_const =
1398
0
        const_mgr->FindDeclaredConstant(max_id);
1399
1400
0
    uint32_t min_id = operand_inst->GetSingleWordInOperand(3);
1401
0
    const analysis::Constant* min_const =
1402
0
        const_mgr->FindDeclaredConstant(min_id);
1403
1404
0
    bool found_result = false;
1405
0
    bool result = false;
1406
1407
0
    switch (cmp_opcode) {
1408
0
      case spv::Op::OpFOrdLessThan:
1409
0
      case spv::Op::OpFUnordLessThan:
1410
0
      case spv::Op::OpFOrdGreaterThanEqual:
1411
0
      case spv::Op::OpFUnordGreaterThanEqual:
1412
0
        if (constants[0]) {
1413
0
          if (min_const) {
1414
0
            if (constants[0]->GetValueAsDouble() <
1415
0
                min_const->GetValueAsDouble()) {
1416
0
              found_result = true;
1417
0
              result = (cmp_opcode == spv::Op::OpFOrdLessThan ||
1418
0
                        cmp_opcode == spv::Op::OpFUnordLessThan);
1419
0
            }
1420
0
          }
1421
0
          if (max_const) {
1422
0
            if (constants[0]->GetValueAsDouble() >=
1423
0
                max_const->GetValueAsDouble()) {
1424
0
              found_result = true;
1425
0
              result = !(cmp_opcode == spv::Op::OpFOrdLessThan ||
1426
0
                         cmp_opcode == spv::Op::OpFUnordLessThan);
1427
0
            }
1428
0
          }
1429
0
        }
1430
1431
0
        if (constants[1]) {
1432
0
          if (max_const) {
1433
0
            if (max_const->GetValueAsDouble() <
1434
0
                constants[1]->GetValueAsDouble()) {
1435
0
              found_result = true;
1436
0
              result = (cmp_opcode == spv::Op::OpFOrdLessThan ||
1437
0
                        cmp_opcode == spv::Op::OpFUnordLessThan);
1438
0
            }
1439
0
          }
1440
1441
0
          if (min_const) {
1442
0
            if (min_const->GetValueAsDouble() >=
1443
0
                constants[1]->GetValueAsDouble()) {
1444
0
              found_result = true;
1445
0
              result = !(cmp_opcode == spv::Op::OpFOrdLessThan ||
1446
0
                         cmp_opcode == spv::Op::OpFUnordLessThan);
1447
0
            }
1448
0
          }
1449
0
        }
1450
0
        break;
1451
0
      case spv::Op::OpFOrdGreaterThan:
1452
0
      case spv::Op::OpFUnordGreaterThan:
1453
0
      case spv::Op::OpFOrdLessThanEqual:
1454
0
      case spv::Op::OpFUnordLessThanEqual:
1455
0
        if (constants[0]) {
1456
0
          if (min_const) {
1457
0
            if (constants[0]->GetValueAsDouble() <=
1458
0
                min_const->GetValueAsDouble()) {
1459
0
              found_result = true;
1460
0
              result = (cmp_opcode == spv::Op::OpFOrdLessThanEqual ||
1461
0
                        cmp_opcode == spv::Op::OpFUnordLessThanEqual);
1462
0
            }
1463
0
          }
1464
0
          if (max_const) {
1465
0
            if (constants[0]->GetValueAsDouble() >
1466
0
                max_const->GetValueAsDouble()) {
1467
0
              found_result = true;
1468
0
              result = !(cmp_opcode == spv::Op::OpFOrdLessThanEqual ||
1469
0
                         cmp_opcode == spv::Op::OpFUnordLessThanEqual);
1470
0
            }
1471
0
          }
1472
0
        }
1473
1474
0
        if (constants[1]) {
1475
0
          if (max_const) {
1476
0
            if (max_const->GetValueAsDouble() <=
1477
0
                constants[1]->GetValueAsDouble()) {
1478
0
              found_result = true;
1479
0
              result = (cmp_opcode == spv::Op::OpFOrdLessThanEqual ||
1480
0
                        cmp_opcode == spv::Op::OpFUnordLessThanEqual);
1481
0
            }
1482
0
          }
1483
1484
0
          if (min_const) {
1485
0
            if (min_const->GetValueAsDouble() >
1486
0
                constants[1]->GetValueAsDouble()) {
1487
0
              found_result = true;
1488
0
              result = !(cmp_opcode == spv::Op::OpFOrdLessThanEqual ||
1489
0
                         cmp_opcode == spv::Op::OpFUnordLessThanEqual);
1490
0
            }
1491
0
          }
1492
0
        }
1493
0
        break;
1494
0
      default:
1495
0
        return nullptr;
1496
0
    }
1497
1498
0
    if (!found_result) {
1499
0
      return nullptr;
1500
0
    }
1501
1502
0
    const analysis::Type* bool_type =
1503
0
        context->get_type_mgr()->GetType(inst->type_id());
1504
0
    const analysis::Constant* result_const =
1505
0
        const_mgr->GetConstant(bool_type, {static_cast<uint32_t>(result)});
1506
0
    assert(result_const);
1507
0
    return result_const;
1508
0
  };
1509
128k
}
1510
1511
9.59k
ConstantFoldingRule FoldFMix() {
1512
9.59k
  return [](IRContext* context, Instruction* inst,
1513
9.59k
            const std::vector<const analysis::Constant*>& constants)
1514
12.8k
             -> const analysis::Constant* {
1515
12.8k
    analysis::ConstantManager* const_mgr = context->get_constant_mgr();
1516
12.8k
    assert(inst->opcode() == spv::Op::OpExtInst &&
1517
12.8k
           "Expecting an extended instruction.");
1518
12.8k
    assert(inst->GetSingleWordInOperand(0) ==
1519
12.8k
               context->get_feature_mgr()->GetExtInstImportId_GLSLstd450() &&
1520
12.8k
           "Expecting a GLSLstd450 extended instruction.");
1521
12.8k
    assert(inst->GetSingleWordInOperand(1) == GLSLstd450FMix &&
1522
12.8k
           "Expecting and FMix instruction.");
1523
1524
12.8k
    if (!inst->IsFloatingPointFoldingAllowed()) {
1525
0
      return nullptr;
1526
0
    }
1527
1528
    // Make sure all FMix operands are constants.
1529
14.5k
    for (uint32_t i = 1; i < 4; i++) {
1530
14.2k
      if (constants[i] == nullptr) {
1531
12.5k
        return nullptr;
1532
12.5k
      }
1533
14.2k
    }
1534
1535
272
    const analysis::Constant* one;
1536
272
    bool is_vector = false;
1537
272
    const analysis::Type* result_type = constants[1]->type();
1538
272
    const analysis::Type* base_type = result_type;
1539
272
    if (base_type->AsVector()) {
1540
228
      is_vector = true;
1541
228
      base_type = base_type->AsVector()->element_type();
1542
228
    }
1543
272
    assert(base_type->AsFloat() != nullptr &&
1544
272
           "FMix is suppose to act on floats or vectors of floats.");
1545
1546
272
    if (base_type->AsFloat()->width() == 32) {
1547
272
      one = const_mgr->GetConstant(base_type,
1548
272
                                   utils::FloatProxy<float>(1.0f).GetWords());
1549
272
    } else if (base_type->AsFloat()->width() == 64) {
1550
0
      one = const_mgr->GetConstant(base_type,
1551
0
                                   utils::FloatProxy<double>(1.0).GetWords());
1552
0
    } else {
1553
      // We won't support folding half types.
1554
0
      return nullptr;
1555
0
    }
1556
1557
272
    if (is_vector) {
1558
228
      Instruction* one_inst = const_mgr->GetDefiningInstruction(one);
1559
228
      if (one_inst == nullptr) return nullptr;
1560
228
      uint32_t one_id = one_inst->result_id();
1561
228
      one =
1562
228
          const_mgr->GetConstant(result_type, std::vector<uint32_t>(4, one_id));
1563
228
    }
1564
1565
272
    const analysis::Constant* temp1 = FoldFPBinaryOp(
1566
1.00k
        FOLD_FPARITH_OP(-), inst->type_id(), {one, constants[3]}, context);
1567
272
    if (temp1 == nullptr) {
1568
0
      return nullptr;
1569
0
    }
1570
1571
272
    const analysis::Constant* temp2 = FoldFPBinaryOp(
1572
1.00k
        FOLD_FPARITH_OP(*), inst->type_id(), {constants[1], temp1}, context);
1573
272
    if (temp2 == nullptr) {
1574
0
      return nullptr;
1575
0
    }
1576
272
    const analysis::Constant* temp3 =
1577
1.00k
        FoldFPBinaryOp(FOLD_FPARITH_OP(*), inst->type_id(),
1578
272
                       {constants[2], constants[3]}, context);
1579
272
    if (temp3 == nullptr) {
1580
0
      return nullptr;
1581
0
    }
1582
1.00k
    return FoldFPBinaryOp(FOLD_FPARITH_OP(+), inst->type_id(), {temp2, temp3},
1583
272
                          context);
1584
272
  };
1585
9.59k
}
1586
1587
template <typename FloatType>
1588
15.3k
static bool NegZeroAwareLessThan(FloatType a, FloatType b) {
1589
15.3k
  if (a == 0.0 && b == 0.0) {
1590
2.56k
    bool sba = std::signbit(a);
1591
2.56k
    bool sbb = std::signbit(b);
1592
2.56k
    if (sba && !sbb) {
1593
130
      return true;
1594
130
    }
1595
2.56k
  }
1596
15.2k
  return a < b;
1597
15.3k
}
const_folding_rules.cpp:bool spvtools::opt::(anonymous namespace)::NegZeroAwareLessThan<float>(float, float)
Line
Count
Source
1588
15.3k
static bool NegZeroAwareLessThan(FloatType a, FloatType b) {
1589
15.3k
  if (a == 0.0 && b == 0.0) {
1590
2.56k
    bool sba = std::signbit(a);
1591
2.56k
    bool sbb = std::signbit(b);
1592
2.56k
    if (sba && !sbb) {
1593
130
      return true;
1594
130
    }
1595
2.56k
  }
1596
15.2k
  return a < b;
1597
15.3k
}
Unexecuted instantiation: const_folding_rules.cpp:bool spvtools::opt::(anonymous namespace)::NegZeroAwareLessThan<double>(double, double)
1598
1599
const analysis::Constant* FoldMin(const analysis::Type* result_type,
1600
                                  const analysis::Constant* a,
1601
                                  const analysis::Constant* b,
1602
6.30k
                                  analysis::ConstantManager*) {
1603
6.30k
  if (const analysis::Integer* int_type = result_type->AsInteger()) {
1604
87
    if (int_type->width() <= 32) {
1605
87
      assert(
1606
87
          (a->AsIntConstant() != nullptr || a->AsNullConstant() != nullptr) &&
1607
87
          "Must be an integer or null constant.");
1608
87
      assert(
1609
87
          (b->AsIntConstant() != nullptr || b->AsNullConstant() != nullptr) &&
1610
87
          "Must be an integer or null constant.");
1611
1612
87
      if (int_type->IsSigned()) {
1613
64
        int32_t va = (a->AsIntConstant() != nullptr)
1614
64
                         ? a->AsIntConstant()->GetS32BitValue()
1615
64
                         : 0;
1616
64
        int32_t vb = (b->AsIntConstant() != nullptr)
1617
64
                         ? b->AsIntConstant()->GetS32BitValue()
1618
64
                         : 0;
1619
64
        return (va < vb ? a : b);
1620
64
      } else {
1621
23
        uint32_t va = (a->AsIntConstant() != nullptr)
1622
23
                          ? a->AsIntConstant()->GetU32BitValue()
1623
23
                          : 0;
1624
23
        uint32_t vb = (b->AsIntConstant() != nullptr)
1625
23
                          ? b->AsIntConstant()->GetU32BitValue()
1626
23
                          : 0;
1627
23
        return (va < vb ? a : b);
1628
23
      }
1629
87
    } else if (int_type->width() == 64) {
1630
0
      if (int_type->IsSigned()) {
1631
0
        int64_t va = a->GetS64();
1632
0
        int64_t vb = b->GetS64();
1633
0
        return (va < vb ? a : b);
1634
0
      } else {
1635
0
        uint64_t va = a->GetU64();
1636
0
        uint64_t vb = b->GetU64();
1637
0
        return (va < vb ? a : b);
1638
0
      }
1639
0
    }
1640
6.21k
  } else if (const analysis::Float* float_type = result_type->AsFloat()) {
1641
6.21k
    if (float_type->width() == 32) {
1642
6.21k
      float va = a->GetFloat();
1643
6.21k
      float vb = b->GetFloat();
1644
6.21k
      return NegZeroAwareLessThan(va, vb) ? a : b;
1645
6.21k
    } else if (float_type->width() == 64) {
1646
0
      double va = a->GetDouble();
1647
0
      double vb = b->GetDouble();
1648
0
      return NegZeroAwareLessThan(va, vb) ? a : b;
1649
0
    }
1650
6.21k
  }
1651
0
  return nullptr;
1652
6.30k
}
1653
1654
const analysis::Constant* FoldMax(const analysis::Type* result_type,
1655
                                  const analysis::Constant* a,
1656
                                  const analysis::Constant* b,
1657
8.03k
                                  analysis::ConstantManager*) {
1658
8.03k
  if (const analysis::Integer* int_type = result_type->AsInteger()) {
1659
36
    if (int_type->width() <= 32) {
1660
36
      assert(
1661
36
          (a->AsIntConstant() != nullptr || a->AsNullConstant() != nullptr) &&
1662
36
          "Must be an integer or null constant.");
1663
36
      assert(
1664
36
          (b->AsIntConstant() != nullptr || b->AsNullConstant() != nullptr) &&
1665
36
          "Must be an integer or null constant.");
1666
1667
36
      if (int_type->IsSigned()) {
1668
36
        int32_t va = (a->AsIntConstant() != nullptr)
1669
36
                         ? a->AsIntConstant()->GetS32BitValue()
1670
36
                         : 0;
1671
36
        int32_t vb = (b->AsIntConstant() != nullptr)
1672
36
                         ? b->AsIntConstant()->GetS32BitValue()
1673
36
                         : 0;
1674
36
        return (va > vb ? a : b);
1675
36
      } else {
1676
0
        uint32_t va = (a->AsIntConstant() != nullptr)
1677
0
                          ? a->AsIntConstant()->GetU32BitValue()
1678
0
                          : 0;
1679
0
        uint32_t vb = (b->AsIntConstant() != nullptr)
1680
0
                          ? b->AsIntConstant()->GetU32BitValue()
1681
0
                          : 0;
1682
0
        return (va > vb ? a : b);
1683
0
      }
1684
36
    } else if (int_type->width() == 64) {
1685
0
      if (int_type->IsSigned()) {
1686
0
        int64_t va = a->GetS64();
1687
0
        int64_t vb = b->GetS64();
1688
0
        return (va > vb ? a : b);
1689
0
      } else {
1690
0
        uint64_t va = a->GetU64();
1691
0
        uint64_t vb = b->GetU64();
1692
0
        return (va > vb ? a : b);
1693
0
      }
1694
0
    }
1695
8.00k
  } else if (const analysis::Float* float_type = result_type->AsFloat()) {
1696
8.00k
    if (float_type->width() == 32) {
1697
8.00k
      float va = a->GetFloat();
1698
8.00k
      float vb = b->GetFloat();
1699
8.00k
      return NegZeroAwareLessThan(vb, va) ? a : b;
1700
8.00k
    } else if (float_type->width() == 64) {
1701
0
      double va = a->GetDouble();
1702
0
      double vb = b->GetDouble();
1703
0
      return NegZeroAwareLessThan(vb, va) ? a : b;
1704
0
    }
1705
8.00k
  }
1706
0
  return nullptr;
1707
8.03k
}
1708
1709
const analysis::Constant* FoldNMin(const analysis::Type* result_type,
1710
                                   const analysis::Constant* a,
1711
                                   const analysis::Constant* b,
1712
661
                                   analysis::ConstantManager*) {
1713
661
  if (const analysis::Float* float_type = result_type->AsFloat()) {
1714
661
    if (float_type->width() == 32) {
1715
661
      float va = a->GetFloat();
1716
661
      float vb = b->GetFloat();
1717
661
      if (std::isnan(va)) {
1718
39
        return b;
1719
39
      }
1720
622
      if (std::isnan(vb)) {
1721
34
        return a;
1722
34
      }
1723
588
      return NegZeroAwareLessThan(va, vb) ? a : b;
1724
622
    } else if (float_type->width() == 64) {
1725
0
      double va = a->GetDouble();
1726
0
      double vb = b->GetDouble();
1727
0
      if (std::isnan(va)) {
1728
0
        return b;
1729
0
      }
1730
0
      if (std::isnan(vb)) {
1731
0
        return a;
1732
0
      }
1733
0
      return NegZeroAwareLessThan(va, vb) ? a : b;
1734
0
    }
1735
661
  }
1736
0
  return nullptr;
1737
661
}
1738
1739
const analysis::Constant* FoldNMax(const analysis::Type* result_type,
1740
                                   const analysis::Constant* a,
1741
                                   const analysis::Constant* b,
1742
663
                                   analysis::ConstantManager*) {
1743
663
  if (const analysis::Float* float_type = result_type->AsFloat()) {
1744
663
    if (float_type->width() == 32) {
1745
663
      float va = a->GetFloat();
1746
663
      float vb = b->GetFloat();
1747
663
      if (std::isnan(va)) {
1748
68
        return b;
1749
68
      }
1750
595
      if (std::isnan(vb)) {
1751
51
        return a;
1752
51
      }
1753
544
      return NegZeroAwareLessThan(vb, va) ? a : b;
1754
595
    } else if (float_type->width() == 64) {
1755
0
      double va = a->GetDouble();
1756
0
      double vb = b->GetDouble();
1757
0
      if (std::isnan(va)) {
1758
0
        return b;
1759
0
      }
1760
0
      if (std::isnan(vb)) {
1761
0
        return a;
1762
0
      }
1763
0
      return NegZeroAwareLessThan(vb, va) ? a : b;
1764
0
    }
1765
663
  }
1766
0
  return nullptr;
1767
663
}
1768
1769
// Fold an clamp instruction when all three operands are constant.
1770
const analysis::Constant* FoldClamp1(
1771
    IRContext* context, Instruction* inst,
1772
12.7k
    const std::vector<const analysis::Constant*>& constants) {
1773
12.7k
  assert(inst->opcode() == spv::Op::OpExtInst &&
1774
12.7k
         "Expecting an extended instruction.");
1775
12.7k
  assert(inst->GetSingleWordInOperand(0) ==
1776
12.7k
             context->get_feature_mgr()->GetExtInstImportId_GLSLstd450() &&
1777
12.7k
         "Expecting a GLSLstd450 extended instruction.");
1778
1779
  // Make sure all Clamp operands are constants.
1780
24.3k
  for (uint32_t i = 1; i < 4; i++) {
1781
22.5k
    if (constants[i] == nullptr) {
1782
10.9k
      return nullptr;
1783
10.9k
    }
1784
22.5k
  }
1785
1786
1.83k
  const analysis::Constant* temp = FoldFPBinaryOp(
1787
1.83k
      FoldMax, inst->type_id(), {constants[1], constants[2]}, context);
1788
1.83k
  if (temp == nullptr) {
1789
0
    return nullptr;
1790
0
  }
1791
1.83k
  return FoldFPBinaryOp(FoldMin, inst->type_id(), {temp, constants[3]},
1792
1.83k
                        context);
1793
1.83k
}
1794
1795
// Fold a clamp instruction when |x <= min_val|.
1796
const analysis::Constant* FoldClamp2(
1797
    IRContext* context, Instruction* inst,
1798
10.9k
    const std::vector<const analysis::Constant*>& constants) {
1799
10.9k
  assert(inst->opcode() == spv::Op::OpExtInst &&
1800
10.9k
         "Expecting an extended instruction.");
1801
10.9k
  assert(inst->GetSingleWordInOperand(0) ==
1802
10.9k
             context->get_feature_mgr()->GetExtInstImportId_GLSLstd450() &&
1803
10.9k
         "Expecting a GLSLstd450 extended instruction.");
1804
1805
10.9k
  const analysis::Constant* x = constants[1];
1806
10.9k
  const analysis::Constant* min_val = constants[2];
1807
1808
10.9k
  if (x == nullptr || min_val == nullptr) {
1809
8.15k
    return nullptr;
1810
8.15k
  }
1811
1812
2.77k
  const analysis::Constant* temp =
1813
2.77k
      FoldFPBinaryOp(FoldMax, inst->type_id(), {x, min_val}, context);
1814
2.77k
  if (temp == min_val) {
1815
    // We can assume that |min_val| is less than |max_val|.  Therefore, if the
1816
    // result of the max operation is |min_val|, we know the result of the min
1817
    // operation, even if |max_val| is not a constant.
1818
2.05k
    return min_val;
1819
2.05k
  }
1820
714
  return nullptr;
1821
2.77k
}
1822
1823
// Fold a clamp instruction when |x >= max_val|.
1824
const analysis::Constant* FoldClamp3(
1825
    IRContext* context, Instruction* inst,
1826
8.86k
    const std::vector<const analysis::Constant*>& constants) {
1827
8.86k
  assert(inst->opcode() == spv::Op::OpExtInst &&
1828
8.86k
         "Expecting an extended instruction.");
1829
8.86k
  assert(inst->GetSingleWordInOperand(0) ==
1830
8.86k
             context->get_feature_mgr()->GetExtInstImportId_GLSLstd450() &&
1831
8.86k
         "Expecting a GLSLstd450 extended instruction.");
1832
1833
8.86k
  const analysis::Constant* x = constants[1];
1834
8.86k
  const analysis::Constant* max_val = constants[3];
1835
1836
8.86k
  if (x == nullptr || max_val == nullptr) {
1837
8.46k
    return nullptr;
1838
8.46k
  }
1839
1840
404
  const analysis::Constant* temp =
1841
404
      FoldFPBinaryOp(FoldMin, inst->type_id(), {x, max_val}, context);
1842
404
  if (temp == max_val) {
1843
    // We can assume that |min_val| is less than |max_val|.  Therefore, if the
1844
    // result of the max operation is |min_val|, we know the result of the min
1845
    // operation, even if |max_val| is not a constant.
1846
137
    return max_val;
1847
137
  }
1848
267
  return nullptr;
1849
404
}
1850
1851
// Fold an clamp instruction when all three operands are constant.
1852
const analysis::Constant* FoldNClamp1(
1853
    IRContext* context, Instruction* inst,
1854
607
    const std::vector<const analysis::Constant*>& constants) {
1855
607
  assert(inst->opcode() == spv::Op::OpExtInst &&
1856
607
         "Expecting an extended instruction.");
1857
607
  assert(inst->GetSingleWordInOperand(0) ==
1858
607
             context->get_feature_mgr()->GetExtInstImportId_GLSLstd450() &&
1859
607
         "Expecting a GLSLstd450 extended instruction.");
1860
1861
  // Make sure all Clamp operands are constants.
1862
1.41k
  for (uint32_t i = 1; i < 4; i++) {
1863
1.27k
    if (constants[i] == nullptr) {
1864
468
      return nullptr;
1865
468
    }
1866
1.27k
  }
1867
1868
139
  const analysis::Constant* temp = FoldFPBinaryOp(
1869
139
      FoldNMax, inst->type_id(), {constants[1], constants[2]}, context);
1870
139
  if (temp == nullptr) {
1871
0
    return nullptr;
1872
0
  }
1873
139
  return FoldFPBinaryOp(FoldNMin, inst->type_id(), {temp, constants[3]},
1874
139
                        context);
1875
139
}
1876
1877
// Fold a clamp instruction when |x <= min_val|.
1878
const analysis::Constant* FoldNClamp2(
1879
    IRContext* context, Instruction* inst,
1880
468
    const std::vector<const analysis::Constant*>& constants) {
1881
468
  assert(inst->opcode() == spv::Op::OpExtInst &&
1882
468
         "Expecting an extended instruction.");
1883
468
  assert(inst->GetSingleWordInOperand(0) ==
1884
468
             context->get_feature_mgr()->GetExtInstImportId_GLSLstd450() &&
1885
468
         "Expecting a GLSLstd450 extended instruction.");
1886
1887
468
  const analysis::Constant* x = constants[1];
1888
468
  const analysis::Constant* min_val = constants[2];
1889
1890
468
  if (x == nullptr || min_val == nullptr) {
1891
327
    return nullptr;
1892
327
  }
1893
1894
141
  const analysis::Constant* temp =
1895
141
      FoldFPBinaryOp(FoldNMax, inst->type_id(), {x, min_val}, context);
1896
141
  if (temp == min_val) {
1897
    // We can assume that |min_val| is less than |max_val|.  Therefore, if the
1898
    // result of the max operation is |min_val|, we know the result of the min
1899
    // operation, even if |max_val| is not a constant.
1900
65
    return min_val;
1901
65
  }
1902
76
  return nullptr;
1903
141
}
1904
1905
// Fold a clamp instruction when |x >= max_val|.
1906
const analysis::Constant* FoldNClamp3(
1907
    IRContext* context, Instruction* inst,
1908
403
    const std::vector<const analysis::Constant*>& constants) {
1909
403
  assert(inst->opcode() == spv::Op::OpExtInst &&
1910
403
         "Expecting an extended instruction.");
1911
403
  assert(inst->GetSingleWordInOperand(0) ==
1912
403
             context->get_feature_mgr()->GetExtInstImportId_GLSLstd450() &&
1913
403
         "Expecting a GLSLstd450 extended instruction.");
1914
1915
403
  const analysis::Constant* x = constants[1];
1916
403
  const analysis::Constant* max_val = constants[3];
1917
1918
403
  if (x == nullptr || max_val == nullptr) {
1919
334
    return nullptr;
1920
334
  }
1921
1922
69
  const analysis::Constant* temp =
1923
69
      FoldFPBinaryOp(FoldNMin, inst->type_id(), {x, max_val}, context);
1924
69
  if (temp == max_val) {
1925
    // We can assume that |min_val| is less than |max_val|.  Therefore, if the
1926
    // result of the max operation is |min_val|, we know the result of the min
1927
    // operation, even if |max_val| is not a constant.
1928
20
    return max_val;
1929
20
  }
1930
49
  return nullptr;
1931
69
}
1932
1933
105k
UnaryScalarFoldingRule FoldFTranscendentalUnary(double (*fp)(double)) {
1934
105k
  return
1935
105k
      [fp](const analysis::Type* result_type, const analysis::Constant* a,
1936
105k
           analysis::ConstantManager* const_mgr) -> const analysis::Constant* {
1937
3.05k
        assert(result_type != nullptr && a != nullptr);
1938
3.05k
        const analysis::Float* float_type = a->type()->AsFloat();
1939
3.05k
        assert(float_type != nullptr);
1940
3.05k
        assert(float_type == result_type->AsFloat());
1941
3.05k
        if (float_type->width() == 32) {
1942
3.05k
          float fa = a->GetFloat();
1943
3.05k
          float res = static_cast<float>(fp(fa));
1944
3.05k
          utils::FloatProxy<float> result(res);
1945
3.05k
          std::vector<uint32_t> words = result.GetWords();
1946
3.05k
          return const_mgr->GetConstant(result_type, words);
1947
3.05k
        } else if (float_type->width() == 64) {
1948
0
          double fa = a->GetDouble();
1949
0
          double res = fp(fa);
1950
0
          utils::FloatProxy<double> result(res);
1951
0
          std::vector<uint32_t> words = result.GetWords();
1952
0
          return const_mgr->GetConstant(result_type, words);
1953
0
        }
1954
0
        return nullptr;
1955
3.05k
      };
1956
105k
}
1957
1958
BinaryScalarFoldingRule FoldFTranscendentalBinary(double (*fp)(double,
1959
19.1k
                                                               double)) {
1960
19.1k
  return
1961
19.1k
      [fp](const analysis::Type* result_type, const analysis::Constant* a,
1962
19.1k
           const analysis::Constant* b,
1963
19.1k
           analysis::ConstantManager* const_mgr) -> const analysis::Constant* {
1964
217
        assert(result_type != nullptr && a != nullptr);
1965
217
        const analysis::Float* float_type = a->type()->AsFloat();
1966
217
        assert(float_type != nullptr);
1967
217
        assert(float_type == result_type->AsFloat());
1968
217
        assert(float_type == b->type()->AsFloat());
1969
217
        if (float_type->width() == 32) {
1970
217
          float fa = a->GetFloat();
1971
217
          float fb = b->GetFloat();
1972
217
          float res = static_cast<float>(fp(fa, fb));
1973
217
          utils::FloatProxy<float> result(res);
1974
217
          std::vector<uint32_t> words = result.GetWords();
1975
217
          return const_mgr->GetConstant(result_type, words);
1976
217
        } else if (float_type->width() == 64) {
1977
0
          double fa = a->GetDouble();
1978
0
          double fb = b->GetDouble();
1979
0
          double res = fp(fa, fb);
1980
0
          utils::FloatProxy<double> result(res);
1981
0
          std::vector<uint32_t> words = result.GetWords();
1982
0
          return const_mgr->GetConstant(result_type, words);
1983
0
        }
1984
0
        return nullptr;
1985
217
      };
1986
19.1k
}
1987
1988
enum Sign { Signed, Unsigned };
1989
1990
// Returns a BinaryScalarFoldingRule that applies `op` to the scalars.
1991
// The `signedness` is used to determine if the operands should be interpreted
1992
// as signed or unsigned. If the operands are signed, the value will be sign
1993
// extended before the value is passed to `op`. Otherwise the values will be
1994
// zero extended.
1995
template <Sign signedness>
1996
BinaryScalarFoldingRule FoldBinaryIntegerOperation(uint64_t (*op)(uint64_t,
1997
128k
                                                                  uint64_t)) {
1998
128k
  return
1999
128k
      [op](const analysis::Type* result_type, const analysis::Constant* a,
2000
128k
           const analysis::Constant* b,
2001
860k
           analysis::ConstantManager* const_mgr) -> const analysis::Constant* {
2002
860k
        assert(result_type != nullptr && a != nullptr && b != nullptr);
2003
860k
        const analysis::Integer* integer_type = result_type->AsInteger();
2004
860k
        assert(integer_type != nullptr);
2005
860k
        assert(a->type()->kind() == analysis::Type::kInteger);
2006
860k
        assert(b->type()->kind() == analysis::Type::kInteger);
2007
860k
        assert(integer_type->width() == a->type()->AsInteger()->width());
2008
860k
        assert(integer_type->width() == b->type()->AsInteger()->width());
2009
2010
        // In SPIR-V, all operations support unsigned types, but the way they
2011
        // are interpreted depends on the opcode. This is why we use the
2012
        // template argument to determine how to interpret the operands.
2013
860k
        uint64_t ia = (signedness == Signed ? a->GetSignExtendedValue()
2014
860k
                                            : a->GetZeroExtendedValue());
2015
860k
        uint64_t ib = (signedness == Signed ? b->GetSignExtendedValue()
2016
860k
                                            : b->GetZeroExtendedValue());
2017
860k
        uint64_t result = op(ia, ib);
2018
2019
860k
        const analysis::Constant* result_constant =
2020
860k
            const_mgr->GenerateIntegerConstant(integer_type, result);
2021
860k
        return result_constant;
2022
860k
      };
const_folding_rules.cpp:spvtools::opt::(anonymous namespace)::FoldBinaryIntegerOperation<(spvtools::opt::(anonymous namespace)::Sign)1>(unsigned long (*)(unsigned long, unsigned long))::{lambda(spvtools::opt::analysis::Type const*, spvtools::opt::analysis::Constant const*, spvtools::opt::analysis::Constant const*, spvtools::opt::analysis::ConstantManager*)#1}::operator()(spvtools::opt::analysis::Type const*, spvtools::opt::analysis::Constant const*, spvtools::opt::analysis::Constant const*, spvtools::opt::analysis::ConstantManager*) const
Line
Count
Source
2001
750k
           analysis::ConstantManager* const_mgr) -> const analysis::Constant* {
2002
750k
        assert(result_type != nullptr && a != nullptr && b != nullptr);
2003
750k
        const analysis::Integer* integer_type = result_type->AsInteger();
2004
750k
        assert(integer_type != nullptr);
2005
750k
        assert(a->type()->kind() == analysis::Type::kInteger);
2006
750k
        assert(b->type()->kind() == analysis::Type::kInteger);
2007
750k
        assert(integer_type->width() == a->type()->AsInteger()->width());
2008
750k
        assert(integer_type->width() == b->type()->AsInteger()->width());
2009
2010
        // In SPIR-V, all operations support unsigned types, but the way they
2011
        // are interpreted depends on the opcode. This is why we use the
2012
        // template argument to determine how to interpret the operands.
2013
750k
        uint64_t ia = (signedness == Signed ? a->GetSignExtendedValue()
2014
750k
                                            : a->GetZeroExtendedValue());
2015
750k
        uint64_t ib = (signedness == Signed ? b->GetSignExtendedValue()
2016
750k
                                            : b->GetZeroExtendedValue());
2017
750k
        uint64_t result = op(ia, ib);
2018
2019
750k
        const analysis::Constant* result_constant =
2020
750k
            const_mgr->GenerateIntegerConstant(integer_type, result);
2021
750k
        return result_constant;
2022
750k
      };
const_folding_rules.cpp:spvtools::opt::(anonymous namespace)::FoldBinaryIntegerOperation<(spvtools::opt::(anonymous namespace)::Sign)0>(unsigned long (*)(unsigned long, unsigned long))::{lambda(spvtools::opt::analysis::Type const*, spvtools::opt::analysis::Constant const*, spvtools::opt::analysis::Constant const*, spvtools::opt::analysis::ConstantManager*)#1}::operator()(spvtools::opt::analysis::Type const*, spvtools::opt::analysis::Constant const*, spvtools::opt::analysis::Constant const*, spvtools::opt::analysis::ConstantManager*) const
Line
Count
Source
2001
110k
           analysis::ConstantManager* const_mgr) -> const analysis::Constant* {
2002
110k
        assert(result_type != nullptr && a != nullptr && b != nullptr);
2003
110k
        const analysis::Integer* integer_type = result_type->AsInteger();
2004
110k
        assert(integer_type != nullptr);
2005
110k
        assert(a->type()->kind() == analysis::Type::kInteger);
2006
110k
        assert(b->type()->kind() == analysis::Type::kInteger);
2007
110k
        assert(integer_type->width() == a->type()->AsInteger()->width());
2008
110k
        assert(integer_type->width() == b->type()->AsInteger()->width());
2009
2010
        // In SPIR-V, all operations support unsigned types, but the way they
2011
        // are interpreted depends on the opcode. This is why we use the
2012
        // template argument to determine how to interpret the operands.
2013
110k
        uint64_t ia = (signedness == Signed ? a->GetSignExtendedValue()
2014
110k
                                            : a->GetZeroExtendedValue());
2015
110k
        uint64_t ib = (signedness == Signed ? b->GetSignExtendedValue()
2016
110k
                                            : b->GetZeroExtendedValue());
2017
110k
        uint64_t result = op(ia, ib);
2018
2019
110k
        const analysis::Constant* result_constant =
2020
110k
            const_mgr->GenerateIntegerConstant(integer_type, result);
2021
110k
        return result_constant;
2022
110k
      };
2023
128k
}
const_folding_rules.cpp:std::__1::function<spvtools::opt::analysis::Constant const* (spvtools::opt::analysis::Type const*, spvtools::opt::analysis::Constant const*, spvtools::opt::analysis::Constant const*, spvtools::opt::analysis::ConstantManager*)> spvtools::opt::(anonymous namespace)::FoldBinaryIntegerOperation<(spvtools::opt::(anonymous namespace)::Sign)1>(unsigned long (*)(unsigned long, unsigned long))
Line
Count
Source
1997
80.4k
                                                                  uint64_t)) {
1998
80.4k
  return
1999
80.4k
      [op](const analysis::Type* result_type, const analysis::Constant* a,
2000
80.4k
           const analysis::Constant* b,
2001
80.4k
           analysis::ConstantManager* const_mgr) -> const analysis::Constant* {
2002
80.4k
        assert(result_type != nullptr && a != nullptr && b != nullptr);
2003
80.4k
        const analysis::Integer* integer_type = result_type->AsInteger();
2004
80.4k
        assert(integer_type != nullptr);
2005
80.4k
        assert(a->type()->kind() == analysis::Type::kInteger);
2006
80.4k
        assert(b->type()->kind() == analysis::Type::kInteger);
2007
80.4k
        assert(integer_type->width() == a->type()->AsInteger()->width());
2008
80.4k
        assert(integer_type->width() == b->type()->AsInteger()->width());
2009
2010
        // In SPIR-V, all operations support unsigned types, but the way they
2011
        // are interpreted depends on the opcode. This is why we use the
2012
        // template argument to determine how to interpret the operands.
2013
80.4k
        uint64_t ia = (signedness == Signed ? a->GetSignExtendedValue()
2014
80.4k
                                            : a->GetZeroExtendedValue());
2015
80.4k
        uint64_t ib = (signedness == Signed ? b->GetSignExtendedValue()
2016
80.4k
                                            : b->GetZeroExtendedValue());
2017
80.4k
        uint64_t result = op(ia, ib);
2018
2019
80.4k
        const analysis::Constant* result_constant =
2020
80.4k
            const_mgr->GenerateIntegerConstant(integer_type, result);
2021
80.4k
        return result_constant;
2022
80.4k
      };
2023
80.4k
}
const_folding_rules.cpp:std::__1::function<spvtools::opt::analysis::Constant const* (spvtools::opt::analysis::Type const*, spvtools::opt::analysis::Constant const*, spvtools::opt::analysis::Constant const*, spvtools::opt::analysis::ConstantManager*)> spvtools::opt::(anonymous namespace)::FoldBinaryIntegerOperation<(spvtools::opt::(anonymous namespace)::Sign)0>(unsigned long (*)(unsigned long, unsigned long))
Line
Count
Source
1997
48.2k
                                                                  uint64_t)) {
1998
48.2k
  return
1999
48.2k
      [op](const analysis::Type* result_type, const analysis::Constant* a,
2000
48.2k
           const analysis::Constant* b,
2001
48.2k
           analysis::ConstantManager* const_mgr) -> const analysis::Constant* {
2002
48.2k
        assert(result_type != nullptr && a != nullptr && b != nullptr);
2003
48.2k
        const analysis::Integer* integer_type = result_type->AsInteger();
2004
48.2k
        assert(integer_type != nullptr);
2005
48.2k
        assert(a->type()->kind() == analysis::Type::kInteger);
2006
48.2k
        assert(b->type()->kind() == analysis::Type::kInteger);
2007
48.2k
        assert(integer_type->width() == a->type()->AsInteger()->width());
2008
48.2k
        assert(integer_type->width() == b->type()->AsInteger()->width());
2009
2010
        // In SPIR-V, all operations support unsigned types, but the way they
2011
        // are interpreted depends on the opcode. This is why we use the
2012
        // template argument to determine how to interpret the operands.
2013
48.2k
        uint64_t ia = (signedness == Signed ? a->GetSignExtendedValue()
2014
48.2k
                                            : a->GetZeroExtendedValue());
2015
48.2k
        uint64_t ib = (signedness == Signed ? b->GetSignExtendedValue()
2016
48.2k
                                            : b->GetZeroExtendedValue());
2017
48.2k
        uint64_t result = op(ia, ib);
2018
2019
48.2k
        const analysis::Constant* result_constant =
2020
48.2k
            const_mgr->GenerateIntegerConstant(integer_type, result);
2021
48.2k
        return result_constant;
2022
48.2k
      };
2023
48.2k
}
2024
2025
// A scalar folding rule that folds OpSConvert.
2026
const analysis::Constant* FoldScalarSConvert(
2027
    const analysis::Type* result_type, const analysis::Constant* a,
2028
0
    analysis::ConstantManager* const_mgr) {
2029
0
  assert(result_type != nullptr);
2030
0
  assert(a != nullptr);
2031
0
  assert(const_mgr != nullptr);
2032
0
  const analysis::Integer* integer_type = result_type->AsInteger();
2033
0
  assert(integer_type && "The result type of an SConvert");
2034
0
  int64_t value = a->GetSignExtendedValue();
2035
0
  return const_mgr->GenerateIntegerConstant(integer_type, value);
2036
0
}
2037
2038
// A scalar folding rule that folds OpUConvert.
2039
const analysis::Constant* FoldScalarUConvert(
2040
    const analysis::Type* result_type, const analysis::Constant* a,
2041
0
    analysis::ConstantManager* const_mgr) {
2042
0
  assert(result_type != nullptr);
2043
0
  assert(a != nullptr);
2044
0
  assert(const_mgr != nullptr);
2045
0
  const analysis::Integer* integer_type = result_type->AsInteger();
2046
0
  assert(integer_type && "The result type of an UConvert");
2047
0
  uint64_t value = a->GetZeroExtendedValue();
2048
2049
  // If the operand was an unsigned value with less than 32-bit, it would have
2050
  // been sign extended earlier, and we need to clear those bits.
2051
0
  auto* operand_type = a->type()->AsInteger();
2052
0
  value = utils::ClearHighBits(value, 64 - operand_type->width());
2053
0
  return const_mgr->GenerateIntegerConstant(integer_type, value);
2054
0
}
2055
}  // namespace
2056
2057
16.0k
void ConstantFoldingRules::AddFoldingRules() {
2058
  // Add all folding rules to the list for the opcodes to which they apply.
2059
  // Note that the order in which rules are added to the list matters. If a rule
2060
  // applies to the instruction, the rest of the rules will not be attempted.
2061
  // Take that into consideration.
2062
2063
16.0k
  rules_[spv::Op::OpCompositeConstruct].push_back(FoldCompositeWithConstants());
2064
2065
16.0k
  rules_[spv::Op::OpCompositeExtract].push_back(FoldExtractWithConstants());
2066
16.0k
  rules_[spv::Op::OpCompositeInsert].push_back(FoldInsertWithConstants());
2067
2068
16.0k
  rules_[spv::Op::OpConvertFToS].push_back(FoldFToI());
2069
16.0k
  rules_[spv::Op::OpConvertFToU].push_back(FoldFToI());
2070
16.0k
  rules_[spv::Op::OpConvertSToF].push_back(FoldIToF());
2071
16.0k
  rules_[spv::Op::OpConvertUToF].push_back(FoldIToF());
2072
16.0k
  rules_[spv::Op::OpSConvert].push_back(FoldUnaryOp(FoldScalarSConvert));
2073
16.0k
  rules_[spv::Op::OpUConvert].push_back(FoldUnaryOp(FoldScalarUConvert));
2074
2075
16.0k
  rules_[spv::Op::OpDot].push_back(FoldOpDotWithConstants());
2076
16.0k
  rules_[spv::Op::OpFAdd].push_back(FoldFAdd());
2077
2078
16.0k
  rules_[spv::Op::OpFDiv].push_back(FoldFDiv());
2079
16.0k
  rules_[spv::Op::OpFDiv].push_back(FoldRedundantDiv());
2080
2081
16.0k
  rules_[spv::Op::OpFMul].push_back(FoldFMul());
2082
2083
16.0k
  rules_[spv::Op::OpFSub].push_back(FoldFSub());
2084
16.0k
  rules_[spv::Op::OpFSub].push_back(FoldRedundantSub());
2085
2086
16.0k
  rules_[spv::Op::OpSelect].push_back(FoldInvariantSelect());
2087
2088
16.0k
  rules_[spv::Op::OpFOrdEqual].push_back(FoldFOrdEqual());
2089
2090
16.0k
  rules_[spv::Op::OpFUnordEqual].push_back(FoldFUnordEqual());
2091
2092
16.0k
  rules_[spv::Op::OpFOrdNotEqual].push_back(FoldFOrdNotEqual());
2093
2094
16.0k
  rules_[spv::Op::OpFUnordNotEqual].push_back(FoldFUnordNotEqual());
2095
2096
16.0k
  rules_[spv::Op::OpFOrdLessThan].push_back(FoldFOrdLessThan());
2097
16.0k
  rules_[spv::Op::OpFOrdLessThan].push_back(
2098
16.0k
      FoldFClampFeedingCompare(spv::Op::OpFOrdLessThan));
2099
2100
16.0k
  rules_[spv::Op::OpFUnordLessThan].push_back(FoldFUnordLessThan());
2101
16.0k
  rules_[spv::Op::OpFUnordLessThan].push_back(
2102
16.0k
      FoldFClampFeedingCompare(spv::Op::OpFUnordLessThan));
2103
2104
16.0k
  rules_[spv::Op::OpFOrdGreaterThan].push_back(FoldFOrdGreaterThan());
2105
16.0k
  rules_[spv::Op::OpFOrdGreaterThan].push_back(
2106
16.0k
      FoldFClampFeedingCompare(spv::Op::OpFOrdGreaterThan));
2107
2108
16.0k
  rules_[spv::Op::OpFUnordGreaterThan].push_back(FoldFUnordGreaterThan());
2109
16.0k
  rules_[spv::Op::OpFUnordGreaterThan].push_back(
2110
16.0k
      FoldFClampFeedingCompare(spv::Op::OpFUnordGreaterThan));
2111
2112
16.0k
  rules_[spv::Op::OpFOrdLessThanEqual].push_back(FoldFOrdLessThanEqual());
2113
16.0k
  rules_[spv::Op::OpFOrdLessThanEqual].push_back(
2114
16.0k
      FoldFClampFeedingCompare(spv::Op::OpFOrdLessThanEqual));
2115
2116
16.0k
  rules_[spv::Op::OpFUnordLessThanEqual].push_back(FoldFUnordLessThanEqual());
2117
16.0k
  rules_[spv::Op::OpFUnordLessThanEqual].push_back(
2118
16.0k
      FoldFClampFeedingCompare(spv::Op::OpFUnordLessThanEqual));
2119
2120
16.0k
  rules_[spv::Op::OpFOrdGreaterThanEqual].push_back(FoldFOrdGreaterThanEqual());
2121
16.0k
  rules_[spv::Op::OpFOrdGreaterThanEqual].push_back(
2122
16.0k
      FoldFClampFeedingCompare(spv::Op::OpFOrdGreaterThanEqual));
2123
2124
16.0k
  rules_[spv::Op::OpFUnordGreaterThanEqual].push_back(
2125
16.0k
      FoldFUnordGreaterThanEqual());
2126
16.0k
  rules_[spv::Op::OpFUnordGreaterThanEqual].push_back(
2127
16.0k
      FoldFClampFeedingCompare(spv::Op::OpFUnordGreaterThanEqual));
2128
2129
16.0k
  rules_[spv::Op::OpVectorShuffle].push_back(FoldVectorShuffleWithConstants());
2130
16.0k
  rules_[spv::Op::OpVectorTimesScalar].push_back(FoldVectorTimesScalar());
2131
16.0k
  rules_[spv::Op::OpVectorTimesMatrix].push_back(FoldVectorTimesMatrix());
2132
16.0k
  rules_[spv::Op::OpMatrixTimesVector].push_back(FoldMatrixTimesVector());
2133
16.0k
  rules_[spv::Op::OpTranspose].push_back(FoldTranspose);
2134
2135
16.0k
  rules_[spv::Op::OpFNegate].push_back(FoldFNegate());
2136
16.0k
  rules_[spv::Op::OpSNegate].push_back(FoldSNegate());
2137
16.0k
  rules_[spv::Op::OpQuantizeToF16].push_back(FoldQuantizeToF16());
2138
2139
16.0k
  rules_[spv::Op::OpIAdd].push_back(
2140
16.0k
      FoldBinaryOp(FoldBinaryIntegerOperation<Unsigned>(
2141
466k
          [](uint64_t a, uint64_t b) { return a + b; })));
2142
2143
16.0k
  rules_[spv::Op::OpISub].push_back(
2144
16.0k
      FoldBinaryOp(FoldBinaryIntegerOperation<Unsigned>(
2145
218k
          [](uint64_t a, uint64_t b) { return a - b; })));
2146
16.0k
  rules_[spv::Op::OpISub].push_back(FoldRedundantSub());
2147
2148
16.0k
  rules_[spv::Op::OpIMul].push_back(
2149
16.0k
      FoldBinaryOp(FoldBinaryIntegerOperation<Unsigned>(
2150
64.5k
          [](uint64_t a, uint64_t b) { return a * b; })));
2151
2152
16.0k
  rules_[spv::Op::OpUDiv].push_back(
2153
16.0k
      FoldBinaryOp(FoldBinaryIntegerOperation<Unsigned>(
2154
16.0k
          [](uint64_t a, uint64_t b) { return (b != 0 ? a / b : 0); })));
2155
16.0k
  rules_[spv::Op::OpUDiv].push_back(FoldRedundantDiv());
2156
2157
16.0k
  rules_[spv::Op::OpSDiv].push_back(FoldBinaryOp(
2158
38.8k
      FoldBinaryIntegerOperation<Signed>([](uint64_t a, uint64_t b) {
2159
38.8k
        return (b != 0 ? static_cast<uint64_t>(static_cast<int64_t>(a) /
2160
37.9k
                                               static_cast<int64_t>(b))
2161
38.8k
                       : 0);
2162
38.8k
      })));
2163
16.0k
  rules_[spv::Op::OpSDiv].push_back(FoldRedundantDiv());
2164
2165
16.0k
  rules_[spv::Op::OpUMod].push_back(
2166
16.0k
      FoldBinaryOp(FoldBinaryIntegerOperation<Unsigned>(
2167
16.0k
          [](uint64_t a, uint64_t b) { return (b != 0 ? a % b : 0); })));
2168
2169
16.0k
  rules_[spv::Op::OpSRem].push_back(FoldBinaryOp(
2170
65.9k
      FoldBinaryIntegerOperation<Signed>([](uint64_t a, uint64_t b) {
2171
65.9k
        return (b != 0 ? static_cast<uint64_t>(static_cast<int64_t>(a) %
2172
65.1k
                                               static_cast<int64_t>(b))
2173
65.9k
                       : 0);
2174
65.9k
      })));
2175
2176
16.0k
  rules_[spv::Op::OpSMod].push_back(FoldBinaryOp(
2177
16.0k
      FoldBinaryIntegerOperation<Signed>([](uint64_t a, uint64_t b) {
2178
5.37k
        if (b == 0) return static_cast<uint64_t>(0ull);
2179
2180
5.14k
        int64_t signed_a = static_cast<int64_t>(a);
2181
5.14k
        int64_t signed_b = static_cast<int64_t>(b);
2182
5.14k
        int64_t result = signed_a % signed_b;
2183
5.14k
        if ((signed_b < 0) != (result < 0)) result += signed_b;
2184
5.14k
        return static_cast<uint64_t>(result);
2185
5.37k
      })));
2186
2187
  // Add rules for GLSLstd450
2188
16.0k
  FeatureManager* feature_manager = context_->get_feature_mgr();
2189
16.0k
  uint32_t ext_inst_glslstd450_id =
2190
16.0k
      feature_manager->GetExtInstImportId_GLSLstd450();
2191
16.0k
  if (ext_inst_glslstd450_id != 0) {
2192
9.59k
    ext_rules_[{ext_inst_glslstd450_id, GLSLstd450FMix}].push_back(FoldFMix());
2193
9.59k
    ext_rules_[{ext_inst_glslstd450_id, GLSLstd450SMin}].push_back(
2194
9.59k
        FoldFPBinaryOp(FoldMin));
2195
9.59k
    ext_rules_[{ext_inst_glslstd450_id, GLSLstd450UMin}].push_back(
2196
9.59k
        FoldFPBinaryOp(FoldMin));
2197
9.59k
    ext_rules_[{ext_inst_glslstd450_id, GLSLstd450FMin}].push_back(
2198
9.59k
        FoldFPBinaryOp(FoldMin));
2199
9.59k
    ext_rules_[{ext_inst_glslstd450_id, GLSLstd450NMin}].push_back(
2200
9.59k
        FoldFPBinaryOp(FoldNMin));
2201
9.59k
    ext_rules_[{ext_inst_glslstd450_id, GLSLstd450SMax}].push_back(
2202
9.59k
        FoldFPBinaryOp(FoldMax));
2203
9.59k
    ext_rules_[{ext_inst_glslstd450_id, GLSLstd450UMax}].push_back(
2204
9.59k
        FoldFPBinaryOp(FoldMax));
2205
9.59k
    ext_rules_[{ext_inst_glslstd450_id, GLSLstd450FMax}].push_back(
2206
9.59k
        FoldFPBinaryOp(FoldMax));
2207
9.59k
    ext_rules_[{ext_inst_glslstd450_id, GLSLstd450NMax}].push_back(
2208
9.59k
        FoldFPBinaryOp(FoldNMax));
2209
9.59k
    ext_rules_[{ext_inst_glslstd450_id, GLSLstd450UClamp}].push_back(
2210
9.59k
        FoldClamp1);
2211
9.59k
    ext_rules_[{ext_inst_glslstd450_id, GLSLstd450UClamp}].push_back(
2212
9.59k
        FoldClamp2);
2213
9.59k
    ext_rules_[{ext_inst_glslstd450_id, GLSLstd450UClamp}].push_back(
2214
9.59k
        FoldClamp3);
2215
9.59k
    ext_rules_[{ext_inst_glslstd450_id, GLSLstd450SClamp}].push_back(
2216
9.59k
        FoldClamp1);
2217
9.59k
    ext_rules_[{ext_inst_glslstd450_id, GLSLstd450SClamp}].push_back(
2218
9.59k
        FoldClamp2);
2219
9.59k
    ext_rules_[{ext_inst_glslstd450_id, GLSLstd450SClamp}].push_back(
2220
9.59k
        FoldClamp3);
2221
9.59k
    ext_rules_[{ext_inst_glslstd450_id, GLSLstd450FClamp}].push_back(
2222
9.59k
        FoldClamp1);
2223
9.59k
    ext_rules_[{ext_inst_glslstd450_id, GLSLstd450FClamp}].push_back(
2224
9.59k
        FoldClamp2);
2225
9.59k
    ext_rules_[{ext_inst_glslstd450_id, GLSLstd450FClamp}].push_back(
2226
9.59k
        FoldClamp3);
2227
9.59k
    ext_rules_[{ext_inst_glslstd450_id, GLSLstd450NClamp}].push_back(
2228
9.59k
        FoldNClamp1);
2229
9.59k
    ext_rules_[{ext_inst_glslstd450_id, GLSLstd450NClamp}].push_back(
2230
9.59k
        FoldNClamp2);
2231
9.59k
    ext_rules_[{ext_inst_glslstd450_id, GLSLstd450NClamp}].push_back(
2232
9.59k
        FoldNClamp3);
2233
9.59k
    ext_rules_[{ext_inst_glslstd450_id, GLSLstd450Sin}].push_back(
2234
9.59k
        FoldFPUnaryOp(FoldFTranscendentalUnary(std::sin)));
2235
9.59k
    ext_rules_[{ext_inst_glslstd450_id, GLSLstd450Cos}].push_back(
2236
9.59k
        FoldFPUnaryOp(FoldFTranscendentalUnary(std::cos)));
2237
9.59k
    ext_rules_[{ext_inst_glslstd450_id, GLSLstd450Tan}].push_back(
2238
9.59k
        FoldFPUnaryOp(FoldFTranscendentalUnary(std::tan)));
2239
9.59k
    ext_rules_[{ext_inst_glslstd450_id, GLSLstd450Asin}].push_back(
2240
9.59k
        FoldFPUnaryOp(FoldFTranscendentalUnary(std::asin)));
2241
9.59k
    ext_rules_[{ext_inst_glslstd450_id, GLSLstd450Acos}].push_back(
2242
9.59k
        FoldFPUnaryOp(FoldFTranscendentalUnary(std::acos)));
2243
9.59k
    ext_rules_[{ext_inst_glslstd450_id, GLSLstd450Atan}].push_back(
2244
9.59k
        FoldFPUnaryOp(FoldFTranscendentalUnary(std::atan)));
2245
9.59k
    ext_rules_[{ext_inst_glslstd450_id, GLSLstd450Exp}].push_back(
2246
9.59k
        FoldFPUnaryOp(FoldFTranscendentalUnary(std::exp)));
2247
9.59k
    ext_rules_[{ext_inst_glslstd450_id, GLSLstd450Log}].push_back(
2248
9.59k
        FoldFPUnaryOp(FoldFTranscendentalUnary(std::log)));
2249
2250
#ifdef __ANDROID__
2251
    // Android NDK r15c targeting ABI 15 doesn't have full support for C++11
2252
    // (no std::exp2/log2). ::exp2 is available from C99 but ::log2 isn't
2253
    // available up until ABI 18 so we use a shim
2254
    auto log2_shim = [](double v) -> double { return log(v) / log(2.0); };
2255
    ext_rules_[{ext_inst_glslstd450_id, GLSLstd450Exp2}].push_back(
2256
        FoldFPUnaryOp(FoldFTranscendentalUnary(::exp2)));
2257
    ext_rules_[{ext_inst_glslstd450_id, GLSLstd450Log2}].push_back(
2258
        FoldFPUnaryOp(FoldFTranscendentalUnary(log2_shim)));
2259
#else
2260
9.59k
    ext_rules_[{ext_inst_glslstd450_id, GLSLstd450Exp2}].push_back(
2261
9.59k
        FoldFPUnaryOp(FoldFTranscendentalUnary(std::exp2)));
2262
9.59k
    ext_rules_[{ext_inst_glslstd450_id, GLSLstd450Log2}].push_back(
2263
9.59k
        FoldFPUnaryOp(FoldFTranscendentalUnary(std::log2)));
2264
9.59k
#endif
2265
2266
9.59k
    ext_rules_[{ext_inst_glslstd450_id, GLSLstd450Sqrt}].push_back(
2267
9.59k
        FoldFPUnaryOp(FoldFTranscendentalUnary(std::sqrt)));
2268
9.59k
    ext_rules_[{ext_inst_glslstd450_id, GLSLstd450Atan2}].push_back(
2269
9.59k
        FoldFPBinaryOp(FoldFTranscendentalBinary(std::atan2)));
2270
9.59k
    ext_rules_[{ext_inst_glslstd450_id, GLSLstd450Pow}].push_back(
2271
9.59k
        FoldFPBinaryOp(FoldFTranscendentalBinary(std::pow)));
2272
9.59k
  }
2273
16.0k
}
2274
}  // namespace opt
2275
}  // namespace spvtools