Coverage Report

Created: 2026-06-30 06:51

next uncovered line (L), next uncovered region (R), next uncovered branch (B)
/src/spirv-tools/source/opt/folding_rules.cpp
Line
Count
Source
1
// Copyright (c) 2018 Google LLC
2
//
3
// Licensed under the Apache License, Version 2.0 (the "License");
4
// you may not use this file except in compliance with the License.
5
// You may obtain a copy of the License at
6
//
7
//     http://www.apache.org/licenses/LICENSE-2.0
8
//
9
// Unless required by applicable law or agreed to in writing, software
10
// distributed under the License is distributed on an "AS IS" BASIS,
11
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
// See the License for the specific language governing permissions and
13
// limitations under the License.
14
15
#include "source/opt/folding_rules.h"
16
17
#include <limits>
18
#include <memory>
19
#include <optional>
20
#include <utility>
21
22
#include "ir_builder.h"
23
#include "source/latest_version_glsl_std_450_header.h"
24
#include "source/opt/ir_context.h"
25
26
namespace spvtools {
27
namespace opt {
28
namespace {
29
30
constexpr uint32_t kExtractCompositeIdInIdx = 0;
31
constexpr uint32_t kInsertObjectIdInIdx = 0;
32
constexpr uint32_t kInsertCompositeIdInIdx = 1;
33
constexpr uint32_t kExtInstSetIdInIdx = 0;
34
constexpr uint32_t kExtInstInstructionInIdx = 1;
35
constexpr uint32_t kFMixXIdInIdx = 2;
36
constexpr uint32_t kFMixYIdInIdx = 3;
37
constexpr uint32_t kFMixAIdInIdx = 4;
38
constexpr uint32_t kStoreObjectInIdx = 1;
39
40
// Some image instructions may contain an "image operands" argument.
41
// Returns the operand index for the "image operands".
42
// Returns -1 if the instruction does not have image operands.
43
270k
int32_t ImageOperandsMaskInOperandIndex(Instruction* inst) {
44
270k
  const auto opcode = inst->opcode();
45
270k
  switch (opcode) {
46
270k
    case spv::Op::OpImageSampleImplicitLod:
47
270k
    case spv::Op::OpImageSampleExplicitLod:
48
270k
    case spv::Op::OpImageSampleProjImplicitLod:
49
270k
    case spv::Op::OpImageSampleProjExplicitLod:
50
270k
    case spv::Op::OpImageFetch:
51
270k
    case spv::Op::OpImageRead:
52
270k
    case spv::Op::OpImageSparseSampleImplicitLod:
53
270k
    case spv::Op::OpImageSparseSampleExplicitLod:
54
270k
    case spv::Op::OpImageSparseSampleProjImplicitLod:
55
270k
    case spv::Op::OpImageSparseSampleProjExplicitLod:
56
270k
    case spv::Op::OpImageSparseFetch:
57
270k
    case spv::Op::OpImageSparseRead:
58
270k
      return inst->NumOperands() > 4 ? 2 : -1;
59
2
    case spv::Op::OpImageSampleDrefImplicitLod:
60
2
    case spv::Op::OpImageSampleDrefExplicitLod:
61
2
    case spv::Op::OpImageSampleProjDrefImplicitLod:
62
2
    case spv::Op::OpImageSampleProjDrefExplicitLod:
63
2
    case spv::Op::OpImageGather:
64
2
    case spv::Op::OpImageDrefGather:
65
2
    case spv::Op::OpImageSparseSampleDrefImplicitLod:
66
2
    case spv::Op::OpImageSparseSampleDrefExplicitLod:
67
2
    case spv::Op::OpImageSparseSampleProjDrefImplicitLod:
68
2
    case spv::Op::OpImageSparseSampleProjDrefExplicitLod:
69
2
    case spv::Op::OpImageSparseGather:
70
2
    case spv::Op::OpImageSparseDrefGather:
71
2
      return inst->NumOperands() > 5 ? 3 : -1;
72
0
    case spv::Op::OpImageWrite:
73
0
      return inst->NumOperands() > 3 ? 3 : -1;
74
0
    default:
75
0
      return -1;
76
270k
  }
77
270k
}
78
79
// Returns the element width of |type|.
80
5.76M
uint32_t ElementWidth(const analysis::Type* type) {
81
5.76M
  if (const analysis::CooperativeVectorNV* coopvec_type =
82
5.76M
          type->AsCooperativeVectorNV()) {
83
0
    return ElementWidth(coopvec_type->component_type());
84
5.76M
  } else if (const analysis::Vector* vec_type = type->AsVector()) {
85
2.13M
    return ElementWidth(vec_type->element_type());
86
3.63M
  } else if (const analysis::Float* float_type = type->AsFloat()) {
87
3.03M
    return float_type->width();
88
3.03M
  } else {
89
600k
    assert(type->AsInteger());
90
600k
    return type->AsInteger()->width();
91
600k
  }
92
5.76M
}
93
94
// Returns true if |type| is Float or a vector of Float.
95
5.94M
bool HasFloatingPoint(const analysis::Type* type) {
96
5.94M
  if (type->AsFloat()) {
97
1.86M
    return true;
98
4.08M
  } else if (const analysis::Vector* vec_type = type->AsVector()) {
99
3.35M
    return vec_type->element_type()->AsFloat() != nullptr;
100
3.35M
  }
101
102
731k
  return false;
103
5.94M
}
104
105
// Returns false if |val| is NaN, infinite or subnormal.
106
template <typename T>
107
176k
bool IsValidResult(T val) {
108
176k
  int classified = std::fpclassify(val);
109
176k
  switch (classified) {
110
6.14k
    case FP_NAN:
111
50.0k
    case FP_INFINITE:
112
53.4k
    case FP_SUBNORMAL:
113
53.4k
      return false;
114
122k
    default:
115
122k
      return true;
116
176k
  }
117
176k
}
Unexecuted instantiation: folding_rules.cpp:bool spvtools::opt::(anonymous namespace)::IsValidResult<double>(double)
folding_rules.cpp:bool spvtools::opt::(anonymous namespace)::IsValidResult<float>(float)
Line
Count
Source
107
176k
bool IsValidResult(T val) {
108
176k
  int classified = std::fpclassify(val);
109
176k
  switch (classified) {
110
6.14k
    case FP_NAN:
111
50.0k
    case FP_INFINITE:
112
53.4k
    case FP_SUBNORMAL:
113
53.4k
      return false;
114
122k
    default:
115
122k
      return true;
116
176k
  }
117
176k
}
118
119
const analysis::Constant* ConstInput(
120
4.14M
    const std::vector<const analysis::Constant*>& constants) {
121
4.14M
  return constants[0] ? constants[0] : constants[1];
122
4.14M
}
123
124
Instruction* NonConstInput(IRContext* context, const analysis::Constant* c,
125
1.98M
                           Instruction* inst) {
126
1.98M
  uint32_t in_op = c ? 1u : 0u;
127
1.98M
  return context->get_def_use_mgr()->GetDef(
128
1.98M
      inst->GetSingleWordInOperand(in_op));
129
1.98M
}
130
131
0
std::vector<uint32_t> ExtractInts(uint64_t val) {
132
0
  std::vector<uint32_t> words;
133
0
  words.push_back(static_cast<uint32_t>(val));
134
0
  words.push_back(static_cast<uint32_t>(val >> 32));
135
0
  return words;
136
0
}
137
138
std::vector<uint32_t> GetWordsFromScalarIntConstant(
139
1.29k
    const analysis::IntConstant* c) {
140
1.29k
  assert(c != nullptr);
141
1.29k
  uint32_t width = c->type()->AsInteger()->width();
142
1.29k
  assert(width == 8 || width == 16 || width == 32 || width == 64);
143
1.29k
  if (width == 64) {
144
0
    uint64_t uval = static_cast<uint64_t>(c->GetU64());
145
0
    return ExtractInts(uval);
146
0
  }
147
  // Section 2.2.1 of the SPIR-V spec guarantees that all integer types
148
  // smaller than 32-bits are automatically zero or sign extended to 32-bits.
149
1.29k
  return {c->GetU32BitValue()};
150
1.29k
}
151
152
std::vector<uint32_t> GetWordsFromScalarFloatConstant(
153
274
    const analysis::FloatConstant* c) {
154
274
  assert(c != nullptr);
155
274
  uint32_t width = c->type()->AsFloat()->width();
156
274
  assert(width == 16 || width == 32 || width == 64);
157
274
  if (width == 64) {
158
0
    utils::FloatProxy<double> result(c->GetDouble());
159
0
    return result.GetWords();
160
0
  }
161
  // Section 2.2.1 of the SPIR-V spec guarantees that all floating-point types
162
  // smaller than 32-bits are automatically zero extended to 32-bits.
163
274
  return {c->GetU32BitValue()};
164
274
}
165
166
std::vector<uint32_t> GetWordsFromNumericScalarOrVectorConstant(
167
2.21k
    analysis::ConstantManager* const_mgr, const analysis::Constant* c) {
168
2.21k
  if (const auto* float_constant = c->AsFloatConstant()) {
169
274
    return GetWordsFromScalarFloatConstant(float_constant);
170
1.93k
  } else if (const auto* int_constant = c->AsIntConstant()) {
171
1.29k
    return GetWordsFromScalarIntConstant(int_constant);
172
1.29k
  } else if (const auto* vec_constant = c->AsVectorConstant()) {
173
86
    std::vector<uint32_t> words;
174
    // Retrieve all the components as 32bit words.
175
308
    for (const auto* comp : vec_constant->GetComponents()) {
176
308
      auto comp_in_words =
177
308
          GetWordsFromNumericScalarOrVectorConstant(const_mgr, comp);
178
308
      words.insert(words.end(), comp_in_words.begin(), comp_in_words.end());
179
308
    }
180
181
86
    if (ElementWidth(c->type()) >= 32) {
182
86
      return words;
183
86
    }
184
    // Check the element width and concactenate if the width is less than 32.
185
0
    if (ElementWidth(c->type()) == 8) {
186
0
      assert(words.size() <= 4);
187
      // Each 32-bit word will comprise 4 8-bit integers.
188
      // reverse the order when compacting.
189
0
      uint32_t compacted_word = 0;
190
0
      for (int32_t i = static_cast<int32_t>(words.size()) - 1; i >= 0; --i) {
191
0
        compacted_word <<= 8;
192
0
        compacted_word |= (words[i] & 0xFF);
193
0
      }
194
0
      return {compacted_word};
195
0
    } else if (ElementWidth(c->type()) == 16) {
196
0
      assert(words.size() <= 4);
197
0
      std::vector<uint32_t> compacted_words;
198
      // Each 32-bit word will comprise 2 16-bit integers.
199
      // reverse the order pair-wise when compacting.
200
0
      for (uint32_t i = 0; i < words.size(); i += 2) {
201
0
        uint32_t word1 = words[i];
202
0
        uint32_t word2 = (i + 1 < words.size()) ? words[i + 1] : 0;
203
0
        uint32_t compacted_word = (word2 << 16) | (word1 & 0xFFFF);
204
0
        compacted_words.push_back(compacted_word);
205
0
      }
206
0
      return compacted_words;
207
0
    }
208
0
    assert(false && "Unhandled element width");
209
552
  } else if (c->AsNullConstant()) {
210
552
    uint32_t num_elements = 1;
211
212
552
    if (const auto* vec_type = c->type()->AsVector()) {
213
0
      num_elements = vec_type->element_count();
214
0
    }
215
216
    // We need to check the element width to determine how many 32-bit words are
217
    // needed.
218
552
    uint32_t element_width = ElementWidth(c->type());
219
552
    if (element_width < 32) {
220
0
      num_elements = (num_elements + 1) / 2;
221
552
    } else if (element_width == 64) {
222
0
      num_elements = num_elements * 2;
223
0
    }
224
552
    return std::vector<uint32_t>(num_elements, 0);
225
552
  }
226
0
  return {};
227
2.21k
}
228
229
const analysis::Constant* ConvertWordsToNumericScalarOrVectorConstant(
230
    analysis::ConstantManager* const_mgr, const std::vector<uint32_t>& words,
231
1.90k
    const analysis::Type* type) {
232
1.90k
  const spvtools::opt::analysis::Integer* int_type = type->AsInteger();
233
234
1.90k
  if (int_type && int_type->width() <= 32) {
235
876
    assert(words.size() == 1);
236
876
    return const_mgr->GenerateIntegerConstant(int_type, words[0]);
237
876
  }
238
239
1.02k
  if (int_type || type->AsFloat()) return const_mgr->GetConstant(type, words);
240
86
  if (const auto* vec_type = type->AsVector())
241
86
    return const_mgr->GetNumericVectorConstantWithWords(vec_type, words);
242
0
  return nullptr;
243
86
}
244
245
// Returns the negation of |c|. |c| must be a 32 or 64 bit floating point
246
// constant.
247
uint32_t NegateFloatingPointConstant(analysis::ConstantManager* const_mgr,
248
2.50k
                                     const analysis::Constant* c) {
249
2.50k
  assert(c);
250
2.50k
  assert(c->type()->AsFloat());
251
2.50k
  uint32_t width = c->type()->AsFloat()->width();
252
2.50k
  assert(width == 32 || width == 64);
253
2.50k
  std::vector<uint32_t> words;
254
2.50k
  if (width == 64) {
255
0
    utils::FloatProxy<double> result(c->GetDouble() * -1.0);
256
0
    words = result.GetWords();
257
2.50k
  } else {
258
2.50k
    utils::FloatProxy<float> result(c->GetFloat() * -1.0f);
259
2.50k
    words = result.GetWords();
260
2.50k
  }
261
262
2.50k
  const analysis::Constant* negated_const =
263
2.50k
      const_mgr->GetConstant(c->type(), std::move(words));
264
2.50k
  return const_mgr->GetDefiningInstruction(negated_const)->result_id();
265
2.50k
}
266
267
// Negates the integer constant |c|. Returns the id of the defining instruction.
268
uint32_t NegateIntegerConstant(analysis::ConstantManager* const_mgr,
269
616
                               const analysis::Constant* c) {
270
616
  assert(c);
271
616
  assert(c->type()->AsInteger());
272
616
  uint32_t width = c->type()->AsInteger()->width();
273
616
  assert(width == 32 || width == 64);
274
616
  std::vector<uint32_t> words;
275
616
  if (width == 64) {
276
0
    uint64_t uval = static_cast<uint64_t>(0 - c->GetU64());
277
0
    words = ExtractInts(uval);
278
616
  } else {
279
616
    words.push_back(static_cast<uint32_t>(0 - c->GetU32()));
280
616
  }
281
282
616
  const analysis::Constant* negated_const =
283
616
      const_mgr->GetConstant(c->type(), std::move(words));
284
616
  return const_mgr->GetDefiningInstruction(negated_const)->result_id();
285
616
}
286
287
// Negates the vector constant |c|. Returns the id of the defining instruction.
288
uint32_t NegateVectorConstant(analysis::ConstantManager* const_mgr,
289
656
                              const analysis::Constant* c) {
290
656
  assert(const_mgr && c);
291
656
  assert(c->type()->AsVector());
292
656
  if (c->AsNullConstant()) {
293
    // 0.0 vs -0.0 shouldn't matter.
294
0
    return const_mgr->GetDefiningInstruction(c)->result_id();
295
656
  } else {
296
656
    const analysis::Type* component_type =
297
656
        c->AsVectorConstant()->component_type();
298
656
    std::vector<uint32_t> words;
299
1.31k
    for (auto& comp : c->AsVectorConstant()->GetComponents()) {
300
1.31k
      if (component_type->AsFloat()) {
301
1.31k
        words.push_back(NegateFloatingPointConstant(const_mgr, comp));
302
1.31k
      } else {
303
0
        assert(component_type->AsInteger());
304
0
        words.push_back(NegateIntegerConstant(const_mgr, comp));
305
0
      }
306
1.31k
    }
307
308
656
    const analysis::Constant* negated_const =
309
656
        const_mgr->GetConstant(c->type(), std::move(words));
310
656
    return const_mgr->GetDefiningInstruction(negated_const)->result_id();
311
656
  }
312
656
}
313
314
// Negates |c|. Returns the id of the defining instruction.
315
uint32_t NegateConstant(analysis::ConstantManager* const_mgr,
316
2.46k
                        const analysis::Constant* c) {
317
2.46k
  if (c->type()->AsVector()) {
318
656
    return NegateVectorConstant(const_mgr, c);
319
1.80k
  } else if (c->type()->AsFloat()) {
320
1.19k
    return NegateFloatingPointConstant(const_mgr, c);
321
1.19k
  } else {
322
616
    assert(c->type()->AsInteger());
323
616
    return NegateIntegerConstant(const_mgr, c);
324
616
  }
325
2.46k
}
326
327
// Takes the reciprocal of |c|. |c|'s type must be Float or a vector of Float.
328
// Returns 0 if the reciprocal is NaN, infinite or subnormal.
329
uint32_t Reciprocal(analysis::ConstantManager* const_mgr,
330
152k
                    const analysis::Constant* c) {
331
152k
  assert(const_mgr && c);
332
152k
  assert(c->type()->AsFloat());
333
334
152k
  uint32_t width = c->type()->AsFloat()->width();
335
152k
  assert(width == 32 || width == 64);
336
152k
  std::vector<uint32_t> words;
337
338
152k
  if (c->IsZero()) {
339
18.1k
    return 0;
340
18.1k
  }
341
342
133k
  if (width == 64) {
343
0
    spvtools::utils::FloatProxy<double> result(1.0 / c->GetDouble());
344
0
    if (!IsValidResult(result.getAsFloat())) return 0;
345
0
    words = result.GetWords();
346
133k
  } else {
347
133k
    spvtools::utils::FloatProxy<float> result(1.0f / c->GetFloat());
348
133k
    if (!IsValidResult(result.getAsFloat())) return 0;
349
92.4k
    words = result.GetWords();
350
92.4k
  }
351
352
92.4k
  const analysis::Constant* negated_const =
353
92.4k
      const_mgr->GetConstant(c->type(), std::move(words));
354
92.4k
  return const_mgr->GetDefiningInstruction(negated_const)->result_id();
355
133k
}
356
357
// Replaces fdiv where second operand is constant with fmul.
358
16.7k
FoldingRule ReciprocalFDiv() {
359
16.7k
  return [](IRContext* context, Instruction* inst,
360
145k
            const std::vector<const analysis::Constant*>& constants) {
361
145k
    assert(inst->opcode() == spv::Op::OpFDiv);
362
145k
    analysis::ConstantManager* const_mgr = context->get_constant_mgr();
363
145k
    const analysis::Type* type =
364
145k
        context->get_type_mgr()->GetType(inst->type_id());
365
366
145k
    if (type->IsCooperativeMatrix()) {
367
0
      return false;
368
0
    }
369
370
145k
    if (!inst->IsFloatingPointFoldingAllowed()) return false;
371
372
145k
    uint32_t width = ElementWidth(type);
373
145k
    if (width != 32 && width != 64) return false;
374
375
145k
    if (constants[1] != nullptr) {
376
105k
      uint32_t id = 0;
377
105k
      if (const analysis::VectorConstant* vector_const =
378
105k
              constants[1]->AsVectorConstant()) {
379
94.3k
        std::vector<uint32_t> neg_ids;
380
140k
        for (auto& comp : vector_const->GetComponents()) {
381
140k
          id = Reciprocal(const_mgr, comp);
382
140k
          if (id == 0) return false;
383
83.4k
          neg_ids.push_back(id);
384
83.4k
        }
385
37.2k
        const analysis::Constant* negated_const =
386
37.2k
            const_mgr->GetConstant(constants[1]->type(), std::move(neg_ids));
387
37.2k
        id = const_mgr->GetDefiningInstruction(negated_const)->result_id();
388
37.2k
      } else if (constants[1]->AsFloatConstant()) {
389
11.2k
        id = Reciprocal(const_mgr, constants[1]);
390
11.2k
        if (id == 0) return false;
391
11.2k
      } else {
392
        // Don't fold a null constant.
393
257
        return false;
394
257
      }
395
46.0k
      inst->SetOpcode(spv::Op::OpFMul);
396
46.0k
      inst->SetInOperands(
397
46.0k
          {{SPV_OPERAND_TYPE_ID, {inst->GetSingleWordInOperand(0u)}},
398
46.0k
           {SPV_OPERAND_TYPE_ID, {id}}});
399
46.0k
      return true;
400
105k
    }
401
402
40.0k
    return false;
403
145k
  };
404
16.7k
}
405
406
// Elides consecutive negate instructions.
407
33.5k
FoldingRule MergeNegateArithmetic() {
408
33.5k
  return [](IRContext* context, Instruction* inst,
409
33.5k
            const std::vector<const analysis::Constant*>& constants) {
410
7.29k
    assert(inst->opcode() == spv::Op::OpFNegate ||
411
7.29k
           inst->opcode() == spv::Op::OpSNegate);
412
7.29k
    (void)constants;
413
7.29k
    const analysis::Type* type =
414
7.29k
        context->get_type_mgr()->GetType(inst->type_id());
415
7.29k
    if (HasFloatingPoint(type) && !inst->IsFloatingPointFoldingAllowed())
416
0
      return false;
417
418
7.29k
    Instruction* op_inst =
419
7.29k
        context->get_def_use_mgr()->GetDef(inst->GetSingleWordInOperand(0u));
420
7.29k
    if (HasFloatingPoint(type) && !op_inst->IsFloatingPointFoldingAllowed())
421
0
      return false;
422
423
7.29k
    if (op_inst->opcode() == inst->opcode()) {
424
      // Elide negates.
425
155
      inst->SetOpcode(spv::Op::OpCopyObject);
426
155
      inst->SetInOperands(
427
155
          {{SPV_OPERAND_TYPE_ID, {op_inst->GetSingleWordInOperand(0u)}}});
428
155
      return true;
429
155
    }
430
431
7.13k
    return false;
432
7.29k
  };
433
33.5k
}
434
435
// Merges negate into a mul or div operation if that operation contains a
436
// constant operand.
437
// Cases:
438
// -(x * 2) = x * -2
439
// -(2 * x) = x * -2
440
// -(x / 2) = x / -2
441
// -(2 / x) = -2 / x
442
33.5k
FoldingRule MergeNegateMulDivArithmetic() {
443
33.5k
  return [](IRContext* context, Instruction* inst,
444
33.5k
            const std::vector<const analysis::Constant*>& constants) {
445
6.99k
    assert(inst->opcode() == spv::Op::OpFNegate ||
446
6.99k
           inst->opcode() == spv::Op::OpSNegate);
447
6.99k
    (void)constants;
448
6.99k
    analysis::ConstantManager* const_mgr = context->get_constant_mgr();
449
6.99k
    const analysis::Type* type =
450
6.99k
        context->get_type_mgr()->GetType(inst->type_id());
451
452
6.99k
    if (type->IsCooperativeMatrix()) {
453
0
      return false;
454
0
    }
455
456
6.99k
    if (HasFloatingPoint(type) && !inst->IsFloatingPointFoldingAllowed())
457
0
      return false;
458
459
6.99k
    Instruction* op_inst =
460
6.99k
        context->get_def_use_mgr()->GetDef(inst->GetSingleWordInOperand(0u));
461
6.99k
    if (HasFloatingPoint(type) && !op_inst->IsFloatingPointFoldingAllowed())
462
0
      return false;
463
464
6.99k
    uint32_t width = ElementWidth(type);
465
6.99k
    if (width != 32 && width != 64) return false;
466
467
6.99k
    spv::Op opcode = op_inst->opcode();
468
6.99k
    if (opcode != spv::Op::OpFMul && opcode != spv::Op::OpFDiv &&
469
6.27k
        opcode != spv::Op::OpIMul && opcode != spv::Op::OpSDiv) {
470
6.21k
      return false;
471
6.21k
    }
472
473
771
    std::vector<const analysis::Constant*> op_constants =
474
771
        const_mgr->GetOperandConstants(op_inst);
475
    // Merge negate into mul or div if one operand is constant.
476
771
    if (op_constants[0] == nullptr && op_constants[1] == nullptr) {
477
248
      return false;
478
248
    }
479
480
523
    bool zero_is_variable = op_constants[0] == nullptr;
481
523
    const analysis::Constant* c = ConstInput(op_constants);
482
523
    uint32_t neg_id = NegateConstant(const_mgr, c);
483
523
    uint32_t non_const_id = zero_is_variable
484
523
                                ? op_inst->GetSingleWordInOperand(0u)
485
523
                                : op_inst->GetSingleWordInOperand(1u);
486
    // Change this instruction to a mul/div.
487
523
    inst->SetOpcode(op_inst->opcode());
488
523
    if (opcode == spv::Op::OpFDiv || opcode == spv::Op::OpUDiv ||
489
491
        opcode == spv::Op::OpSDiv) {
490
39
      uint32_t op0 = zero_is_variable ? non_const_id : neg_id;
491
39
      uint32_t op1 = zero_is_variable ? neg_id : non_const_id;
492
39
      inst->SetInOperands(
493
39
          {{SPV_OPERAND_TYPE_ID, {op0}}, {SPV_OPERAND_TYPE_ID, {op1}}});
494
484
    } else {
495
484
      inst->SetInOperands({{SPV_OPERAND_TYPE_ID, {non_const_id}},
496
484
                           {SPV_OPERAND_TYPE_ID, {neg_id}}});
497
484
    }
498
523
    return true;
499
771
  };
500
33.5k
}
501
502
// Merges negate into a add or sub operation if that operation contains a
503
// constant operand.
504
// Cases:
505
// -(x + 2) = -2 - x
506
// -(2 + x) = -2 - x
507
// -(x - 2) = 2 - x
508
// -(2 - x) = x - 2
509
33.5k
FoldingRule MergeNegateAddSubArithmetic() {
510
33.5k
  return [](IRContext* context, Instruction* inst,
511
33.5k
            const std::vector<const analysis::Constant*>& constants) {
512
7.11k
    assert(inst->opcode() == spv::Op::OpFNegate ||
513
7.11k
           inst->opcode() == spv::Op::OpSNegate);
514
7.11k
    (void)constants;
515
7.11k
    analysis::ConstantManager* const_mgr = context->get_constant_mgr();
516
7.11k
    const analysis::Type* type =
517
7.11k
        context->get_type_mgr()->GetType(inst->type_id());
518
519
7.11k
    if (type->IsCooperativeMatrix()) {
520
0
      return false;
521
0
    }
522
523
7.11k
    if (HasFloatingPoint(type) && !inst->IsFloatingPointFoldingAllowed())
524
0
      return false;
525
526
7.11k
    Instruction* op_inst =
527
7.11k
        context->get_def_use_mgr()->GetDef(inst->GetSingleWordInOperand(0u));
528
7.11k
    if (HasFloatingPoint(type) && !op_inst->IsFloatingPointFoldingAllowed())
529
0
      return false;
530
531
7.11k
    uint32_t width = ElementWidth(type);
532
7.11k
    if (width != 32 && width != 64) return false;
533
534
7.11k
    if (op_inst->opcode() == spv::Op::OpFAdd ||
535
6.92k
        op_inst->opcode() == spv::Op::OpFSub ||
536
6.69k
        op_inst->opcode() == spv::Op::OpIAdd ||
537
6.44k
        op_inst->opcode() == spv::Op::OpISub) {
538
674
      std::vector<const analysis::Constant*> op_constants =
539
674
          const_mgr->GetOperandConstants(op_inst);
540
674
      if (op_constants[0] || op_constants[1]) {
541
236
        bool zero_is_variable = op_constants[0] == nullptr;
542
236
        bool is_add = (op_inst->opcode() == spv::Op::OpFAdd) ||
543
179
                      (op_inst->opcode() == spv::Op::OpIAdd);
544
236
        bool swap_operands = !is_add || zero_is_variable;
545
236
        bool negate_const = is_add;
546
236
        const analysis::Constant* c = ConstInput(op_constants);
547
236
        uint32_t const_id = 0;
548
236
        if (negate_const) {
549
141
          const_id = NegateConstant(const_mgr, c);
550
141
        } else {
551
95
          const_id = zero_is_variable ? op_inst->GetSingleWordInOperand(1u)
552
95
                                      : op_inst->GetSingleWordInOperand(0u);
553
95
        }
554
555
        // Swap operands if necessary and make the instruction a subtraction.
556
236
        uint32_t op0 =
557
236
            zero_is_variable ? op_inst->GetSingleWordInOperand(0u) : const_id;
558
236
        uint32_t op1 =
559
236
            zero_is_variable ? const_id : op_inst->GetSingleWordInOperand(1u);
560
236
        if (swap_operands) std::swap(op0, op1);
561
236
        inst->SetOpcode(HasFloatingPoint(type) ? spv::Op::OpFSub
562
236
                                               : spv::Op::OpISub);
563
236
        inst->SetInOperands(
564
236
            {{SPV_OPERAND_TYPE_ID, {op0}}, {SPV_OPERAND_TYPE_ID, {op1}}});
565
236
        return true;
566
236
      }
567
674
    }
568
569
6.87k
    return false;
570
7.11k
  };
571
33.5k
}
572
573
// Returns true if |c| has a zero element.
574
362k
bool HasZero(const analysis::Constant* c) {
575
362k
  if (c->AsNullConstant()) {
576
522
    return true;
577
522
  }
578
362k
  if (const analysis::VectorConstant* vec_const = c->AsVectorConstant()) {
579
124k
    for (auto& comp : vec_const->GetComponents())
580
212k
      if (HasZero(comp)) return true;
581
238k
  } else {
582
238k
    assert(c->AsScalarConstant());
583
238k
    return c->AsScalarConstant()->IsZero();
584
238k
  }
585
586
84.0k
  return false;
587
362k
}
588
589
// Performs |input1| |opcode| |input2| and returns the merged constant result
590
// id. Returns 0 if the result is not a valid value. The input types must be
591
// Float.
592
uint32_t PerformFloatingPointOperation(analysis::ConstantManager* const_mgr,
593
                                       spv::Op opcode,
594
                                       const analysis::Constant* input1,
595
42.2k
                                       const analysis::Constant* input2) {
596
42.2k
  const analysis::Type* type = input1->type();
597
42.2k
  assert(type->AsFloat());
598
42.2k
  uint32_t width = type->AsFloat()->width();
599
42.2k
  assert(width == 32 || width == 64);
600
42.2k
  std::vector<uint32_t> words;
601
42.2k
#define FOLD_OP(op)                                                          \
602
42.2k
  if (width == 64) {                                                         \
603
0
    utils::FloatProxy<double> val =                                          \
604
0
        input1->GetDouble() op input2->GetDouble();                          \
605
0
    double dval = val.getAsFloat();                                          \
606
0
    if (!IsValidResult(dval)) return 0;                                      \
607
0
    words = val.GetWords();                                                  \
608
42.2k
  } else {                                                                   \
609
42.2k
    utils::FloatProxy<float> val = input1->GetFloat() op input2->GetFloat(); \
610
42.2k
    float fval = val.getAsFloat();                                           \
611
42.2k
    if (!IsValidResult(fval)) return 0;                                      \
612
42.2k
    words = val.GetWords();                                                  \
613
30.3k
  }                                                                          \
614
42.2k
  static_assert(true, "require extra semicolon")
615
42.2k
  switch (opcode) {
616
6.66k
    case spv::Op::OpFMul:
617
6.66k
      FOLD_OP(*);
618
3.44k
      break;
619
2.49k
    case spv::Op::OpFDiv:
620
2.49k
      if (HasZero(input2)) return 0;
621
2.49k
      FOLD_OP(/);
622
1.66k
      break;
623
26.7k
    case spv::Op::OpFAdd:
624
26.7k
      FOLD_OP(+);
625
20.5k
      break;
626
6.41k
    case spv::Op::OpFSub:
627
6.41k
      FOLD_OP(-);
628
4.61k
      break;
629
0
    default:
630
0
      assert(false && "Unexpected operation");
631
0
      break;
632
42.2k
  }
633
30.3k
#undef FOLD_OP
634
30.3k
  const analysis::Constant* merged_const = const_mgr->GetConstant(type, words);
635
30.3k
  return const_mgr->GetDefiningInstruction(merged_const)->result_id();
636
42.2k
}
637
638
// Performs |input1| |opcode| |input2| and returns the merged constant result
639
// id. Returns 0 if the result is not a valid value. The input types must be
640
// Integers.
641
uint32_t PerformIntegerOperation(analysis::ConstantManager* const_mgr,
642
                                 spv::Op opcode,
643
                                 const analysis::Constant* input1,
644
7.69k
                                 const analysis::Constant* input2) {
645
7.69k
  assert(input1->type()->AsInteger());
646
7.69k
  const analysis::Integer* type = input1->type()->AsInteger();
647
7.69k
  uint32_t width = type->AsInteger()->width();
648
7.69k
  assert(width == 32 || width == 64);
649
7.69k
  std::vector<uint32_t> words;
650
  // Regardless of the sign of the constant, folding is performed on an unsigned
651
  // interpretation of the constant data. This avoids signed integer overflow
652
  // while folding, and works because sign is irrelevant for the IAdd, ISub and
653
  // IMul instructions.
654
7.69k
#define FOLD_OP(op)                                      \
655
7.69k
  if (width == 64) {                                     \
656
0
    uint64_t val = input1->GetU64() op input2->GetU64(); \
657
0
    words = ExtractInts(val);                            \
658
7.69k
  } else {                                               \
659
7.69k
    uint32_t val = input1->GetU32() op input2->GetU32(); \
660
7.69k
    words.push_back(val);                                \
661
7.69k
  }                                                      \
662
7.69k
  static_assert(true, "require extra semicolon")
663
7.69k
  switch (opcode) {
664
349
    case spv::Op::OpIMul:
665
349
      FOLD_OP(*);
666
349
      break;
667
0
    case spv::Op::OpSDiv:
668
0
    case spv::Op::OpUDiv:
669
0
      assert(false && "Should not merge integer division");
670
0
      break;
671
3.23k
    case spv::Op::OpIAdd:
672
3.23k
      FOLD_OP(+);
673
3.23k
      break;
674
2.04k
    case spv::Op::OpISub:
675
2.04k
      FOLD_OP(-);
676
2.04k
      break;
677
8
    case spv::Op::OpBitwiseXor:
678
8
      FOLD_OP(^);
679
8
      break;
680
1.02k
    case spv::Op::OpBitwiseOr:
681
1.02k
      FOLD_OP(|);
682
1.02k
      break;
683
1.03k
    case spv::Op::OpBitwiseAnd:
684
1.03k
      FOLD_OP(&);
685
1.03k
      break;
686
0
    default:
687
0
      assert(false && "Unexpected operation");
688
0
      break;
689
7.69k
  }
690
7.69k
#undef FOLD_OP
691
7.69k
  const analysis::Constant* merged_const = const_mgr->GetConstant(type, words);
692
7.69k
  return const_mgr->GetDefiningInstruction(merged_const)->result_id();
693
7.69k
}
694
695
// Performs |input1| |opcode| |input2| and returns the merged constant result
696
// id. Returns 0 if the result is not a valid value. The input types must be
697
// Integers, Floats or Vectors of such.
698
uint32_t PerformOperation(analysis::ConstantManager* const_mgr, spv::Op opcode,
699
                          const analysis::Constant* input1,
700
40.3k
                          const analysis::Constant* input2) {
701
40.3k
  assert(input1 && input2);
702
40.3k
  const analysis::Type* type = input1->type();
703
40.3k
  std::vector<uint32_t> words;
704
40.3k
  if (const analysis::Vector* vector_type = type->AsVector()) {
705
12.3k
    const analysis::Type* ele_type = vector_type->element_type();
706
31.0k
    for (uint32_t i = 0; i != vector_type->element_count(); ++i) {
707
22.0k
      uint32_t id = 0;
708
709
22.0k
      const analysis::Constant* input1_comp = nullptr;
710
22.0k
      if (const analysis::VectorConstant* input1_vector =
711
22.0k
              input1->AsVectorConstant()) {
712
22.0k
        input1_comp = input1_vector->GetComponents()[i];
713
22.0k
      } else {
714
0
        assert(input1->AsNullConstant());
715
0
        input1_comp = const_mgr->GetConstant(ele_type, {});
716
0
      }
717
718
22.0k
      const analysis::Constant* input2_comp = nullptr;
719
22.0k
      if (const analysis::VectorConstant* input2_vector =
720
22.0k
              input2->AsVectorConstant()) {
721
22.0k
        input2_comp = input2_vector->GetComponents()[i];
722
22.0k
      } else {
723
0
        assert(input2->AsNullConstant());
724
0
        input2_comp = const_mgr->GetConstant(ele_type, {});
725
0
      }
726
727
22.0k
      if (ele_type->AsFloat()) {
728
22.0k
        id = PerformFloatingPointOperation(const_mgr, opcode, input1_comp,
729
22.0k
                                           input2_comp);
730
22.0k
      } else {
731
0
        assert(ele_type->AsInteger());
732
0
        id = PerformIntegerOperation(const_mgr, opcode, input1_comp,
733
0
                                     input2_comp);
734
0
      }
735
22.0k
      if (id == 0) return 0;
736
18.6k
      words.push_back(id);
737
18.6k
    }
738
9.03k
    const analysis::Constant* merged_const =
739
9.03k
        const_mgr->GetConstant(type, words);
740
9.03k
    return const_mgr->GetDefiningInstruction(merged_const)->result_id();
741
27.9k
  } else if (type->AsFloat()) {
742
20.2k
    return PerformFloatingPointOperation(const_mgr, opcode, input1, input2);
743
20.2k
  } else {
744
7.69k
    assert(type->AsInteger());
745
7.69k
    return PerformIntegerOperation(const_mgr, opcode, input1, input2);
746
7.69k
  }
747
40.3k
}
748
749
// Merges consecutive multiplies where each contains one constant operand.
750
// Cases:
751
// 2 * (x * 2) = x * 4
752
// 2 * (2 * x) = x * 4
753
// (x * 2) * 2 = x * 4
754
// (2 * x) * 2 = x * 4
755
33.5k
FoldingRule MergeMulMulArithmetic() {
756
33.5k
  return [](IRContext* context, Instruction* inst,
757
214k
            const std::vector<const analysis::Constant*>& constants) {
758
214k
    assert(inst->opcode() == spv::Op::OpFMul ||
759
214k
           inst->opcode() == spv::Op::OpIMul);
760
214k
    analysis::ConstantManager* const_mgr = context->get_constant_mgr();
761
214k
    const analysis::Type* type =
762
214k
        context->get_type_mgr()->GetType(inst->type_id());
763
764
214k
    if (type->IsCooperativeMatrix()) {
765
0
      return false;
766
0
    }
767
768
214k
    if (HasFloatingPoint(type) && !inst->IsFloatingPointFoldingAllowed())
769
19
      return false;
770
771
214k
    uint32_t width = ElementWidth(type);
772
214k
    if (width != 32 && width != 64) return false;
773
774
    // Determine the constant input and the variable input in |inst|.
775
214k
    const analysis::Constant* const_input1 = ConstInput(constants);
776
214k
    if (!const_input1) return false;
777
152k
    Instruction* other_inst = NonConstInput(context, constants[0], inst);
778
152k
    if (HasFloatingPoint(type) && !other_inst->IsFloatingPointFoldingAllowed())
779
29
      return false;
780
781
152k
    if (other_inst->opcode() == inst->opcode()) {
782
5.74k
      std::vector<const analysis::Constant*> other_constants =
783
5.74k
          const_mgr->GetOperandConstants(other_inst);
784
5.74k
      const analysis::Constant* const_input2 = ConstInput(other_constants);
785
5.74k
      if (!const_input2) return false;
786
787
4.07k
      bool other_first_is_variable = other_constants[0] == nullptr;
788
4.07k
      uint32_t merged_id = PerformOperation(const_mgr, inst->opcode(),
789
4.07k
                                            const_input1, const_input2);
790
4.07k
      if (merged_id == 0) return false;
791
792
1.99k
      uint32_t non_const_id = other_first_is_variable
793
1.99k
                                  ? other_inst->GetSingleWordInOperand(0u)
794
1.99k
                                  : other_inst->GetSingleWordInOperand(1u);
795
1.99k
      inst->SetInOperands({{SPV_OPERAND_TYPE_ID, {non_const_id}},
796
1.99k
                           {SPV_OPERAND_TYPE_ID, {merged_id}}});
797
1.99k
      return true;
798
4.07k
    }
799
800
146k
    return false;
801
152k
  };
802
33.5k
}
803
804
// Merges divides into subsequent multiplies if each instruction contains one
805
// constant operand. Does not support integer operations.
806
// Cases:
807
// 2 * (x / 2) = x * 1
808
// 2 * (2 / x) = 4 / x
809
// (x / 2) * 2 = x * 1
810
// (2 / x) * 2 = 4 / x
811
// (y / x) * x = y
812
// x * (y / x) = y
813
16.7k
FoldingRule MergeMulDivArithmetic() {
814
16.7k
  return [](IRContext* context, Instruction* inst,
815
199k
            const std::vector<const analysis::Constant*>& constants) {
816
199k
    assert(inst->opcode() == spv::Op::OpFMul);
817
199k
    analysis::ConstantManager* const_mgr = context->get_constant_mgr();
818
199k
    analysis::DefUseManager* def_use_mgr = context->get_def_use_mgr();
819
820
199k
    const analysis::Type* type =
821
199k
        context->get_type_mgr()->GetType(inst->type_id());
822
823
199k
    if (type->IsCooperativeMatrix()) {
824
0
      return false;
825
0
    }
826
827
199k
    if (!inst->IsFloatingPointFoldingAllowed()) return false;
828
829
199k
    uint32_t width = ElementWidth(type);
830
199k
    if (width != 32 && width != 64) return false;
831
832
596k
    for (uint32_t i = 0; i < 2; i++) {
833
398k
      uint32_t op_id = inst->GetSingleWordInOperand(i);
834
398k
      Instruction* op_inst = def_use_mgr->GetDef(op_id);
835
398k
      if (op_inst->opcode() == spv::Op::OpFDiv) {
836
15.1k
        if (op_inst->GetSingleWordInOperand(1) ==
837
15.1k
            inst->GetSingleWordInOperand(1 - i)) {
838
284
          inst->SetOpcode(spv::Op::OpCopyObject);
839
284
          inst->SetInOperands(
840
284
              {{SPV_OPERAND_TYPE_ID, {op_inst->GetSingleWordInOperand(0)}}});
841
284
          return true;
842
284
        }
843
15.1k
      }
844
398k
    }
845
846
198k
    const analysis::Constant* const_input1 = ConstInput(constants);
847
198k
    if (!const_input1) return false;
848
137k
    Instruction* other_inst = NonConstInput(context, constants[0], inst);
849
137k
    if (!other_inst->IsFloatingPointFoldingAllowed()) return false;
850
851
137k
    if (other_inst->opcode() == spv::Op::OpFDiv) {
852
2.80k
      std::vector<const analysis::Constant*> other_constants =
853
2.80k
          const_mgr->GetOperandConstants(other_inst);
854
2.80k
      const analysis::Constant* const_input2 = ConstInput(other_constants);
855
2.80k
      if (!const_input2 || HasZero(const_input2)) return false;
856
857
776
      bool other_first_is_variable = other_constants[0] == nullptr;
858
      // If the variable value is the second operand of the divide, multiply
859
      // the constants together. Otherwise divide the constants.
860
776
      uint32_t merged_id = PerformOperation(
861
776
          const_mgr,
862
776
          other_first_is_variable ? other_inst->opcode() : inst->opcode(),
863
776
          const_input1, const_input2);
864
776
      if (merged_id == 0) return false;
865
866
334
      uint32_t non_const_id = other_first_is_variable
867
334
                                  ? other_inst->GetSingleWordInOperand(0u)
868
334
                                  : other_inst->GetSingleWordInOperand(1u);
869
870
      // If the variable value is on the second operand of the div, then this
871
      // operation is a div. Otherwise it should be a multiply.
872
334
      inst->SetOpcode(other_first_is_variable ? inst->opcode()
873
334
                                              : other_inst->opcode());
874
334
      if (other_first_is_variable) {
875
27
        inst->SetInOperands({{SPV_OPERAND_TYPE_ID, {non_const_id}},
876
27
                             {SPV_OPERAND_TYPE_ID, {merged_id}}});
877
307
      } else {
878
307
        inst->SetInOperands({{SPV_OPERAND_TYPE_ID, {merged_id}},
879
307
                             {SPV_OPERAND_TYPE_ID, {non_const_id}}});
880
307
      }
881
334
      return true;
882
776
    }
883
884
135k
    return false;
885
137k
  };
886
16.7k
}
887
888
// Merges multiply of constant and negation.
889
// Cases:
890
// (-x) * 2 = x * -2
891
// 2 * (-x) = x * -2
892
33.5k
FoldingRule MergeMulNegateArithmetic() {
893
33.5k
  return [](IRContext* context, Instruction* inst,
894
212k
            const std::vector<const analysis::Constant*>& constants) {
895
212k
    assert(inst->opcode() == spv::Op::OpFMul ||
896
212k
           inst->opcode() == spv::Op::OpIMul);
897
212k
    analysis::ConstantManager* const_mgr = context->get_constant_mgr();
898
212k
    const analysis::Type* type =
899
212k
        context->get_type_mgr()->GetType(inst->type_id());
900
901
212k
    if (type->IsCooperativeMatrix()) {
902
0
      return false;
903
0
    }
904
905
212k
    bool uses_float = HasFloatingPoint(type);
906
212k
    if (uses_float && !inst->IsFloatingPointFoldingAllowed()) return false;
907
908
212k
    uint32_t width = ElementWidth(type);
909
212k
    if (width != 32 && width != 64) return false;
910
911
212k
    const analysis::Constant* const_input1 = ConstInput(constants);
912
212k
    if (!const_input1) return false;
913
149k
    Instruction* other_inst = NonConstInput(context, constants[0], inst);
914
149k
    if (uses_float && !other_inst->IsFloatingPointFoldingAllowed())
915
29
      return false;
916
917
149k
    if (other_inst->opcode() == spv::Op::OpFNegate ||
918
149k
        other_inst->opcode() == spv::Op::OpSNegate) {
919
36
      uint32_t neg_id = NegateConstant(const_mgr, const_input1);
920
921
36
      inst->SetInOperands(
922
36
          {{SPV_OPERAND_TYPE_ID, {other_inst->GetSingleWordInOperand(0u)}},
923
36
           {SPV_OPERAND_TYPE_ID, {neg_id}}});
924
36
      return true;
925
36
    }
926
927
149k
    return false;
928
149k
  };
929
33.5k
}
930
931
// Returns true if |inst| is negation op and is safe to fold.
932
1.92M
static bool IsFoldableNegation(const Instruction* inst) {
933
1.92M
  return (inst->opcode() == spv::Op::OpSNegate ||
934
1.92M
          (inst->opcode() == spv::Op::OpFNegate &&
935
803
           inst->IsFloatingPointFoldingAllowed()));
936
1.92M
}
937
938
// Merges multiplies / divisions of two negations.
939
// Cases:
940
// (-x) * (-y) = x * y
941
// (-x) / (-y) = x / y
942
83.8k
FoldingRule MergeDivMulDoubleNegative() {
943
83.8k
  return [](IRContext* context, Instruction* inst,
944
656k
            const std::vector<const analysis::Constant*>&) {
945
656k
    assert(inst->opcode() == spv::Op::OpFMul ||
946
656k
           inst->opcode() == spv::Op::OpVectorTimesScalar ||
947
656k
           inst->opcode() == spv::Op::OpFDiv ||
948
656k
           inst->opcode() == spv::Op::OpIMul ||
949
656k
           inst->opcode() == spv::Op::OpSDiv);
950
951
656k
    const analysis::Type* type =
952
656k
        context->get_type_mgr()->GetType(inst->type_id());
953
954
656k
    bool uses_float = HasFloatingPoint(type);
955
656k
    if (uses_float && !inst->IsFloatingPointFoldingAllowed()) return false;
956
957
655k
    analysis::DefUseManager* def_use_mgr = context->get_def_use_mgr();
958
655k
    Instruction* lhs = def_use_mgr->GetDef(inst->GetSingleWordInOperand(0));
959
655k
    Instruction* rhs = def_use_mgr->GetDef(inst->GetSingleWordInOperand(1));
960
961
655k
    if (IsFoldableNegation(lhs) && IsFoldableNegation(rhs)) {
962
12
      inst->SetInOperands(
963
12
          {{SPV_OPERAND_TYPE_ID, {lhs->GetSingleWordInOperand(0u)}},
964
12
           {SPV_OPERAND_TYPE_ID, {rhs->GetSingleWordInOperand(0u)}}});
965
12
      return true;
966
12
    }
967
655k
    return false;
968
655k
  };
969
83.8k
}
970
971
// Merges consecutive divides if each instruction contains one constant operand.
972
// Does not support integer division.
973
// Cases:
974
// 2 / (x / 2) = 4 / x
975
// 4 / (2 / x) = 2 * x
976
// (4 / x) / 2 = 2 / x
977
// (x / 2) / 2 = x / 4
978
16.7k
FoldingRule MergeDivDivArithmetic() {
979
16.7k
  return [](IRContext* context, Instruction* inst,
980
99.9k
            const std::vector<const analysis::Constant*>& constants) {
981
99.9k
    assert(inst->opcode() == spv::Op::OpFDiv);
982
99.9k
    analysis::ConstantManager* const_mgr = context->get_constant_mgr();
983
99.9k
    const analysis::Type* type =
984
99.9k
        context->get_type_mgr()->GetType(inst->type_id());
985
986
99.9k
    if (type->IsCooperativeMatrix()) {
987
0
      return false;
988
0
    }
989
990
99.9k
    if (!inst->IsFloatingPointFoldingAllowed()) return false;
991
992
99.9k
    uint32_t width = ElementWidth(type);
993
99.9k
    if (width != 32 && width != 64) return false;
994
995
99.9k
    const analysis::Constant* const_input1 = ConstInput(constants);
996
99.9k
    if (!const_input1 || HasZero(const_input1)) return false;
997
52.4k
    Instruction* other_inst = NonConstInput(context, constants[0], inst);
998
52.4k
    if (!other_inst->IsFloatingPointFoldingAllowed()) return false;
999
1000
52.4k
    bool first_is_variable = constants[0] == nullptr;
1001
52.4k
    if (other_inst->opcode() == inst->opcode()) {
1002
2.05k
      std::vector<const analysis::Constant*> other_constants =
1003
2.05k
          const_mgr->GetOperandConstants(other_inst);
1004
2.05k
      const analysis::Constant* const_input2 = ConstInput(other_constants);
1005
2.05k
      if (!const_input2 || HasZero(const_input2)) return false;
1006
1007
1.65k
      bool other_first_is_variable = other_constants[0] == nullptr;
1008
1009
1.65k
      spv::Op merge_op = inst->opcode();
1010
1.65k
      if (other_first_is_variable) {
1011
        // Constants magnify.
1012
883
        merge_op = spv::Op::OpFMul;
1013
883
      }
1014
1015
      // This is an x / (*) case. Swap the inputs. Doesn't harm multiply
1016
      // because it is commutative.
1017
1.65k
      if (first_is_variable) std::swap(const_input1, const_input2);
1018
1.65k
      uint32_t merged_id =
1019
1.65k
          PerformOperation(const_mgr, merge_op, const_input1, const_input2);
1020
1.65k
      if (merged_id == 0) return false;
1021
1022
726
      uint32_t non_const_id = other_first_is_variable
1023
726
                                  ? other_inst->GetSingleWordInOperand(0u)
1024
726
                                  : other_inst->GetSingleWordInOperand(1u);
1025
1026
726
      spv::Op op = inst->opcode();
1027
726
      if (!first_is_variable && !other_first_is_variable) {
1028
        // Effectively div of 1/x, so change to multiply.
1029
516
        op = spv::Op::OpFMul;
1030
516
      }
1031
1032
726
      uint32_t op1 = merged_id;
1033
726
      uint32_t op2 = non_const_id;
1034
726
      if (first_is_variable && other_first_is_variable) std::swap(op1, op2);
1035
726
      inst->SetOpcode(op);
1036
726
      inst->SetInOperands(
1037
726
          {{SPV_OPERAND_TYPE_ID, {op1}}, {SPV_OPERAND_TYPE_ID, {op2}}});
1038
726
      return true;
1039
1.65k
    }
1040
1041
50.4k
    return false;
1042
52.4k
  };
1043
16.7k
}
1044
1045
// Fold multiplies succeeded by divides where each instruction contains a
1046
// constant operand. Does not support integer divide.
1047
// Cases:
1048
// 4 / (x * 2) = 2 / x
1049
// 4 / (2 * x) = 2 / x
1050
// (x * 4) / 2 = x * 2
1051
// (4 * x) / 2 = x * 2
1052
// (x * y) / x = y
1053
// (y * x) / x = y
1054
16.7k
FoldingRule MergeDivMulArithmetic() {
1055
16.7k
  return [](IRContext* context, Instruction* inst,
1056
99.2k
            const std::vector<const analysis::Constant*>& constants) {
1057
99.2k
    assert(inst->opcode() == spv::Op::OpFDiv);
1058
99.2k
    analysis::DefUseManager* def_use_mgr = context->get_def_use_mgr();
1059
99.2k
    analysis::ConstantManager* const_mgr = context->get_constant_mgr();
1060
1061
99.2k
    const analysis::Type* type =
1062
99.2k
        context->get_type_mgr()->GetType(inst->type_id());
1063
1064
99.2k
    if (type->IsCooperativeMatrix()) {
1065
0
      return false;
1066
0
    }
1067
1068
99.2k
    if (!inst->IsFloatingPointFoldingAllowed()) return false;
1069
1070
99.1k
    uint32_t width = ElementWidth(type);
1071
99.1k
    if (width != 32 && width != 64) return false;
1072
1073
99.1k
    uint32_t op_id = inst->GetSingleWordInOperand(0);
1074
99.1k
    Instruction* op_inst = def_use_mgr->GetDef(op_id);
1075
1076
99.1k
    if (op_inst->opcode() == spv::Op::OpFMul) {
1077
3.31k
      for (uint32_t i = 0; i < 2; i++) {
1078
2.36k
        if (op_inst->GetSingleWordInOperand(i) ==
1079
2.36k
            inst->GetSingleWordInOperand(1)) {
1080
242
          inst->SetOpcode(spv::Op::OpCopyObject);
1081
242
          inst->SetInOperands({{SPV_OPERAND_TYPE_ID,
1082
242
                                {op_inst->GetSingleWordInOperand(1 - i)}}});
1083
242
          return true;
1084
242
        }
1085
2.36k
      }
1086
1.19k
    }
1087
1088
98.9k
    const analysis::Constant* const_input1 = ConstInput(constants);
1089
98.9k
    if (!const_input1 || HasZero(const_input1)) return false;
1090
51.5k
    Instruction* other_inst = NonConstInput(context, constants[0], inst);
1091
51.5k
    if (!other_inst->IsFloatingPointFoldingAllowed()) return false;
1092
1093
51.5k
    bool first_is_variable = constants[0] == nullptr;
1094
51.5k
    if (other_inst->opcode() == spv::Op::OpFMul) {
1095
1.37k
      std::vector<const analysis::Constant*> other_constants =
1096
1.37k
          const_mgr->GetOperandConstants(other_inst);
1097
1.37k
      const analysis::Constant* const_input2 = ConstInput(other_constants);
1098
1.37k
      if (!const_input2) return false;
1099
1100
1.01k
      bool other_first_is_variable = other_constants[0] == nullptr;
1101
1102
      // This is an x / (*) case. Swap the inputs.
1103
1.01k
      if (first_is_variable) std::swap(const_input1, const_input2);
1104
1.01k
      uint32_t merged_id = PerformOperation(const_mgr, inst->opcode(),
1105
1.01k
                                            const_input1, const_input2);
1106
1.01k
      if (merged_id == 0) return false;
1107
1108
670
      uint32_t non_const_id = other_first_is_variable
1109
670
                                  ? other_inst->GetSingleWordInOperand(0u)
1110
670
                                  : other_inst->GetSingleWordInOperand(1u);
1111
1112
670
      uint32_t op1 = merged_id;
1113
670
      uint32_t op2 = non_const_id;
1114
670
      if (first_is_variable) std::swap(op1, op2);
1115
1116
      // Convert to multiply
1117
670
      if (first_is_variable) inst->SetOpcode(other_inst->opcode());
1118
670
      inst->SetInOperands(
1119
670
          {{SPV_OPERAND_TYPE_ID, {op1}}, {SPV_OPERAND_TYPE_ID, {op2}}});
1120
670
      return true;
1121
1.01k
    }
1122
1123
50.1k
    return false;
1124
51.5k
  };
1125
16.7k
}
1126
1127
// Fold divides of a constant and a negation.
1128
// Cases:
1129
// (-x) / 2 = x / -2
1130
// 2 / (-x) = -2 / x
1131
16.7k
FoldingRule MergeDivNegateArithmetic() {
1132
16.7k
  return [](IRContext* context, Instruction* inst,
1133
98.2k
            const std::vector<const analysis::Constant*>& constants) {
1134
98.2k
    assert(inst->opcode() == spv::Op::OpFDiv);
1135
98.2k
    analysis::ConstantManager* const_mgr = context->get_constant_mgr();
1136
98.2k
    if (!inst->IsFloatingPointFoldingAllowed()) return false;
1137
1138
98.2k
    const analysis::Constant* const_input1 = ConstInput(constants);
1139
98.2k
    if (!const_input1) return false;
1140
71.1k
    Instruction* other_inst = NonConstInput(context, constants[0], inst);
1141
71.1k
    if (!other_inst->IsFloatingPointFoldingAllowed()) return false;
1142
1143
71.1k
    bool first_is_variable = constants[0] == nullptr;
1144
71.1k
    if (other_inst->opcode() == spv::Op::OpFNegate) {
1145
40
      uint32_t neg_id = NegateConstant(const_mgr, const_input1);
1146
1147
40
      if (first_is_variable) {
1148
26
        inst->SetInOperands(
1149
26
            {{SPV_OPERAND_TYPE_ID, {other_inst->GetSingleWordInOperand(0u)}},
1150
26
             {SPV_OPERAND_TYPE_ID, {neg_id}}});
1151
26
      } else {
1152
14
        inst->SetInOperands(
1153
14
            {{SPV_OPERAND_TYPE_ID, {neg_id}},
1154
14
             {SPV_OPERAND_TYPE_ID, {other_inst->GetSingleWordInOperand(0u)}}});
1155
14
      }
1156
40
      return true;
1157
40
    }
1158
1159
71.1k
    return false;
1160
71.1k
  };
1161
16.7k
}
1162
1163
// Folds addition, where one side is a negation.
1164
// (-x) + y = y - x
1165
// y + (-x) = y - x
1166
33.5k
FoldingRule MergeAddNegateArithmetic() {
1167
33.5k
  return [](IRContext* context, Instruction* inst,
1168
528k
            const std::vector<const analysis::Constant*>&) {
1169
528k
    assert(inst->opcode() == spv::Op::OpFAdd ||
1170
528k
           inst->opcode() == spv::Op::OpIAdd);
1171
528k
    const analysis::Type* type =
1172
528k
        context->get_type_mgr()->GetType(inst->type_id());
1173
528k
    bool uses_float = HasFloatingPoint(type);
1174
528k
    if (uses_float && !inst->IsFloatingPointFoldingAllowed()) return false;
1175
1176
528k
    analysis::DefUseManager* def_use_mgr = context->get_def_use_mgr();
1177
528k
    Instruction* lhs = def_use_mgr->GetDef(inst->GetSingleWordInOperand(0));
1178
528k
    Instruction* rhs = def_use_mgr->GetDef(inst->GetSingleWordInOperand(1));
1179
1180
528k
    auto TrySubstitute = [inst, uses_float](Instruction* first,
1181
1.05M
                                            Instruction* second) {
1182
1.05M
      if (IsFoldableNegation(first)) {
1183
55
        inst->SetOpcode(uses_float ? spv::Op::OpFSub : spv::Op::OpISub);
1184
55
        inst->SetInOperands(
1185
55
            {{SPV_OPERAND_TYPE_ID, {second->result_id()}},
1186
55
             {SPV_OPERAND_TYPE_ID, {first->GetSingleWordInOperand(0u)}}});
1187
55
        return true;
1188
55
      }
1189
1.05M
      return false;
1190
1.05M
    };
1191
1192
528k
    return TrySubstitute(lhs, rhs) || TrySubstitute(rhs, lhs);
1193
528k
  };
1194
33.5k
}
1195
1196
// Folds subtraction, where one side is a negation.
1197
// Cases:
1198
// (-x) - 2 = -2 - x
1199
// y - (-x) = x + y
1200
33.5k
FoldingRule MergeSubNegateArithmetic() {
1201
33.5k
  return [](IRContext* context, Instruction* inst,
1202
144k
            const std::vector<const analysis::Constant*>& constants) {
1203
144k
    assert(inst->opcode() == spv::Op::OpFSub ||
1204
144k
           inst->opcode() == spv::Op::OpISub);
1205
144k
    const analysis::Type* type =
1206
144k
        context->get_type_mgr()->GetType(inst->type_id());
1207
1208
144k
    bool uses_float = HasFloatingPoint(type);
1209
144k
    if (uses_float && !inst->IsFloatingPointFoldingAllowed()) return false;
1210
1211
140k
    analysis::DefUseManager* def_use_mgr = context->get_def_use_mgr();
1212
140k
    Instruction* lhs = def_use_mgr->GetDef(inst->GetSingleWordInOperand(0));
1213
140k
    Instruction* rhs = def_use_mgr->GetDef(inst->GetSingleWordInOperand(1));
1214
1215
140k
    if (IsFoldableNegation(rhs)) {
1216
573
      inst->SetOpcode(uses_float ? spv::Op::OpFAdd : spv::Op::OpIAdd);
1217
573
      inst->SetInOperands(
1218
573
          {{SPV_OPERAND_TYPE_ID, {lhs->result_id()}},
1219
573
           {SPV_OPERAND_TYPE_ID, {rhs->GetSingleWordInOperand(0)}}});
1220
573
      return true;
1221
573
    }
1222
1223
139k
    if (type->IsCooperativeMatrix()) {
1224
0
      return false;
1225
0
    }
1226
1227
139k
    uint32_t width = ElementWidth(type);
1228
139k
    if (width != 32 && width != 64) return false;
1229
1230
139k
    if (constants[1] && IsFoldableNegation(lhs)) {
1231
35
      inst->SetInOperands(
1232
35
          {{SPV_OPERAND_TYPE_ID,
1233
35
            {NegateConstant(context->get_constant_mgr(), constants[1])}},
1234
35
           {SPV_OPERAND_TYPE_ID, {lhs->GetSingleWordInOperand(0)}}});
1235
35
      return true;
1236
35
    }
1237
139k
    return false;
1238
139k
  };
1239
33.5k
}
1240
1241
// Folds addition of an addition where each operation has a constant operand.
1242
// Cases:
1243
// (x + 2) + 2 = x + 4
1244
// (2 + x) + 2 = x + 4
1245
// 2 + (x + 2) = x + 4
1246
// 2 + (2 + x) = x + 4
1247
33.5k
FoldingRule MergeAddAddArithmetic() {
1248
33.5k
  return [](IRContext* context, Instruction* inst,
1249
528k
            const std::vector<const analysis::Constant*>& constants) {
1250
528k
    assert(inst->opcode() == spv::Op::OpFAdd ||
1251
528k
           inst->opcode() == spv::Op::OpIAdd);
1252
528k
    const analysis::Type* type =
1253
528k
        context->get_type_mgr()->GetType(inst->type_id());
1254
1255
528k
    if (type->IsCooperativeMatrix()) {
1256
0
      return false;
1257
0
    }
1258
1259
528k
    analysis::ConstantManager* const_mgr = context->get_constant_mgr();
1260
528k
    bool uses_float = HasFloatingPoint(type);
1261
528k
    if (uses_float && !inst->IsFloatingPointFoldingAllowed()) return false;
1262
1263
528k
    uint32_t width = ElementWidth(type);
1264
528k
    if (width != 32 && width != 64) return false;
1265
1266
528k
    const analysis::Constant* const_input1 = ConstInput(constants);
1267
528k
    if (!const_input1) return false;
1268
139k
    Instruction* other_inst = NonConstInput(context, constants[0], inst);
1269
139k
    if (uses_float && !other_inst->IsFloatingPointFoldingAllowed())
1270
465
      return false;
1271
1272
139k
    if (other_inst->opcode() == spv::Op::OpFAdd ||
1273
126k
        other_inst->opcode() == spv::Op::OpIAdd) {
1274
16.0k
      std::vector<const analysis::Constant*> other_constants =
1275
16.0k
          const_mgr->GetOperandConstants(other_inst);
1276
16.0k
      const analysis::Constant* const_input2 = ConstInput(other_constants);
1277
16.0k
      if (!const_input2) return false;
1278
1279
13.6k
      Instruction* non_const_input =
1280
13.6k
          NonConstInput(context, other_constants[0], other_inst);
1281
13.6k
      uint32_t merged_id = PerformOperation(const_mgr, inst->opcode(),
1282
13.6k
                                            const_input1, const_input2);
1283
13.6k
      if (merged_id == 0) return false;
1284
1285
9.93k
      inst->SetInOperands(
1286
9.93k
          {{SPV_OPERAND_TYPE_ID, {non_const_input->result_id()}},
1287
9.93k
           {SPV_OPERAND_TYPE_ID, {merged_id}}});
1288
9.93k
      return true;
1289
13.6k
    }
1290
123k
    return false;
1291
139k
  };
1292
33.5k
}
1293
1294
// Folds addition of a subtraction where each operation has a constant operand.
1295
// Cases:
1296
// (x - 2) + 2 = x + 0
1297
// (2 - x) + 2 = 4 - x
1298
// 2 + (x - 2) = x + 0
1299
// 2 + (2 - x) = 4 - x
1300
33.5k
FoldingRule MergeAddSubArithmetic() {
1301
33.5k
  return [](IRContext* context, Instruction* inst,
1302
518k
            const std::vector<const analysis::Constant*>& constants) {
1303
518k
    assert(inst->opcode() == spv::Op::OpFAdd ||
1304
518k
           inst->opcode() == spv::Op::OpIAdd);
1305
518k
    const analysis::Type* type =
1306
518k
        context->get_type_mgr()->GetType(inst->type_id());
1307
1308
518k
    if (type->IsCooperativeMatrix()) {
1309
0
      return false;
1310
0
    }
1311
1312
518k
    analysis::ConstantManager* const_mgr = context->get_constant_mgr();
1313
518k
    bool uses_float = HasFloatingPoint(type);
1314
518k
    if (uses_float && !inst->IsFloatingPointFoldingAllowed()) return false;
1315
1316
518k
    uint32_t width = ElementWidth(type);
1317
518k
    if (width != 32 && width != 64) return false;
1318
1319
518k
    const analysis::Constant* const_input1 = ConstInput(constants);
1320
518k
    if (!const_input1) return false;
1321
129k
    Instruction* other_inst = NonConstInput(context, constants[0], inst);
1322
129k
    if (uses_float && !other_inst->IsFloatingPointFoldingAllowed())
1323
465
      return false;
1324
1325
129k
    if (other_inst->opcode() == spv::Op::OpFSub ||
1326
127k
        other_inst->opcode() == spv::Op::OpISub) {
1327
3.35k
      std::vector<const analysis::Constant*> other_constants =
1328
3.35k
          const_mgr->GetOperandConstants(other_inst);
1329
3.35k
      const analysis::Constant* const_input2 = ConstInput(other_constants);
1330
3.35k
      if (!const_input2) return false;
1331
1332
3.21k
      bool first_is_variable = other_constants[0] == nullptr;
1333
3.21k
      spv::Op op = inst->opcode();
1334
3.21k
      uint32_t op1 = 0;
1335
3.21k
      uint32_t op2 = 0;
1336
3.21k
      if (first_is_variable) {
1337
        // Subtract constants. Non-constant operand is first.
1338
2.84k
        op1 = other_inst->GetSingleWordInOperand(0u);
1339
2.84k
        op2 = PerformOperation(const_mgr, other_inst->opcode(), const_input1,
1340
2.84k
                               const_input2);
1341
2.84k
      } else {
1342
        // Add constants. Constant operand is first. Change the opcode.
1343
373
        op1 = PerformOperation(const_mgr, inst->opcode(), const_input1,
1344
373
                               const_input2);
1345
373
        op2 = other_inst->GetSingleWordInOperand(1u);
1346
373
        op = other_inst->opcode();
1347
373
      }
1348
3.21k
      if (op1 == 0 || op2 == 0) return false;
1349
1350
2.29k
      inst->SetOpcode(op);
1351
2.29k
      inst->SetInOperands(
1352
2.29k
          {{SPV_OPERAND_TYPE_ID, {op1}}, {SPV_OPERAND_TYPE_ID, {op2}}});
1353
2.29k
      return true;
1354
3.21k
    }
1355
126k
    return false;
1356
129k
  };
1357
33.5k
}
1358
1359
// Folds subtraction of an addition where each operand has a constant operand.
1360
// Cases:
1361
// (x + 2) - 2 = x + 0
1362
// (2 + x) - 2 = x + 0
1363
// 2 - (x + 2) = 0 - x
1364
// 2 - (2 + x) = 0 - x
1365
33.5k
FoldingRule MergeSubAddArithmetic() {
1366
33.5k
  return [](IRContext* context, Instruction* inst,
1367
143k
            const std::vector<const analysis::Constant*>& constants) {
1368
143k
    assert(inst->opcode() == spv::Op::OpFSub ||
1369
143k
           inst->opcode() == spv::Op::OpISub);
1370
143k
    const analysis::Type* type =
1371
143k
        context->get_type_mgr()->GetType(inst->type_id());
1372
1373
143k
    if (type->IsCooperativeMatrix()) {
1374
0
      return false;
1375
0
    }
1376
1377
143k
    analysis::ConstantManager* const_mgr = context->get_constant_mgr();
1378
143k
    bool uses_float = HasFloatingPoint(type);
1379
143k
    if (uses_float && !inst->IsFloatingPointFoldingAllowed()) return false;
1380
1381
139k
    uint32_t width = ElementWidth(type);
1382
139k
    if (width != 32 && width != 64) return false;
1383
1384
139k
    const analysis::Constant* const_input1 = ConstInput(constants);
1385
139k
    if (!const_input1) return false;
1386
86.7k
    Instruction* other_inst = NonConstInput(context, constants[0], inst);
1387
86.7k
    if (uses_float && !other_inst->IsFloatingPointFoldingAllowed())
1388
11
      return false;
1389
1390
86.7k
    if (other_inst->opcode() == spv::Op::OpFAdd ||
1391
83.2k
        other_inst->opcode() == spv::Op::OpIAdd) {
1392
11.4k
      std::vector<const analysis::Constant*> other_constants =
1393
11.4k
          const_mgr->GetOperandConstants(other_inst);
1394
11.4k
      const analysis::Constant* const_input2 = ConstInput(other_constants);
1395
11.4k
      if (!const_input2) return false;
1396
1397
2.98k
      Instruction* non_const_input =
1398
2.98k
          NonConstInput(context, other_constants[0], other_inst);
1399
1400
      // If the first operand of the sub is not a constant, swap the constants
1401
      // so the subtraction has the correct operands.
1402
2.98k
      if (constants[0] == nullptr) std::swap(const_input1, const_input2);
1403
      // Subtract the constants.
1404
2.98k
      uint32_t merged_id = PerformOperation(const_mgr, inst->opcode(),
1405
2.98k
                                            const_input1, const_input2);
1406
2.98k
      spv::Op op = inst->opcode();
1407
2.98k
      uint32_t op1 = 0;
1408
2.98k
      uint32_t op2 = 0;
1409
2.98k
      if (constants[0] == nullptr) {
1410
        // Non-constant operand is first. Change the opcode.
1411
1.71k
        op1 = non_const_input->result_id();
1412
1.71k
        op2 = merged_id;
1413
1.71k
        op = other_inst->opcode();
1414
1.71k
      } else {
1415
        // Constant operand is first.
1416
1.26k
        op1 = merged_id;
1417
1.26k
        op2 = non_const_input->result_id();
1418
1.26k
      }
1419
2.98k
      if (op1 == 0 || op2 == 0) return false;
1420
1421
2.41k
      inst->SetOpcode(op);
1422
2.41k
      inst->SetInOperands(
1423
2.41k
          {{SPV_OPERAND_TYPE_ID, {op1}}, {SPV_OPERAND_TYPE_ID, {op2}}});
1424
2.41k
      return true;
1425
2.98k
    }
1426
75.2k
    return false;
1427
86.7k
  };
1428
33.5k
}
1429
1430
// Folds subtraction of a subtraction where each operand has a constant operand.
1431
// Cases:
1432
// (x - 2) - 2 = x - 4
1433
// (2 - x) - 2 = 0 - x
1434
// 2 - (x - 2) = 4 - x
1435
// 2 - (2 - x) = x + 0
1436
33.5k
FoldingRule MergeSubSubArithmetic() {
1437
33.5k
  return [](IRContext* context, Instruction* inst,
1438
141k
            const std::vector<const analysis::Constant*>& constants) {
1439
141k
    assert(inst->opcode() == spv::Op::OpFSub ||
1440
141k
           inst->opcode() == spv::Op::OpISub);
1441
141k
    const analysis::Type* type =
1442
141k
        context->get_type_mgr()->GetType(inst->type_id());
1443
1444
141k
    if (type->IsCooperativeMatrix()) {
1445
0
      return false;
1446
0
    }
1447
1448
141k
    analysis::ConstantManager* const_mgr = context->get_constant_mgr();
1449
141k
    bool uses_float = HasFloatingPoint(type);
1450
141k
    if (uses_float && !inst->IsFloatingPointFoldingAllowed()) return false;
1451
1452
137k
    uint32_t width = ElementWidth(type);
1453
137k
    if (width != 32 && width != 64) return false;
1454
1455
137k
    const analysis::Constant* const_input1 = ConstInput(constants);
1456
137k
    if (!const_input1) return false;
1457
84.3k
    Instruction* other_inst = NonConstInput(context, constants[0], inst);
1458
84.3k
    if (uses_float && !other_inst->IsFloatingPointFoldingAllowed())
1459
11
      return false;
1460
1461
84.2k
    if (other_inst->opcode() == spv::Op::OpFSub ||
1462
76.6k
        other_inst->opcode() == spv::Op::OpISub) {
1463
9.17k
      std::vector<const analysis::Constant*> other_constants =
1464
9.17k
          const_mgr->GetOperandConstants(other_inst);
1465
9.17k
      const analysis::Constant* const_input2 = ConstInput(other_constants);
1466
9.17k
      if (!const_input2) return false;
1467
1468
8.68k
      Instruction* non_const_input =
1469
8.68k
          NonConstInput(context, other_constants[0], other_inst);
1470
1471
      // Merge the constants.
1472
8.68k
      uint32_t merged_id = 0;
1473
8.68k
      spv::Op merge_op = inst->opcode();
1474
8.68k
      if (other_constants[0] == nullptr) {
1475
7.18k
        merge_op = uses_float ? spv::Op::OpFAdd : spv::Op::OpIAdd;
1476
7.18k
      } else if (constants[0] == nullptr) {
1477
402
        std::swap(const_input1, const_input2);
1478
402
      }
1479
8.68k
      merged_id =
1480
8.68k
          PerformOperation(const_mgr, merge_op, const_input1, const_input2);
1481
8.68k
      if (merged_id == 0) return false;
1482
1483
6.22k
      spv::Op op = inst->opcode();
1484
6.22k
      if (constants[0] != nullptr && other_constants[0] != nullptr) {
1485
        // Change the operation.
1486
809
        op = uses_float ? spv::Op::OpFAdd : spv::Op::OpIAdd;
1487
809
      }
1488
1489
6.22k
      uint32_t op1 = 0;
1490
6.22k
      uint32_t op2 = 0;
1491
6.22k
      if ((constants[0] == nullptr) ^ (other_constants[0] == nullptr)) {
1492
572
        op1 = merged_id;
1493
572
        op2 = non_const_input->result_id();
1494
5.65k
      } else {
1495
5.65k
        op1 = non_const_input->result_id();
1496
5.65k
        op2 = merged_id;
1497
5.65k
      }
1498
1499
6.22k
      inst->SetOpcode(op);
1500
6.22k
      inst->SetInOperands(
1501
6.22k
          {{SPV_OPERAND_TYPE_ID, {op1}}, {SPV_OPERAND_TYPE_ID, {op2}}});
1502
6.22k
      return true;
1503
8.68k
    }
1504
75.1k
    return false;
1505
84.2k
  };
1506
33.5k
}
1507
1508
// Helper function for MergeGenericAddSubArithmetic. If |addend| and
1509
// subtrahend of |sub| is the same, merge to copy of minuend of |sub|.
1510
1.03M
bool MergeGenericAddendSub(uint32_t addend, uint32_t sub, Instruction* inst) {
1511
1.03M
  IRContext* context = inst->context();
1512
1.03M
  analysis::DefUseManager* def_use_mgr = context->get_def_use_mgr();
1513
1.03M
  Instruction* sub_inst = def_use_mgr->GetDef(sub);
1514
1.03M
  if (sub_inst->opcode() != spv::Op::OpFSub &&
1515
1.02M
      sub_inst->opcode() != spv::Op::OpISub)
1516
1.02M
    return false;
1517
9.64k
  if (sub_inst->opcode() == spv::Op::OpFSub &&
1518
7.78k
      !sub_inst->IsFloatingPointFoldingAllowed())
1519
0
    return false;
1520
9.64k
  if (addend != sub_inst->GetSingleWordInOperand(1)) return false;
1521
1.07k
  inst->SetOpcode(spv::Op::OpCopyObject);
1522
1.07k
  inst->SetInOperands(
1523
1.07k
      {{SPV_OPERAND_TYPE_ID, {sub_inst->GetSingleWordInOperand(0)}}});
1524
1.07k
  context->UpdateDefUse(inst);
1525
1.07k
  return true;
1526
9.64k
}
1527
1528
// Folds addition of a subtraction where the subtrahend is equal to the
1529
// other addend. Return a copy of the minuend. Accepts generic (const and
1530
// non-const) operands.
1531
// Cases:
1532
// (a - b) + b = a
1533
// b + (a - b) = a
1534
33.5k
FoldingRule MergeGenericAddSubArithmetic() {
1535
33.5k
  return [](IRContext* context, Instruction* inst,
1536
516k
            const std::vector<const analysis::Constant*>&) {
1537
516k
    assert(inst->opcode() == spv::Op::OpFAdd ||
1538
516k
           inst->opcode() == spv::Op::OpIAdd);
1539
516k
    const analysis::Type* type =
1540
516k
        context->get_type_mgr()->GetType(inst->type_id());
1541
1542
516k
    if (type->IsCooperativeMatrix()) {
1543
0
      return false;
1544
0
    }
1545
1546
516k
    bool uses_float = HasFloatingPoint(type);
1547
516k
    if (uses_float && !inst->IsFloatingPointFoldingAllowed()) return false;
1548
1549
515k
    uint32_t width = ElementWidth(type);
1550
515k
    if (width != 32 && width != 64) return false;
1551
1552
515k
    uint32_t add_op0 = inst->GetSingleWordInOperand(0);
1553
515k
    uint32_t add_op1 = inst->GetSingleWordInOperand(1);
1554
515k
    if (MergeGenericAddendSub(add_op0, add_op1, inst)) return true;
1555
515k
    return MergeGenericAddendSub(add_op1, add_op0, inst);
1556
515k
  };
1557
33.5k
}
1558
1559
// Helper function for FactorAddSubMuls.
1560
// If |factor0_0| is the same as |factor1_0|, generate:
1561
//   |factor0_0| * (|factor0_1| + |factor1_1|)
1562
//   |factor0_0| * (|factor0_1| - |factor1_1|)
1563
bool FactorAddSubMulsOpnds(uint32_t factor0_0, uint32_t factor0_1,
1564
                           uint32_t factor1_0, uint32_t factor1_1,
1565
5.37k
                           Instruction* inst) {
1566
5.37k
  IRContext* context = inst->context();
1567
5.37k
  if (factor0_0 != factor1_0) return false;
1568
237
  InstructionBuilder ir_builder(
1569
237
      context, inst,
1570
237
      IRContext::kAnalysisDefUse | IRContext::kAnalysisInstrToBlockMapping);
1571
237
  Instruction* new_add_inst = ir_builder.AddBinaryOp(
1572
237
      inst->type_id(), inst->opcode(), factor0_1, factor1_1);
1573
237
  if (!new_add_inst) {
1574
0
    return false;
1575
0
  }
1576
1577
237
  bool is_float =
1578
237
      inst->opcode() == spv::Op::OpFAdd || inst->opcode() == spv::Op::OpFSub;
1579
237
  inst->SetOpcode(is_float ? spv::Op::OpFMul : spv::Op::OpIMul);
1580
237
  inst->SetInOperands({{SPV_OPERAND_TYPE_ID, {factor0_0}},
1581
237
                       {SPV_OPERAND_TYPE_ID, {new_add_inst->result_id()}}});
1582
237
  context->UpdateDefUse(inst);
1583
237
  return true;
1584
237
}
1585
1586
// Perform the following factoring identity, handling all operand order
1587
// combinations:
1588
//   (a * b) + (a * c) = a * (b + c)
1589
//   (a * b) - (a * c) = a * (b - c)
1590
67.1k
FoldingRule FactorAddSubMuls() {
1591
67.1k
  return [](IRContext* context, Instruction* inst,
1592
648k
            const std::vector<const analysis::Constant*>&) {
1593
648k
    assert(inst->opcode() == spv::Op::OpFAdd ||
1594
648k
           inst->opcode() == spv::Op::OpFSub ||
1595
648k
           inst->opcode() == spv::Op::OpIAdd ||
1596
648k
           inst->opcode() == spv::Op::OpISub);
1597
648k
    const analysis::Type* type =
1598
648k
        context->get_type_mgr()->GetType(inst->type_id());
1599
648k
    bool uses_float = HasFloatingPoint(type);
1600
648k
    if (uses_float && !inst->IsFloatingPointFoldingAllowed()) return false;
1601
1602
644k
    analysis::DefUseManager* def_use_mgr = context->get_def_use_mgr();
1603
644k
    uint32_t add_op0 = inst->GetSingleWordInOperand(0);
1604
644k
    Instruction* add_op0_inst = def_use_mgr->GetDef(add_op0);
1605
644k
    if (add_op0_inst->opcode() != spv::Op::OpFMul &&
1606
626k
        add_op0_inst->opcode() != spv::Op::OpIMul)
1607
622k
      return false;
1608
21.5k
    uint32_t add_op1 = inst->GetSingleWordInOperand(1);
1609
21.5k
    Instruction* add_op1_inst = def_use_mgr->GetDef(add_op1);
1610
21.5k
    if (add_op1_inst->opcode() != spv::Op::OpFMul &&
1611
19.3k
        add_op1_inst->opcode() != spv::Op::OpIMul)
1612
18.6k
      return false;
1613
1614
    // Only perform this optimization if both of the muls only have one use.
1615
    // Otherwise this is a deoptimization in size and performance.
1616
2.81k
    if (def_use_mgr->NumUses(add_op0_inst) > 1) return false;
1617
1.59k
    if (def_use_mgr->NumUses(add_op1_inst) > 1) return false;
1618
1619
1.46k
    if (add_op0_inst->opcode() == spv::Op::OpFMul &&
1620
1.38k
        (!add_op0_inst->IsFloatingPointFoldingAllowed() ||
1621
1.38k
         !add_op1_inst->IsFloatingPointFoldingAllowed()))
1622
0
      return false;
1623
1624
3.95k
    for (int i = 0; i < 2; i++) {
1625
7.86k
      for (int j = 0; j < 2; j++) {
1626
        // Check if operand i in add_op0_inst matches operand j in add_op1_inst.
1627
5.37k
        if (FactorAddSubMulsOpnds(add_op0_inst->GetSingleWordInOperand(i),
1628
5.37k
                                  add_op0_inst->GetSingleWordInOperand(1 - i),
1629
5.37k
                                  add_op1_inst->GetSingleWordInOperand(j),
1630
5.37k
                                  add_op1_inst->GetSingleWordInOperand(1 - j),
1631
5.37k
                                  inst))
1632
237
          return true;
1633
5.37k
      }
1634
2.72k
    }
1635
1.22k
    return false;
1636
1.46k
  };
1637
67.1k
}
1638
1639
// Reassociate integer instructions where both operands share the same opcode
1640
// and both source instructions contain a constant.
1641
// e.g:
1642
//   (a * C0) * (C1 * b) = (C0 * C1) * (a * b)
1643
//   (a ^ C0) ^ (b ^ C1) = (C0 ^ C1) ^ (a ^ b)
1644
//   (C0 | a) | (b | C1) = (C0 | C1) | (a | b)
1645
//   (a & C0) & (b & C1) = (C0 & C1) & (a & b)
1646
static const constexpr spv::Op ReassociateNestedGenericIntOps[] = {
1647
    spv::Op::OpIMul, spv::Op::OpBitwiseOr, spv::Op::OpBitwiseXor,
1648
    spv::Op::OpBitwiseAnd};
1649
1650
67.1k
FoldingRule ReassociateNestedGenericInt(spv::Op opcode) {
1651
67.1k
  assert(std::find(std::begin(ReassociateNestedGenericIntOps),
1652
67.1k
                   std::end(ReassociateNestedGenericIntOps),
1653
67.1k
                   opcode) != std::end(ReassociateNestedGenericIntOps) &&
1654
67.1k
         "Wrong opcode.");
1655
1656
67.1k
  return [opcode](IRContext* context, Instruction* inst,
1657
67.1k
                  const std::vector<const analysis::Constant*>& constants) {
1658
    // Handled by other folding rules.
1659
63.3k
    if (constants[0] || constants[1]) {
1660
31.8k
      return false;
1661
31.8k
    }
1662
1663
31.4k
    if (inst->opcode() != opcode) {
1664
0
      return false;
1665
0
    }
1666
1667
31.4k
    const analysis::Type* type =
1668
31.4k
        context->get_type_mgr()->GetType(inst->type_id());
1669
1670
31.4k
    if (type->IsCooperativeMatrix()) {
1671
0
      return false;
1672
0
    }
1673
1674
31.4k
    uint32_t width = ElementWidth(type);
1675
31.4k
    if (width != 32 && width != 64) return false;
1676
1677
31.4k
    analysis::DefUseManager* def_use_mgr = context->get_def_use_mgr();
1678
31.4k
    Instruction* lhs = def_use_mgr->GetDef(inst->GetSingleWordInOperand(0));
1679
31.4k
    Instruction* rhs = def_use_mgr->GetDef(inst->GetSingleWordInOperand(1));
1680
1681
31.4k
    if (lhs->opcode() != opcode || rhs->opcode() != opcode) {
1682
30.2k
      return false;
1683
30.2k
    }
1684
1685
1.22k
    analysis::ConstantManager* const_mgr = context->get_constant_mgr();
1686
1.22k
    std::vector<const analysis::Constant*> lhs_constants =
1687
1.22k
        const_mgr->GetOperandConstants(lhs);
1688
1.22k
    const analysis::Constant* lhs_const = ConstInput(lhs_constants);
1689
1.22k
    if (!lhs_const) {
1690
899
      return false;
1691
899
    }
1692
1693
322
    std::vector<const analysis::Constant*> rhs_constants =
1694
322
        const_mgr->GetOperandConstants(rhs);
1695
322
    const analysis::Constant* rhs_const = ConstInput(rhs_constants);
1696
322
    if (!rhs_const) {
1697
117
      return false;
1698
117
    }
1699
1700
205
    uint32_t merged_constant =
1701
205
        PerformOperation(const_mgr, opcode, lhs_const, rhs_const);
1702
205
    if (!merged_constant) {
1703
0
      return false;
1704
0
    }
1705
1706
205
    InstructionBuilder ir_builder(
1707
205
        context, inst,
1708
205
        IRContext::kAnalysisDefUse | IRContext::kAnalysisInstrToBlockMapping);
1709
1710
205
    Instruction* new_rhs = ir_builder.AddBinaryOp(
1711
205
        inst->type_id(), opcode,
1712
205
        NonConstInput(context, lhs_constants[0], lhs)->result_id(),
1713
205
        NonConstInput(context, rhs_constants[0], rhs)->result_id());
1714
1715
205
    if (!new_rhs) {
1716
0
      return false;
1717
0
    }
1718
1719
205
    inst->SetInOperands({{SPV_OPERAND_TYPE_ID, {merged_constant}},
1720
205
                         {SPV_OPERAND_TYPE_ID, {new_rhs->result_id()}}});
1721
205
    return true;
1722
205
  };
1723
67.1k
}
1724
1725
// Reassociate floating point mul/div instructions, which have mul/div inputs,
1726
// both of which contain a constant.
1727
// e.g:
1728
//   (a * C0) / (C1 / b) =  (C0 / C1) * (a * b)
1729
//   (C0 / a) * (b / C1) =  (C0 / C1) * (b / a)
1730
//   (a / C0) / (b * C1) =  (1 / (C0 * C1)) * (a / b)
1731
33.5k
FoldingRule ReassociateNestedMulDivFloat() {
1732
33.5k
  return [](IRContext* context, Instruction* inst,
1733
296k
            const std::vector<const analysis::Constant*>& constants) {
1734
296k
    assert(inst->opcode() == spv::Op::OpFMul ||
1735
296k
           inst->opcode() == spv::Op::OpFDiv);
1736
1737
    // Handled by other folding rules.
1738
296k
    if (constants[0] || constants[1]) {
1739
208k
      return false;
1740
208k
    }
1741
1742
87.8k
    const analysis::Type* type =
1743
87.8k
        context->get_type_mgr()->GetType(inst->type_id());
1744
1745
87.8k
    if (type->IsCooperativeMatrix()) {
1746
0
      return false;
1747
0
    }
1748
1749
87.8k
    uint32_t width = ElementWidth(type);
1750
87.8k
    if (width != 32 && width != 64) return false;
1751
1752
87.8k
    if (!inst->IsFloatingPointFoldingAllowed()) return false;
1753
1754
87.8k
    analysis::DefUseManager* def_use_mgr = context->get_def_use_mgr();
1755
87.8k
    Instruction* lhs = def_use_mgr->GetDef(inst->GetSingleWordInOperand(0));
1756
87.8k
    Instruction* rhs = def_use_mgr->GetDef(inst->GetSingleWordInOperand(1));
1757
1758
87.8k
    bool lhs_is_mul = lhs->opcode() == spv::Op::OpFMul;
1759
87.8k
    bool lhs_is_div = lhs->opcode() == spv::Op::OpFDiv;
1760
87.8k
    bool rhs_is_mul = rhs->opcode() == spv::Op::OpFMul;
1761
87.8k
    bool rhs_is_div = rhs->opcode() == spv::Op::OpFDiv;
1762
87.8k
    if (!(lhs_is_mul || lhs_is_div) || !(rhs_is_mul || rhs_is_div)) {
1763
84.4k
      return false;
1764
84.4k
    }
1765
1766
3.38k
    if (!lhs->IsFloatingPointFoldingAllowed() ||
1767
3.38k
        !rhs->IsFloatingPointFoldingAllowed()) {
1768
0
      return false;
1769
0
    }
1770
1771
3.38k
    analysis::ConstantManager* const_mgr = context->get_constant_mgr();
1772
3.38k
    std::vector<const analysis::Constant*> lhs_constants =
1773
3.38k
        const_mgr->GetOperandConstants(lhs);
1774
3.38k
    if (!lhs_constants[0] && !lhs_constants[1]) {
1775
2.42k
      return false;
1776
2.42k
    }
1777
1778
966
    std::vector<const analysis::Constant*> rhs_constants =
1779
966
        const_mgr->GetOperandConstants(rhs);
1780
966
    if (!rhs_constants[0] && !rhs_constants[1]) {
1781
201
      return false;
1782
201
    }
1783
1784
765
    const analysis::Constant* lhs_const =
1785
765
        lhs_constants[0] ? lhs_constants[0] : lhs_constants[1];
1786
765
    const analysis::Constant* rhs_const =
1787
765
        rhs_constants[0] ? rhs_constants[0] : rhs_constants[1];
1788
765
    if (!lhs_const || !rhs_const) return false;
1789
1790
765
    bool const_lhs_rcp = lhs_constants[0] ? false : lhs_is_div;
1791
765
    bool const_rhs_rcp = rhs_constants[0] ? false : rhs_is_div;
1792
1793
765
    uint32_t non_const_lhs = lhs_constants[0] ? lhs->GetSingleWordInOperand(1)
1794
765
                                              : lhs->GetSingleWordInOperand(0);
1795
765
    bool non_const_lhs_rcp = lhs_constants[0] ? lhs_is_div : false;
1796
1797
765
    uint32_t non_const_rhs = rhs_constants[0] ? rhs->GetSingleWordInOperand(1)
1798
765
                                              : rhs->GetSingleWordInOperand(0);
1799
765
    bool non_const_rhs_rcp = rhs_constants[0] ? rhs_is_div : false;
1800
1801
    // Rcp the rhs if we're actually dividing it.
1802
765
    if (inst->opcode() == spv::Op::OpFDiv) {
1803
132
      const_rhs_rcp = !const_rhs_rcp;
1804
132
      non_const_rhs_rcp = !non_const_rhs_rcp;
1805
132
    }
1806
1807
765
    if (const_lhs_rcp) {
1808
19
      lhs_const =
1809
19
          const_mgr->FindDeclaredConstant(Reciprocal(const_mgr, lhs_const));
1810
19
      if (!lhs_const) {
1811
19
        return false;
1812
19
      }
1813
19
    }
1814
746
    if (const_rhs_rcp) {
1815
126
      rhs_const =
1816
126
          const_mgr->FindDeclaredConstant(Reciprocal(const_mgr, rhs_const));
1817
126
      if (!rhs_const) {
1818
32
        return false;
1819
32
      }
1820
126
    }
1821
1822
714
    uint32_t merged_constant =
1823
714
        PerformOperation(const_mgr, spv::Op::OpFMul, lhs_const, rhs_const);
1824
1825
714
    if (!merged_constant) {
1826
251
      return false;
1827
251
    }
1828
1829
463
    spv::Op op = spv::Op::OpNop;
1830
463
    Instruction* new_rhs = nullptr;
1831
1832
463
    InstructionBuilder ir_builder(
1833
463
        context, inst,
1834
463
        IRContext::kAnalysisDefUse | IRContext::kAnalysisInstrToBlockMapping);
1835
1836
    //  a * b => C * (b * a)
1837
463
    if (!non_const_lhs_rcp && !non_const_rhs_rcp) {
1838
279
      new_rhs = ir_builder.AddBinaryOp(inst->type_id(), spv::Op::OpFMul,
1839
279
                                       non_const_lhs, non_const_rhs);
1840
279
      op = spv::Op::OpFMul;
1841
279
    }
1842
    // 1/a * b => C * (b / a)
1843
184
    else if (non_const_lhs_rcp && !non_const_rhs_rcp) {
1844
7
      new_rhs = ir_builder.AddBinaryOp(inst->type_id(), spv::Op::OpFDiv,
1845
7
                                       non_const_rhs, non_const_lhs);
1846
7
      op = spv::Op::OpFMul;
1847
7
    }
1848
    //  a * 1/b => C * (a / b)
1849
177
    else if (!non_const_lhs_rcp && non_const_rhs_rcp) {
1850
55
      new_rhs = ir_builder.AddBinaryOp(inst->type_id(), spv::Op::OpFDiv,
1851
55
                                       non_const_lhs, non_const_rhs);
1852
55
      op = spv::Op::OpFMul;
1853
55
    }
1854
    // 1/a * 1/b => C / (a * b)
1855
122
    else {
1856
122
      new_rhs = ir_builder.AddBinaryOp(inst->type_id(), spv::Op::OpFMul,
1857
122
                                       non_const_lhs, non_const_rhs);
1858
122
      op = spv::Op::OpFDiv;
1859
122
    }
1860
1861
463
    if (!new_rhs) {
1862
0
      return false;
1863
0
    }
1864
1865
463
    inst->SetOpcode(op);
1866
463
    inst->SetInOperands({{SPV_OPERAND_TYPE_ID, {merged_constant}},
1867
463
                         {SPV_OPERAND_TYPE_ID, {new_rhs->result_id()}}});
1868
463
    return true;
1869
463
  };
1870
33.5k
}
1871
1872
// Reassociate add/sub instructions, which have add/sub inputs,
1873
// both of which contain a constant.
1874
// e.g:
1875
//   (a + C0) - (C1 - b) =  (C0 - C1) + (a + b)
1876
//   (C0 - a) + (b - C1) =  (C0 - C1) + (b - a)
1877
//   (a - C0) - (b + C1) = (-C0 - C1) + (a - b)
1878
67.1k
FoldingRule ReassociateNestedAddSub() {
1879
67.1k
  return [](IRContext* context, Instruction* inst,
1880
650k
            const std::vector<const analysis::Constant*>& constants) {
1881
650k
    assert(inst->opcode() == spv::Op::OpFAdd ||
1882
650k
           inst->opcode() == spv::Op::OpIAdd ||
1883
650k
           inst->opcode() == spv::Op::OpFSub ||
1884
650k
           inst->opcode() == spv::Op::OpISub);
1885
1886
    // Handled by other folding rules.
1887
650k
    if (constants[0] || constants[1]) {
1888
206k
      return false;
1889
206k
    }
1890
1891
443k
    const analysis::Type* type =
1892
443k
        context->get_type_mgr()->GetType(inst->type_id());
1893
1894
443k
    if (type->IsCooperativeMatrix()) {
1895
0
      return false;
1896
0
    }
1897
1898
443k
    uint32_t width = ElementWidth(type);
1899
443k
    if (width != 32 && width != 64) return false;
1900
1901
443k
    bool uses_float = HasFloatingPoint(type);
1902
443k
    if (uses_float && !inst->IsFloatingPointFoldingAllowed()) return false;
1903
1904
440k
    analysis::DefUseManager* def_use_mgr = context->get_def_use_mgr();
1905
440k
    Instruction* lhs = def_use_mgr->GetDef(inst->GetSingleWordInOperand(0));
1906
440k
    Instruction* rhs = def_use_mgr->GetDef(inst->GetSingleWordInOperand(1));
1907
1908
440k
    spv::Op add_op = uses_float ? spv::Op::OpFAdd : spv::Op::OpIAdd;
1909
440k
    spv::Op sub_op = uses_float ? spv::Op::OpFSub : spv::Op::OpISub;
1910
1911
440k
    bool lhs_is_add = lhs->opcode() == add_op;
1912
440k
    bool lhs_is_sub = lhs->opcode() == sub_op;
1913
440k
    bool rhs_is_add = rhs->opcode() == add_op;
1914
440k
    bool rhs_is_sub = rhs->opcode() == sub_op;
1915
440k
    if (!(lhs_is_add || lhs_is_sub) || !(rhs_is_add || rhs_is_sub)) {
1916
438k
      return false;
1917
438k
    }
1918
1919
2.65k
    if (uses_float && (!lhs->IsFloatingPointFoldingAllowed() ||
1920
2.35k
                       !rhs->IsFloatingPointFoldingAllowed())) {
1921
0
      return false;
1922
0
    }
1923
1924
2.65k
    analysis::ConstantManager* const_mgr = context->get_constant_mgr();
1925
2.65k
    std::vector<const analysis::Constant*> lhs_constants =
1926
2.65k
        const_mgr->GetOperandConstants(lhs);
1927
2.65k
    if (!lhs_constants[0] && !lhs_constants[1]) {
1928
1.11k
      return false;
1929
1.11k
    }
1930
1931
1.54k
    std::vector<const analysis::Constant*> rhs_constants =
1932
1.54k
        const_mgr->GetOperandConstants(rhs);
1933
1.54k
    if (!rhs_constants[0] && !rhs_constants[1]) {
1934
48
      return false;
1935
48
    }
1936
1937
1.49k
    const analysis::Constant* lhs_const =
1938
1.49k
        lhs_constants[0] ? lhs_constants[0] : lhs_constants[1];
1939
1.49k
    const analysis::Constant* rhs_const =
1940
1.49k
        rhs_constants[0] ? rhs_constants[0] : rhs_constants[1];
1941
1.49k
    if (!lhs_const || !rhs_const) return false;
1942
1943
1.49k
    bool const_lhs_neg = lhs_constants[0] ? false : lhs_is_sub;
1944
1.49k
    bool const_rhs_neg = rhs_constants[0] ? false : rhs_is_sub;
1945
1946
1.49k
    uint32_t non_const_lhs = lhs_constants[0] ? lhs->GetSingleWordInOperand(1)
1947
1.49k
                                              : lhs->GetSingleWordInOperand(0);
1948
1.49k
    bool non_const_lhs_neg = lhs_constants[0] ? lhs_is_sub : false;
1949
1950
1.49k
    uint32_t non_const_rhs = rhs_constants[0] ? rhs->GetSingleWordInOperand(1)
1951
1.49k
                                              : rhs->GetSingleWordInOperand(0);
1952
1.49k
    bool non_const_rhs_neg = rhs_constants[0] ? rhs_is_sub : false;
1953
1954
    // Negate the rhs if we're actually subtracting it.
1955
1.49k
    if (inst->opcode() == spv::Op::OpFSub ||
1956
810
        inst->opcode() == spv::Op::OpISub) {
1957
687
      const_rhs_neg = !const_rhs_neg;
1958
687
      non_const_rhs_neg = !non_const_rhs_neg;
1959
687
    }
1960
1961
1.49k
    if (const_lhs_neg) {
1962
489
      lhs_const =
1963
489
          const_mgr->FindDeclaredConstant(NegateConstant(const_mgr, lhs_const));
1964
489
      if (!lhs_const) {
1965
0
        return false;
1966
0
      }
1967
489
    }
1968
1.49k
    if (const_rhs_neg) {
1969
1.19k
      rhs_const =
1970
1.19k
          const_mgr->FindDeclaredConstant(NegateConstant(const_mgr, rhs_const));
1971
1.19k
      if (!rhs_const) {
1972
0
        return false;
1973
0
      }
1974
1.19k
    }
1975
1976
1.49k
    uint32_t merged_constant =
1977
1.49k
        PerformOperation(const_mgr, add_op, lhs_const, rhs_const);
1978
1979
1.49k
    if (!merged_constant) {
1980
260
      return false;
1981
260
    }
1982
1983
1.23k
    spv::Op op = spv::Op::OpNop;
1984
1.23k
    Instruction* new_rhs = nullptr;
1985
1986
1.23k
    InstructionBuilder ir_builder(
1987
1.23k
        context, inst,
1988
1.23k
        IRContext::kAnalysisDefUse | IRContext::kAnalysisInstrToBlockMapping);
1989
1990
    //  a +  b => C + (b + a)
1991
1.23k
    if (!non_const_lhs_neg && !non_const_rhs_neg) {
1992
599
      new_rhs = ir_builder.AddBinaryOp(inst->type_id(), add_op, non_const_lhs,
1993
599
                                       non_const_rhs);
1994
599
      op = add_op;
1995
599
    }
1996
    // -a +  b => C + (b - a)
1997
638
    else if (non_const_lhs_neg && !non_const_rhs_neg) {
1998
44
      new_rhs = ir_builder.AddBinaryOp(inst->type_id(), sub_op, non_const_rhs,
1999
44
                                       non_const_lhs);
2000
44
      op = add_op;
2001
44
    }
2002
    //  a + -b => C + (a - b)
2003
594
    else if (!non_const_lhs_neg && non_const_rhs_neg) {
2004
9
      new_rhs = ir_builder.AddBinaryOp(inst->type_id(), sub_op, non_const_lhs,
2005
9
                                       non_const_rhs);
2006
9
      op = add_op;
2007
9
    }
2008
    // -a + -b => C - (a + b)
2009
585
    else {
2010
585
      new_rhs = ir_builder.AddBinaryOp(inst->type_id(), add_op, non_const_lhs,
2011
585
                                       non_const_rhs);
2012
585
      op = sub_op;
2013
585
    }
2014
2015
1.23k
    if (!new_rhs) {
2016
0
      return false;
2017
0
    }
2018
2019
1.23k
    inst->SetOpcode(op);
2020
1.23k
    inst->SetInOperands({{SPV_OPERAND_TYPE_ID, {merged_constant}},
2021
1.23k
                         {SPV_OPERAND_TYPE_ID, {new_rhs->result_id()}}});
2022
1.23k
    return true;
2023
1.23k
  };
2024
67.1k
}
2025
2026
16.7k
FoldingRule IntMultipleBy1() {
2027
16.7k
  return [](IRContext*, Instruction* inst,
2028
16.7k
            const std::vector<const analysis::Constant*>& constants) {
2029
14.7k
    assert(inst->opcode() == spv::Op::OpIMul &&
2030
14.7k
           "Wrong opcode.  Should be OpIMul.");
2031
43.6k
    for (uint32_t i = 0; i < 2; i++) {
2032
29.4k
      if (constants[i] == nullptr) {
2033
16.3k
        continue;
2034
16.3k
      }
2035
13.0k
      const analysis::IntConstant* int_constant = constants[i]->AsIntConstant();
2036
13.0k
      if (int_constant) {
2037
12.8k
        uint32_t width = ElementWidth(int_constant->type());
2038
12.8k
        if (width != 32 && width != 64) return false;
2039
12.8k
        bool is_one = (width == 32) ? int_constant->GetU32BitValue() == 1u
2040
12.8k
                                    : int_constant->GetU64BitValue() == 1ull;
2041
12.8k
        if (is_one) {
2042
479
          inst->SetOpcode(spv::Op::OpCopyObject);
2043
479
          inst->SetInOperands(
2044
479
              {{SPV_OPERAND_TYPE_ID, {inst->GetSingleWordInOperand(1 - i)}}});
2045
479
          return true;
2046
479
        }
2047
12.8k
      }
2048
13.0k
    }
2049
14.2k
    return false;
2050
14.7k
  };
2051
16.7k
}
2052
2053
// Returns the number of elements that the |index|th in operand in |inst|
2054
// contributes to the result of |inst|.  |inst| must be an
2055
// OpCompositeConstructInstruction.
2056
uint32_t GetNumOfElementsContributedByOperand(IRContext* context,
2057
                                              const Instruction* inst,
2058
17.4k
                                              uint32_t index) {
2059
17.4k
  assert(inst->opcode() == spv::Op::OpCompositeConstruct);
2060
17.4k
  analysis::DefUseManager* def_use_mgr = context->get_def_use_mgr();
2061
17.4k
  analysis::TypeManager* type_mgr = context->get_type_mgr();
2062
2063
17.4k
  analysis::Vector* result_type =
2064
17.4k
      type_mgr->GetType(inst->type_id())->AsVector();
2065
17.4k
  if (result_type == nullptr) {
2066
    // If the result of the OpCompositeConstruct is not a vector then every
2067
    // operands corresponds to a single element in the result.
2068
0
    return 1;
2069
0
  }
2070
2071
  // If the result type is a vector then the operands are either scalars or
2072
  // vectors. If it is a scalar, then it corresponds to a single element.  If it
2073
  // is a vector, then each element in the vector will be an element in the
2074
  // result.
2075
17.4k
  uint32_t id = inst->GetSingleWordInOperand(index);
2076
17.4k
  Instruction* def = def_use_mgr->GetDef(id);
2077
17.4k
  analysis::Vector* type = type_mgr->GetType(def->type_id())->AsVector();
2078
17.4k
  if (type == nullptr) {
2079
17.4k
    return 1;
2080
17.4k
  }
2081
0
  return type->element_count();
2082
17.4k
}
2083
2084
// Returns the in-operands for an OpCompositeExtract instruction that are needed
2085
// to extract the |result_index|th element in the result of |inst| without using
2086
// the result of |inst|. Returns the empty vector if |result_index| is
2087
// out-of-bounds. |inst| must be an |OpCompositeConstruct| instruction.
2088
std::vector<Operand> GetExtractOperandsForElementOfCompositeConstruct(
2089
39.6k
    IRContext* context, const Instruction* inst, uint32_t result_index) {
2090
39.6k
  assert(inst->opcode() == spv::Op::OpCompositeConstruct);
2091
39.6k
  analysis::DefUseManager* def_use_mgr = context->get_def_use_mgr();
2092
39.6k
  analysis::TypeManager* type_mgr = context->get_type_mgr();
2093
2094
39.6k
  analysis::Type* result_type = type_mgr->GetType(inst->type_id());
2095
39.6k
  if (result_type->AsVector() == nullptr) {
2096
28.7k
    if (result_index < inst->NumInOperands()) {
2097
28.7k
      uint32_t id = inst->GetSingleWordInOperand(result_index);
2098
28.7k
      return {Operand(SPV_OPERAND_TYPE_ID, {id})};
2099
28.7k
    }
2100
0
    return {};
2101
28.7k
  }
2102
2103
  // If the result type is a vector, then vector operands are concatenated.
2104
10.8k
  uint32_t total_element_count = 0;
2105
17.4k
  for (uint32_t idx = 0; idx < inst->NumInOperands(); ++idx) {
2106
17.4k
    uint32_t element_count =
2107
17.4k
        GetNumOfElementsContributedByOperand(context, inst, idx);
2108
17.4k
    total_element_count += element_count;
2109
17.4k
    if (result_index < total_element_count) {
2110
10.8k
      std::vector<Operand> operands;
2111
10.8k
      uint32_t id = inst->GetSingleWordInOperand(idx);
2112
10.8k
      Instruction* operand_def = def_use_mgr->GetDef(id);
2113
10.8k
      analysis::Type* operand_type = type_mgr->GetType(operand_def->type_id());
2114
2115
10.8k
      operands.push_back({SPV_OPERAND_TYPE_ID, {id}});
2116
10.8k
      if (operand_type->AsVector()) {
2117
0
        uint32_t start_index_of_id = total_element_count - element_count;
2118
0
        uint32_t index_into_id = result_index - start_index_of_id;
2119
0
        operands.push_back({SPV_OPERAND_TYPE_LITERAL_INTEGER, {index_into_id}});
2120
0
      }
2121
10.8k
      return operands;
2122
10.8k
    }
2123
17.4k
  }
2124
0
  return {};
2125
10.8k
}
2126
2127
bool CompositeConstructFeedingExtract(
2128
    IRContext* context, Instruction* inst,
2129
290k
    const std::vector<const analysis::Constant*>&) {
2130
  // If the input to an OpCompositeExtract is an OpCompositeConstruct,
2131
  // then we can simply use the appropriate element in the construction.
2132
290k
  assert(inst->opcode() == spv::Op::OpCompositeExtract &&
2133
290k
         "Wrong opcode.  Should be OpCompositeExtract.");
2134
290k
  analysis::DefUseManager* def_use_mgr = context->get_def_use_mgr();
2135
2136
  // If there are no index operands, then this rule cannot do anything.
2137
290k
  if (inst->NumInOperands() <= 1) {
2138
0
    return false;
2139
0
  }
2140
2141
290k
  uint32_t cid = inst->GetSingleWordInOperand(kExtractCompositeIdInIdx);
2142
290k
  Instruction* cinst = def_use_mgr->GetDef(cid);
2143
2144
290k
  if (cinst->opcode() != spv::Op::OpCompositeConstruct) {
2145
250k
    return false;
2146
250k
  }
2147
2148
39.6k
  uint32_t index_into_result = inst->GetSingleWordInOperand(1);
2149
39.6k
  std::vector<Operand> operands =
2150
39.6k
      GetExtractOperandsForElementOfCompositeConstruct(context, cinst,
2151
39.6k
                                                       index_into_result);
2152
2153
39.6k
  if (operands.empty()) {
2154
0
    return false;
2155
0
  }
2156
2157
  // Add the remaining indices for extraction.
2158
39.6k
  for (uint32_t i = 2; i < inst->NumInOperands(); ++i) {
2159
34
    operands.push_back(
2160
34
        {SPV_OPERAND_TYPE_LITERAL_INTEGER, {inst->GetSingleWordInOperand(i)}});
2161
34
  }
2162
2163
39.6k
  if (operands.size() == 1) {
2164
    // If there were no extra indices, then we have the final object.  No need
2165
    // to extract any more.
2166
39.5k
    inst->SetOpcode(spv::Op::OpCopyObject);
2167
39.5k
  }
2168
2169
39.6k
  inst->SetInOperands(std::move(operands));
2170
39.6k
  return true;
2171
39.6k
}
2172
2173
// Walks the indexes chain from |start| to |end| of an OpCompositeInsert or
2174
// OpCompositeExtract instruction, and returns the type id of the final element
2175
// being accessed. Returns 0 if a valid type could not be found.
2176
uint32_t GetElementType(uint32_t type_id, Instruction::iterator start,
2177
                        Instruction::iterator end,
2178
96.0k
                        const analysis::DefUseManager* def_use_manager) {
2179
96.0k
  for (auto index : make_range(std::move(start), std::move(end))) {
2180
1.12k
    const Instruction* type_inst = def_use_manager->GetDef(type_id);
2181
1.12k
    assert(index.type == SPV_OPERAND_TYPE_LITERAL_INTEGER &&
2182
1.12k
           index.words.size() == 1);
2183
1.12k
    switch (type_inst->opcode()) {
2184
536
      case spv::Op::OpTypeArray:
2185
536
      case spv::Op::OpTypeMatrix:
2186
536
      case spv::Op::OpTypeVector:
2187
536
      case spv::Op::OpTypeVectorIdEXT:
2188
536
        type_id = type_inst->GetSingleWordInOperand(0);
2189
536
        break;
2190
586
      case spv::Op::OpTypeStruct:
2191
586
        type_id = type_inst->GetSingleWordInOperand(index.words[0]);
2192
586
        break;
2193
0
      default:
2194
0
        return 0;
2195
1.12k
    }
2196
1.12k
  }
2197
96.0k
  return type_id;
2198
96.0k
}
2199
2200
// If the input to an OpCompositeExtract is an OpCopyLogical, then we can
2201
// hoist the extraction before the copy.
2202
bool CopyLogicalFeedingExtract(IRContext* context, Instruction* inst,
2203
245k
                               const std::vector<const analysis::Constant*>&) {
2204
245k
  assert(inst->opcode() == spv::Op::OpCompositeExtract &&
2205
245k
         "Wrong opcode.  Should be OpCompositeExtract.");
2206
2207
245k
  analysis::DefUseManager* def_use_mgr = context->get_def_use_mgr();
2208
245k
  uint32_t cid = inst->GetSingleWordInOperand(kExtractCompositeIdInIdx);
2209
245k
  Instruction* cinst = def_use_mgr->GetDef(cid);
2210
2211
245k
  if (cinst->opcode() != spv::Op::OpCopyLogical) {
2212
245k
    return false;
2213
245k
  }
2214
2215
0
  uint32_t original_composite_id = cinst->GetSingleWordInOperand(0);
2216
0
  Instruction* original_composite_inst =
2217
0
      def_use_mgr->GetDef(original_composite_id);
2218
2219
0
  std::vector<uint32_t> indices;
2220
0
  for (uint32_t i = 1; i < inst->NumInOperands(); ++i) {
2221
0
    indices.push_back(inst->GetSingleWordInOperand(i));
2222
0
  }
2223
2224
0
  uint32_t original_element_type_id =
2225
0
      GetElementType(original_composite_inst->type_id(), inst->begin() + 3,
2226
0
                     inst->end(), def_use_mgr);
2227
0
  assert(original_element_type_id != 0 &&
2228
0
         "Could not find the element type.  Invalid SPIR-V.");
2229
2230
0
  InstructionBuilder ir_builder(
2231
0
      context, inst,
2232
0
      IRContext::kAnalysisDefUse | IRContext::kAnalysisInstrToBlockMapping);
2233
2234
0
  Instruction* new_extract = ir_builder.AddCompositeExtract(
2235
0
      original_element_type_id, original_composite_id, indices);
2236
2237
0
  if (original_element_type_id == inst->type_id())
2238
0
    inst->SetOpcode(spv::Op::OpCopyObject);
2239
0
  else
2240
0
    inst->SetOpcode(spv::Op::OpCopyLogical);
2241
0
  inst->SetInOperands({{SPV_OPERAND_TYPE_ID, {new_extract->result_id()}}});
2242
0
  return true;
2243
0
}
2244
2245
// If the input to an OpCompositeExtract is an OpLoad, we can change the
2246
// load into a load of an OpAccessChain.
2247
bool LoadFeedingExtract(IRContext* context, Instruction* inst,
2248
245k
                        const std::vector<const analysis::Constant*>&) {
2249
245k
  assert(inst->opcode() == spv::Op::OpCompositeExtract &&
2250
245k
         "Wrong opcode.  Should be OpCompositeExtract.");
2251
2252
245k
  analysis::DefUseManager* def_use_mgr = context->get_def_use_mgr();
2253
245k
  uint32_t cid = inst->GetSingleWordInOperand(kExtractCompositeIdInIdx);
2254
245k
  Instruction* cinst = def_use_mgr->GetDef(cid);
2255
2256
245k
  if (cinst->opcode() != spv::Op::OpLoad) {
2257
182k
    return false;
2258
182k
  }
2259
2260
63.0k
  Instruction* composite_type_inst = def_use_mgr->GetDef(cinst->type_id());
2261
63.0k
  if (composite_type_inst->opcode() != spv::Op::OpTypeStruct &&
2262
32.9k
      composite_type_inst->opcode() != spv::Op::OpTypeArray) {
2263
11.6k
    return false;
2264
11.6k
  }
2265
2266
  // Check the memory operands.
2267
51.4k
  if (cinst->NumInOperands() > 1) {
2268
603
    uint32_t memory_access_mask = cinst->GetSingleWordInOperand(1);
2269
603
    if (memory_access_mask & uint32_t(spv::MemoryAccessMask::Volatile)) {
2270
24
      return false;
2271
24
    }
2272
603
  }
2273
2274
51.4k
  uint32_t ptr_id = cinst->GetSingleWordInOperand(0);
2275
51.4k
  Instruction* ptr_inst = def_use_mgr->GetDef(ptr_id);
2276
51.4k
  Instruction* ptr_type_inst = def_use_mgr->GetDef(ptr_inst->type_id());
2277
51.4k
  assert(ptr_type_inst->opcode() == spv::Op::OpTypePointer);
2278
51.4k
  spv::StorageClass storage_class =
2279
51.4k
      static_cast<spv::StorageClass>(ptr_type_inst->GetSingleWordInOperand(0));
2280
2281
  // If the storage class is Function or Private, we do not want to fold.
2282
  // These are the storage classes that the local-access-chain-convert pass
2283
  // works on.
2284
51.4k
  if (storage_class == spv::StorageClass::Function ||
2285
51.4k
      storage_class == spv::StorageClass::Private) {
2286
51.4k
    return false;
2287
51.4k
  }
2288
2289
0
  analysis::ConstantManager* const_mgr = context->get_constant_mgr();
2290
0
  analysis::TypeManager* type_mgr = context->get_type_mgr();
2291
0
  std::vector<uint32_t> index_ids;
2292
0
  for (uint32_t i = 1; i < inst->NumInOperands(); ++i) {
2293
0
    uint32_t index = inst->GetSingleWordInOperand(i);
2294
0
    const analysis::Constant* index_const =
2295
0
        const_mgr->GetConstant(type_mgr->GetUIntType(), {index});
2296
0
    index_ids.push_back(
2297
0
        const_mgr->GetDefiningInstruction(index_const)->result_id());
2298
0
  }
2299
2300
0
  InstructionBuilder ir_builder(
2301
0
      context, cinst,
2302
0
      IRContext::kAnalysisDefUse | IRContext::kAnalysisInstrToBlockMapping);
2303
2304
0
  uint32_t element_ptr_type_id =
2305
0
      type_mgr->FindPointerToType(inst->type_id(), storage_class);
2306
0
  if (element_ptr_type_id == 0) {
2307
0
    return false;
2308
0
  }
2309
2310
0
  Instruction* access_chain =
2311
0
      ir_builder.AddAccessChain(element_ptr_type_id, ptr_id, index_ids);
2312
0
  std::vector<Operand> load_operands;
2313
0
  load_operands.push_back({SPV_OPERAND_TYPE_ID, {access_chain->result_id()}});
2314
2315
0
  if (cinst->NumInOperands() > 1) {
2316
0
    uint32_t memory_access_mask = cinst->GetSingleWordInOperand(1);
2317
0
    load_operands.push_back(
2318
0
        {SPV_OPERAND_TYPE_MEMORY_ACCESS, {memory_access_mask}});
2319
2320
0
    uint32_t current_operand_index = 2;
2321
0
    if (memory_access_mask & uint32_t(spv::MemoryAccessMask::Aligned)) {
2322
0
      uint32_t original_alignment =
2323
0
          cinst->GetSingleWordInOperand(current_operand_index);
2324
2325
0
      std::vector<uint32_t> extract_indices;
2326
0
      for (uint32_t i = 1; i < inst->NumInOperands(); ++i) {
2327
0
        extract_indices.push_back(inst->GetSingleWordInOperand(i));
2328
0
      }
2329
2330
0
      std::optional<uint32_t> offset =
2331
0
          type_mgr->GetType(cinst->type_id())->GetByteOffset(extract_indices);
2332
0
      if (!offset) {
2333
0
        return false;
2334
0
      }
2335
2336
0
      uint32_t new_alignment = original_alignment;
2337
0
      if (*offset != 0) {
2338
0
        uint32_t offset_alignment = *offset & ~(*offset - 1);
2339
0
        new_alignment = std::min(original_alignment, offset_alignment);
2340
0
      }
2341
2342
0
      load_operands.push_back(
2343
0
          {SPV_OPERAND_TYPE_TYPED_LITERAL_NUMBER, {new_alignment}});
2344
0
      current_operand_index++;
2345
0
    }
2346
2347
    // Copy the remaining operands
2348
0
    for (; current_operand_index < cinst->NumInOperands();
2349
0
         ++current_operand_index) {
2350
0
      load_operands.push_back(cinst->GetInOperand(current_operand_index));
2351
0
    }
2352
0
  }
2353
2354
0
  uint32_t load_result_id = context->TakeNextId();
2355
0
  if (load_result_id == 0) return false;
2356
2357
0
  std::unique_ptr<Instruction> new_load_inst(
2358
0
      new Instruction(context, spv::Op::OpLoad, inst->type_id(), load_result_id,
2359
0
                      load_operands));
2360
0
  Instruction* new_load = ir_builder.AddInstruction(std::move(new_load_inst));
2361
2362
0
  inst->SetOpcode(spv::Op::OpCopyObject);
2363
0
  inst->SetInOperands({{SPV_OPERAND_TYPE_ID, {new_load->result_id()}}});
2364
2365
0
  return true;
2366
0
}
2367
2368
// Returns true of |inst_1| and |inst_2| have the same indexes that will be used
2369
// to index into a composite object, excluding the last index.  The two
2370
// instructions must have the same opcode, and be either OpCompositeExtract or
2371
// OpCompositeInsert instructions.
2372
401k
bool HaveSameIndexesExceptForLast(Instruction* inst_1, Instruction* inst_2) {
2373
401k
  assert(inst_1->opcode() == inst_2->opcode() &&
2374
401k
         "Expecting the opcodes to be the same.");
2375
401k
  assert((inst_1->opcode() == spv::Op::OpCompositeInsert ||
2376
401k
          inst_1->opcode() == spv::Op::OpCompositeExtract) &&
2377
401k
         "Instructions must be OpCompositeInsert or OpCompositeExtract.");
2378
2379
401k
  if (inst_1->NumInOperands() != inst_2->NumInOperands()) {
2380
2.08k
    return false;
2381
2.08k
  }
2382
2383
399k
  uint32_t first_index_position =
2384
399k
      (inst_1->opcode() == spv::Op::OpCompositeInsert ? 2 : 1);
2385
409k
  for (uint32_t i = first_index_position; i < inst_1->NumInOperands() - 1;
2386
399k
       i++) {
2387
10.1k
    if (inst_1->GetSingleWordInOperand(i) !=
2388
10.1k
        inst_2->GetSingleWordInOperand(i)) {
2389
492
      return false;
2390
492
    }
2391
10.1k
  }
2392
398k
  return true;
2393
399k
}
2394
2395
// If the OpCompositeConstruct is simply putting back together elements that
2396
// where extracted from the same source, we can simply reuse the source.
2397
//
2398
// This is a common code pattern because of the way that scalar replacement
2399
// works.
2400
bool CompositeExtractFeedingConstruct(
2401
    IRContext* context, Instruction* inst,
2402
97.5k
    const std::vector<const analysis::Constant*>&) {
2403
97.5k
  assert(inst->opcode() == spv::Op::OpCompositeConstruct &&
2404
97.5k
         "Wrong opcode.  Should be OpCompositeConstruct.");
2405
97.5k
  analysis::DefUseManager* def_use_mgr = context->get_def_use_mgr();
2406
97.5k
  uint32_t original_id = 0;
2407
2408
97.5k
  if (inst->NumInOperands() == 0) {
2409
    // The struct being constructed has no members.
2410
0
    return false;
2411
0
  }
2412
2413
  // Check each element to make sure they are:
2414
  // - extractions
2415
  // - extracting the same position they are inserting
2416
  // - all extract from the same id.
2417
97.5k
  Instruction* first_element_inst = nullptr;
2418
126k
  for (uint32_t i = 0; i < inst->NumInOperands(); ++i) {
2419
124k
    const uint32_t element_id = inst->GetSingleWordInOperand(i);
2420
124k
    Instruction* element_inst = def_use_mgr->GetDef(element_id);
2421
124k
    if (first_element_inst == nullptr) {
2422
97.5k
      first_element_inst = element_inst;
2423
97.5k
    }
2424
2425
124k
    if (element_inst->opcode() != spv::Op::OpCompositeExtract) {
2426
92.0k
      return false;
2427
92.0k
    }
2428
2429
32.6k
    if (!HaveSameIndexesExceptForLast(element_inst, first_element_inst)) {
2430
0
      return false;
2431
0
    }
2432
2433
32.6k
    if (element_inst->GetSingleWordInOperand(element_inst->NumInOperands() -
2434
32.6k
                                             1) != i) {
2435
3.65k
      return false;
2436
3.65k
    }
2437
2438
29.0k
    if (i == 0) {
2439
9.86k
      original_id =
2440
9.86k
          element_inst->GetSingleWordInOperand(kExtractCompositeIdInIdx);
2441
19.1k
    } else if (original_id !=
2442
19.1k
               element_inst->GetSingleWordInOperand(kExtractCompositeIdInIdx)) {
2443
438
      return false;
2444
438
    }
2445
29.0k
  }
2446
97.5k
  assert(first_element_inst != nullptr);
2447
2448
  // The last check it to see that the object being extracted from is the
2449
  // correct type.
2450
1.40k
  Instruction* original_inst = def_use_mgr->GetDef(original_id);
2451
1.40k
  uint32_t original_type_id =
2452
1.40k
      GetElementType(original_inst->type_id(), first_element_inst->begin() + 3,
2453
1.40k
                     first_element_inst->end() - 1, def_use_mgr);
2454
2455
1.40k
  if (inst->type_id() != original_type_id) {
2456
191
    return false;
2457
191
  }
2458
2459
1.21k
  if (first_element_inst->NumInOperands() == 2) {
2460
    // Simplify by using the original object.
2461
1.21k
    inst->SetOpcode(spv::Op::OpCopyObject);
2462
1.21k
    inst->SetInOperands({{SPV_OPERAND_TYPE_ID, {original_id}}});
2463
1.21k
    return true;
2464
1.21k
  }
2465
2466
  // Copies the original id and all indexes except for the last to the new
2467
  // extract instruction.
2468
0
  inst->SetOpcode(spv::Op::OpCompositeExtract);
2469
0
  inst->SetInOperands(std::vector<Operand>(first_element_inst->begin() + 2,
2470
0
                                           first_element_inst->end() - 1));
2471
0
  return true;
2472
1.21k
}
2473
2474
16.7k
FoldingRule InsertFeedingExtract() {
2475
16.7k
  return [](IRContext* context, Instruction* inst,
2476
378k
            const std::vector<const analysis::Constant*>&) {
2477
378k
    assert(inst->opcode() == spv::Op::OpCompositeExtract &&
2478
378k
           "Wrong opcode.  Should be OpCompositeExtract.");
2479
378k
    analysis::DefUseManager* def_use_mgr = context->get_def_use_mgr();
2480
378k
    uint32_t cid = inst->GetSingleWordInOperand(kExtractCompositeIdInIdx);
2481
378k
    Instruction* cinst = def_use_mgr->GetDef(cid);
2482
2483
378k
    if (cinst->opcode() != spv::Op::OpCompositeInsert) {
2484
289k
      return false;
2485
289k
    }
2486
2487
    // Find the first position where the list of insert and extract indicies
2488
    // differ, if at all.
2489
88.5k
    uint32_t i;
2490
108k
    for (i = 1; i < inst->NumInOperands(); ++i) {
2491
89.1k
      if (i + 1 >= cinst->NumInOperands()) {
2492
0
        break;
2493
0
      }
2494
2495
89.1k
      if (inst->GetSingleWordInOperand(i) !=
2496
89.1k
          cinst->GetSingleWordInOperand(i + 1)) {
2497
69.3k
        break;
2498
69.3k
      }
2499
89.1k
    }
2500
2501
    // We are extracting the element that was inserted.
2502
88.5k
    if (i == inst->NumInOperands() && i + 1 == cinst->NumInOperands()) {
2503
18.8k
      inst->SetOpcode(spv::Op::OpCopyObject);
2504
18.8k
      inst->SetInOperands(
2505
18.8k
          {{SPV_OPERAND_TYPE_ID,
2506
18.8k
            {cinst->GetSingleWordInOperand(kInsertObjectIdInIdx)}}});
2507
18.8k
      return true;
2508
18.8k
    }
2509
2510
    // Extracting the value that was inserted along with values for the base
2511
    // composite.  Cannot do anything.
2512
69.6k
    if (i == inst->NumInOperands()) {
2513
357
      return false;
2514
357
    }
2515
2516
    // Extracting an element of the value that was inserted.  Extract from
2517
    // that value directly.
2518
69.3k
    if (i + 1 == cinst->NumInOperands()) {
2519
0
      std::vector<Operand> operands;
2520
0
      operands.push_back(
2521
0
          {SPV_OPERAND_TYPE_ID,
2522
0
           {cinst->GetSingleWordInOperand(kInsertObjectIdInIdx)}});
2523
0
      for (; i < inst->NumInOperands(); ++i) {
2524
0
        operands.push_back({SPV_OPERAND_TYPE_LITERAL_INTEGER,
2525
0
                            {inst->GetSingleWordInOperand(i)}});
2526
0
      }
2527
0
      inst->SetInOperands(std::move(operands));
2528
0
      return true;
2529
0
    }
2530
2531
    // Extracting a value that is disjoint from the element being inserted.
2532
    // Rewrite the extract to use the composite input to the insert.
2533
69.3k
    std::vector<Operand> operands;
2534
69.3k
    operands.push_back(
2535
69.3k
        {SPV_OPERAND_TYPE_ID,
2536
69.3k
         {cinst->GetSingleWordInOperand(kInsertCompositeIdInIdx)}});
2537
139k
    for (i = 1; i < inst->NumInOperands(); ++i) {
2538
69.8k
      operands.push_back({SPV_OPERAND_TYPE_LITERAL_INTEGER,
2539
69.8k
                          {inst->GetSingleWordInOperand(i)}});
2540
69.8k
    }
2541
69.3k
    inst->SetInOperands(std::move(operands));
2542
69.3k
    return true;
2543
69.3k
  };
2544
16.7k
}
2545
2546
// When a VectorShuffle is feeding an Extract, we can extract from one of the
2547
// operands of the VectorShuffle.  We just need to adjust the index in the
2548
// extract instruction.
2549
16.7k
FoldingRule VectorShuffleFeedingExtract() {
2550
16.7k
  return [](IRContext* context, Instruction* inst,
2551
250k
            const std::vector<const analysis::Constant*>&) {
2552
250k
    assert(inst->opcode() == spv::Op::OpCompositeExtract &&
2553
250k
           "Wrong opcode.  Should be OpCompositeExtract.");
2554
250k
    analysis::DefUseManager* def_use_mgr = context->get_def_use_mgr();
2555
250k
    analysis::TypeManager* type_mgr = context->get_type_mgr();
2556
250k
    uint32_t cid = inst->GetSingleWordInOperand(kExtractCompositeIdInIdx);
2557
250k
    Instruction* cinst = def_use_mgr->GetDef(cid);
2558
2559
250k
    if (cinst->opcode() != spv::Op::OpVectorShuffle) {
2560
246k
      return false;
2561
246k
    }
2562
2563
    // Find the size of the first vector operand of the VectorShuffle
2564
4.64k
    Instruction* first_input =
2565
4.64k
        def_use_mgr->GetDef(cinst->GetSingleWordInOperand(0));
2566
4.64k
    analysis::Type* first_input_type =
2567
4.64k
        type_mgr->GetType(first_input->type_id());
2568
4.64k
    assert(first_input_type->AsVector() &&
2569
4.64k
           "Input to vector shuffle should be vectors.");
2570
4.64k
    uint32_t first_input_size = first_input_type->AsVector()->element_count();
2571
2572
    // Get index of the element the vector shuffle is placing in the position
2573
    // being extracted.
2574
4.64k
    uint32_t new_index =
2575
4.64k
        cinst->GetSingleWordInOperand(2 + inst->GetSingleWordInOperand(1));
2576
2577
    // Extracting an undefined value so fold this extract into an undef.
2578
4.64k
    const uint32_t undef_literal_value = 0xffffffff;
2579
4.64k
    if (new_index == undef_literal_value) {
2580
187
      inst->SetOpcode(spv::Op::OpUndef);
2581
187
      inst->SetInOperands({});
2582
187
      return true;
2583
187
    }
2584
2585
    // Get the id of the of the vector the elemtent comes from, and update the
2586
    // index if needed.
2587
4.46k
    uint32_t new_vector = 0;
2588
4.46k
    if (new_index < first_input_size) {
2589
2.76k
      new_vector = cinst->GetSingleWordInOperand(0);
2590
2.76k
    } else {
2591
1.69k
      new_vector = cinst->GetSingleWordInOperand(1);
2592
1.69k
      new_index -= first_input_size;
2593
1.69k
    }
2594
2595
    // Update the extract instruction.
2596
4.46k
    inst->SetInOperand(kExtractCompositeIdInIdx, {new_vector});
2597
4.46k
    inst->SetInOperand(1, {new_index});
2598
4.46k
    return true;
2599
4.64k
  };
2600
16.7k
}
2601
2602
// When an FMix with is feeding an Extract that extracts an element whose
2603
// corresponding |a| in the FMix is 0 or 1, we can extract from one of the
2604
// operands of the FMix.
2605
16.7k
FoldingRule FMixFeedingExtract() {
2606
16.7k
  return [](IRContext* context, Instruction* inst,
2607
246k
            const std::vector<const analysis::Constant*>&) {
2608
246k
    assert(inst->opcode() == spv::Op::OpCompositeExtract &&
2609
246k
           "Wrong opcode.  Should be OpCompositeExtract.");
2610
246k
    analysis::DefUseManager* def_use_mgr = context->get_def_use_mgr();
2611
246k
    analysis::ConstantManager* const_mgr = context->get_constant_mgr();
2612
2613
246k
    uint32_t composite_id =
2614
246k
        inst->GetSingleWordInOperand(kExtractCompositeIdInIdx);
2615
246k
    Instruction* composite_inst = def_use_mgr->GetDef(composite_id);
2616
2617
246k
    if (composite_inst->opcode() != spv::Op::OpExtInst) {
2618
225k
      return false;
2619
225k
    }
2620
2621
20.4k
    uint32_t inst_set_id =
2622
20.4k
        context->get_feature_mgr()->GetExtInstImportId_GLSLstd450();
2623
2624
20.4k
    if (composite_inst->GetSingleWordInOperand(kExtInstSetIdInIdx) !=
2625
20.4k
            inst_set_id ||
2626
20.4k
        composite_inst->GetSingleWordInOperand(kExtInstInstructionInIdx) !=
2627
20.4k
            GLSLstd450FMix) {
2628
15.3k
      return false;
2629
15.3k
    }
2630
2631
    // Get the |a| for the FMix instruction.
2632
5.03k
    uint32_t a_id = composite_inst->GetSingleWordInOperand(kFMixAIdInIdx);
2633
5.03k
    std::unique_ptr<Instruction> a(inst->Clone(context));
2634
5.03k
    a->SetInOperand(kExtractCompositeIdInIdx, {a_id});
2635
5.03k
    context->get_instruction_folder().FoldInstruction(a.get());
2636
2637
5.03k
    if (a->opcode() != spv::Op::OpCopyObject) {
2638
1.31k
      return false;
2639
1.31k
    }
2640
2641
3.71k
    const analysis::Constant* a_const =
2642
3.71k
        const_mgr->FindDeclaredConstant(a->GetSingleWordInOperand(0));
2643
2644
3.71k
    if (!a_const) {
2645
1.86k
      return false;
2646
1.86k
    }
2647
2648
1.84k
    bool use_x = false;
2649
2650
1.84k
    assert(a_const->type()->AsFloat());
2651
2652
1.84k
    const analysis::Type* type =
2653
1.84k
        context->get_type_mgr()->GetType(inst->type_id());
2654
1.84k
    uint32_t width = ElementWidth(type);
2655
1.84k
    if (width != 32 && width != 64) {
2656
      // We won't support folding half float values.
2657
0
      return false;
2658
0
    }
2659
2660
1.84k
    double element_value = a_const->GetValueAsDouble();
2661
1.84k
    if (element_value == 0.0) {
2662
67
      use_x = true;
2663
1.77k
    } else if (element_value == 1.0) {
2664
28
      use_x = false;
2665
1.75k
    } else {
2666
1.75k
      return false;
2667
1.75k
    }
2668
2669
    // Get the id of the of the vector the element comes from.
2670
95
    uint32_t new_vector = 0;
2671
95
    if (use_x) {
2672
67
      new_vector = composite_inst->GetSingleWordInOperand(kFMixXIdInIdx);
2673
67
    } else {
2674
28
      new_vector = composite_inst->GetSingleWordInOperand(kFMixYIdInIdx);
2675
28
    }
2676
2677
    // Update the extract instruction.
2678
95
    inst->SetInOperand(kExtractCompositeIdInIdx, {new_vector});
2679
95
    return true;
2680
1.84k
  };
2681
16.7k
}
2682
2683
// Returns the number of elements in the composite type |type|.  Returns 0 if
2684
// |type| is a scalar value. Return UINT32_MAX when the size is unknown at
2685
// compile time.
2686
94.6k
uint32_t GetNumberOfElements(const analysis::Type* type) {
2687
94.6k
  if (auto* vector_type = type->AsVector()) {
2688
86.2k
    return vector_type->element_count();
2689
86.2k
  }
2690
8.46k
  if (auto* matrix_type = type->AsMatrix()) {
2691
0
    return matrix_type->element_count();
2692
0
  }
2693
8.46k
  if (auto* struct_type = type->AsStruct()) {
2694
3.61k
    return static_cast<uint32_t>(struct_type->element_types().size());
2695
3.61k
  }
2696
4.84k
  if (auto* array_type = type->AsArray()) {
2697
4.84k
    if (array_type->length_info().words[0] ==
2698
4.84k
            analysis::Array::LengthInfo::kConstant &&
2699
4.84k
        array_type->length_info().words.size() == 2) {
2700
4.84k
      return array_type->length_info().words[1];
2701
4.84k
    }
2702
0
    return UINT32_MAX;
2703
4.84k
  }
2704
0
  return 0;
2705
4.84k
}
2706
2707
// Returns a map with the set of values that were inserted into an object by
2708
// the chain of OpCompositeInsertInstruction starting with |inst|.
2709
// The map will map the index to the value inserted at that index. An empty map
2710
// will be returned if the map could not be properly generated.
2711
94.6k
std::map<uint32_t, uint32_t> GetInsertedValues(Instruction* inst) {
2712
94.6k
  analysis::DefUseManager* def_use_mgr = inst->context()->get_def_use_mgr();
2713
94.6k
  std::map<uint32_t, uint32_t> values_inserted;
2714
94.6k
  Instruction* current_inst = inst;
2715
463k
  while (current_inst->opcode() == spv::Op::OpCompositeInsert) {
2716
369k
    if (current_inst->NumInOperands() > inst->NumInOperands()) {
2717
      // This is to catch the case
2718
      //   %2 = OpCompositeInsert %m2x2int %v2int_1_0 %m2x2int_undef 0
2719
      //   %3 = OpCompositeInsert %m2x2int %int_4 %2 0 0
2720
      //   %4 = OpCompositeInsert %m2x2int %v2int_2_3 %3 1
2721
      // In this case we cannot do a single construct to get the matrix.
2722
1.05k
      uint32_t partially_inserted_element_index =
2723
1.05k
          current_inst->GetSingleWordInOperand(inst->NumInOperands() - 1);
2724
1.05k
      if (values_inserted.count(partially_inserted_element_index) == 0)
2725
231
        return {};
2726
1.05k
    }
2727
368k
    if (HaveSameIndexesExceptForLast(inst, current_inst)) {
2728
366k
      values_inserted.insert(
2729
366k
          {current_inst->GetSingleWordInOperand(current_inst->NumInOperands() -
2730
366k
                                                1),
2731
366k
           current_inst->GetSingleWordInOperand(kInsertObjectIdInIdx)});
2732
366k
    }
2733
368k
    current_inst = def_use_mgr->GetDef(
2734
368k
        current_inst->GetSingleWordInOperand(kInsertCompositeIdInIdx));
2735
368k
  }
2736
94.4k
  return values_inserted;
2737
94.6k
}
2738
2739
// Returns true of there is an entry in |values_inserted| for every element of
2740
// |Type|.
2741
bool DoInsertedValuesCoverEntireObject(
2742
94.6k
    const analysis::Type* type, std::map<uint32_t, uint32_t>& values_inserted) {
2743
94.6k
  uint32_t container_size = GetNumberOfElements(type);
2744
94.6k
  if (container_size != values_inserted.size()) {
2745
83.6k
    return false;
2746
83.6k
  }
2747
2748
11.0k
  if (values_inserted.rbegin()->first >= container_size) {
2749
0
    return false;
2750
0
  }
2751
11.0k
  return true;
2752
11.0k
}
2753
2754
// Returns id of the type of the element that immediately contains the element
2755
// being inserted by the OpCompositeInsert instruction |inst|. Returns 0 if it
2756
// could not be found.
2757
94.6k
uint32_t GetContainerTypeId(Instruction* inst) {
2758
94.6k
  assert(inst->opcode() == spv::Op::OpCompositeInsert);
2759
94.6k
  analysis::DefUseManager* def_use_manager = inst->context()->get_def_use_mgr();
2760
94.6k
  uint32_t container_type_id = GetElementType(
2761
94.6k
      inst->type_id(), inst->begin() + 4, inst->end() - 1, def_use_manager);
2762
94.6k
  return container_type_id;
2763
94.6k
}
2764
2765
// Returns an OpCompositeConstruct instruction that build an object with
2766
// |type_id| out of the values in |values_inserted|.  Each value will be
2767
// placed at the index corresponding to the value.  The new instruction will
2768
// be placed before |insert_before|.
2769
Instruction* BuildCompositeConstruct(
2770
    uint32_t type_id, const std::map<uint32_t, uint32_t>& values_inserted,
2771
11.0k
    Instruction* insert_before) {
2772
11.0k
  InstructionBuilder ir_builder(
2773
11.0k
      insert_before->context(), insert_before,
2774
11.0k
      IRContext::kAnalysisDefUse | IRContext::kAnalysisInstrToBlockMapping);
2775
2776
11.0k
  std::vector<uint32_t> ids_in_order;
2777
22.5k
  for (auto it : values_inserted) {
2778
22.5k
    ids_in_order.push_back(it.second);
2779
22.5k
  }
2780
11.0k
  Instruction* construct =
2781
11.0k
      ir_builder.AddCompositeConstruct(type_id, ids_in_order);
2782
11.0k
  return construct;
2783
11.0k
}
2784
2785
// Replaces the OpCompositeInsert |inst| that inserts |construct| into the same
2786
// object as |inst| with final index removed.  If the resulting
2787
// OpCompositeInsert instruction would have no remaining indexes, the
2788
// instruction is replaced with an OpCopyObject instead.
2789
11.0k
void InsertConstructedObject(Instruction* inst, const Instruction* construct) {
2790
11.0k
  if (inst->NumInOperands() == 3) {
2791
10.9k
    inst->SetOpcode(spv::Op::OpCopyObject);
2792
10.9k
    inst->SetInOperands({{SPV_OPERAND_TYPE_ID, {construct->result_id()}}});
2793
10.9k
  } else {
2794
54
    inst->SetInOperand(kInsertObjectIdInIdx, {construct->result_id()});
2795
54
    inst->RemoveOperand(inst->NumOperands() - 1);
2796
54
  }
2797
11.0k
}
2798
2799
// Replaces a series of |OpCompositeInsert| instruction that cover the entire
2800
// object with an |OpCompositeConstruct|.
2801
bool CompositeInsertToCompositeConstruct(
2802
    IRContext* context, Instruction* inst,
2803
94.6k
    const std::vector<const analysis::Constant*>&) {
2804
94.6k
  assert(inst->opcode() == spv::Op::OpCompositeInsert &&
2805
94.6k
         "Wrong opcode.  Should be OpCompositeInsert.");
2806
94.6k
  if (inst->NumInOperands() < 3) return false;
2807
2808
94.6k
  std::map<uint32_t, uint32_t> values_inserted = GetInsertedValues(inst);
2809
94.6k
  uint32_t container_type_id = GetContainerTypeId(inst);
2810
94.6k
  if (container_type_id == 0) {
2811
0
    return false;
2812
0
  }
2813
2814
94.6k
  analysis::TypeManager* type_mgr = context->get_type_mgr();
2815
94.6k
  const analysis::Type* container_type = type_mgr->GetType(container_type_id);
2816
94.6k
  assert(container_type && "GetContainerTypeId returned a bad id.");
2817
94.6k
  if (!DoInsertedValuesCoverEntireObject(container_type, values_inserted)) {
2818
83.6k
    return false;
2819
83.6k
  }
2820
2821
11.0k
  Instruction* construct =
2822
11.0k
      BuildCompositeConstruct(container_type_id, values_inserted, inst);
2823
11.0k
  InsertConstructedObject(inst, construct);
2824
11.0k
  return true;
2825
94.6k
}
2826
2827
16.7k
FoldingRule RedundantPhi() {
2828
  // An OpPhi instruction where all values are the same or the result of the phi
2829
  // itself, can be replaced by the value itself.
2830
16.7k
  return [](IRContext*, Instruction* inst,
2831
350k
            const std::vector<const analysis::Constant*>&) {
2832
350k
    assert(inst->opcode() == spv::Op::OpPhi &&
2833
350k
           "Wrong opcode.  Should be OpPhi.");
2834
2835
350k
    uint32_t incoming_value = 0;
2836
2837
813k
    for (uint32_t i = 0; i < inst->NumInOperands(); i += 2) {
2838
721k
      uint32_t op_id = inst->GetSingleWordInOperand(i);
2839
721k
      if (op_id == inst->result_id()) {
2840
71.6k
        continue;
2841
71.6k
      }
2842
2843
650k
      if (incoming_value == 0) {
2844
350k
        incoming_value = op_id;
2845
350k
      } else if (op_id != incoming_value) {
2846
        // Found two possible value.  Can't simplify.
2847
258k
        return false;
2848
258k
      }
2849
650k
    }
2850
2851
91.7k
    if (incoming_value == 0) {
2852
      // Code looks invalid.  Don't do anything.
2853
0
      return false;
2854
0
    }
2855
2856
    // We have a single incoming value.  Simplify using that value.
2857
91.7k
    inst->SetOpcode(spv::Op::OpCopyObject);
2858
91.7k
    inst->SetInOperands({{SPV_OPERAND_TYPE_ID, {incoming_value}}});
2859
91.7k
    return true;
2860
91.7k
  };
2861
16.7k
}
2862
2863
16.7k
FoldingRule BitCastScalarOrVector() {
2864
16.7k
  return [](IRContext* context, Instruction* inst,
2865
16.7k
            const std::vector<const analysis::Constant*>& constants) {
2866
2.41k
    assert(inst->opcode() == spv::Op::OpBitcast && constants.size() == 1);
2867
2.41k
    if (constants[0] == nullptr) return false;
2868
2869
1.41k
    const analysis::Type* type =
2870
1.41k
        context->get_type_mgr()->GetType(inst->type_id());
2871
1.41k
    if (HasFloatingPoint(type) && !inst->IsFloatingPointFoldingAllowed())
2872
8
      return false;
2873
2874
1.40k
    analysis::ConstantManager* const_mgr = context->get_constant_mgr();
2875
1.40k
    std::vector<uint32_t> words =
2876
1.40k
        GetWordsFromNumericScalarOrVectorConstant(const_mgr, constants[0]);
2877
1.40k
    if (words.size() == 0) return false;
2878
2879
1.40k
    const analysis::Constant* bitcasted_constant =
2880
1.40k
        ConvertWordsToNumericScalarOrVectorConstant(const_mgr, words, type);
2881
1.40k
    if (!bitcasted_constant) return false;
2882
2883
1.40k
    auto new_feeder_id =
2884
1.40k
        const_mgr->GetDefiningInstruction(bitcasted_constant, inst->type_id())
2885
1.40k
            ->result_id();
2886
1.40k
    inst->SetOpcode(spv::Op::OpCopyObject);
2887
1.40k
    inst->SetInOperands({{SPV_OPERAND_TYPE_ID, {new_feeder_id}}});
2888
1.40k
    return true;
2889
1.40k
  };
2890
16.7k
}
2891
2892
// Remove indirect bitcasts which have no effect.
2893
//   uint32 x;  asuint32(x)            => x
2894
//   uint32 x;  asuint32(asint32(x))   => x
2895
//   float32 x; asuint32(asint32(x))   => asuint32(x)
2896
16.7k
FoldingRule RedundantBitcast() {
2897
16.7k
  return [](IRContext* context, Instruction* inst,
2898
16.7k
            const std::vector<const analysis::Constant*>&) {
2899
1.00k
    assert(inst->opcode() == spv::Op::OpBitcast);
2900
2901
1.00k
    analysis::DefUseManager* def_mgr = context->get_def_use_mgr();
2902
1.00k
    Instruction* child = def_mgr->GetDef(inst->GetSingleWordInOperand(0));
2903
2904
1.00k
    if (inst->type_id() == child->type_id()) {
2905
165
      inst->SetOpcode(spv::Op::OpCopyObject);
2906
165
      inst->SetInOperands({{SPV_OPERAND_TYPE_ID, {child->result_id()}}});
2907
165
      return true;
2908
165
    }
2909
2910
843
    if (child->opcode() != spv::Op::OpBitcast) {
2911
843
      return false;
2912
843
    }
2913
2914
0
    if (def_mgr->GetDef(child->GetSingleWordInOperand(0))->type_id() ==
2915
0
        inst->type_id()) {
2916
0
      inst->SetOpcode(spv::Op::OpCopyObject);
2917
0
    }
2918
0
    inst->SetInOperands(
2919
0
        {{SPV_OPERAND_TYPE_ID, {child->GetSingleWordInOperand(0)}}});
2920
2921
0
    return true;
2922
843
  };
2923
16.7k
}
2924
2925
16.7k
FoldingRule BitReverseScalarOrVector() {
2926
16.7k
  return [](IRContext* context, Instruction* inst,
2927
16.7k
            const std::vector<const analysis::Constant*>& constants) {
2928
750
    assert(inst->opcode() == spv::Op::OpBitReverse && constants.size() == 1);
2929
750
    if (constants[0] == nullptr) return false;
2930
2931
501
    const analysis::Type* type =
2932
501
        context->get_type_mgr()->GetType(inst->type_id());
2933
501
    assert(!HasFloatingPoint(type) &&
2934
501
           "BitReverse cannot be applied to floating point types.");
2935
501
    assert((type->AsInteger() || type->AsVector()) &&
2936
501
           "BitReverse can only be applied to integer scalars or vectors.");
2937
501
    assert((ElementWidth(type) == 32) &&
2938
501
           "BitReverse can only be applied to integer types of width 32");
2939
2940
501
    analysis::ConstantManager* const_mgr = context->get_constant_mgr();
2941
501
    std::vector<uint32_t> words =
2942
501
        GetWordsFromNumericScalarOrVectorConstant(const_mgr, constants[0]);
2943
501
    if (words.size() == 0) return false;
2944
2945
504
    for (uint32_t& word : words) {
2946
      // Reverse the bits in each word.
2947
504
      word = ((word & 0x55555555) << 1) | ((word >> 1) & 0x55555555);
2948
504
      word = ((word & 0x33333333) << 2) | ((word >> 2) & 0x33333333);
2949
504
      word = ((word & 0x0F0F0F0F) << 4) | ((word >> 4) & 0x0F0F0F0F);
2950
504
      word = ((word & 0x00FF00FF) << 8) | ((word >> 8) & 0x00FF00FF);
2951
504
      word = (word << 16) | (word >> 16);
2952
504
    }
2953
2954
501
    const analysis::Constant* bitreversed_constant =
2955
501
        ConvertWordsToNumericScalarOrVectorConstant(const_mgr, words, type);
2956
501
    if (!bitreversed_constant) return false;
2957
2958
501
    auto new_feeder_id =
2959
501
        const_mgr->GetDefiningInstruction(bitreversed_constant, inst->type_id())
2960
501
            ->result_id();
2961
501
    inst->SetOpcode(spv::Op::OpCopyObject);
2962
501
    inst->SetInOperands({{SPV_OPERAND_TYPE_ID, {new_feeder_id}}});
2963
501
    return true;
2964
501
  };
2965
16.7k
}
2966
2967
16.7k
FoldingRule RedundantSelect() {
2968
  // An OpSelect instruction where both values are the same or the condition is
2969
  // constant can be replaced by one of the values
2970
16.7k
  return [](IRContext*, Instruction* inst,
2971
17.1k
            const std::vector<const analysis::Constant*>& constants) {
2972
17.1k
    assert(inst->opcode() == spv::Op::OpSelect &&
2973
17.1k
           "Wrong opcode.  Should be OpSelect.");
2974
17.1k
    assert(inst->NumInOperands() == 3);
2975
17.1k
    assert(constants.size() == 3);
2976
2977
17.1k
    uint32_t true_id = inst->GetSingleWordInOperand(1);
2978
17.1k
    uint32_t false_id = inst->GetSingleWordInOperand(2);
2979
2980
17.1k
    if (true_id == false_id) {
2981
      // Both results are the same, condition doesn't matter
2982
117
      inst->SetOpcode(spv::Op::OpCopyObject);
2983
117
      inst->SetInOperands({{SPV_OPERAND_TYPE_ID, {true_id}}});
2984
117
      return true;
2985
17.0k
    } else if (constants[0]) {
2986
1.41k
      const analysis::Type* type = constants[0]->type();
2987
1.41k
      if (type->AsBool()) {
2988
        // Scalar constant value, select the corresponding value.
2989
959
        inst->SetOpcode(spv::Op::OpCopyObject);
2990
959
        if (constants[0]->AsNullConstant() ||
2991
959
            !constants[0]->AsBoolConstant()->value()) {
2992
650
          inst->SetInOperands({{SPV_OPERAND_TYPE_ID, {false_id}}});
2993
650
        } else {
2994
309
          inst->SetInOperands({{SPV_OPERAND_TYPE_ID, {true_id}}});
2995
309
        }
2996
959
        return true;
2997
959
      } else {
2998
454
        assert(type->AsVector());
2999
454
        if (constants[0]->AsNullConstant()) {
3000
          // All values come from false id.
3001
0
          inst->SetOpcode(spv::Op::OpCopyObject);
3002
0
          inst->SetInOperands({{SPV_OPERAND_TYPE_ID, {false_id}}});
3003
0
          return true;
3004
454
        } else {
3005
          // Convert to a vector shuffle.
3006
454
          std::vector<Operand> ops;
3007
454
          ops.push_back({SPV_OPERAND_TYPE_ID, {true_id}});
3008
454
          ops.push_back({SPV_OPERAND_TYPE_ID, {false_id}});
3009
454
          const analysis::VectorConstant* vector_const =
3010
454
              constants[0]->AsVectorConstant();
3011
454
          uint32_t size =
3012
454
              static_cast<uint32_t>(vector_const->GetComponents().size());
3013
1.36k
          for (uint32_t i = 0; i != size; ++i) {
3014
910
            const analysis::Constant* component =
3015
910
                vector_const->GetComponents()[i];
3016
910
            if (component->AsNullConstant() ||
3017
910
                !component->AsBoolConstant()->value()) {
3018
              // Selecting from the false vector which is the second input
3019
              // vector to the shuffle. Offset the index by |size|.
3020
27
              ops.push_back({SPV_OPERAND_TYPE_LITERAL_INTEGER, {i + size}});
3021
883
            } else {
3022
              // Selecting from true vector which is the first input vector to
3023
              // the shuffle.
3024
883
              ops.push_back({SPV_OPERAND_TYPE_LITERAL_INTEGER, {i}});
3025
883
            }
3026
910
          }
3027
3028
454
          inst->SetOpcode(spv::Op::OpVectorShuffle);
3029
454
          inst->SetInOperands(std::move(ops));
3030
454
          return true;
3031
454
        }
3032
454
      }
3033
1.41k
    }
3034
3035
15.6k
    return false;
3036
17.1k
  };
3037
16.7k
}
3038
3039
10.8k
std::optional<bool> GetBoolConstantKind(const analysis::Constant* c) {
3040
10.8k
  if (!c) {
3041
3.64k
    return {};
3042
3.64k
  }
3043
7.19k
  if (auto composite = c->AsCompositeConstant()) {
3044
0
    auto& components = composite->GetComponents();
3045
0
    if (components.empty()) {
3046
0
      return {};
3047
0
    }
3048
0
    auto first = GetBoolConstantKind(components[0]);
3049
0
    if (!first) {
3050
0
      return {};
3051
0
    }
3052
0
    if (std::all_of(std::begin(components) + 1, std::end(components),
3053
0
                    [first](const analysis::Constant* c2) {
3054
0
                      return GetBoolConstantKind(c2) == first;
3055
0
                    })) {
3056
0
      return first;
3057
0
    }
3058
0
    return {};
3059
7.19k
  } else if (c->AsNullConstant()) {
3060
5
    return false;
3061
7.19k
  } else if (c->AsBoolConstant()) {
3062
7.19k
    return c->AsBoolConstant()->value();
3063
7.19k
  }
3064
0
  return {};
3065
7.19k
}
3066
3067
// Fold OpSelect instructions which have constant booleans as their result.
3068
//   x ? true  : false =  x
3069
//   x ? false : true  = !x
3070
16.7k
FoldingRule FoldConstantBooleanSelect() {
3071
16.7k
  return [](IRContext* context, Instruction* inst,
3072
16.7k
            const std::vector<const analysis::Constant*>& constants) {
3073
15.6k
    assert(inst->opcode() == spv::Op::OpSelect);
3074
15.6k
    assert(inst->NumInOperands() == 3);
3075
15.6k
    assert(constants.size() == 3);
3076
3077
15.6k
    if (!constants[1] || !constants[2]) {
3078
10.5k
      return false;
3079
10.5k
    }
3080
3081
5.07k
    analysis::DefUseManager* def_mgr = context->get_def_use_mgr();
3082
5.07k
    if (inst->type_id() !=
3083
5.07k
        def_mgr->GetDef(inst->GetSingleWordInOperand(0))->type_id()) {
3084
5.00k
      return false;
3085
5.00k
    }
3086
3087
70
    std::optional<bool> uniform_true = GetBoolConstantKind(constants[1]);
3088
70
    std::optional<bool> uniform_false = GetBoolConstantKind(constants[2]);
3089
3090
70
    if (!uniform_true || !uniform_false) {
3091
0
      return false;
3092
0
    }
3093
3094
70
    if (uniform_true.value() && !uniform_false.value()) {
3095
28
      inst->SetOpcode(spv::Op::OpCopyObject);
3096
28
      inst->SetInOperands(
3097
28
          {{SPV_OPERAND_TYPE_ID, {inst->GetSingleWordInOperand(0)}}});
3098
28
      return true;
3099
42
    } else if (!uniform_true.value() && uniform_false.value()) {
3100
42
      inst->SetOpcode(spv::Op::OpLogicalNot);
3101
42
      inst->SetInOperands(
3102
42
          {{SPV_OPERAND_TYPE_ID, {inst->GetSingleWordInOperand(0)}}});
3103
42
      return true;
3104
42
    }
3105
0
    return false;
3106
70
  };
3107
16.7k
}
3108
3109
// Fold OpLogicalAnd instructions which have a constant true on one side.
3110
//   x && true = x
3111
//   true && x = x
3112
16.7k
FoldingRule RedundantLogicalAnd() {
3113
16.7k
  return [](IRContext* context, Instruction* inst,
3114
16.7k
            const std::vector<const analysis::Constant*>& constants) {
3115
8.55k
    assert(inst->opcode() == spv::Op::OpLogicalAnd);
3116
3117
8.55k
    if (GetBoolConstantKind(ConstInput(constants)) ==
3118
8.55k
        std::optional<bool>(true)) {
3119
5.28k
      Instruction* other_inst = NonConstInput(context, constants[0], inst);
3120
5.28k
      inst->SetOpcode(spv::Op::OpCopyObject);
3121
5.28k
      inst->SetInOperands({{SPV_OPERAND_TYPE_ID, {other_inst->result_id()}}});
3122
5.28k
      return true;
3123
5.28k
    }
3124
3.27k
    return false;
3125
8.55k
  };
3126
16.7k
}
3127
3128
// Fold OpLogicalOr instructions which have a constant false on one side.
3129
//   x || false = x
3130
//   false || x = x
3131
16.7k
FoldingRule RedundantLogicalOr() {
3132
16.7k
  return [](IRContext* context, Instruction* inst,
3133
16.7k
            const std::vector<const analysis::Constant*>& constants) {
3134
490
    assert(inst->opcode() == spv::Op::OpLogicalOr);
3135
3136
490
    if (GetBoolConstantKind(ConstInput(constants)) ==
3137
490
        std::optional<bool>(false)) {
3138
115
      Instruction* other_inst = NonConstInput(context, constants[0], inst);
3139
115
      inst->SetOpcode(spv::Op::OpCopyObject);
3140
115
      inst->SetInOperands({{SPV_OPERAND_TYPE_ID, {other_inst->result_id()}}});
3141
115
      return true;
3142
115
    }
3143
375
    return false;
3144
490
  };
3145
16.7k
}
3146
3147
// Fold concurrent OpLogicalNot instructions:
3148
//   !!x = x
3149
16.7k
FoldingRule RedundantLogicalNot() {
3150
16.7k
  return [](IRContext* context, Instruction* inst,
3151
16.7k
            const std::vector<const analysis::Constant*>&) {
3152
15.2k
    assert(inst->opcode() == spv::Op::OpLogicalNot);
3153
15.2k
    Instruction* child =
3154
15.2k
        context->get_def_use_mgr()->GetDef(inst->GetSingleWordInOperand(0));
3155
15.2k
    if (child->opcode() == spv::Op::OpLogicalNot) {
3156
51
      inst->SetOpcode(spv::Op::OpCopyObject);
3157
51
      inst->SetInOperands(
3158
51
          {{SPV_OPERAND_TYPE_ID, {child->GetSingleWordInOperand(0)}}});
3159
51
      return true;
3160
51
    }
3161
15.2k
    return false;
3162
15.2k
  };
3163
16.7k
}
3164
3165
// Cases handled:
3166
//  ((a ? C0 : C1) == C2)  =  ((a ? (C0 == C2) : (C1 == C2))
3167
//  ((a ? C0 : C1) != C2)  =  ((a ? (C0 != C2) : (C1 != C2))
3168
//  ((a ? C0 : C1) <  C2)  =  ((a ? (C0 <  C2) : (C1 <  C2))
3169
//  ((a ? C0 : C1) <= C2)  =  ((a ? (C0 <= C2) : (C1 <= C2))
3170
//  ((a ? C0 : C1) >  C2)  =  ((a ? (C0 >  C2) : (C1 >  C2))
3171
//  ((a ? C0 : C1) >= C2)  =  ((a ? (C0 >= C2) : (C1 >= C2))
3172
//  ((a ? C0 : C1) || C2)  =  ((a ? (C0 || C2) : (C1 || C2))
3173
//  ((a ? C0 : C1) && C2)  =  ((a ? (C0 && C2) : (C1 && C2))
3174
//  ((a ? C0 : C1) +  C2)  =  ((a ? (C0 +  C2) : (C1 +  C2))
3175
//  ((a ? C0 : C1) -  C2)  =  ((a ? (C0 -  C2) : (C1 -  C2))
3176
//  ((a ? C0 : C1) *  C2)  =  ((a ? (C0 *  C2) : (C1 *  C2))
3177
//  ((a ? C0 : C1) /  C2)  =  ((a ? (C0 /  C2) : (C1 /  C2))
3178
//  ((a ? C0 : C1) >> C2)  =  ((a ? (C0 >> C2) : (C1 >> C2))
3179
//  ((a ? C0 : C1) << C2)  =  ((a ? (C0 << C2) : (C1 << C2))
3180
//  ((a ? C0 : C1) ^  C2)  =  ((a ? (C0 ^  C2) : (C1 ^  C2))
3181
//  ((a ? C0 : C1) |  C2)  =  ((a ? (C0 |  C2) : (C1 |  C2))
3182
//  ((a ? C0 : C1) &  C2)  =  ((a ? (C0 &  C2) : (C1 &  C2))
3183
static const constexpr spv::Op MergeBinaryOpSelectOps[] = {
3184
    spv::Op::OpLogicalEqual,
3185
    spv::Op::OpLogicalNotEqual,
3186
    spv::Op::OpLogicalAnd,
3187
    spv::Op::OpLogicalOr,
3188
    spv::Op::OpIEqual,
3189
    spv::Op::OpINotEqual,
3190
    spv::Op::OpUGreaterThan,
3191
    spv::Op::OpSGreaterThan,
3192
    spv::Op::OpUGreaterThanEqual,
3193
    spv::Op::OpSGreaterThanEqual,
3194
    spv::Op::OpULessThan,
3195
    spv::Op::OpSLessThan,
3196
    spv::Op::OpULessThanEqual,
3197
    spv::Op::OpSLessThanEqual,
3198
    spv::Op::OpFOrdEqual,
3199
    spv::Op::OpFUnordEqual,
3200
    spv::Op::OpFOrdNotEqual,
3201
    spv::Op::OpFUnordNotEqual,
3202
    spv::Op::OpFOrdLessThan,
3203
    spv::Op::OpFUnordLessThan,
3204
    spv::Op::OpFOrdGreaterThan,
3205
    spv::Op::OpFUnordGreaterThan,
3206
    spv::Op::OpFOrdLessThanEqual,
3207
    spv::Op::OpFUnordLessThanEqual,
3208
    spv::Op::OpFOrdGreaterThanEqual,
3209
    spv::Op::OpFUnordGreaterThanEqual,
3210
    spv::Op::OpIAdd,
3211
    spv::Op::OpFAdd,
3212
    spv::Op::OpISub,
3213
    spv::Op::OpFSub,
3214
    spv::Op::OpIMul,
3215
    spv::Op::OpFMul,
3216
    spv::Op::OpUDiv,
3217
    spv::Op::OpSDiv,
3218
    spv::Op::OpFDiv,
3219
    spv::Op::OpVectorTimesScalar,
3220
    spv::Op::OpShiftRightLogical,
3221
    spv::Op::OpShiftRightArithmetic,
3222
    spv::Op::OpShiftLeftLogical,
3223
    spv::Op::OpBitwiseXor,
3224
    spv::Op::OpBitwiseOr,
3225
    spv::Op::OpBitwiseAnd};
3226
3227
704k
FoldingRule MergeBinaryOpSelect(spv::Op opcode) {
3228
704k
  assert(std::find(std::begin(MergeBinaryOpSelectOps),
3229
704k
                   std::end(MergeBinaryOpSelectOps),
3230
704k
                   opcode) != std::end(MergeBinaryOpSelectOps) &&
3231
704k
         "Wrong opcode.");
3232
3233
704k
  return [opcode](IRContext* context, Instruction* inst,
3234
1.73M
                  const std::vector<const analysis::Constant*>& constants) {
3235
1.73M
    const analysis::Constant* const_input = ConstInput(constants);
3236
1.73M
    if (!const_input) {
3237
887k
      return false;
3238
887k
    }
3239
848k
    Instruction* non_const = NonConstInput(context, constants[0], inst);
3240
848k
    if (non_const->opcode() != spv::Op::OpSelect) {
3241
848k
      return false;
3242
848k
    }
3243
901
    std::vector<const analysis::Constant*> select_constants =
3244
901
        context->get_constant_mgr()->GetOperandConstants(non_const);
3245
901
    if (!select_constants[1] || !select_constants[2]) {
3246
223
      return false;
3247
223
    }
3248
3249
678
    InstructionBuilder ir_builder(
3250
678
        context, inst,
3251
678
        IRContext::kAnalysisDefUse | IRContext::kAnalysisInstrToBlockMapping);
3252
3253
678
    Instruction *lhs, *rhs;
3254
678
    if (constants[0]) {
3255
354
      lhs = ir_builder.AddBinaryOp(inst->type_id(), opcode,
3256
354
                                   inst->GetSingleWordInOperand(0),
3257
354
                                   non_const->GetSingleWordInOperand(1));
3258
354
      rhs = ir_builder.AddBinaryOp(inst->type_id(), opcode,
3259
354
                                   inst->GetSingleWordInOperand(0),
3260
354
                                   non_const->GetSingleWordInOperand(2));
3261
354
    } else {
3262
324
      lhs = ir_builder.AddBinaryOp(inst->type_id(), opcode,
3263
324
                                   non_const->GetSingleWordInOperand(1),
3264
324
                                   inst->GetSingleWordInOperand(1));
3265
324
      rhs = ir_builder.AddBinaryOp(inst->type_id(), opcode,
3266
324
                                   non_const->GetSingleWordInOperand(2),
3267
324
                                   inst->GetSingleWordInOperand(1));
3268
324
    }
3269
3270
678
    if (!lhs || !rhs) {
3271
0
      return false;
3272
0
    }
3273
3274
678
    if (context->get_instruction_folder().FoldInstruction(lhs)) {
3275
678
      context->AnalyzeDefUse(lhs);
3276
1.35k
      while (lhs->opcode() == spv::Op::OpCopyObject) {
3277
678
        lhs =
3278
678
            context->get_def_use_mgr()->GetDef(lhs->GetSingleWordInOperand(0));
3279
678
      }
3280
678
    }
3281
678
    if (context->get_instruction_folder().FoldInstruction(rhs)) {
3282
678
      context->AnalyzeDefUse(rhs);
3283
1.35k
      while (rhs->opcode() == spv::Op::OpCopyObject) {
3284
678
        rhs =
3285
678
            context->get_def_use_mgr()->GetDef(rhs->GetSingleWordInOperand(0));
3286
678
      }
3287
678
    }
3288
678
    inst->SetOpcode(spv::Op::OpSelect);
3289
678
    inst->SetInOperands(
3290
678
        {{SPV_OPERAND_TYPE_ID, {non_const->GetSingleWordInOperand(0)}},
3291
678
         {SPV_OPERAND_TYPE_ID, {lhs->result_id()}},
3292
678
         {SPV_OPERAND_TYPE_ID, {rhs->result_id()}}});
3293
678
    return true;
3294
678
  };
3295
704k
}
3296
3297
// Fold OpLogicalNot instructions that follow a comparison,
3298
// if the comparison is only used by that instruction.
3299
//
3300
// !(a == b) = (a != b)
3301
// !(a != b) = (a == b)
3302
// !(a < b)  = (a >= b)
3303
// !(a >= b) = (a < b)
3304
// !(a > b)  = (a <= b)
3305
// !(a <= b) = (a > b)
3306
16.7k
FoldingRule FoldLogicalNotComparison() {
3307
16.7k
  return [](IRContext* context, Instruction* inst,
3308
16.7k
            const std::vector<const analysis::Constant*>&) {
3309
15.2k
    assert(inst->opcode() == spv::Op::OpLogicalNot);
3310
15.2k
    analysis::DefUseManager* def_mgr = context->get_def_use_mgr();
3311
15.2k
    Instruction* child =
3312
15.2k
        context->get_def_use_mgr()->GetDef(inst->GetSingleWordInOperand(0));
3313
3314
15.2k
    if (def_mgr->NumUses(child) > 1) {
3315
8.07k
      return false;
3316
8.07k
    }
3317
3318
7.16k
    spv::Op new_opcode = spv::Op::OpNop;
3319
7.16k
    switch (child->opcode()) {
3320
      // (a == b) <=> (a != b)
3321
2
      case spv::Op::OpIEqual:
3322
2
        new_opcode = spv::Op::OpINotEqual;
3323
2
        break;
3324
29
      case spv::Op::OpINotEqual:
3325
29
        new_opcode = spv::Op::OpIEqual;
3326
29
        break;
3327
138
      case spv::Op::OpFOrdEqual:
3328
138
        new_opcode = spv::Op::OpFUnordNotEqual;
3329
138
        break;
3330
61
      case spv::Op::OpFOrdNotEqual:
3331
61
        new_opcode = spv::Op::OpFUnordEqual;
3332
61
        break;
3333
112
      case spv::Op::OpFUnordEqual:
3334
112
        new_opcode = spv::Op::OpFOrdNotEqual;
3335
112
        break;
3336
47
      case spv::Op::OpFUnordNotEqual:
3337
47
        new_opcode = spv::Op::OpFOrdEqual;
3338
47
        break;
3339
6
      case spv::Op::OpLogicalEqual:
3340
6
        new_opcode = spv::Op::OpLogicalNotEqual;
3341
6
        break;
3342
9
      case spv::Op::OpLogicalNotEqual:
3343
9
        new_opcode = spv::Op::OpLogicalEqual;
3344
9
        break;
3345
3346
      // (a > b) <=> (a <= b)
3347
5
      case spv::Op::OpUGreaterThan:
3348
5
        new_opcode = spv::Op::OpULessThanEqual;
3349
5
        break;
3350
4
      case spv::Op::OpULessThanEqual:
3351
4
        new_opcode = spv::Op::OpUGreaterThan;
3352
4
        break;
3353
13
      case spv::Op::OpSGreaterThan:
3354
13
        new_opcode = spv::Op::OpSLessThanEqual;
3355
13
        break;
3356
2
      case spv::Op::OpSLessThanEqual:
3357
2
        new_opcode = spv::Op::OpSGreaterThan;
3358
2
        break;
3359
2.48k
      case spv::Op::OpFOrdGreaterThan:
3360
2.48k
        new_opcode = spv::Op::OpFUnordLessThanEqual;
3361
2.48k
        break;
3362
123
      case spv::Op::OpFOrdLessThanEqual:
3363
123
        new_opcode = spv::Op::OpFUnordGreaterThan;
3364
123
        break;
3365
114
      case spv::Op::OpFUnordGreaterThan:
3366
114
        new_opcode = spv::Op::OpFOrdLessThanEqual;
3367
114
        break;
3368
37
      case spv::Op::OpFUnordLessThanEqual:
3369
37
        new_opcode = spv::Op::OpFOrdGreaterThan;
3370
37
        break;
3371
3372
      // (a < b) <=> (a >= b)
3373
1
      case spv::Op::OpULessThan:
3374
1
        new_opcode = spv::Op::OpUGreaterThanEqual;
3375
1
        break;
3376
1
      case spv::Op::OpUGreaterThanEqual:
3377
1
        new_opcode = spv::Op::OpULessThan;
3378
1
        break;
3379
2
      case spv::Op::OpSLessThan:
3380
2
        new_opcode = spv::Op::OpSGreaterThanEqual;
3381
2
        break;
3382
2
      case spv::Op::OpSGreaterThanEqual:
3383
2
        new_opcode = spv::Op::OpSLessThan;
3384
2
        break;
3385
2.81k
      case spv::Op::OpFOrdLessThan:
3386
2.81k
        new_opcode = spv::Op::OpFUnordGreaterThanEqual;
3387
2.81k
        break;
3388
67
      case spv::Op::OpFOrdGreaterThanEqual:
3389
67
        new_opcode = spv::Op::OpFUnordLessThan;
3390
67
        break;
3391
30
      case spv::Op::OpFUnordLessThan:
3392
30
        new_opcode = spv::Op::OpFOrdGreaterThanEqual;
3393
30
        break;
3394
39
      case spv::Op::OpFUnordGreaterThanEqual:
3395
39
        new_opcode = spv::Op::OpFOrdLessThan;
3396
39
        break;
3397
3398
1.01k
      default:
3399
1.01k
        break;
3400
7.16k
    }
3401
3402
7.16k
    if (new_opcode == spv::Op::OpNop) {
3403
1.01k
      return false;
3404
1.01k
    }
3405
3406
6.14k
    inst->SetOpcode(new_opcode);
3407
6.14k
    inst->SetInOperands(
3408
6.14k
        {{SPV_OPERAND_TYPE_ID, {child->GetSingleWordInOperand(0)}},
3409
6.14k
         {SPV_OPERAND_TYPE_ID, {child->GetSingleWordInOperand(1)}}});
3410
3411
6.14k
    return true;
3412
7.16k
  };
3413
16.7k
}
3414
3415
// (a == true)  =  a
3416
// (a == false) = !a
3417
// (a != true)  = !a
3418
// (a != false) =  a
3419
33.5k
FoldingRule RedundantLogicalEqual() {
3420
33.5k
  return [](IRContext* context, Instruction* inst,
3421
33.5k
            const std::vector<const analysis::Constant*>& constants) {
3422
1.96k
    assert(inst->opcode() == spv::Op::OpLogicalEqual ||
3423
1.96k
           inst->opcode() == spv::Op::OpLogicalNotEqual);
3424
3425
1.96k
    const analysis::Constant* const_input = ConstInput(constants);
3426
1.96k
    if (!const_input) {
3427
310
      return false;
3428
310
    }
3429
3430
1.65k
    analysis::DefUseManager* def_mgr = context->get_def_use_mgr();
3431
1.65k
    if (inst->type_id() !=
3432
1.65k
        def_mgr->GetDef(inst->GetSingleWordInOperand(0))->type_id()) {
3433
0
      return false;
3434
0
    }
3435
3436
1.65k
    std::optional<bool> uniform_const = GetBoolConstantKind(const_input);
3437
1.65k
    if (!uniform_const) {
3438
0
      return false;
3439
0
    }
3440
3441
1.65k
    bool direct_copy = inst->opcode() == spv::Op::OpLogicalEqual
3442
1.65k
                           ? uniform_const.value()
3443
1.65k
                           : !uniform_const.value();
3444
3445
1.65k
    inst->SetOpcode(direct_copy ? spv::Op::OpCopyObject
3446
1.65k
                                : spv::Op::OpLogicalNot);
3447
1.65k
    inst->SetInOperands(
3448
1.65k
        {{SPV_OPERAND_TYPE_ID,
3449
1.65k
          {NonConstInput(context, constants[0], inst)->result_id()}}});
3450
1.65k
    return true;
3451
1.65k
  };
3452
33.5k
}
3453
3454
enum class FloatConstantKind { Unknown, Zero, One };
3455
3456
2.47M
FloatConstantKind getFloatConstantKind(const analysis::Constant* constant) {
3457
2.47M
  if (constant == nullptr) {
3458
1.42M
    return FloatConstantKind::Unknown;
3459
1.42M
  }
3460
3461
2.47M
  assert(HasFloatingPoint(constant->type()) && "Unexpected constant type");
3462
3463
1.05M
  if (constant->AsNullConstant()) {
3464
4.72k
    return FloatConstantKind::Zero;
3465
1.04M
  } else if (const analysis::VectorConstant* vc =
3466
1.04M
                 constant->AsVectorConstant()) {
3467
272k
    const std::vector<const analysis::Constant*>& components =
3468
272k
        vc->GetComponents();
3469
272k
    assert(!components.empty());
3470
3471
272k
    FloatConstantKind kind = getFloatConstantKind(components[0]);
3472
3473
518k
    for (size_t i = 1; i < components.size(); ++i) {
3474
309k
      if (getFloatConstantKind(components[i]) != kind) {
3475
63.1k
        return FloatConstantKind::Unknown;
3476
63.1k
      }
3477
309k
    }
3478
3479
209k
    return kind;
3480
776k
  } else if (const analysis::FloatConstant* fc = constant->AsFloatConstant()) {
3481
776k
    if (fc->IsZero()) return FloatConstantKind::Zero;
3482
3483
706k
    uint32_t width = fc->type()->AsFloat()->width();
3484
706k
    if (width != 32 && width != 64) return FloatConstantKind::Unknown;
3485
3486
706k
    double value = (width == 64) ? fc->GetDoubleValue() : fc->GetFloatValue();
3487
3488
706k
    if (value == 0.0) {
3489
34.1k
      return FloatConstantKind::Zero;
3490
672k
    } else if (value == 1.0) {
3491
31.2k
      return FloatConstantKind::One;
3492
641k
    } else {
3493
641k
      return FloatConstantKind::Unknown;
3494
641k
    }
3495
706k
  } else {
3496
0
    return FloatConstantKind::Unknown;
3497
0
  }
3498
1.05M
}
3499
3500
16.7k
FoldingRule RedundantFAdd() {
3501
16.7k
  return [](IRContext*, Instruction* inst,
3502
458k
            const std::vector<const analysis::Constant*>& constants) {
3503
458k
    assert(inst->opcode() == spv::Op::OpFAdd &&
3504
458k
           "Wrong opcode.  Should be OpFAdd.");
3505
458k
    assert(constants.size() == 2);
3506
3507
458k
    if (!inst->IsFloatingPointFoldingAllowed()) {
3508
559
      return false;
3509
559
    }
3510
3511
458k
    FloatConstantKind kind0 = getFloatConstantKind(constants[0]);
3512
458k
    FloatConstantKind kind1 = getFloatConstantKind(constants[1]);
3513
3514
458k
    if (kind0 == FloatConstantKind::Zero || kind1 == FloatConstantKind::Zero) {
3515
13.1k
      inst->SetOpcode(spv::Op::OpCopyObject);
3516
13.1k
      inst->SetInOperands({{SPV_OPERAND_TYPE_ID,
3517
13.1k
                            {inst->GetSingleWordInOperand(
3518
13.1k
                                kind0 == FloatConstantKind::Zero ? 1 : 0)}}});
3519
13.1k
      return true;
3520
13.1k
    }
3521
3522
445k
    return false;
3523
458k
  };
3524
16.7k
}
3525
3526
16.7k
FoldingRule RedundantFSub() {
3527
16.7k
  return [](IRContext*, Instruction* inst,
3528
95.7k
            const std::vector<const analysis::Constant*>& constants) {
3529
95.7k
    assert(inst->opcode() == spv::Op::OpFSub &&
3530
95.7k
           "Wrong opcode.  Should be OpFSub.");
3531
95.7k
    assert(constants.size() == 2);
3532
3533
95.7k
    if (!inst->IsFloatingPointFoldingAllowed()) {
3534
3.90k
      return false;
3535
3.90k
    }
3536
3537
91.8k
    FloatConstantKind kind0 = getFloatConstantKind(constants[0]);
3538
91.8k
    FloatConstantKind kind1 = getFloatConstantKind(constants[1]);
3539
3540
91.8k
    if (kind0 == FloatConstantKind::Zero) {
3541
1.86k
      inst->SetOpcode(spv::Op::OpFNegate);
3542
1.86k
      inst->SetInOperands(
3543
1.86k
          {{SPV_OPERAND_TYPE_ID, {inst->GetSingleWordInOperand(1)}}});
3544
1.86k
      return true;
3545
1.86k
    }
3546
3547
89.9k
    if (kind1 == FloatConstantKind::Zero) {
3548
2.26k
      inst->SetOpcode(spv::Op::OpCopyObject);
3549
2.26k
      inst->SetInOperands(
3550
2.26k
          {{SPV_OPERAND_TYPE_ID, {inst->GetSingleWordInOperand(0)}}});
3551
2.26k
      return true;
3552
2.26k
    }
3553
3554
87.6k
    return false;
3555
89.9k
  };
3556
16.7k
}
3557
3558
16.7k
FoldingRule RedundantFMul() {
3559
16.7k
  return [](IRContext*, Instruction* inst,
3560
207k
            const std::vector<const analysis::Constant*>& constants) {
3561
207k
    assert(inst->opcode() == spv::Op::OpFMul &&
3562
207k
           "Wrong opcode.  Should be OpFMul.");
3563
207k
    assert(constants.size() == 2);
3564
3565
207k
    if (!inst->IsFloatingPointFoldingAllowed()) {
3566
19
      return false;
3567
19
    }
3568
3569
207k
    FloatConstantKind kind0 = getFloatConstantKind(constants[0]);
3570
207k
    FloatConstantKind kind1 = getFloatConstantKind(constants[1]);
3571
3572
207k
    if (kind0 == FloatConstantKind::Zero || kind1 == FloatConstantKind::Zero) {
3573
5.56k
      inst->SetOpcode(spv::Op::OpCopyObject);
3574
5.56k
      inst->SetInOperands({{SPV_OPERAND_TYPE_ID,
3575
5.56k
                            {inst->GetSingleWordInOperand(
3576
5.56k
                                kind0 == FloatConstantKind::Zero ? 0 : 1)}}});
3577
5.56k
      return true;
3578
5.56k
    }
3579
3580
202k
    if (kind0 == FloatConstantKind::One || kind1 == FloatConstantKind::One) {
3581
1.66k
      inst->SetOpcode(spv::Op::OpCopyObject);
3582
1.66k
      inst->SetInOperands({{SPV_OPERAND_TYPE_ID,
3583
1.66k
                            {inst->GetSingleWordInOperand(
3584
1.66k
                                kind0 == FloatConstantKind::One ? 1 : 0)}}});
3585
1.66k
      return true;
3586
1.66k
    }
3587
3588
200k
    return false;
3589
202k
  };
3590
16.7k
}
3591
3592
16.7k
FoldingRule RedundantFDiv() {
3593
16.7k
  return [](IRContext*, Instruction* inst,
3594
147k
            const std::vector<const analysis::Constant*>& constants) {
3595
147k
    assert(inst->opcode() == spv::Op::OpFDiv &&
3596
147k
           "Wrong opcode.  Should be OpFDiv.");
3597
147k
    assert(constants.size() == 2);
3598
3599
147k
    if (!inst->IsFloatingPointFoldingAllowed()) {
3600
30
      return false;
3601
30
    }
3602
3603
147k
    FloatConstantKind kind0 = getFloatConstantKind(constants[0]);
3604
147k
    FloatConstantKind kind1 = getFloatConstantKind(constants[1]);
3605
3606
147k
    if (kind0 == FloatConstantKind::Zero || kind1 == FloatConstantKind::One) {
3607
1.24k
      inst->SetOpcode(spv::Op::OpCopyObject);
3608
1.24k
      inst->SetInOperands(
3609
1.24k
          {{SPV_OPERAND_TYPE_ID, {inst->GetSingleWordInOperand(0)}}});
3610
1.24k
      return true;
3611
1.24k
    }
3612
3613
145k
    return false;
3614
147k
  };
3615
16.7k
}
3616
3617
16.7k
FoldingRule RedundantFMod() {
3618
16.7k
  return [](IRContext*, Instruction* inst,
3619
71.2k
            const std::vector<const analysis::Constant*>& constants) {
3620
71.2k
    assert(inst->opcode() == spv::Op::OpFMod &&
3621
71.2k
           "Wrong opcode.  Should be OpFMod.");
3622
71.2k
    assert(constants.size() == 2);
3623
3624
71.2k
    if (!inst->IsFloatingPointFoldingAllowed()) {
3625
6
      return false;
3626
6
    }
3627
3628
71.2k
    FloatConstantKind kind0 = getFloatConstantKind(constants[0]);
3629
3630
71.2k
    if (kind0 == FloatConstantKind::Zero) {
3631
5.35k
      inst->SetOpcode(spv::Op::OpCopyObject);
3632
5.35k
      inst->SetInOperands(
3633
5.35k
          {{SPV_OPERAND_TYPE_ID, {inst->GetSingleWordInOperand(0)}}});
3634
5.35k
      return true;
3635
5.35k
    }
3636
3637
65.8k
    return false;
3638
71.2k
  };
3639
16.7k
}
3640
3641
10.0k
FoldingRule RedundantFMix() {
3642
10.0k
  return [](IRContext* context, Instruction* inst,
3643
10.0k
            const std::vector<const analysis::Constant*>& constants) {
3644
9.40k
    assert(inst->opcode() == spv::Op::OpExtInst &&
3645
9.40k
           "Wrong opcode.  Should be OpExtInst.");
3646
3647
9.40k
    if (!inst->IsFloatingPointFoldingAllowed()) {
3648
0
      return false;
3649
0
    }
3650
3651
9.40k
    uint32_t instSetId =
3652
9.40k
        context->get_feature_mgr()->GetExtInstImportId_GLSLstd450();
3653
3654
9.40k
    if (inst->GetSingleWordInOperand(kExtInstSetIdInIdx) == instSetId &&
3655
9.40k
        inst->GetSingleWordInOperand(kExtInstInstructionInIdx) ==
3656
9.40k
            GLSLstd450FMix) {
3657
9.40k
      assert(constants.size() == 5);
3658
3659
9.40k
      FloatConstantKind kind4 = getFloatConstantKind(constants[4]);
3660
3661
9.40k
      if (kind4 == FloatConstantKind::Zero || kind4 == FloatConstantKind::One) {
3662
121
        inst->SetOpcode(spv::Op::OpCopyObject);
3663
121
        inst->SetInOperands(
3664
121
            {{SPV_OPERAND_TYPE_ID,
3665
121
              {inst->GetSingleWordInOperand(kind4 == FloatConstantKind::Zero
3666
121
                                                ? kFMixXIdInIdx
3667
121
                                                : kFMixYIdInIdx)}}});
3668
121
        return true;
3669
121
      }
3670
9.40k
    }
3671
3672
9.28k
    return false;
3673
9.40k
  };
3674
10.0k
}
3675
3676
// Returns a folding rule that folds the instruction to operand |foldToArg|
3677
// (0 or 1) if operand |arg| (0 or 1) is a zero constant.
3678
285k
FoldingRule RedundantBinaryOpWithZeroOperand(uint32_t arg, uint32_t foldToArg) {
3679
285k
  return [arg, foldToArg](
3680
285k
             IRContext* context, Instruction* inst,
3681
333k
             const std::vector<const analysis::Constant*>& constants) {
3682
333k
    assert(constants.size() == 2);
3683
3684
333k
    if (constants[arg] && constants[arg]->IsZero()) {
3685
7.79k
      auto operand = inst->GetSingleWordInOperand(foldToArg);
3686
7.79k
      auto operand_type = constants[arg]->type();
3687
3688
7.79k
      const analysis::Type* inst_type =
3689
7.79k
          context->get_type_mgr()->GetType(inst->type_id());
3690
7.79k
      if (inst_type->IsSame(operand_type)) {
3691
7.68k
        inst->SetOpcode(spv::Op::OpCopyObject);
3692
7.68k
      } else {
3693
105
        inst->SetOpcode(spv::Op::OpBitcast);
3694
105
      }
3695
7.79k
      inst->SetInOperands({{SPV_OPERAND_TYPE_ID, {operand}}});
3696
7.79k
      return true;
3697
7.79k
    }
3698
325k
    return false;
3699
333k
  };
3700
285k
}
3701
3702
// This rule handles any of RedundantBinaryRhs0Ops with a 0 or vector 0 on the
3703
// right-hand side (a | 0 => a).
3704
static const constexpr spv::Op RedundantBinaryRhs0Ops[] = {
3705
    spv::Op::OpBitwiseOr,
3706
    spv::Op::OpBitwiseXor,
3707
    spv::Op::OpShiftRightLogical,
3708
    spv::Op::OpShiftRightArithmetic,
3709
    spv::Op::OpShiftLeftLogical,
3710
    spv::Op::OpIAdd,
3711
    spv::Op::OpISub};
3712
117k
FoldingRule RedundantBinaryRhs0(spv::Op op) {
3713
117k
  assert(std::find(std::begin(RedundantBinaryRhs0Ops),
3714
117k
                   std::end(RedundantBinaryRhs0Ops),
3715
117k
                   op) != std::end(RedundantBinaryRhs0Ops) &&
3716
117k
         "Wrong opcode.");
3717
117k
  (void)op;
3718
117k
  return RedundantBinaryOpWithZeroOperand(1, 0);
3719
117k
}
3720
3721
// This rule handles any of RedundantBinaryLhs0Ops with a 0 or vector 0 on the
3722
// left-hand side (0 | a => a).
3723
static const constexpr spv::Op RedundantBinaryLhs0Ops[] = {
3724
    spv::Op::OpBitwiseOr, spv::Op::OpBitwiseXor, spv::Op::OpIAdd};
3725
50.3k
FoldingRule RedundantBinaryLhs0(spv::Op op) {
3726
50.3k
  assert(std::find(std::begin(RedundantBinaryLhs0Ops),
3727
50.3k
                   std::end(RedundantBinaryLhs0Ops),
3728
50.3k
                   op) != std::end(RedundantBinaryLhs0Ops) &&
3729
50.3k
         "Wrong opcode.");
3730
50.3k
  (void)op;
3731
50.3k
  return RedundantBinaryOpWithZeroOperand(0, 1);
3732
50.3k
}
3733
3734
// This rule handles shifts and divisions of 0 or vector 0 by any amount
3735
// (0 >> a => 0).
3736
static const constexpr spv::Op RedundantBinaryLhs0To0Ops[] = {
3737
    spv::Op::OpShiftRightLogical,
3738
    spv::Op::OpShiftRightArithmetic,
3739
    spv::Op::OpShiftLeftLogical,
3740
    spv::Op::OpSDiv,
3741
    spv::Op::OpUDiv,
3742
    spv::Op::OpSMod,
3743
    spv::Op::OpUMod};
3744
117k
FoldingRule RedundantBinaryLhs0To0(spv::Op op) {
3745
117k
  assert(std::find(std::begin(RedundantBinaryLhs0To0Ops),
3746
117k
                   std::end(RedundantBinaryLhs0To0Ops),
3747
117k
                   op) != std::end(RedundantBinaryLhs0To0Ops) &&
3748
117k
         "Wrong opcode.");
3749
117k
  (void)op;
3750
117k
  return RedundantBinaryOpWithZeroOperand(0, 0);
3751
117k
}
3752
3753
50.3k
FoldingRule ReassociateCommutiveOp() {
3754
50.3k
  return [](IRContext* context, Instruction* inst,
3755
50.4k
            const std::vector<const analysis::Constant*>& constants) {
3756
50.4k
    const analysis::Type* type =
3757
50.4k
        context->get_type_mgr()->GetType(inst->type_id());
3758
50.4k
    uint32_t width = ElementWidth(type);
3759
50.4k
    if (width != 32) return false;
3760
3761
50.4k
    analysis::ConstantManager* const_mgr = context->get_constant_mgr();
3762
50.4k
    const analysis::Constant* const_input1 = ConstInput(constants);
3763
50.4k
    if (!const_input1) return false;
3764
20.6k
    Instruction* other_inst = NonConstInput(context, constants[0], inst);
3765
3766
20.6k
    if (other_inst->opcode() == inst->opcode()) {
3767
7.10k
      std::vector<const analysis::Constant*> other_constants =
3768
7.10k
          const_mgr->GetOperandConstants(other_inst);
3769
7.10k
      const analysis::Constant* const_input2 = ConstInput(other_constants);
3770
7.10k
      if (!const_input2) return false;
3771
3772
1.86k
      Instruction* non_const_input =
3773
1.86k
          NonConstInput(context, other_constants[0], other_inst);
3774
1.86k
      uint32_t merged_id = PerformOperation(const_mgr, inst->opcode(),
3775
1.86k
                                            const_input1, const_input2);
3776
3777
1.86k
      if (merged_id == 0) return false;
3778
1.86k
      inst->SetInOperands(
3779
1.86k
          {{SPV_OPERAND_TYPE_ID, {non_const_input->result_id()}},
3780
1.86k
           {SPV_OPERAND_TYPE_ID, {merged_id}}});
3781
1.86k
      return true;
3782
1.86k
    }
3783
3784
13.5k
    return false;
3785
20.6k
  };
3786
50.3k
}
3787
3788
// A | (b | C) = b | (A | C)
3789
// A ^ (b ^ C) = b ^ (A ^ C)
3790
// A & (b & C) = b & (A & C)
3791
// Where A and C are constants
3792
static const constexpr spv::Op ReassociateCommutiveBitwiseOps[] = {
3793
    spv::Op::OpBitwiseOr, spv::Op::OpBitwiseXor, spv::Op::OpBitwiseAnd};
3794
50.3k
FoldingRule ReassociateCommutiveBitwise(spv::Op op) {
3795
50.3k
  assert(std::find(std::begin(ReassociateCommutiveBitwiseOps),
3796
50.3k
                   std::end(ReassociateCommutiveBitwiseOps),
3797
50.3k
                   op) != std::end(ReassociateCommutiveBitwiseOps) &&
3798
50.3k
         "Wrong opcode.");
3799
50.3k
  (void)op;
3800
50.3k
  return ReassociateCommutiveOp();
3801
50.3k
}
3802
3803
// Returns true if all elements in |c| are 1.
3804
16.8k
bool IsAllInt1(const analysis::Constant* c) {
3805
16.8k
  if (auto composite = c->AsCompositeConstant()) {
3806
0
    auto& components = composite->GetComponents();
3807
0
    return std::all_of(std::begin(components), std::end(components), IsAllInt1);
3808
16.8k
  } else if (c->AsIntConstant()) {
3809
16.7k
    return c->GetSignExtendedValue() == 1;
3810
16.7k
  }
3811
3812
27
  return false;
3813
16.8k
}
3814
3815
// This rule handles divisions by 1 or vector 1 (a / 1 => a).
3816
33.5k
FoldingRule RedundantSUDiv() {
3817
33.5k
  return [](IRContext* context, Instruction* inst,
3818
33.5k
            const std::vector<const analysis::Constant*>& constants) {
3819
12.7k
    assert(constants.size() == 2);
3820
12.7k
    assert((inst->opcode() == spv::Op::OpUDiv ||
3821
12.7k
            inst->opcode() == spv::Op::OpSDiv) &&
3822
12.7k
           "Wrong opcode.");
3823
3824
12.7k
    if (constants[1] && IsAllInt1(constants[1])) {
3825
1.06k
      auto operand = inst->GetSingleWordInOperand(0);
3826
1.06k
      auto operand_type = constants[1]->type();
3827
3828
1.06k
      const analysis::Type* inst_type =
3829
1.06k
          context->get_type_mgr()->GetType(inst->type_id());
3830
1.06k
      if (inst_type->IsSame(operand_type)) {
3831
786
        inst->SetOpcode(spv::Op::OpCopyObject);
3832
786
      } else {
3833
274
        inst->SetOpcode(spv::Op::OpBitcast);
3834
274
      }
3835
1.06k
      inst->SetInOperands({{SPV_OPERAND_TYPE_ID, {operand}}});
3836
1.06k
      return true;
3837
1.06k
    }
3838
11.7k
    return false;
3839
12.7k
  };
3840
33.5k
}
3841
3842
// This rule handles modulo from division by 1 or vector 1 (a % 1 => 0).
3843
33.5k
FoldingRule RedundantSUMod() {
3844
33.5k
  return [](IRContext* context, Instruction* inst,
3845
33.5k
            const std::vector<const analysis::Constant*>& constants) {
3846
6.48k
    assert(constants.size() == 2);
3847
6.48k
    assert((inst->opcode() == spv::Op::OpUMod ||
3848
6.48k
            inst->opcode() == spv::Op::OpSMod) &&
3849
6.48k
           "Wrong opcode.");
3850
3851
6.48k
    if (constants[1] && IsAllInt1(constants[1])) {
3852
807
      auto type = context->get_type_mgr()->GetType(inst->type_id());
3853
807
      auto zero_id = context->get_constant_mgr()->GetNullConstId(type);
3854
3855
807
      inst->SetOpcode(spv::Op::OpCopyObject);
3856
807
      inst->SetInOperands({{SPV_OPERAND_TYPE_ID, {zero_id}}});
3857
807
      return true;
3858
807
    }
3859
5.68k
    return false;
3860
6.48k
  };
3861
33.5k
}
3862
3863
// Utility function for applying |callback| to |input1| and |input2|.
3864
// If they are vectors it applies element wise.
3865
// The constants |input1| and |input2| must be integers or a vector of integers.
3866
template <typename Callback>
3867
void ForEachIntegerConstantPair(analysis::ConstantManager* const_mgr,
3868
                                const analysis::Constant* input1,
3869
                                const analysis::Constant* input2,
3870
1.55k
                                Callback&& callback) {
3871
1.55k
  assert(input1 && input2);
3872
3873
1.55k
  auto Dispatch = [&callback](const analysis::Constant* lhs,
3874
1.55k
                              const analysis::Constant* rhs) {
3875
1.55k
    assert(lhs->type()->AsInteger());
3876
1.55k
    const analysis::Integer* type = lhs->type()->AsInteger();
3877
1.55k
    uint32_t width = type->AsInteger()->width();
3878
1.55k
    assert(width == 32 || width == 64);
3879
1.55k
    if (width == 32) {
3880
1.55k
      callback(lhs->GetU32(), rhs->GetU32());
3881
1.55k
    } else {
3882
0
      callback(lhs->GetU64(), rhs->GetU64());
3883
0
    }
3884
1.55k
  };
folding_rules.cpp:spvtools::opt::(anonymous namespace)::ForEachIntegerConstantPair<spvtools::opt::(anonymous namespace)::RedundantAndOrXor()::$_0::operator()(spvtools::opt::IRContext*, spvtools::opt::Instruction*, std::__1::vector<spvtools::opt::analysis::Constant const*, std::__1::allocator<spvtools::opt::analysis::Constant const*> > const&) const::{lambda(auto:1, auto:2)#1}>(spvtools::opt::analysis::ConstantManager*, spvtools::opt::analysis::Constant const*, spvtools::opt::analysis::Constant const*, spvtools::opt::(anonymous namespace)::RedundantAndOrXor()::$_0::operator()(spvtools::opt::IRContext*, spvtools::opt::Instruction*, std::__1::vector<spvtools::opt::analysis::Constant const*, std::__1::allocator<spvtools::opt::analysis::Constant const*> > const&) const::{lambda(auto:1, auto:2)#1}&&)::{lambda(spvtools::opt::analysis::Constant const*, spvtools::opt::analysis::Constant const*)#1}::operator()(spvtools::opt::analysis::Constant const*, spvtools::opt::analysis::Constant const*) const
Line
Count
Source
3874
812
                              const analysis::Constant* rhs) {
3875
812
    assert(lhs->type()->AsInteger());
3876
812
    const analysis::Integer* type = lhs->type()->AsInteger();
3877
812
    uint32_t width = type->AsInteger()->width();
3878
812
    assert(width == 32 || width == 64);
3879
812
    if (width == 32) {
3880
812
      callback(lhs->GetU32(), rhs->GetU32());
3881
812
    } else {
3882
0
      callback(lhs->GetU64(), rhs->GetU64());
3883
0
    }
3884
812
  };
folding_rules.cpp:spvtools::opt::(anonymous namespace)::ForEachIntegerConstantPair<spvtools::opt::(anonymous namespace)::RedundantAndAddSub()::$_0::operator()(spvtools::opt::IRContext*, spvtools::opt::Instruction*, std::__1::vector<spvtools::opt::analysis::Constant const*, std::__1::allocator<spvtools::opt::analysis::Constant const*> > const&) const::{lambda(auto:1, auto:2)#1}>(spvtools::opt::analysis::ConstantManager*, spvtools::opt::analysis::Constant const*, spvtools::opt::analysis::Constant const*, spvtools::opt::(anonymous namespace)::RedundantAndAddSub()::$_0::operator()(spvtools::opt::IRContext*, spvtools::opt::Instruction*, std::__1::vector<spvtools::opt::analysis::Constant const*, std::__1::allocator<spvtools::opt::analysis::Constant const*> > const&) const::{lambda(auto:1, auto:2)#1}&&)::{lambda(spvtools::opt::analysis::Constant const*, spvtools::opt::analysis::Constant const*)#1}::operator()(spvtools::opt::analysis::Constant const*, spvtools::opt::analysis::Constant const*) const
Line
Count
Source
3874
744
                              const analysis::Constant* rhs) {
3875
744
    assert(lhs->type()->AsInteger());
3876
744
    const analysis::Integer* type = lhs->type()->AsInteger();
3877
744
    uint32_t width = type->AsInteger()->width();
3878
744
    assert(width == 32 || width == 64);
3879
744
    if (width == 32) {
3880
744
      callback(lhs->GetU32(), rhs->GetU32());
3881
744
    } else {
3882
0
      callback(lhs->GetU64(), rhs->GetU64());
3883
0
    }
3884
744
  };
3885
3886
1.55k
  const analysis::Type* type = input1->type();
3887
1.55k
  if (const analysis::Vector* vector_type = type->AsVector()) {
3888
0
    const analysis::Type* ele_type = vector_type->element_type();
3889
0
    assert(ele_type->AsInteger());
3890
0
    for (uint32_t i = 0; i != vector_type->element_count(); ++i) {
3891
0
      const analysis::Constant* input1_comp = nullptr;
3892
0
      if (const analysis::VectorConstant* input1_vector =
3893
0
              input1->AsVectorConstant()) {
3894
0
        input1_comp = input1_vector->GetComponents()[i];
3895
0
      } else {
3896
0
        assert(input1->AsNullConstant());
3897
0
        input1_comp = const_mgr->GetConstant(ele_type, {});
3898
0
      }
3899
3900
0
      const analysis::Constant* input2_comp = nullptr;
3901
0
      if (const analysis::VectorConstant* input2_vector =
3902
0
              input2->AsVectorConstant()) {
3903
0
        input2_comp = input2_vector->GetComponents()[i];
3904
0
      } else {
3905
0
        assert(input2->AsNullConstant());
3906
0
        input2_comp = const_mgr->GetConstant(ele_type, {});
3907
0
      }
3908
3909
0
      assert(ele_type->AsInteger());
3910
0
      Dispatch(input1_comp, input2_comp);
3911
0
    }
3912
3913
1.55k
  } else {
3914
1.55k
    assert(type->AsInteger());
3915
1.55k
    Dispatch(input1, input2);
3916
1.55k
  }
3917
1.55k
}
folding_rules.cpp:void spvtools::opt::(anonymous namespace)::ForEachIntegerConstantPair<spvtools::opt::(anonymous namespace)::RedundantAndOrXor()::$_0::operator()(spvtools::opt::IRContext*, spvtools::opt::Instruction*, std::__1::vector<spvtools::opt::analysis::Constant const*, std::__1::allocator<spvtools::opt::analysis::Constant const*> > const&) const::{lambda(auto:1, auto:2)#1}>(spvtools::opt::analysis::ConstantManager*, spvtools::opt::analysis::Constant const*, spvtools::opt::analysis::Constant const*, spvtools::opt::(anonymous namespace)::RedundantAndOrXor()::$_0::operator()(spvtools::opt::IRContext*, spvtools::opt::Instruction*, std::__1::vector<spvtools::opt::analysis::Constant const*, std::__1::allocator<spvtools::opt::analysis::Constant const*> > const&) const::{lambda(auto:1, auto:2)#1}&&)
Line
Count
Source
3870
812
                                Callback&& callback) {
3871
812
  assert(input1 && input2);
3872
3873
812
  auto Dispatch = [&callback](const analysis::Constant* lhs,
3874
812
                              const analysis::Constant* rhs) {
3875
812
    assert(lhs->type()->AsInteger());
3876
812
    const analysis::Integer* type = lhs->type()->AsInteger();
3877
812
    uint32_t width = type->AsInteger()->width();
3878
812
    assert(width == 32 || width == 64);
3879
812
    if (width == 32) {
3880
812
      callback(lhs->GetU32(), rhs->GetU32());
3881
812
    } else {
3882
812
      callback(lhs->GetU64(), rhs->GetU64());
3883
812
    }
3884
812
  };
3885
3886
812
  const analysis::Type* type = input1->type();
3887
812
  if (const analysis::Vector* vector_type = type->AsVector()) {
3888
0
    const analysis::Type* ele_type = vector_type->element_type();
3889
0
    assert(ele_type->AsInteger());
3890
0
    for (uint32_t i = 0; i != vector_type->element_count(); ++i) {
3891
0
      const analysis::Constant* input1_comp = nullptr;
3892
0
      if (const analysis::VectorConstant* input1_vector =
3893
0
              input1->AsVectorConstant()) {
3894
0
        input1_comp = input1_vector->GetComponents()[i];
3895
0
      } else {
3896
0
        assert(input1->AsNullConstant());
3897
0
        input1_comp = const_mgr->GetConstant(ele_type, {});
3898
0
      }
3899
3900
0
      const analysis::Constant* input2_comp = nullptr;
3901
0
      if (const analysis::VectorConstant* input2_vector =
3902
0
              input2->AsVectorConstant()) {
3903
0
        input2_comp = input2_vector->GetComponents()[i];
3904
0
      } else {
3905
0
        assert(input2->AsNullConstant());
3906
0
        input2_comp = const_mgr->GetConstant(ele_type, {});
3907
0
      }
3908
3909
0
      assert(ele_type->AsInteger());
3910
0
      Dispatch(input1_comp, input2_comp);
3911
0
    }
3912
3913
812
  } else {
3914
812
    assert(type->AsInteger());
3915
812
    Dispatch(input1, input2);
3916
812
  }
3917
812
}
folding_rules.cpp:void spvtools::opt::(anonymous namespace)::ForEachIntegerConstantPair<spvtools::opt::(anonymous namespace)::RedundantAndAddSub()::$_0::operator()(spvtools::opt::IRContext*, spvtools::opt::Instruction*, std::__1::vector<spvtools::opt::analysis::Constant const*, std::__1::allocator<spvtools::opt::analysis::Constant const*> > const&) const::{lambda(auto:1, auto:2)#1}>(spvtools::opt::analysis::ConstantManager*, spvtools::opt::analysis::Constant const*, spvtools::opt::analysis::Constant const*, spvtools::opt::(anonymous namespace)::RedundantAndAddSub()::$_0::operator()(spvtools::opt::IRContext*, spvtools::opt::Instruction*, std::__1::vector<spvtools::opt::analysis::Constant const*, std::__1::allocator<spvtools::opt::analysis::Constant const*> > const&) const::{lambda(auto:1, auto:2)#1}&&)
Line
Count
Source
3870
744
                                Callback&& callback) {
3871
744
  assert(input1 && input2);
3872
3873
744
  auto Dispatch = [&callback](const analysis::Constant* lhs,
3874
744
                              const analysis::Constant* rhs) {
3875
744
    assert(lhs->type()->AsInteger());
3876
744
    const analysis::Integer* type = lhs->type()->AsInteger();
3877
744
    uint32_t width = type->AsInteger()->width();
3878
744
    assert(width == 32 || width == 64);
3879
744
    if (width == 32) {
3880
744
      callback(lhs->GetU32(), rhs->GetU32());
3881
744
    } else {
3882
744
      callback(lhs->GetU64(), rhs->GetU64());
3883
744
    }
3884
744
  };
3885
3886
744
  const analysis::Type* type = input1->type();
3887
744
  if (const analysis::Vector* vector_type = type->AsVector()) {
3888
0
    const analysis::Type* ele_type = vector_type->element_type();
3889
0
    assert(ele_type->AsInteger());
3890
0
    for (uint32_t i = 0; i != vector_type->element_count(); ++i) {
3891
0
      const analysis::Constant* input1_comp = nullptr;
3892
0
      if (const analysis::VectorConstant* input1_vector =
3893
0
              input1->AsVectorConstant()) {
3894
0
        input1_comp = input1_vector->GetComponents()[i];
3895
0
      } else {
3896
0
        assert(input1->AsNullConstant());
3897
0
        input1_comp = const_mgr->GetConstant(ele_type, {});
3898
0
      }
3899
3900
0
      const analysis::Constant* input2_comp = nullptr;
3901
0
      if (const analysis::VectorConstant* input2_vector =
3902
0
              input2->AsVectorConstant()) {
3903
0
        input2_comp = input2_vector->GetComponents()[i];
3904
0
      } else {
3905
0
        assert(input2->AsNullConstant());
3906
0
        input2_comp = const_mgr->GetConstant(ele_type, {});
3907
0
      }
3908
3909
0
      assert(ele_type->AsInteger());
3910
0
      Dispatch(input1_comp, input2_comp);
3911
0
    }
3912
3913
744
  } else {
3914
744
    assert(type->AsInteger());
3915
744
    Dispatch(input1, input2);
3916
744
  }
3917
744
}
3918
3919
// Folds redundant xor and or ops that are part of an and.
3920
// Cases handled:
3921
// 0b1110 & (a | 0b0001) = a & 0b1110
3922
// 0b1110 & (a ^ 0b0001) = a & 0b1110
3923
// 0b0110 & (a | 0b1110) = 0b0110
3924
16.7k
FoldingRule RedundantAndOrXor() {
3925
16.7k
  return [](IRContext* context, Instruction* inst,
3926
16.7k
            const std::vector<const analysis::Constant*>& constants) {
3927
12.8k
    assert(inst->opcode() == spv::Op::OpBitwiseAnd && "Wrong opcode.");
3928
12.8k
    const analysis::Type* type =
3929
12.8k
        context->get_type_mgr()->GetType(inst->type_id());
3930
12.8k
    uint32_t width = ElementWidth(type);
3931
12.8k
    if ((width != 32) && (width != 64)) return false;
3932
3933
12.8k
    analysis::ConstantManager* const_mgr = context->get_constant_mgr();
3934
12.8k
    const analysis::Constant* const_input1 = ConstInput(constants);
3935
12.8k
    if (!const_input1) return false;
3936
9.23k
    Instruction* other_inst = NonConstInput(context, constants[0], inst);
3937
3938
9.23k
    if (other_inst->opcode() == spv::Op::OpBitwiseOr ||
3939
7.74k
        other_inst->opcode() == spv::Op::OpBitwiseXor) {
3940
1.60k
      std::vector<const analysis::Constant*> other_constants =
3941
1.60k
          const_mgr->GetOperandConstants(other_inst);
3942
1.60k
      const analysis::Constant* const_input2 = ConstInput(other_constants);
3943
1.60k
      if (!const_input2) return false;
3944
3945
812
      bool can_convert_to_const = other_inst->opcode() == spv::Op::OpBitwiseOr;
3946
812
      bool can_remove_inner = true;
3947
3948
812
      ForEachIntegerConstantPair(
3949
812
          const_mgr, const_input1, const_input2,
3950
812
          [&can_remove_inner, &can_convert_to_const](auto lhs, auto rhs) {
3951
            // Only convert to const if 'and' is a subset of 'or'
3952
812
            can_convert_to_const = can_convert_to_const && ((lhs & rhs) == lhs);
3953
            // Only remove 'xor'/'or' if no bits intersect with 'and'
3954
812
            can_remove_inner = can_remove_inner && ((lhs & rhs) == 0);
3955
812
          });
folding_rules.cpp:auto spvtools::opt::(anonymous namespace)::RedundantAndOrXor()::$_0::operator()(spvtools::opt::IRContext*, spvtools::opt::Instruction*, std::__1::vector<spvtools::opt::analysis::Constant const*, std::__1::allocator<spvtools::opt::analysis::Constant const*> > const&) const::{lambda(auto:1, auto:2)#1}::operator()<unsigned int, unsigned int>(unsigned int, unsigned int) const
Line
Count
Source
3950
812
          [&can_remove_inner, &can_convert_to_const](auto lhs, auto rhs) {
3951
            // Only convert to const if 'and' is a subset of 'or'
3952
812
            can_convert_to_const = can_convert_to_const && ((lhs & rhs) == lhs);
3953
            // Only remove 'xor'/'or' if no bits intersect with 'and'
3954
812
            can_remove_inner = can_remove_inner && ((lhs & rhs) == 0);
3955
812
          });
Unexecuted instantiation: folding_rules.cpp:auto spvtools::opt::(anonymous namespace)::RedundantAndOrXor()::$_0::operator()(spvtools::opt::IRContext*, spvtools::opt::Instruction*, std::__1::vector<spvtools::opt::analysis::Constant const*, std::__1::allocator<spvtools::opt::analysis::Constant const*> > const&) const::{lambda(auto:1, auto:2)#1}::operator()<unsigned long, unsigned long>(unsigned long, unsigned long) const
3956
3957
812
      if (can_convert_to_const) {
3958
63
        Instruction* const_inst =
3959
63
            const_mgr->GetDefiningInstruction(const_input1);
3960
63
        inst->SetOpcode(spv::Op::OpCopyObject);
3961
63
        inst->SetInOperands({{SPV_OPERAND_TYPE_ID, {const_inst->result_id()}}});
3962
63
        return true;
3963
749
      } else if (can_remove_inner) {
3964
59
        Instruction* non_const_input =
3965
59
            NonConstInput(context, other_constants[0], other_inst);
3966
59
        Instruction* const_inst =
3967
59
            const_mgr->GetDefiningInstruction(const_input1);
3968
59
        inst->SetInOperands(
3969
59
            {{SPV_OPERAND_TYPE_ID, {non_const_input->result_id()}},
3970
59
             {SPV_OPERAND_TYPE_ID, {const_inst->result_id()}}});
3971
59
        return true;
3972
59
      }
3973
812
    }
3974
8.31k
    return false;
3975
9.23k
  };
3976
16.7k
}
3977
3978
// Folds redundant add and sub ops that are part of an and.
3979
// Cases handled:
3980
// 1 & (b + 2) = b & 1
3981
// 1 & (b - 2) = b & 1
3982
16.7k
FoldingRule RedundantAndAddSub() {
3983
16.7k
  return [](IRContext* context, Instruction* inst,
3984
16.7k
            const std::vector<const analysis::Constant*>& constants) {
3985
12.7k
    assert(inst->opcode() == spv::Op::OpBitwiseAnd && "Wrong opcode.");
3986
12.7k
    const analysis::Type* type =
3987
12.7k
        context->get_type_mgr()->GetType(inst->type_id());
3988
12.7k
    uint32_t width = ElementWidth(type);
3989
12.7k
    if ((width != 32) && (width != 64)) return false;
3990
3991
12.7k
    analysis::ConstantManager* const_mgr = context->get_constant_mgr();
3992
12.7k
    const analysis::Constant* const_input1 = ConstInput(constants);
3993
12.7k
    if (!const_input1) return false;
3994
9.10k
    Instruction* other_inst = NonConstInput(context, constants[0], inst);
3995
3996
9.10k
    if (other_inst->opcode() != spv::Op::OpIAdd &&
3997
8.35k
        other_inst->opcode() != spv::Op::OpISub) {
3998
8.20k
      return false;
3999
8.20k
    }
4000
902
    std::vector<const analysis::Constant*> other_constants =
4001
902
        const_mgr->GetOperandConstants(other_inst);
4002
902
    const analysis::Constant* const_input2 = ConstInput(other_constants);
4003
902
    if (!const_input2) return false;
4004
4005
    // Only valid for subtraction if const is on the right
4006
774
    if ((other_inst->opcode() == spv::Op::OpISub) && other_constants[0]) {
4007
30
      return false;
4008
30
    }
4009
4010
744
    bool can_remove_inner = true;
4011
744
    ForEachIntegerConstantPair(const_mgr, const_input1, const_input2,
4012
744
                               [&can_remove_inner](auto and_op, auto add_op) {
4013
744
                                 if (can_remove_inner) {
4014
                                   // Only valid if no bits from the +/- could
4015
                                   // affect bits from the & operation.
4016
744
                                   can_remove_inner =
4017
744
                                       utils::LSB(add_op) > and_op;
4018
744
                                 }
4019
744
                               });
folding_rules.cpp:auto spvtools::opt::(anonymous namespace)::RedundantAndAddSub()::$_0::operator()(spvtools::opt::IRContext*, spvtools::opt::Instruction*, std::__1::vector<spvtools::opt::analysis::Constant const*, std::__1::allocator<spvtools::opt::analysis::Constant const*> > const&) const::{lambda(auto:1, auto:2)#1}::operator()<unsigned int, unsigned int>(unsigned int, unsigned int) const
Line
Count
Source
4012
744
                               [&can_remove_inner](auto and_op, auto add_op) {
4013
744
                                 if (can_remove_inner) {
4014
                                   // Only valid if no bits from the +/- could
4015
                                   // affect bits from the & operation.
4016
744
                                   can_remove_inner =
4017
744
                                       utils::LSB(add_op) > and_op;
4018
744
                                 }
4019
744
                               });
Unexecuted instantiation: folding_rules.cpp:auto spvtools::opt::(anonymous namespace)::RedundantAndAddSub()::$_0::operator()(spvtools::opt::IRContext*, spvtools::opt::Instruction*, std::__1::vector<spvtools::opt::analysis::Constant const*, std::__1::allocator<spvtools::opt::analysis::Constant const*> > const&) const::{lambda(auto:1, auto:2)#1}::operator()<unsigned long, unsigned long>(unsigned long, unsigned long) const
4020
4021
744
    if (can_remove_inner) {
4022
22
      Instruction* non_const_input =
4023
22
          NonConstInput(context, other_constants[0], other_inst);
4024
22
      Instruction* const_inst = const_mgr->GetDefiningInstruction(const_input1);
4025
22
      inst->SetInOperands(
4026
22
          {{SPV_OPERAND_TYPE_ID, {non_const_input->result_id()}},
4027
22
           {SPV_OPERAND_TYPE_ID, {const_inst->result_id()}}});
4028
22
      return true;
4029
22
    }
4030
722
    return false;
4031
744
  };
4032
16.7k
}
4033
4034
// Folds redundant shift ops that are part of an and.
4035
// Cases handled:
4036
// 1 & (b << 1) = 0
4037
// 0x80000000 & (b >> 1) = 0
4038
16.7k
FoldingRule RedundantAndShift() {
4039
16.7k
  return [](IRContext* context, Instruction* inst,
4040
16.7k
            const std::vector<const analysis::Constant*>& constants) {
4041
12.6k
    assert(inst->opcode() == spv::Op::OpBitwiseAnd && "Wrong opcode.");
4042
12.6k
    const analysis::Type* type =
4043
12.6k
        context->get_type_mgr()->GetType(inst->type_id());
4044
12.6k
    uint32_t width = ElementWidth(type);
4045
12.6k
    if (width != 8 && width != 16 && width != 32 && width != 64) return false;
4046
12.6k
    const uint64_t width_mask =
4047
12.6k
        (width == 64) ? ~0ull : ((1ull << width) - 1ull);
4048
4049
12.6k
    analysis::ConstantManager* const_mgr = context->get_constant_mgr();
4050
12.6k
    const analysis::Constant* const_input1 = ConstInput(constants);
4051
12.6k
    if (!const_input1) return false;
4052
9.08k
    Instruction* other_inst = NonConstInput(context, constants[0], inst);
4053
4054
9.08k
    spv::Op other_op = other_inst->opcode();
4055
9.08k
    if (other_op != spv::Op::OpShiftLeftLogical &&
4056
8.90k
        other_op != spv::Op::OpShiftRightLogical) {
4057
8.65k
      return false;
4058
8.65k
    }
4059
4060
431
    std::vector<const analysis::Constant*> other_constants =
4061
431
        const_mgr->GetOperandConstants(other_inst);
4062
4063
    // Only valid if const is on the right.
4064
431
    if (other_constants[0]) return false;
4065
352
    const analysis::Constant* const_input2 = other_constants[1];
4066
352
    if (!const_input2) return false;
4067
4068
138
    auto get_value_u64 =
4069
276
        [](const analysis::Constant* c) -> std::optional<uint64_t> {
4070
276
      if (!c) return std::nullopt;
4071
276
      const analysis::Integer* int_t = c->type()->AsInteger();
4072
276
      if (!int_t) return std::nullopt;
4073
276
      return c->GetZeroExtendedValue();
4074
276
    };
4075
4076
138
    auto can_fold_component =
4077
138
        [&](const analysis::Constant* mask_const,
4078
138
            const analysis::Constant* shift_const) -> std::optional<bool> {
4079
138
      auto lhs = get_value_u64(mask_const);
4080
138
      auto rhs = get_value_u64(shift_const);
4081
138
      if (!lhs || !rhs) return std::nullopt;
4082
138
      if (*rhs >= width) return false;
4083
138
      uint64_t lhs_masked = *lhs & width_mask;
4084
138
      if (other_op == spv::Op::OpShiftRightLogical) {
4085
69
        return ((lhs_masked << *rhs) & width_mask) == 0;
4086
69
      }
4087
69
      return ((lhs_masked >> *rhs) & width_mask) == 0;
4088
138
    };
4089
4090
138
    if (const analysis::Vector* mask_vec = type->AsVector()) {
4091
0
      const analysis::Vector* shift_vec = const_input2->type()->AsVector();
4092
0
      if (!shift_vec ||
4093
0
          shift_vec->element_count() != mask_vec->element_count()) {
4094
0
        return false;
4095
0
      }
4096
0
      const auto mask_components = const_input1->GetVectorComponents(const_mgr);
4097
0
      const auto shift_components =
4098
0
          const_input2->GetVectorComponents(const_mgr);
4099
0
      for (uint32_t i = 0; i != mask_vec->element_count(); ++i) {
4100
0
        auto result =
4101
0
            can_fold_component(mask_components[i], shift_components[i]);
4102
0
        if (!result || !*result) return false;
4103
0
      }
4104
138
    } else {
4105
138
      if (const_input2->type()->AsVector()) return false;
4106
138
      auto result = can_fold_component(const_input1, const_input2);
4107
138
      if (!result || !*result) return false;
4108
138
    }
4109
4110
33
    auto zero_id = context->get_constant_mgr()->GetNullConstId(type);
4111
33
    inst->SetOpcode(spv::Op::OpCopyObject);
4112
33
    inst->SetInOperands({{SPV_OPERAND_TYPE_ID, {zero_id}}});
4113
33
    return true;
4114
138
  };
4115
16.7k
}
4116
4117
// This rule look for a dot with a constant vector containing a single 1 and
4118
// the rest 0s.  This is the same as doing an extract.
4119
16.7k
FoldingRule DotProductDoingExtract() {
4120
16.7k
  return [](IRContext* context, Instruction* inst,
4121
16.7k
            const std::vector<const analysis::Constant*>& constants) {
4122
51
    assert(inst->opcode() == spv::Op::OpDot &&
4123
51
           "Wrong opcode.  Should be OpDot.");
4124
4125
51
    analysis::ConstantManager* const_mgr = context->get_constant_mgr();
4126
4127
51
    if (!inst->IsFloatingPointFoldingAllowed()) {
4128
0
      return false;
4129
0
    }
4130
4131
153
    for (int i = 0; i < 2; ++i) {
4132
102
      if (!constants[i]) {
4133
60
        continue;
4134
60
      }
4135
4136
42
      const analysis::Vector* vector_type = constants[i]->type()->AsVector();
4137
42
      assert(vector_type && "Inputs to OpDot must be vectors.");
4138
42
      const analysis::Float* element_type =
4139
42
          vector_type->element_type()->AsFloat();
4140
42
      assert(element_type && "Inputs to OpDot must be vectors of floats.");
4141
42
      uint32_t element_width = element_type->width();
4142
42
      if (element_width != 32 && element_width != 64) {
4143
0
        return false;
4144
0
      }
4145
4146
42
      std::vector<const analysis::Constant*> components;
4147
42
      components = constants[i]->GetVectorComponents(const_mgr);
4148
4149
42
      constexpr uint32_t kNotFound = std::numeric_limits<uint32_t>::max();
4150
4151
42
      uint32_t component_with_one = kNotFound;
4152
42
      bool all_others_zero = true;
4153
43
      for (uint32_t j = 0; j < components.size(); ++j) {
4154
43
        const analysis::Constant* element = components[j];
4155
43
        double value =
4156
43
            (element_width == 32 ? element->GetFloat() : element->GetDouble());
4157
43
        if (value == 0.0) {
4158
1
          continue;
4159
42
        } else if (value == 1.0) {
4160
0
          if (component_with_one == kNotFound) {
4161
0
            component_with_one = j;
4162
0
          } else {
4163
0
            component_with_one = kNotFound;
4164
0
            break;
4165
0
          }
4166
42
        } else {
4167
42
          all_others_zero = false;
4168
42
          break;
4169
42
        }
4170
43
      }
4171
4172
42
      if (!all_others_zero || component_with_one == kNotFound) {
4173
42
        continue;
4174
42
      }
4175
4176
0
      std::vector<Operand> operands;
4177
0
      operands.push_back(
4178
0
          {SPV_OPERAND_TYPE_ID, {inst->GetSingleWordInOperand(1u - i)}});
4179
0
      operands.push_back(
4180
0
          {SPV_OPERAND_TYPE_LITERAL_INTEGER, {component_with_one}});
4181
4182
0
      inst->SetOpcode(spv::Op::OpCompositeExtract);
4183
0
      inst->SetInOperands(std::move(operands));
4184
0
      return true;
4185
42
    }
4186
51
    return false;
4187
51
  };
4188
16.7k
}
4189
4190
// If we are storing an undef, then we can remove the store.
4191
//
4192
// TODO: We can do something similar for OpImageWrite, but checking for volatile
4193
// is complicated.  Waiting to see if it is needed.
4194
16.7k
FoldingRule StoringUndef() {
4195
16.7k
  return [](IRContext* context, Instruction* inst,
4196
1.04M
            const std::vector<const analysis::Constant*>&) {
4197
1.04M
    assert(inst->opcode() == spv::Op::OpStore &&
4198
1.04M
           "Wrong opcode.  Should be OpStore.");
4199
4200
1.04M
    analysis::DefUseManager* def_use_mgr = context->get_def_use_mgr();
4201
4202
    // If this is a volatile store, the store cannot be removed.
4203
1.04M
    if (inst->NumInOperands() == 3) {
4204
8.97k
      if (inst->GetSingleWordInOperand(2) &
4205
8.97k
          uint32_t(spv::MemoryAccessMask::Volatile)) {
4206
5.95k
        return false;
4207
5.95k
      }
4208
8.97k
    }
4209
4210
1.04M
    uint32_t object_id = inst->GetSingleWordInOperand(kStoreObjectInIdx);
4211
1.04M
    Instruction* object_inst = def_use_mgr->GetDef(object_id);
4212
1.04M
    if (object_inst->opcode() == spv::Op::OpUndef) {
4213
21.4k
      inst->ToNop();
4214
21.4k
      return true;
4215
21.4k
    }
4216
1.01M
    return false;
4217
1.04M
  };
4218
16.7k
}
4219
4220
16.7k
FoldingRule VectorShuffleFeedingShuffle() {
4221
16.7k
  return [](IRContext* context, Instruction* inst,
4222
26.0k
            const std::vector<const analysis::Constant*>&) {
4223
26.0k
    assert(inst->opcode() == spv::Op::OpVectorShuffle &&
4224
26.0k
           "Wrong opcode.  Should be OpVectorShuffle.");
4225
4226
26.0k
    analysis::DefUseManager* def_use_mgr = context->get_def_use_mgr();
4227
26.0k
    analysis::TypeManager* type_mgr = context->get_type_mgr();
4228
4229
26.0k
    Instruction* feeding_shuffle_inst =
4230
26.0k
        def_use_mgr->GetDef(inst->GetSingleWordInOperand(0));
4231
26.0k
    analysis::Vector* op0_type =
4232
26.0k
        type_mgr->GetType(feeding_shuffle_inst->type_id())->AsVector();
4233
26.0k
    uint32_t op0_length = op0_type->element_count();
4234
4235
26.0k
    bool feeder_is_op0 = true;
4236
26.0k
    if (feeding_shuffle_inst->opcode() != spv::Op::OpVectorShuffle) {
4237
25.8k
      feeding_shuffle_inst =
4238
25.8k
          def_use_mgr->GetDef(inst->GetSingleWordInOperand(1));
4239
25.8k
      feeder_is_op0 = false;
4240
25.8k
    }
4241
4242
26.0k
    if (feeding_shuffle_inst->opcode() != spv::Op::OpVectorShuffle) {
4243
25.3k
      return false;
4244
25.3k
    }
4245
4246
671
    Instruction* feeder2 =
4247
671
        def_use_mgr->GetDef(feeding_shuffle_inst->GetSingleWordInOperand(0));
4248
671
    analysis::Vector* feeder_op0_type =
4249
671
        type_mgr->GetType(feeder2->type_id())->AsVector();
4250
671
    uint32_t feeder_op0_length = feeder_op0_type->element_count();
4251
4252
671
    uint32_t new_feeder_id = 0;
4253
671
    std::vector<Operand> new_operands;
4254
671
    new_operands.resize(
4255
671
        2, {SPV_OPERAND_TYPE_ID, {0}});  // Place holders for vector operands.
4256
671
    const uint32_t undef_literal = 0xffffffff;
4257
2.22k
    for (uint32_t op = 2; op < inst->NumInOperands(); ++op) {
4258
1.61k
      uint32_t component_index = inst->GetSingleWordInOperand(op);
4259
4260
      // Do not interpret the undefined value literal as coming from operand 1.
4261
1.61k
      if (component_index != undef_literal &&
4262
1.51k
          feeder_is_op0 == (component_index < op0_length)) {
4263
        // This component comes from the feeding_shuffle_inst.  Update
4264
        // |component_index| to be the index into the operand of the feeder.
4265
4266
        // Adjust component_index to get the index into the operands of the
4267
        // feeding_shuffle_inst.
4268
828
        if (component_index >= op0_length) {
4269
484
          component_index -= op0_length;
4270
484
        }
4271
828
        component_index =
4272
828
            feeding_shuffle_inst->GetSingleWordInOperand(component_index + 2);
4273
4274
        // Check if we are using a component from the first or second operand of
4275
        // the feeding instruction.
4276
828
        if (component_index < feeder_op0_length) {
4277
643
          if (new_feeder_id == 0) {
4278
            // First time through, save the id of the operand the element comes
4279
            // from.
4280
374
            new_feeder_id = feeding_shuffle_inst->GetSingleWordInOperand(0);
4281
374
          } else if (new_feeder_id !=
4282
269
                     feeding_shuffle_inst->GetSingleWordInOperand(0)) {
4283
            // We need both elements of the feeding_shuffle_inst, so we cannot
4284
            // fold.
4285
41
            return false;
4286
41
          }
4287
643
        } else if (component_index != undef_literal) {
4288
122
          if (new_feeder_id == 0) {
4289
            // First time through, save the id of the operand the element comes
4290
            // from.
4291
83
            new_feeder_id = feeding_shuffle_inst->GetSingleWordInOperand(1);
4292
83
          } else if (new_feeder_id !=
4293
39
                     feeding_shuffle_inst->GetSingleWordInOperand(1)) {
4294
            // We need both elements of the feeding_shuffle_inst, so we cannot
4295
            // fold.
4296
22
            return false;
4297
22
          }
4298
100
          component_index -= feeder_op0_length;
4299
100
        }
4300
4301
765
        if (!feeder_is_op0 && component_index != undef_literal) {
4302
456
          component_index += op0_length;
4303
456
        }
4304
765
      }
4305
1.55k
      new_operands.push_back(
4306
1.55k
          {SPV_OPERAND_TYPE_LITERAL_INTEGER, {component_index}});
4307
1.55k
    }
4308
4309
608
    if (new_feeder_id == 0) {
4310
214
      analysis::ConstantManager* const_mgr = context->get_constant_mgr();
4311
214
      const analysis::Type* type =
4312
214
          type_mgr->GetType(feeding_shuffle_inst->type_id());
4313
214
      const analysis::Constant* null_const = const_mgr->GetConstant(type, {});
4314
214
      new_feeder_id =
4315
214
          const_mgr->GetDefiningInstruction(null_const, 0)->result_id();
4316
214
    }
4317
4318
608
    if (feeder_is_op0) {
4319
      // If the size of the first vector operand changed then the indices
4320
      // referring to the second operand need to be adjusted.
4321
169
      Instruction* new_feeder_inst = def_use_mgr->GetDef(new_feeder_id);
4322
169
      analysis::Type* new_feeder_type =
4323
169
          type_mgr->GetType(new_feeder_inst->type_id());
4324
169
      uint32_t new_op0_size = new_feeder_type->AsVector()->element_count();
4325
169
      int32_t adjustment = op0_length - new_op0_size;
4326
4327
169
      if (adjustment != 0) {
4328
250
        for (uint32_t i = 2; i < new_operands.size(); i++) {
4329
168
          uint32_t operand = inst->GetSingleWordInOperand(i);
4330
168
          if (operand >= op0_length && operand != undef_literal) {
4331
47
            new_operands[i].words[0] -= adjustment;
4332
47
          }
4333
168
        }
4334
82
      }
4335
4336
169
      new_operands[0].words[0] = new_feeder_id;
4337
169
      new_operands[1] = inst->GetInOperand(1);
4338
439
    } else {
4339
439
      new_operands[1].words[0] = new_feeder_id;
4340
439
      new_operands[0] = inst->GetInOperand(0);
4341
439
    }
4342
4343
608
    inst->SetInOperands(std::move(new_operands));
4344
608
    return true;
4345
671
  };
4346
16.7k
}
4347
4348
// Removes duplicate ids from the interface list of an OpEntryPoint
4349
// instruction.
4350
16.7k
FoldingRule RemoveRedundantOperands() {
4351
16.7k
  return [](IRContext*, Instruction* inst,
4352
16.7k
            const std::vector<const analysis::Constant*>&) {
4353
0
    assert(inst->opcode() == spv::Op::OpEntryPoint &&
4354
0
           "Wrong opcode.  Should be OpEntryPoint.");
4355
0
    bool has_redundant_operand = false;
4356
0
    std::unordered_set<uint32_t> seen_operands;
4357
0
    std::vector<Operand> new_operands;
4358
4359
0
    new_operands.emplace_back(inst->GetOperand(0));
4360
0
    new_operands.emplace_back(inst->GetOperand(1));
4361
0
    new_operands.emplace_back(inst->GetOperand(2));
4362
0
    for (uint32_t i = 3; i < inst->NumOperands(); ++i) {
4363
0
      if (seen_operands.insert(inst->GetSingleWordOperand(i)).second) {
4364
0
        new_operands.emplace_back(inst->GetOperand(i));
4365
0
      } else {
4366
0
        has_redundant_operand = true;
4367
0
      }
4368
0
    }
4369
4370
0
    if (!has_redundant_operand) {
4371
0
      return false;
4372
0
    }
4373
4374
0
    inst->SetInOperands(std::move(new_operands));
4375
0
    return true;
4376
0
  };
4377
16.7k
}
4378
4379
// If an image instruction's operand is a constant, updates the image operand
4380
// flag from Offset to ConstOffset.
4381
419k
FoldingRule UpdateImageOperands() {
4382
419k
  return [](IRContext*, Instruction* inst,
4383
419k
            const std::vector<const analysis::Constant*>& constants) {
4384
270k
    const auto opcode = inst->opcode();
4385
270k
    (void)opcode;
4386
270k
    assert((opcode == spv::Op::OpImageSampleImplicitLod ||
4387
270k
            opcode == spv::Op::OpImageSampleExplicitLod ||
4388
270k
            opcode == spv::Op::OpImageSampleDrefImplicitLod ||
4389
270k
            opcode == spv::Op::OpImageSampleDrefExplicitLod ||
4390
270k
            opcode == spv::Op::OpImageSampleProjImplicitLod ||
4391
270k
            opcode == spv::Op::OpImageSampleProjExplicitLod ||
4392
270k
            opcode == spv::Op::OpImageSampleProjDrefImplicitLod ||
4393
270k
            opcode == spv::Op::OpImageSampleProjDrefExplicitLod ||
4394
270k
            opcode == spv::Op::OpImageFetch ||
4395
270k
            opcode == spv::Op::OpImageGather ||
4396
270k
            opcode == spv::Op::OpImageDrefGather ||
4397
270k
            opcode == spv::Op::OpImageRead || opcode == spv::Op::OpImageWrite ||
4398
270k
            opcode == spv::Op::OpImageSparseSampleImplicitLod ||
4399
270k
            opcode == spv::Op::OpImageSparseSampleExplicitLod ||
4400
270k
            opcode == spv::Op::OpImageSparseSampleDrefImplicitLod ||
4401
270k
            opcode == spv::Op::OpImageSparseSampleDrefExplicitLod ||
4402
270k
            opcode == spv::Op::OpImageSparseSampleProjImplicitLod ||
4403
270k
            opcode == spv::Op::OpImageSparseSampleProjExplicitLod ||
4404
270k
            opcode == spv::Op::OpImageSparseSampleProjDrefImplicitLod ||
4405
270k
            opcode == spv::Op::OpImageSparseSampleProjDrefExplicitLod ||
4406
270k
            opcode == spv::Op::OpImageSparseFetch ||
4407
270k
            opcode == spv::Op::OpImageSparseGather ||
4408
270k
            opcode == spv::Op::OpImageSparseDrefGather ||
4409
270k
            opcode == spv::Op::OpImageSparseRead) &&
4410
270k
           "Wrong opcode.  Should be an image instruction.");
4411
4412
270k
    int32_t operand_index = ImageOperandsMaskInOperandIndex(inst);
4413
270k
    if (operand_index >= 0) {
4414
12
      auto image_operands = inst->GetSingleWordInOperand(operand_index);
4415
12
      if (image_operands & uint32_t(spv::ImageOperandsMask::Offset)) {
4416
0
        uint32_t offset_operand_index = operand_index + 1;
4417
0
        if (image_operands & uint32_t(spv::ImageOperandsMask::Bias))
4418
0
          offset_operand_index++;
4419
0
        if (image_operands & uint32_t(spv::ImageOperandsMask::Lod))
4420
0
          offset_operand_index++;
4421
0
        if (image_operands & uint32_t(spv::ImageOperandsMask::Grad))
4422
0
          offset_operand_index += 2;
4423
0
        assert(((image_operands &
4424
0
                 uint32_t(spv::ImageOperandsMask::ConstOffset)) == 0) &&
4425
0
               "Offset and ConstOffset may not be used together");
4426
0
        if (offset_operand_index < inst->NumOperands()) {
4427
0
          if (constants[offset_operand_index]) {
4428
0
            if (constants[offset_operand_index]->IsZero()) {
4429
0
              inst->RemoveInOperand(offset_operand_index);
4430
0
            } else {
4431
0
              image_operands = image_operands |
4432
0
                               uint32_t(spv::ImageOperandsMask::ConstOffset);
4433
0
            }
4434
0
            image_operands =
4435
0
                image_operands & ~uint32_t(spv::ImageOperandsMask::Offset);
4436
0
            inst->SetInOperand(operand_index, {image_operands});
4437
0
            return true;
4438
0
          }
4439
0
        }
4440
0
      }
4441
12
    }
4442
4443
270k
    return false;
4444
270k
  };
4445
419k
}
4446
4447
}  // namespace
4448
4449
16.7k
void FoldingRules::AddFoldingRules() {
4450
  // Add all folding rules to the list for the opcodes to which they apply.
4451
  // Note that the order in which rules are added to the list matters. If a rule
4452
  // applies to the instruction, the rest of the rules will not be attempted.
4453
  // Take that into consideration.
4454
16.7k
  for (auto op : RedundantBinaryRhs0Ops)
4455
117k
    rules_[op].push_back(RedundantBinaryRhs0(op));
4456
16.7k
  for (auto op : RedundantBinaryLhs0Ops)
4457
50.3k
    rules_[op].push_back(RedundantBinaryLhs0(op));
4458
16.7k
  for (auto op : RedundantBinaryLhs0To0Ops)
4459
117k
    rules_[op].push_back(RedundantBinaryLhs0To0(op));
4460
16.7k
  for (auto op : ReassociateCommutiveBitwiseOps)
4461
50.3k
    rules_[op].push_back(ReassociateCommutiveBitwise(op));
4462
16.7k
  for (auto op : ReassociateNestedGenericIntOps)
4463
67.1k
    rules_[op].push_back(ReassociateNestedGenericInt(op));
4464
16.7k
  for (auto op : MergeBinaryOpSelectOps)
4465
704k
    rules_[op].push_back(MergeBinaryOpSelect(op));
4466
16.7k
  rules_[spv::Op::OpSDiv].push_back(RedundantSUDiv());
4467
16.7k
  rules_[spv::Op::OpUDiv].push_back(RedundantSUDiv());
4468
16.7k
  rules_[spv::Op::OpSMod].push_back(RedundantSUMod());
4469
16.7k
  rules_[spv::Op::OpUMod].push_back(RedundantSUMod());
4470
4471
16.7k
  rules_[spv::Op::OpBitcast].push_back(BitCastScalarOrVector());
4472
16.7k
  rules_[spv::Op::OpBitcast].push_back(RedundantBitcast());
4473
4474
16.7k
  rules_[spv::Op::OpBitReverse].push_back(BitReverseScalarOrVector());
4475
4476
16.7k
  rules_[spv::Op::OpCompositeConstruct].push_back(
4477
16.7k
      CompositeExtractFeedingConstruct);
4478
4479
16.7k
  rules_[spv::Op::OpCompositeExtract].push_back(InsertFeedingExtract());
4480
16.7k
  rules_[spv::Op::OpCompositeExtract].push_back(
4481
16.7k
      CompositeConstructFeedingExtract);
4482
16.7k
  rules_[spv::Op::OpCompositeExtract].push_back(VectorShuffleFeedingExtract());
4483
16.7k
  rules_[spv::Op::OpCompositeExtract].push_back(FMixFeedingExtract());
4484
16.7k
  rules_[spv::Op::OpCompositeExtract].push_back(CopyLogicalFeedingExtract);
4485
16.7k
  rules_[spv::Op::OpCompositeExtract].push_back(LoadFeedingExtract);
4486
4487
16.7k
  rules_[spv::Op::OpCompositeInsert].push_back(
4488
16.7k
      CompositeInsertToCompositeConstruct);
4489
4490
16.7k
  rules_[spv::Op::OpDot].push_back(DotProductDoingExtract());
4491
4492
16.7k
  rules_[spv::Op::OpEntryPoint].push_back(RemoveRedundantOperands());
4493
4494
16.7k
  rules_[spv::Op::OpFAdd].push_back(RedundantFAdd());
4495
16.7k
  rules_[spv::Op::OpFAdd].push_back(MergeAddNegateArithmetic());
4496
16.7k
  rules_[spv::Op::OpFAdd].push_back(MergeAddAddArithmetic());
4497
16.7k
  rules_[spv::Op::OpFAdd].push_back(MergeAddSubArithmetic());
4498
16.7k
  rules_[spv::Op::OpFAdd].push_back(MergeGenericAddSubArithmetic());
4499
16.7k
  rules_[spv::Op::OpFAdd].push_back(ReassociateNestedAddSub());
4500
16.7k
  rules_[spv::Op::OpFAdd].push_back(FactorAddSubMuls());
4501
4502
16.7k
  rules_[spv::Op::OpFDiv].push_back(RedundantFDiv());
4503
16.7k
  rules_[spv::Op::OpFDiv].push_back(ReciprocalFDiv());
4504
16.7k
  rules_[spv::Op::OpFDiv].push_back(MergeDivDivArithmetic());
4505
16.7k
  rules_[spv::Op::OpFDiv].push_back(MergeDivMulArithmetic());
4506
16.7k
  rules_[spv::Op::OpFDiv].push_back(MergeDivNegateArithmetic());
4507
16.7k
  rules_[spv::Op::OpFDiv].push_back(MergeDivMulDoubleNegative());
4508
16.7k
  rules_[spv::Op::OpFDiv].push_back(ReassociateNestedMulDivFloat());
4509
4510
16.7k
  rules_[spv::Op::OpFMod].push_back(RedundantFMod());
4511
4512
16.7k
  rules_[spv::Op::OpFMul].push_back(RedundantFMul());
4513
16.7k
  rules_[spv::Op::OpFMul].push_back(MergeMulMulArithmetic());
4514
16.7k
  rules_[spv::Op::OpFMul].push_back(MergeMulDivArithmetic());
4515
16.7k
  rules_[spv::Op::OpFMul].push_back(MergeMulNegateArithmetic());
4516
16.7k
  rules_[spv::Op::OpFMul].push_back(MergeDivMulDoubleNegative());
4517
16.7k
  rules_[spv::Op::OpFMul].push_back(ReassociateNestedMulDivFloat());
4518
4519
16.7k
  rules_[spv::Op::OpVectorTimesScalar].push_back(MergeDivMulDoubleNegative());
4520
4521
16.7k
  rules_[spv::Op::OpFNegate].push_back(MergeNegateArithmetic());
4522
16.7k
  rules_[spv::Op::OpFNegate].push_back(MergeNegateAddSubArithmetic());
4523
16.7k
  rules_[spv::Op::OpFNegate].push_back(MergeNegateMulDivArithmetic());
4524
4525
16.7k
  rules_[spv::Op::OpFSub].push_back(RedundantFSub());
4526
16.7k
  rules_[spv::Op::OpFSub].push_back(MergeSubNegateArithmetic());
4527
16.7k
  rules_[spv::Op::OpFSub].push_back(MergeSubAddArithmetic());
4528
16.7k
  rules_[spv::Op::OpFSub].push_back(MergeSubSubArithmetic());
4529
16.7k
  rules_[spv::Op::OpFSub].push_back(ReassociateNestedAddSub());
4530
16.7k
  rules_[spv::Op::OpFSub].push_back(FactorAddSubMuls());
4531
4532
16.7k
  rules_[spv::Op::OpIAdd].push_back(MergeAddNegateArithmetic());
4533
16.7k
  rules_[spv::Op::OpIAdd].push_back(MergeAddAddArithmetic());
4534
16.7k
  rules_[spv::Op::OpIAdd].push_back(MergeAddSubArithmetic());
4535
16.7k
  rules_[spv::Op::OpIAdd].push_back(MergeGenericAddSubArithmetic());
4536
16.7k
  rules_[spv::Op::OpIAdd].push_back(ReassociateNestedAddSub());
4537
16.7k
  rules_[spv::Op::OpIAdd].push_back(FactorAddSubMuls());
4538
4539
16.7k
  rules_[spv::Op::OpSDiv].push_back(MergeDivMulDoubleNegative());
4540
4541
16.7k
  rules_[spv::Op::OpIMul].push_back(IntMultipleBy1());
4542
16.7k
  rules_[spv::Op::OpIMul].push_back(MergeMulMulArithmetic());
4543
16.7k
  rules_[spv::Op::OpIMul].push_back(MergeMulNegateArithmetic());
4544
16.7k
  rules_[spv::Op::OpIMul].push_back(MergeDivMulDoubleNegative());
4545
4546
16.7k
  rules_[spv::Op::OpISub].push_back(MergeSubNegateArithmetic());
4547
16.7k
  rules_[spv::Op::OpISub].push_back(MergeSubAddArithmetic());
4548
16.7k
  rules_[spv::Op::OpISub].push_back(MergeSubSubArithmetic());
4549
16.7k
  rules_[spv::Op::OpISub].push_back(ReassociateNestedAddSub());
4550
16.7k
  rules_[spv::Op::OpISub].push_back(FactorAddSubMuls());
4551
4552
16.7k
  rules_[spv::Op::OpBitwiseAnd].push_back(RedundantAndOrXor());
4553
16.7k
  rules_[spv::Op::OpBitwiseAnd].push_back(RedundantAndAddSub());
4554
16.7k
  rules_[spv::Op::OpBitwiseAnd].push_back(RedundantAndShift());
4555
4556
16.7k
  rules_[spv::Op::OpPhi].push_back(RedundantPhi());
4557
4558
16.7k
  rules_[spv::Op::OpSNegate].push_back(MergeNegateArithmetic());
4559
16.7k
  rules_[spv::Op::OpSNegate].push_back(MergeNegateMulDivArithmetic());
4560
16.7k
  rules_[spv::Op::OpSNegate].push_back(MergeNegateAddSubArithmetic());
4561
4562
16.7k
  rules_[spv::Op::OpSelect].push_back(RedundantSelect());
4563
16.7k
  rules_[spv::Op::OpSelect].push_back(FoldConstantBooleanSelect());
4564
4565
16.7k
  rules_[spv::Op::OpLogicalAnd].push_back(RedundantLogicalAnd());
4566
4567
16.7k
  rules_[spv::Op::OpLogicalOr].push_back(RedundantLogicalOr());
4568
4569
16.7k
  rules_[spv::Op::OpLogicalNot].push_back(RedundantLogicalNot());
4570
16.7k
  rules_[spv::Op::OpLogicalNot].push_back(FoldLogicalNotComparison());
4571
4572
16.7k
  rules_[spv::Op::OpLogicalEqual].push_back(RedundantLogicalEqual());
4573
16.7k
  rules_[spv::Op::OpLogicalNotEqual].push_back(RedundantLogicalEqual());
4574
4575
16.7k
  rules_[spv::Op::OpStore].push_back(StoringUndef());
4576
4577
16.7k
  rules_[spv::Op::OpVectorShuffle].push_back(VectorShuffleFeedingShuffle());
4578
4579
16.7k
  rules_[spv::Op::OpImageSampleImplicitLod].push_back(UpdateImageOperands());
4580
16.7k
  rules_[spv::Op::OpImageSampleExplicitLod].push_back(UpdateImageOperands());
4581
16.7k
  rules_[spv::Op::OpImageSampleDrefImplicitLod].push_back(
4582
16.7k
      UpdateImageOperands());
4583
16.7k
  rules_[spv::Op::OpImageSampleDrefExplicitLod].push_back(
4584
16.7k
      UpdateImageOperands());
4585
16.7k
  rules_[spv::Op::OpImageSampleProjImplicitLod].push_back(
4586
16.7k
      UpdateImageOperands());
4587
16.7k
  rules_[spv::Op::OpImageSampleProjExplicitLod].push_back(
4588
16.7k
      UpdateImageOperands());
4589
16.7k
  rules_[spv::Op::OpImageSampleProjDrefImplicitLod].push_back(
4590
16.7k
      UpdateImageOperands());
4591
16.7k
  rules_[spv::Op::OpImageSampleProjDrefExplicitLod].push_back(
4592
16.7k
      UpdateImageOperands());
4593
16.7k
  rules_[spv::Op::OpImageFetch].push_back(UpdateImageOperands());
4594
16.7k
  rules_[spv::Op::OpImageGather].push_back(UpdateImageOperands());
4595
16.7k
  rules_[spv::Op::OpImageDrefGather].push_back(UpdateImageOperands());
4596
16.7k
  rules_[spv::Op::OpImageRead].push_back(UpdateImageOperands());
4597
16.7k
  rules_[spv::Op::OpImageWrite].push_back(UpdateImageOperands());
4598
16.7k
  rules_[spv::Op::OpImageSparseSampleImplicitLod].push_back(
4599
16.7k
      UpdateImageOperands());
4600
16.7k
  rules_[spv::Op::OpImageSparseSampleExplicitLod].push_back(
4601
16.7k
      UpdateImageOperands());
4602
16.7k
  rules_[spv::Op::OpImageSparseSampleDrefImplicitLod].push_back(
4603
16.7k
      UpdateImageOperands());
4604
16.7k
  rules_[spv::Op::OpImageSparseSampleDrefExplicitLod].push_back(
4605
16.7k
      UpdateImageOperands());
4606
16.7k
  rules_[spv::Op::OpImageSparseSampleProjImplicitLod].push_back(
4607
16.7k
      UpdateImageOperands());
4608
16.7k
  rules_[spv::Op::OpImageSparseSampleProjExplicitLod].push_back(
4609
16.7k
      UpdateImageOperands());
4610
16.7k
  rules_[spv::Op::OpImageSparseSampleProjDrefImplicitLod].push_back(
4611
16.7k
      UpdateImageOperands());
4612
16.7k
  rules_[spv::Op::OpImageSparseSampleProjDrefExplicitLod].push_back(
4613
16.7k
      UpdateImageOperands());
4614
16.7k
  rules_[spv::Op::OpImageSparseFetch].push_back(UpdateImageOperands());
4615
16.7k
  rules_[spv::Op::OpImageSparseGather].push_back(UpdateImageOperands());
4616
16.7k
  rules_[spv::Op::OpImageSparseDrefGather].push_back(UpdateImageOperands());
4617
16.7k
  rules_[spv::Op::OpImageSparseRead].push_back(UpdateImageOperands());
4618
4619
16.7k
  FeatureManager* feature_manager = context_->get_feature_mgr();
4620
  // Add rules for GLSLstd450
4621
16.7k
  uint32_t ext_inst_glslstd450_id =
4622
16.7k
      feature_manager->GetExtInstImportId_GLSLstd450();
4623
16.7k
  if (ext_inst_glslstd450_id != 0) {
4624
10.0k
    ext_rules_[{ext_inst_glslstd450_id, GLSLstd450FMix}].push_back(
4625
10.0k
        RedundantFMix());
4626
10.0k
  }
4627
16.7k
}
4628
}  // namespace opt
4629
}  // namespace spvtools