Coverage Report

Created: 2024-09-11 07:09

/src/spirv-tools/source/opt/folding_rules.cpp
Line
Count
Source (jump to first uncovered line)
1
// Copyright (c) 2018 Google LLC
2
//
3
// Licensed under the Apache License, Version 2.0 (the "License");
4
// you may not use this file except in compliance with the License.
5
// You may obtain a copy of the License at
6
//
7
//     http://www.apache.org/licenses/LICENSE-2.0
8
//
9
// Unless required by applicable law or agreed to in writing, software
10
// distributed under the License is distributed on an "AS IS" BASIS,
11
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
// See the License for the specific language governing permissions and
13
// limitations under the License.
14
15
#include "source/opt/folding_rules.h"
16
17
#include <limits>
18
#include <memory>
19
#include <utility>
20
21
#include "ir_builder.h"
22
#include "source/latest_version_glsl_std_450_header.h"
23
#include "source/opt/ir_context.h"
24
25
namespace spvtools {
26
namespace opt {
27
namespace {
28
29
constexpr uint32_t kExtractCompositeIdInIdx = 0;
30
constexpr uint32_t kInsertObjectIdInIdx = 0;
31
constexpr uint32_t kInsertCompositeIdInIdx = 1;
32
constexpr uint32_t kExtInstSetIdInIdx = 0;
33
constexpr uint32_t kExtInstInstructionInIdx = 1;
34
constexpr uint32_t kFMixXIdInIdx = 2;
35
constexpr uint32_t kFMixYIdInIdx = 3;
36
constexpr uint32_t kFMixAIdInIdx = 4;
37
constexpr uint32_t kStoreObjectInIdx = 1;
38
39
// Some image instructions may contain an "image operands" argument.
40
// Returns the operand index for the "image operands".
41
// Returns -1 if the instruction does not have image operands.
42
638k
int32_t ImageOperandsMaskInOperandIndex(Instruction* inst) {
43
638k
  const auto opcode = inst->opcode();
44
638k
  switch (opcode) {
45
638k
    case spv::Op::OpImageSampleImplicitLod:
46
638k
    case spv::Op::OpImageSampleExplicitLod:
47
638k
    case spv::Op::OpImageSampleProjImplicitLod:
48
638k
    case spv::Op::OpImageSampleProjExplicitLod:
49
638k
    case spv::Op::OpImageFetch:
50
638k
    case spv::Op::OpImageRead:
51
638k
    case spv::Op::OpImageSparseSampleImplicitLod:
52
638k
    case spv::Op::OpImageSparseSampleExplicitLod:
53
638k
    case spv::Op::OpImageSparseSampleProjImplicitLod:
54
638k
    case spv::Op::OpImageSparseSampleProjExplicitLod:
55
638k
    case spv::Op::OpImageSparseFetch:
56
638k
    case spv::Op::OpImageSparseRead:
57
638k
      return inst->NumOperands() > 4 ? 2 : -1;
58
0
    case spv::Op::OpImageSampleDrefImplicitLod:
59
0
    case spv::Op::OpImageSampleDrefExplicitLod:
60
0
    case spv::Op::OpImageSampleProjDrefImplicitLod:
61
0
    case spv::Op::OpImageSampleProjDrefExplicitLod:
62
0
    case spv::Op::OpImageGather:
63
0
    case spv::Op::OpImageDrefGather:
64
0
    case spv::Op::OpImageSparseSampleDrefImplicitLod:
65
0
    case spv::Op::OpImageSparseSampleDrefExplicitLod:
66
0
    case spv::Op::OpImageSparseSampleProjDrefImplicitLod:
67
0
    case spv::Op::OpImageSparseSampleProjDrefExplicitLod:
68
0
    case spv::Op::OpImageSparseGather:
69
0
    case spv::Op::OpImageSparseDrefGather:
70
0
      return inst->NumOperands() > 5 ? 3 : -1;
71
0
    case spv::Op::OpImageWrite:
72
0
      return inst->NumOperands() > 3 ? 3 : -1;
73
0
    default:
74
0
      return -1;
75
638k
  }
76
638k
}
77
78
// Returns the element width of |type|.
79
10.7M
uint32_t ElementWidth(const analysis::Type* type) {
80
10.7M
  if (const analysis::Vector* vec_type = type->AsVector()) {
81
3.85M
    return ElementWidth(vec_type->element_type());
82
6.89M
  } else if (const analysis::Float* float_type = type->AsFloat()) {
83
6.64M
    return float_type->width();
84
6.64M
  } else {
85
251k
    assert(type->AsInteger());
86
251k
    return type->AsInteger()->width();
87
251k
  }
88
10.7M
}
89
90
// Returns true if |type| is Float or a vector of Float.
91
12.3M
bool HasFloatingPoint(const analysis::Type* type) {
92
12.3M
  if (type->AsFloat()) {
93
5.84M
    return true;
94
6.49M
  } else if (const analysis::Vector* vec_type = type->AsVector()) {
95
6.13M
    return vec_type->element_type()->AsFloat() != nullptr;
96
6.13M
  }
97
98
354k
  return false;
99
12.3M
}
100
101
// Returns false if |val| is NaN, infinite or subnormal.
102
template <typename T>
103
133k
bool IsValidResult(T val) {
104
133k
  int classified = std::fpclassify(val);
105
133k
  switch (classified) {
106
18.8k
    case FP_NAN:
107
37.2k
    case FP_INFINITE:
108
41.4k
    case FP_SUBNORMAL:
109
41.4k
      return false;
110
92.5k
    default:
111
92.5k
      return true;
112
133k
  }
113
133k
}
Unexecuted instantiation: folding_rules.cpp:bool spvtools::opt::(anonymous namespace)::IsValidResult<double>(double)
folding_rules.cpp:bool spvtools::opt::(anonymous namespace)::IsValidResult<float>(float)
Line
Count
Source
103
133k
bool IsValidResult(T val) {
104
133k
  int classified = std::fpclassify(val);
105
133k
  switch (classified) {
106
18.8k
    case FP_NAN:
107
37.2k
    case FP_INFINITE:
108
41.4k
    case FP_SUBNORMAL:
109
41.4k
      return false;
110
92.5k
    default:
111
92.5k
      return true;
112
133k
  }
113
133k
}
114
115
// Returns true if `type` is a cooperative matrix.
116
6.88M
bool IsCooperativeMatrix(const analysis::Type* type) {
117
6.88M
  return type->kind() == analysis::Type::kCooperativeMatrixKHR ||
118
6.88M
         type->kind() == analysis::Type::kCooperativeMatrixNV;
119
6.88M
}
120
121
const analysis::Constant* ConstInput(
122
6.83M
    const std::vector<const analysis::Constant*>& constants) {
123
6.83M
  return constants[0] ? constants[0] : constants[1];
124
6.83M
}
125
126
Instruction* NonConstInput(IRContext* context, const analysis::Constant* c,
127
1.97M
                           Instruction* inst) {
128
1.97M
  uint32_t in_op = c ? 1u : 0u;
129
1.97M
  return context->get_def_use_mgr()->GetDef(
130
1.97M
      inst->GetSingleWordInOperand(in_op));
131
1.97M
}
132
133
0
std::vector<uint32_t> ExtractInts(uint64_t val) {
134
0
  std::vector<uint32_t> words;
135
0
  words.push_back(static_cast<uint32_t>(val));
136
0
  words.push_back(static_cast<uint32_t>(val >> 32));
137
0
  return words;
138
0
}
139
140
std::vector<uint32_t> GetWordsFromScalarIntConstant(
141
4.50k
    const analysis::IntConstant* c) {
142
4.50k
  assert(c != nullptr);
143
4.50k
  uint32_t width = c->type()->AsInteger()->width();
144
4.50k
  assert(width == 8 || width == 16 || width == 32 || width == 64);
145
4.50k
  if (width == 64) {
146
0
    uint64_t uval = static_cast<uint64_t>(c->GetU64());
147
0
    return ExtractInts(uval);
148
0
  }
149
  // Section 2.2.1 of the SPIR-V spec guarantees that all integer types
150
  // smaller than 32-bits are automatically zero or sign extended to 32-bits.
151
4.50k
  return {c->GetU32BitValue()};
152
4.50k
}
153
154
std::vector<uint32_t> GetWordsFromScalarFloatConstant(
155
12
    const analysis::FloatConstant* c) {
156
12
  assert(c != nullptr);
157
12
  uint32_t width = c->type()->AsFloat()->width();
158
12
  assert(width == 16 || width == 32 || width == 64);
159
12
  if (width == 64) {
160
0
    utils::FloatProxy<double> result(c->GetDouble());
161
0
    return result.GetWords();
162
0
  }
163
  // Section 2.2.1 of the SPIR-V spec guarantees that all floating-point types
164
  // smaller than 32-bits are automatically zero extended to 32-bits.
165
12
  return {c->GetU32BitValue()};
166
12
}
167
168
std::vector<uint32_t> GetWordsFromNumericScalarOrVectorConstant(
169
4.52k
    analysis::ConstantManager* const_mgr, const analysis::Constant* c) {
170
4.52k
  if (const auto* float_constant = c->AsFloatConstant()) {
171
12
    return GetWordsFromScalarFloatConstant(float_constant);
172
4.51k
  } else if (const auto* int_constant = c->AsIntConstant()) {
173
4.50k
    return GetWordsFromScalarIntConstant(int_constant);
174
4.50k
  } else if (const auto* vec_constant = c->AsVectorConstant()) {
175
3
    std::vector<uint32_t> words;
176
12
    for (const auto* comp : vec_constant->GetComponents()) {
177
12
      auto comp_in_words =
178
12
          GetWordsFromNumericScalarOrVectorConstant(const_mgr, comp);
179
12
      words.insert(words.end(), comp_in_words.begin(), comp_in_words.end());
180
12
    }
181
3
    return words;
182
3
  }
183
8
  return {};
184
4.52k
}
185
186
const analysis::Constant* ConvertWordsToNumericScalarOrVectorConstant(
187
    analysis::ConstantManager* const_mgr, const std::vector<uint32_t>& words,
188
4.51k
    const analysis::Type* type) {
189
4.51k
  const spvtools::opt::analysis::Integer* int_type = type->AsInteger();
190
191
4.51k
  if (int_type && int_type->width() <= 32) {
192
225
    assert(words.size() == 1);
193
225
    return const_mgr->GenerateIntegerConstant(int_type, words[0]);
194
225
  }
195
196
4.29k
  if (int_type || type->AsFloat()) return const_mgr->GetConstant(type, words);
197
3.25k
  if (const auto* vec_type = type->AsVector())
198
3
    return const_mgr->GetNumericVectorConstantWithWords(vec_type, words);
199
3.24k
  return nullptr;
200
3.25k
}
201
202
// Returns the negation of |c|. |c| must be a 32 or 64 bit floating point
203
// constant.
204
uint32_t NegateFloatingPointConstant(analysis::ConstantManager* const_mgr,
205
45
                                     const analysis::Constant* c) {
206
45
  assert(c);
207
45
  assert(c->type()->AsFloat());
208
45
  uint32_t width = c->type()->AsFloat()->width();
209
45
  assert(width == 32 || width == 64);
210
45
  std::vector<uint32_t> words;
211
45
  if (width == 64) {
212
0
    utils::FloatProxy<double> result(c->GetDouble() * -1.0);
213
0
    words = result.GetWords();
214
45
  } else {
215
45
    utils::FloatProxy<float> result(c->GetFloat() * -1.0f);
216
45
    words = result.GetWords();
217
45
  }
218
219
45
  const analysis::Constant* negated_const =
220
45
      const_mgr->GetConstant(c->type(), std::move(words));
221
45
  return const_mgr->GetDefiningInstruction(negated_const)->result_id();
222
45
}
223
224
// Negates the integer constant |c|. Returns the id of the defining instruction.
225
uint32_t NegateIntegerConstant(analysis::ConstantManager* const_mgr,
226
31
                               const analysis::Constant* c) {
227
31
  assert(c);
228
31
  assert(c->type()->AsInteger());
229
31
  uint32_t width = c->type()->AsInteger()->width();
230
31
  assert(width == 32 || width == 64);
231
31
  std::vector<uint32_t> words;
232
31
  if (width == 64) {
233
0
    uint64_t uval = static_cast<uint64_t>(0 - c->GetU64());
234
0
    words = ExtractInts(uval);
235
31
  } else {
236
31
    words.push_back(static_cast<uint32_t>(0 - c->GetU32()));
237
31
  }
238
239
31
  const analysis::Constant* negated_const =
240
31
      const_mgr->GetConstant(c->type(), std::move(words));
241
31
  return const_mgr->GetDefiningInstruction(negated_const)->result_id();
242
31
}
243
244
// Negates the vector constant |c|. Returns the id of the defining instruction.
245
uint32_t NegateVectorConstant(analysis::ConstantManager* const_mgr,
246
0
                              const analysis::Constant* c) {
247
0
  assert(const_mgr && c);
248
0
  assert(c->type()->AsVector());
249
0
  if (c->AsNullConstant()) {
250
    // 0.0 vs -0.0 shouldn't matter.
251
0
    return const_mgr->GetDefiningInstruction(c)->result_id();
252
0
  } else {
253
0
    const analysis::Type* component_type =
254
0
        c->AsVectorConstant()->component_type();
255
0
    std::vector<uint32_t> words;
256
0
    for (auto& comp : c->AsVectorConstant()->GetComponents()) {
257
0
      if (component_type->AsFloat()) {
258
0
        words.push_back(NegateFloatingPointConstant(const_mgr, comp));
259
0
      } else {
260
0
        assert(component_type->AsInteger());
261
0
        words.push_back(NegateIntegerConstant(const_mgr, comp));
262
0
      }
263
0
    }
264
265
0
    const analysis::Constant* negated_const =
266
0
        const_mgr->GetConstant(c->type(), std::move(words));
267
0
    return const_mgr->GetDefiningInstruction(negated_const)->result_id();
268
0
  }
269
0
}
270
271
// Negates |c|. Returns the id of the defining instruction.
272
uint32_t NegateConstant(analysis::ConstantManager* const_mgr,
273
76
                        const analysis::Constant* c) {
274
76
  if (c->type()->AsVector()) {
275
0
    return NegateVectorConstant(const_mgr, c);
276
76
  } else if (c->type()->AsFloat()) {
277
45
    return NegateFloatingPointConstant(const_mgr, c);
278
45
  } else {
279
31
    assert(c->type()->AsInteger());
280
31
    return NegateIntegerConstant(const_mgr, c);
281
31
  }
282
76
}
283
284
// Takes the reciprocal of |c|. |c|'s type must be Float or a vector of Float.
285
// Returns 0 if the reciprocal is NaN, infinite or subnormal.
286
uint32_t Reciprocal(analysis::ConstantManager* const_mgr,
287
137k
                    const analysis::Constant* c) {
288
137k
  assert(const_mgr && c);
289
137k
  assert(c->type()->AsFloat());
290
291
137k
  uint32_t width = c->type()->AsFloat()->width();
292
137k
  assert(width == 32 || width == 64);
293
137k
  std::vector<uint32_t> words;
294
295
137k
  if (c->IsZero()) {
296
16.8k
    return 0;
297
16.8k
  }
298
299
120k
  if (width == 64) {
300
0
    spvtools::utils::FloatProxy<double> result(1.0 / c->GetDouble());
301
0
    if (!IsValidResult(result.getAsFloat())) return 0;
302
0
    words = result.GetWords();
303
120k
  } else {
304
120k
    spvtools::utils::FloatProxy<float> result(1.0f / c->GetFloat());
305
120k
    if (!IsValidResult(result.getAsFloat())) return 0;
306
83.0k
    words = result.GetWords();
307
83.0k
  }
308
309
83.0k
  const analysis::Constant* negated_const =
310
83.0k
      const_mgr->GetConstant(c->type(), std::move(words));
311
83.0k
  return const_mgr->GetDefiningInstruction(negated_const)->result_id();
312
120k
}
313
314
// Replaces fdiv where second operand is constant with fmul.
315
8.16k
FoldingRule ReciprocalFDiv() {
316
8.16k
  return [](IRContext* context, Instruction* inst,
317
108k
            const std::vector<const analysis::Constant*>& constants) {
318
108k
    assert(inst->opcode() == spv::Op::OpFDiv);
319
108k
    analysis::ConstantManager* const_mgr = context->get_constant_mgr();
320
108k
    const analysis::Type* type =
321
108k
        context->get_type_mgr()->GetType(inst->type_id());
322
323
108k
    if (IsCooperativeMatrix(type)) {
324
0
      return false;
325
0
    }
326
327
108k
    if (!inst->IsFloatingPointFoldingAllowed()) return false;
328
329
108k
    uint32_t width = ElementWidth(type);
330
108k
    if (width != 32 && width != 64) return false;
331
332
108k
    if (constants[1] != nullptr) {
333
87.3k
      uint32_t id = 0;
334
87.3k
      if (const analysis::VectorConstant* vector_const =
335
87.3k
              constants[1]->AsVectorConstant()) {
336
83.6k
        std::vector<uint32_t> neg_ids;
337
133k
        for (auto& comp : vector_const->GetComponents()) {
338
133k
          id = Reciprocal(const_mgr, comp);
339
133k
          if (id == 0) return false;
340
81.5k
          neg_ids.push_back(id);
341
81.5k
        }
342
31.2k
        const analysis::Constant* negated_const =
343
31.2k
            const_mgr->GetConstant(constants[1]->type(), std::move(neg_ids));
344
31.2k
        id = const_mgr->GetDefiningInstruction(negated_const)->result_id();
345
31.2k
      } else if (constants[1]->AsFloatConstant()) {
346
3.74k
        id = Reciprocal(const_mgr, constants[1]);
347
3.74k
        if (id == 0) return false;
348
3.74k
      } else {
349
        // Don't fold a null constant.
350
0
        return false;
351
0
      }
352
32.7k
      inst->SetOpcode(spv::Op::OpFMul);
353
32.7k
      inst->SetInOperands(
354
32.7k
          {{SPV_OPERAND_TYPE_ID, {inst->GetSingleWordInOperand(0u)}},
355
32.7k
           {SPV_OPERAND_TYPE_ID, {id}}});
356
32.7k
      return true;
357
87.3k
    }
358
359
21.5k
    return false;
360
108k
  };
361
8.16k
}
362
363
// Elides consecutive negate instructions.
364
16.3k
FoldingRule MergeNegateArithmetic() {
365
16.3k
  return [](IRContext* context, Instruction* inst,
366
24.7k
            const std::vector<const analysis::Constant*>& constants) {
367
24.7k
    assert(inst->opcode() == spv::Op::OpFNegate ||
368
24.7k
           inst->opcode() == spv::Op::OpSNegate);
369
24.7k
    (void)constants;
370
24.7k
    const analysis::Type* type =
371
24.7k
        context->get_type_mgr()->GetType(inst->type_id());
372
24.7k
    if (HasFloatingPoint(type) && !inst->IsFloatingPointFoldingAllowed())
373
0
      return false;
374
375
24.7k
    Instruction* op_inst =
376
24.7k
        context->get_def_use_mgr()->GetDef(inst->GetSingleWordInOperand(0u));
377
24.7k
    if (HasFloatingPoint(type) && !op_inst->IsFloatingPointFoldingAllowed())
378
0
      return false;
379
380
24.7k
    if (op_inst->opcode() == inst->opcode()) {
381
      // Elide negates.
382
0
      inst->SetOpcode(spv::Op::OpCopyObject);
383
0
      inst->SetInOperands(
384
0
          {{SPV_OPERAND_TYPE_ID, {op_inst->GetSingleWordInOperand(0u)}}});
385
0
      return true;
386
0
    }
387
388
24.7k
    return false;
389
24.7k
  };
390
16.3k
}
391
392
// Merges negate into a mul or div operation if that operation contains a
393
// constant operand.
394
// Cases:
395
// -(x * 2) = x * -2
396
// -(2 * x) = x * -2
397
// -(x / 2) = x / -2
398
// -(2 / x) = -2 / x
399
16.3k
FoldingRule MergeNegateMulDivArithmetic() {
400
16.3k
  return [](IRContext* context, Instruction* inst,
401
24.6k
            const std::vector<const analysis::Constant*>& constants) {
402
24.6k
    assert(inst->opcode() == spv::Op::OpFNegate ||
403
24.6k
           inst->opcode() == spv::Op::OpSNegate);
404
24.6k
    (void)constants;
405
24.6k
    analysis::ConstantManager* const_mgr = context->get_constant_mgr();
406
24.6k
    const analysis::Type* type =
407
24.6k
        context->get_type_mgr()->GetType(inst->type_id());
408
409
24.6k
    if (IsCooperativeMatrix(type)) {
410
0
      return false;
411
0
    }
412
413
24.6k
    if (HasFloatingPoint(type) && !inst->IsFloatingPointFoldingAllowed())
414
0
      return false;
415
416
24.6k
    Instruction* op_inst =
417
24.6k
        context->get_def_use_mgr()->GetDef(inst->GetSingleWordInOperand(0u));
418
24.6k
    if (HasFloatingPoint(type) && !op_inst->IsFloatingPointFoldingAllowed())
419
0
      return false;
420
421
24.6k
    uint32_t width = ElementWidth(type);
422
24.6k
    if (width != 32 && width != 64) return false;
423
424
24.6k
    spv::Op opcode = op_inst->opcode();
425
24.6k
    if (opcode == spv::Op::OpFMul || opcode == spv::Op::OpFDiv ||
426
24.6k
        opcode == spv::Op::OpIMul || opcode == spv::Op::OpSDiv ||
427
24.6k
        opcode == spv::Op::OpUDiv) {
428
92
      std::vector<const analysis::Constant*> op_constants =
429
92
          const_mgr->GetOperandConstants(op_inst);
430
      // Merge negate into mul or div if one operand is constant.
431
92
      if (op_constants[0] || op_constants[1]) {
432
53
        bool zero_is_variable = op_constants[0] == nullptr;
433
53
        const analysis::Constant* c = ConstInput(op_constants);
434
53
        uint32_t neg_id = NegateConstant(const_mgr, c);
435
53
        uint32_t non_const_id = zero_is_variable
436
53
                                    ? op_inst->GetSingleWordInOperand(0u)
437
53
                                    : op_inst->GetSingleWordInOperand(1u);
438
        // Change this instruction to a mul/div.
439
53
        inst->SetOpcode(op_inst->opcode());
440
53
        if (opcode == spv::Op::OpFDiv || opcode == spv::Op::OpUDiv ||
441
53
            opcode == spv::Op::OpSDiv) {
442
4
          uint32_t op0 = zero_is_variable ? non_const_id : neg_id;
443
4
          uint32_t op1 = zero_is_variable ? neg_id : non_const_id;
444
4
          inst->SetInOperands(
445
4
              {{SPV_OPERAND_TYPE_ID, {op0}}, {SPV_OPERAND_TYPE_ID, {op1}}});
446
49
        } else {
447
49
          inst->SetInOperands({{SPV_OPERAND_TYPE_ID, {non_const_id}},
448
49
                               {SPV_OPERAND_TYPE_ID, {neg_id}}});
449
49
        }
450
53
        return true;
451
53
      }
452
92
    }
453
454
24.6k
    return false;
455
24.6k
  };
456
16.3k
}
457
458
// Merges negate into a add or sub operation if that operation contains a
459
// constant operand.
460
// Cases:
461
// -(x + 2) = -2 - x
462
// -(2 + x) = -2 - x
463
// -(x - 2) = 2 - x
464
// -(2 - x) = x - 2
465
16.3k
FoldingRule MergeNegateAddSubArithmetic() {
466
16.3k
  return [](IRContext* context, Instruction* inst,
467
24.7k
            const std::vector<const analysis::Constant*>& constants) {
468
24.7k
    assert(inst->opcode() == spv::Op::OpFNegate ||
469
24.7k
           inst->opcode() == spv::Op::OpSNegate);
470
24.7k
    (void)constants;
471
24.7k
    analysis::ConstantManager* const_mgr = context->get_constant_mgr();
472
24.7k
    const analysis::Type* type =
473
24.7k
        context->get_type_mgr()->GetType(inst->type_id());
474
475
24.7k
    if (IsCooperativeMatrix(type)) {
476
0
      return false;
477
0
    }
478
479
24.7k
    if (HasFloatingPoint(type) && !inst->IsFloatingPointFoldingAllowed())
480
0
      return false;
481
482
24.7k
    Instruction* op_inst =
483
24.7k
        context->get_def_use_mgr()->GetDef(inst->GetSingleWordInOperand(0u));
484
24.7k
    if (HasFloatingPoint(type) && !op_inst->IsFloatingPointFoldingAllowed())
485
0
      return false;
486
487
24.7k
    uint32_t width = ElementWidth(type);
488
24.7k
    if (width != 32 && width != 64) return false;
489
490
24.7k
    if (op_inst->opcode() == spv::Op::OpFAdd ||
491
24.7k
        op_inst->opcode() == spv::Op::OpFSub ||
492
24.7k
        op_inst->opcode() == spv::Op::OpIAdd ||
493
24.7k
        op_inst->opcode() == spv::Op::OpISub) {
494
115
      std::vector<const analysis::Constant*> op_constants =
495
115
          const_mgr->GetOperandConstants(op_inst);
496
115
      if (op_constants[0] || op_constants[1]) {
497
47
        bool zero_is_variable = op_constants[0] == nullptr;
498
47
        bool is_add = (op_inst->opcode() == spv::Op::OpFAdd) ||
499
47
                      (op_inst->opcode() == spv::Op::OpIAdd);
500
47
        bool swap_operands = !is_add || zero_is_variable;
501
47
        bool negate_const = is_add;
502
47
        const analysis::Constant* c = ConstInput(op_constants);
503
47
        uint32_t const_id = 0;
504
47
        if (negate_const) {
505
18
          const_id = NegateConstant(const_mgr, c);
506
29
        } else {
507
29
          const_id = zero_is_variable ? op_inst->GetSingleWordInOperand(1u)
508
29
                                      : op_inst->GetSingleWordInOperand(0u);
509
29
        }
510
511
        // Swap operands if necessary and make the instruction a subtraction.
512
47
        uint32_t op0 =
513
47
            zero_is_variable ? op_inst->GetSingleWordInOperand(0u) : const_id;
514
47
        uint32_t op1 =
515
47
            zero_is_variable ? const_id : op_inst->GetSingleWordInOperand(1u);
516
47
        if (swap_operands) std::swap(op0, op1);
517
47
        inst->SetOpcode(HasFloatingPoint(type) ? spv::Op::OpFSub
518
47
                                               : spv::Op::OpISub);
519
47
        inst->SetInOperands(
520
47
            {{SPV_OPERAND_TYPE_ID, {op0}}, {SPV_OPERAND_TYPE_ID, {op1}}});
521
47
        return true;
522
47
      }
523
115
    }
524
525
24.6k
    return false;
526
24.7k
  };
527
16.3k
}
528
529
// Returns true if |c| has a zero element.
530
290k
bool HasZero(const analysis::Constant* c) {
531
290k
  if (c->AsNullConstant()) {
532
0
    return true;
533
0
  }
534
290k
  if (const analysis::VectorConstant* vec_const = c->AsVectorConstant()) {
535
105k
    for (auto& comp : vec_const->GetComponents())
536
178k
      if (HasZero(comp)) return true;
537
185k
  } else {
538
185k
    assert(c->AsScalarConstant());
539
185k
    return c->AsScalarConstant()->IsZero();
540
185k
  }
541
542
69.7k
  return false;
543
290k
}
544
545
// Performs |input1| |opcode| |input2| and returns the merged constant result
546
// id. Returns 0 if the result is not a valid value. The input types must be
547
// Float.
548
uint32_t PerformFloatingPointOperation(analysis::ConstantManager* const_mgr,
549
                                       spv::Op opcode,
550
                                       const analysis::Constant* input1,
551
13.1k
                                       const analysis::Constant* input2) {
552
13.1k
  const analysis::Type* type = input1->type();
553
13.1k
  assert(type->AsFloat());
554
13.1k
  uint32_t width = type->AsFloat()->width();
555
13.1k
  assert(width == 32 || width == 64);
556
13.1k
  std::vector<uint32_t> words;
557
13.1k
#define FOLD_OP(op)                                                          \
558
13.1k
  if (width == 64) {                                                         \
559
0
    utils::FloatProxy<double> val =                                          \
560
0
        input1->GetDouble() op input2->GetDouble();                          \
561
0
    double dval = val.getAsFloat();                                          \
562
0
    if (!IsValidResult(dval)) return 0;                                      \
563
0
    words = val.GetWords();                                                  \
564
13.1k
  } else {                                                                   \
565
13.1k
    utils::FloatProxy<float> val = input1->GetFloat() op input2->GetFloat(); \
566
13.1k
    float fval = val.getAsFloat();                                           \
567
13.1k
    if (!IsValidResult(fval)) return 0;                                      \
568
13.1k
    words = val.GetWords();                                                  \
569
9.51k
  }                                                                          \
570
13.1k
  static_assert(true, "require extra semicolon")
571
13.1k
  switch (opcode) {
572
2.58k
    case spv::Op::OpFMul:
573
2.58k
      FOLD_OP(*);
574
1.18k
      break;
575
69
    case spv::Op::OpFDiv:
576
69
      if (HasZero(input2)) return 0;
577
69
      FOLD_OP(/);
578
36
      break;
579
8.07k
    case spv::Op::OpFAdd:
580
8.07k
      FOLD_OP(+);
581
6.02k
      break;
582
2.42k
    case spv::Op::OpFSub:
583
2.42k
      FOLD_OP(-);
584
2.27k
      break;
585
0
    default:
586
0
      assert(false && "Unexpected operation");
587
0
      break;
588
13.1k
  }
589
9.51k
#undef FOLD_OP
590
9.51k
  const analysis::Constant* merged_const = const_mgr->GetConstant(type, words);
591
9.51k
  return const_mgr->GetDefiningInstruction(merged_const)->result_id();
592
13.1k
}
593
594
// Performs |input1| |opcode| |input2| and returns the merged constant result
595
// id. Returns 0 if the result is not a valid value. The input types must be
596
// Integers.
597
uint32_t PerformIntegerOperation(analysis::ConstantManager* const_mgr,
598
                                 spv::Op opcode,
599
                                 const analysis::Constant* input1,
600
3.03k
                                 const analysis::Constant* input2) {
601
3.03k
  assert(input1->type()->AsInteger());
602
3.03k
  const analysis::Integer* type = input1->type()->AsInteger();
603
3.03k
  uint32_t width = type->AsInteger()->width();
604
3.03k
  assert(width == 32 || width == 64);
605
3.03k
  std::vector<uint32_t> words;
606
  // Regardless of the sign of the constant, folding is performed on an unsigned
607
  // interpretation of the constant data. This avoids signed integer overflow
608
  // while folding, and works because sign is irrelevant for the IAdd, ISub and
609
  // IMul instructions.
610
3.03k
#define FOLD_OP(op)                                      \
611
3.03k
  if (width == 64) {                                     \
612
0
    uint64_t val = input1->GetU64() op input2->GetU64(); \
613
0
    words = ExtractInts(val);                            \
614
3.03k
  } else {                                               \
615
3.03k
    uint32_t val = input1->GetU32() op input2->GetU32(); \
616
3.03k
    words.push_back(val);                                \
617
3.03k
  }                                                      \
618
3.03k
  static_assert(true, "require extra semicolon")
619
3.03k
  switch (opcode) {
620
110
    case spv::Op::OpIMul:
621
110
      FOLD_OP(*);
622
110
      break;
623
0
    case spv::Op::OpSDiv:
624
0
    case spv::Op::OpUDiv:
625
0
      assert(false && "Should not merge integer division");
626
0
      break;
627
733
    case spv::Op::OpIAdd:
628
733
      FOLD_OP(+);
629
733
      break;
630
2.19k
    case spv::Op::OpISub:
631
2.19k
      FOLD_OP(-);
632
2.19k
      break;
633
0
    default:
634
0
      assert(false && "Unexpected operation");
635
0
      break;
636
3.03k
  }
637
3.03k
#undef FOLD_OP
638
3.03k
  const analysis::Constant* merged_const = const_mgr->GetConstant(type, words);
639
3.03k
  return const_mgr->GetDefiningInstruction(merged_const)->result_id();
640
3.03k
}
641
642
// Performs |input1| |opcode| |input2| and returns the merged constant result
643
// id. Returns 0 if the result is not a valid value. The input types must be
644
// Integers, Floats or Vectors of such.
645
uint32_t PerformOperation(analysis::ConstantManager* const_mgr, spv::Op opcode,
646
                          const analysis::Constant* input1,
647
13.0k
                          const analysis::Constant* input2) {
648
13.0k
  assert(input1 && input2);
649
13.0k
  const analysis::Type* type = input1->type();
650
13.0k
  std::vector<uint32_t> words;
651
13.0k
  if (const analysis::Vector* vector_type = type->AsVector()) {
652
4.56k
    const analysis::Type* ele_type = vector_type->element_type();
653
10.4k
    for (uint32_t i = 0; i != vector_type->element_count(); ++i) {
654
7.71k
      uint32_t id = 0;
655
656
7.71k
      const analysis::Constant* input1_comp = nullptr;
657
7.71k
      if (const analysis::VectorConstant* input1_vector =
658
7.71k
              input1->AsVectorConstant()) {
659
7.71k
        input1_comp = input1_vector->GetComponents()[i];
660
7.71k
      } else {
661
0
        assert(input1->AsNullConstant());
662
0
        input1_comp = const_mgr->GetConstant(ele_type, {});
663
0
      }
664
665
7.71k
      const analysis::Constant* input2_comp = nullptr;
666
7.71k
      if (const analysis::VectorConstant* input2_vector =
667
7.71k
              input2->AsVectorConstant()) {
668
7.71k
        input2_comp = input2_vector->GetComponents()[i];
669
7.71k
      } else {
670
0
        assert(input2->AsNullConstant());
671
0
        input2_comp = const_mgr->GetConstant(ele_type, {});
672
0
      }
673
674
7.71k
      if (ele_type->AsFloat()) {
675
7.71k
        id = PerformFloatingPointOperation(const_mgr, opcode, input1_comp,
676
7.71k
                                           input2_comp);
677
7.71k
      } else {
678
0
        assert(ele_type->AsInteger());
679
0
        id = PerformIntegerOperation(const_mgr, opcode, input1_comp,
680
0
                                     input2_comp);
681
0
      }
682
7.71k
      if (id == 0) return 0;
683
5.89k
      words.push_back(id);
684
5.89k
    }
685
2.73k
    const analysis::Constant* merged_const =
686
2.73k
        const_mgr->GetConstant(type, words);
687
2.73k
    return const_mgr->GetDefiningInstruction(merged_const)->result_id();
688
8.47k
  } else if (type->AsFloat()) {
689
5.43k
    return PerformFloatingPointOperation(const_mgr, opcode, input1, input2);
690
5.43k
  } else {
691
3.03k
    assert(type->AsInteger());
692
3.03k
    return PerformIntegerOperation(const_mgr, opcode, input1, input2);
693
3.03k
  }
694
13.0k
}
695
696
// Merges consecutive multiplies where each contains one constant operand.
697
// Cases:
698
// 2 * (x * 2) = x * 4
699
// 2 * (2 * x) = x * 4
700
// (x * 2) * 2 = x * 4
701
// (2 * x) * 2 = x * 4
702
16.3k
FoldingRule MergeMulMulArithmetic() {
703
16.3k
  return [](IRContext* context, Instruction* inst,
704
180k
            const std::vector<const analysis::Constant*>& constants) {
705
180k
    assert(inst->opcode() == spv::Op::OpFMul ||
706
180k
           inst->opcode() == spv::Op::OpIMul);
707
180k
    analysis::ConstantManager* const_mgr = context->get_constant_mgr();
708
180k
    const analysis::Type* type =
709
180k
        context->get_type_mgr()->GetType(inst->type_id());
710
711
180k
    if (IsCooperativeMatrix(type)) {
712
0
      return false;
713
0
    }
714
715
180k
    if (HasFloatingPoint(type) && !inst->IsFloatingPointFoldingAllowed())
716
1
      return false;
717
718
180k
    uint32_t width = ElementWidth(type);
719
180k
    if (width != 32 && width != 64) return false;
720
721
    // Determine the constant input and the variable input in |inst|.
722
180k
    const analysis::Constant* const_input1 = ConstInput(constants);
723
180k
    if (!const_input1) return false;
724
138k
    Instruction* other_inst = NonConstInput(context, constants[0], inst);
725
138k
    if (HasFloatingPoint(type) && !other_inst->IsFloatingPointFoldingAllowed())
726
6
      return false;
727
728
137k
    if (other_inst->opcode() == inst->opcode()) {
729
2.89k
      std::vector<const analysis::Constant*> other_constants =
730
2.89k
          const_mgr->GetOperandConstants(other_inst);
731
2.89k
      const analysis::Constant* const_input2 = ConstInput(other_constants);
732
2.89k
      if (!const_input2) return false;
733
734
2.25k
      bool other_first_is_variable = other_constants[0] == nullptr;
735
2.25k
      uint32_t merged_id = PerformOperation(const_mgr, inst->opcode(),
736
2.25k
                                            const_input1, const_input2);
737
2.25k
      if (merged_id == 0) return false;
738
739
1.07k
      uint32_t non_const_id = other_first_is_variable
740
1.07k
                                  ? other_inst->GetSingleWordInOperand(0u)
741
1.07k
                                  : other_inst->GetSingleWordInOperand(1u);
742
1.07k
      inst->SetInOperands({{SPV_OPERAND_TYPE_ID, {non_const_id}},
743
1.07k
                           {SPV_OPERAND_TYPE_ID, {merged_id}}});
744
1.07k
      return true;
745
2.25k
    }
746
747
135k
    return false;
748
137k
  };
749
16.3k
}
750
751
// Merges divides into subsequent multiplies if each instruction contains one
752
// constant operand. Does not support integer operations.
753
// Cases:
754
// 2 * (x / 2) = x * 1
755
// 2 * (2 / x) = 4 / x
756
// (x / 2) * 2 = x * 1
757
// (2 / x) * 2 = 4 / x
758
// (y / x) * x = y
759
// x * (y / x) = y
760
8.16k
FoldingRule MergeMulDivArithmetic() {
761
8.16k
  return [](IRContext* context, Instruction* inst,
762
170k
            const std::vector<const analysis::Constant*>& constants) {
763
170k
    assert(inst->opcode() == spv::Op::OpFMul);
764
170k
    analysis::ConstantManager* const_mgr = context->get_constant_mgr();
765
170k
    analysis::DefUseManager* def_use_mgr = context->get_def_use_mgr();
766
767
170k
    const analysis::Type* type =
768
170k
        context->get_type_mgr()->GetType(inst->type_id());
769
770
170k
    if (IsCooperativeMatrix(type)) {
771
0
      return false;
772
0
    }
773
774
170k
    if (!inst->IsFloatingPointFoldingAllowed()) return false;
775
776
170k
    uint32_t width = ElementWidth(type);
777
170k
    if (width != 32 && width != 64) return false;
778
779
511k
    for (uint32_t i = 0; i < 2; i++) {
780
341k
      uint32_t op_id = inst->GetSingleWordInOperand(i);
781
341k
      Instruction* op_inst = def_use_mgr->GetDef(op_id);
782
341k
      if (op_inst->opcode() == spv::Op::OpFDiv) {
783
3.47k
        if (op_inst->GetSingleWordInOperand(1) ==
784
3.47k
            inst->GetSingleWordInOperand(1 - i)) {
785
0
          inst->SetOpcode(spv::Op::OpCopyObject);
786
0
          inst->SetInOperands(
787
0
              {{SPV_OPERAND_TYPE_ID, {op_inst->GetSingleWordInOperand(0)}}});
788
0
          return true;
789
0
        }
790
3.47k
      }
791
341k
    }
792
793
170k
    const analysis::Constant* const_input1 = ConstInput(constants);
794
170k
    if (!const_input1) return false;
795
129k
    Instruction* other_inst = NonConstInput(context, constants[0], inst);
796
129k
    if (!other_inst->IsFloatingPointFoldingAllowed()) return false;
797
798
129k
    if (other_inst->opcode() == spv::Op::OpFDiv) {
799
209
      std::vector<const analysis::Constant*> other_constants =
800
209
          const_mgr->GetOperandConstants(other_inst);
801
209
      const analysis::Constant* const_input2 = ConstInput(other_constants);
802
209
      if (!const_input2 || HasZero(const_input2)) return false;
803
804
18
      bool other_first_is_variable = other_constants[0] == nullptr;
805
      // If the variable value is the second operand of the divide, multiply
806
      // the constants together. Otherwise divide the constants.
807
18
      uint32_t merged_id = PerformOperation(
808
18
          const_mgr,
809
18
          other_first_is_variable ? other_inst->opcode() : inst->opcode(),
810
18
          const_input1, const_input2);
811
18
      if (merged_id == 0) return false;
812
813
5
      uint32_t non_const_id = other_first_is_variable
814
5
                                  ? other_inst->GetSingleWordInOperand(0u)
815
5
                                  : other_inst->GetSingleWordInOperand(1u);
816
817
      // If the variable value is on the second operand of the div, then this
818
      // operation is a div. Otherwise it should be a multiply.
819
5
      inst->SetOpcode(other_first_is_variable ? inst->opcode()
820
5
                                              : other_inst->opcode());
821
5
      if (other_first_is_variable) {
822
5
        inst->SetInOperands({{SPV_OPERAND_TYPE_ID, {non_const_id}},
823
5
                             {SPV_OPERAND_TYPE_ID, {merged_id}}});
824
5
      } else {
825
0
        inst->SetInOperands({{SPV_OPERAND_TYPE_ID, {merged_id}},
826
0
                             {SPV_OPERAND_TYPE_ID, {non_const_id}}});
827
0
      }
828
5
      return true;
829
18
    }
830
831
129k
    return false;
832
129k
  };
833
8.16k
}
834
835
// Merges multiply of constant and negation.
836
// Cases:
837
// (-x) * 2 = x * -2
838
// 2 * (-x) = x * -2
839
16.3k
FoldingRule MergeMulNegateArithmetic() {
840
16.3k
  return [](IRContext* context, Instruction* inst,
841
179k
            const std::vector<const analysis::Constant*>& constants) {
842
179k
    assert(inst->opcode() == spv::Op::OpFMul ||
843
179k
           inst->opcode() == spv::Op::OpIMul);
844
179k
    analysis::ConstantManager* const_mgr = context->get_constant_mgr();
845
179k
    const analysis::Type* type =
846
179k
        context->get_type_mgr()->GetType(inst->type_id());
847
848
179k
    if (IsCooperativeMatrix(type)) {
849
0
      return false;
850
0
    }
851
852
179k
    bool uses_float = HasFloatingPoint(type);
853
179k
    if (uses_float && !inst->IsFloatingPointFoldingAllowed()) return false;
854
855
179k
    uint32_t width = ElementWidth(type);
856
179k
    if (width != 32 && width != 64) return false;
857
858
179k
    const analysis::Constant* const_input1 = ConstInput(constants);
859
179k
    if (!const_input1) return false;
860
136k
    Instruction* other_inst = NonConstInput(context, constants[0], inst);
861
136k
    if (uses_float && !other_inst->IsFloatingPointFoldingAllowed())
862
6
      return false;
863
864
136k
    if (other_inst->opcode() == spv::Op::OpFNegate ||
865
136k
        other_inst->opcode() == spv::Op::OpSNegate) {
866
1
      uint32_t neg_id = NegateConstant(const_mgr, const_input1);
867
868
1
      inst->SetInOperands(
869
1
          {{SPV_OPERAND_TYPE_ID, {other_inst->GetSingleWordInOperand(0u)}},
870
1
           {SPV_OPERAND_TYPE_ID, {neg_id}}});
871
1
      return true;
872
1
    }
873
874
136k
    return false;
875
136k
  };
876
16.3k
}
877
878
// Merges consecutive divides if each instruction contains one constant operand.
879
// Does not support integer division.
880
// Cases:
881
// 2 / (x / 2) = 4 / x
882
// 4 / (2 / x) = 2 * x
883
// (4 / x) / 2 = 2 / x
884
// (x / 2) / 2 = x / 4
885
8.16k
FoldingRule MergeDivDivArithmetic() {
886
8.16k
  return [](IRContext* context, Instruction* inst,
887
76.1k
            const std::vector<const analysis::Constant*>& constants) {
888
76.1k
    assert(inst->opcode() == spv::Op::OpFDiv);
889
76.1k
    analysis::ConstantManager* const_mgr = context->get_constant_mgr();
890
76.1k
    const analysis::Type* type =
891
76.1k
        context->get_type_mgr()->GetType(inst->type_id());
892
893
76.1k
    if (IsCooperativeMatrix(type)) {
894
0
      return false;
895
0
    }
896
897
76.1k
    if (!inst->IsFloatingPointFoldingAllowed()) return false;
898
899
76.1k
    uint32_t width = ElementWidth(type);
900
76.1k
    if (width != 32 && width != 64) return false;
901
902
76.1k
    const analysis::Constant* const_input1 = ConstInput(constants);
903
76.1k
    if (!const_input1 || HasZero(const_input1)) return false;
904
37.8k
    Instruction* other_inst = NonConstInput(context, constants[0], inst);
905
37.8k
    if (!other_inst->IsFloatingPointFoldingAllowed()) return false;
906
907
37.8k
    bool first_is_variable = constants[0] == nullptr;
908
37.8k
    if (other_inst->opcode() == inst->opcode()) {
909
334
      std::vector<const analysis::Constant*> other_constants =
910
334
          const_mgr->GetOperandConstants(other_inst);
911
334
      const analysis::Constant* const_input2 = ConstInput(other_constants);
912
334
      if (!const_input2 || HasZero(const_input2)) return false;
913
914
262
      bool other_first_is_variable = other_constants[0] == nullptr;
915
916
262
      spv::Op merge_op = inst->opcode();
917
262
      if (other_first_is_variable) {
918
        // Constants magnify.
919
249
        merge_op = spv::Op::OpFMul;
920
249
      }
921
922
      // This is an x / (*) case. Swap the inputs. Doesn't harm multiply
923
      // because it is commutative.
924
262
      if (first_is_variable) std::swap(const_input1, const_input2);
925
262
      uint32_t merged_id =
926
262
          PerformOperation(const_mgr, merge_op, const_input1, const_input2);
927
262
      if (merged_id == 0) return false;
928
929
44
      uint32_t non_const_id = other_first_is_variable
930
44
                                  ? other_inst->GetSingleWordInOperand(0u)
931
44
                                  : other_inst->GetSingleWordInOperand(1u);
932
933
44
      spv::Op op = inst->opcode();
934
44
      if (!first_is_variable && !other_first_is_variable) {
935
        // Effectively div of 1/x, so change to multiply.
936
7
        op = spv::Op::OpFMul;
937
7
      }
938
939
44
      uint32_t op1 = merged_id;
940
44
      uint32_t op2 = non_const_id;
941
44
      if (first_is_variable && other_first_is_variable) std::swap(op1, op2);
942
44
      inst->SetOpcode(op);
943
44
      inst->SetInOperands(
944
44
          {{SPV_OPERAND_TYPE_ID, {op1}}, {SPV_OPERAND_TYPE_ID, {op2}}});
945
44
      return true;
946
262
    }
947
948
37.4k
    return false;
949
37.8k
  };
950
8.16k
}
951
952
// Fold multiplies succeeded by divides where each instruction contains a
953
// constant operand. Does not support integer divide.
954
// Cases:
955
// 4 / (x * 2) = 2 / x
956
// 4 / (2 * x) = 2 / x
957
// (x * 4) / 2 = x * 2
958
// (4 * x) / 2 = x * 2
959
// (x * y) / x = y
960
// (y * x) / x = y
961
8.16k
FoldingRule MergeDivMulArithmetic() {
962
8.16k
  return [](IRContext* context, Instruction* inst,
963
76.1k
            const std::vector<const analysis::Constant*>& constants) {
964
76.1k
    assert(inst->opcode() == spv::Op::OpFDiv);
965
76.1k
    analysis::DefUseManager* def_use_mgr = context->get_def_use_mgr();
966
76.1k
    analysis::ConstantManager* const_mgr = context->get_constant_mgr();
967
968
76.1k
    const analysis::Type* type =
969
76.1k
        context->get_type_mgr()->GetType(inst->type_id());
970
971
76.1k
    if (IsCooperativeMatrix(type)) {
972
0
      return false;
973
0
    }
974
975
76.1k
    if (!inst->IsFloatingPointFoldingAllowed()) return false;
976
977
76.0k
    uint32_t width = ElementWidth(type);
978
76.0k
    if (width != 32 && width != 64) return false;
979
980
76.0k
    uint32_t op_id = inst->GetSingleWordInOperand(0);
981
76.0k
    Instruction* op_inst = def_use_mgr->GetDef(op_id);
982
983
76.0k
    if (op_inst->opcode() == spv::Op::OpFMul) {
984
565
      for (uint32_t i = 0; i < 2; i++) {
985
377
        if (op_inst->GetSingleWordInOperand(i) ==
986
377
            inst->GetSingleWordInOperand(1)) {
987
1
          inst->SetOpcode(spv::Op::OpCopyObject);
988
1
          inst->SetInOperands({{SPV_OPERAND_TYPE_ID,
989
1
                                {op_inst->GetSingleWordInOperand(1 - i)}}});
990
1
          return true;
991
1
        }
992
377
      }
993
189
    }
994
995
76.0k
    const analysis::Constant* const_input1 = ConstInput(constants);
996
76.0k
    if (!const_input1 || HasZero(const_input1)) return false;
997
37.7k
    Instruction* other_inst = NonConstInput(context, constants[0], inst);
998
37.7k
    if (!other_inst->IsFloatingPointFoldingAllowed()) return false;
999
1000
37.7k
    bool first_is_variable = constants[0] == nullptr;
1001
37.7k
    if (other_inst->opcode() == spv::Op::OpFMul) {
1002
38
      std::vector<const analysis::Constant*> other_constants =
1003
38
          const_mgr->GetOperandConstants(other_inst);
1004
38
      const analysis::Constant* const_input2 = ConstInput(other_constants);
1005
38
      if (!const_input2) return false;
1006
1007
36
      bool other_first_is_variable = other_constants[0] == nullptr;
1008
1009
      // This is an x / (*) case. Swap the inputs.
1010
36
      if (first_is_variable) std::swap(const_input1, const_input2);
1011
36
      uint32_t merged_id = PerformOperation(const_mgr, inst->opcode(),
1012
36
                                            const_input1, const_input2);
1013
36
      if (merged_id == 0) return false;
1014
1015
17
      uint32_t non_const_id = other_first_is_variable
1016
17
                                  ? other_inst->GetSingleWordInOperand(0u)
1017
17
                                  : other_inst->GetSingleWordInOperand(1u);
1018
1019
17
      uint32_t op1 = merged_id;
1020
17
      uint32_t op2 = non_const_id;
1021
17
      if (first_is_variable) std::swap(op1, op2);
1022
1023
      // Convert to multiply
1024
17
      if (first_is_variable) inst->SetOpcode(other_inst->opcode());
1025
17
      inst->SetInOperands(
1026
17
          {{SPV_OPERAND_TYPE_ID, {op1}}, {SPV_OPERAND_TYPE_ID, {op2}}});
1027
17
      return true;
1028
36
    }
1029
1030
37.7k
    return false;
1031
37.7k
  };
1032
8.16k
}
1033
1034
// Fold divides of a constant and a negation.
1035
// Cases:
1036
// (-x) / 2 = x / -2
1037
// 2 / (-x) = -2 / x
1038
8.16k
FoldingRule MergeDivNegateArithmetic() {
1039
8.16k
  return [](IRContext* context, Instruction* inst,
1040
76.1k
            const std::vector<const analysis::Constant*>& constants) {
1041
76.1k
    assert(inst->opcode() == spv::Op::OpFDiv);
1042
76.1k
    analysis::ConstantManager* const_mgr = context->get_constant_mgr();
1043
76.1k
    if (!inst->IsFloatingPointFoldingAllowed()) return false;
1044
1045
76.0k
    const analysis::Constant* const_input1 = ConstInput(constants);
1046
76.0k
    if (!const_input1) return false;
1047
55.4k
    Instruction* other_inst = NonConstInput(context, constants[0], inst);
1048
55.4k
    if (!other_inst->IsFloatingPointFoldingAllowed()) return false;
1049
1050
55.4k
    bool first_is_variable = constants[0] == nullptr;
1051
55.4k
    if (other_inst->opcode() == spv::Op::OpFNegate) {
1052
0
      uint32_t neg_id = NegateConstant(const_mgr, const_input1);
1053
1054
0
      if (first_is_variable) {
1055
0
        inst->SetInOperands(
1056
0
            {{SPV_OPERAND_TYPE_ID, {other_inst->GetSingleWordInOperand(0u)}},
1057
0
             {SPV_OPERAND_TYPE_ID, {neg_id}}});
1058
0
      } else {
1059
0
        inst->SetInOperands(
1060
0
            {{SPV_OPERAND_TYPE_ID, {neg_id}},
1061
0
             {SPV_OPERAND_TYPE_ID, {other_inst->GetSingleWordInOperand(0u)}}});
1062
0
      }
1063
0
      return true;
1064
0
    }
1065
1066
55.4k
    return false;
1067
55.4k
  };
1068
8.16k
}
1069
1070
// Folds addition of a constant and a negation.
1071
// Cases:
1072
// (-x) + 2 = 2 - x
1073
// 2 + (-x) = 2 - x
1074
16.3k
FoldingRule MergeAddNegateArithmetic() {
1075
16.3k
  return [](IRContext* context, Instruction* inst,
1076
1.81M
            const std::vector<const analysis::Constant*>& constants) {
1077
1.81M
    assert(inst->opcode() == spv::Op::OpFAdd ||
1078
1.81M
           inst->opcode() == spv::Op::OpIAdd);
1079
1.81M
    const analysis::Type* type =
1080
1.81M
        context->get_type_mgr()->GetType(inst->type_id());
1081
1.81M
    bool uses_float = HasFloatingPoint(type);
1082
1.81M
    if (uses_float && !inst->IsFloatingPointFoldingAllowed()) return false;
1083
1084
1.81M
    const analysis::Constant* const_input1 = ConstInput(constants);
1085
1.81M
    if (!const_input1) return false;
1086
428k
    Instruction* other_inst = NonConstInput(context, constants[0], inst);
1087
428k
    if (uses_float && !other_inst->IsFloatingPointFoldingAllowed())
1088
13.8k
      return false;
1089
1090
415k
    if (other_inst->opcode() == spv::Op::OpSNegate ||
1091
415k
        other_inst->opcode() == spv::Op::OpFNegate) {
1092
1
      inst->SetOpcode(HasFloatingPoint(type) ? spv::Op::OpFSub
1093
1
                                             : spv::Op::OpISub);
1094
1
      uint32_t const_id = constants[0] ? inst->GetSingleWordInOperand(0u)
1095
1
                                       : inst->GetSingleWordInOperand(1u);
1096
1
      inst->SetInOperands(
1097
1
          {{SPV_OPERAND_TYPE_ID, {const_id}},
1098
1
           {SPV_OPERAND_TYPE_ID, {other_inst->GetSingleWordInOperand(0u)}}});
1099
1
      return true;
1100
1
    }
1101
415k
    return false;
1102
415k
  };
1103
16.3k
}
1104
1105
// Folds subtraction of a constant and a negation.
1106
// Cases:
1107
// (-x) - 2 = -2 - x
1108
// 2 - (-x) = x + 2
1109
16.3k
FoldingRule MergeSubNegateArithmetic() {
1110
16.3k
  return [](IRContext* context, Instruction* inst,
1111
200k
            const std::vector<const analysis::Constant*>& constants) {
1112
200k
    assert(inst->opcode() == spv::Op::OpFSub ||
1113
200k
           inst->opcode() == spv::Op::OpISub);
1114
200k
    analysis::ConstantManager* const_mgr = context->get_constant_mgr();
1115
200k
    const analysis::Type* type =
1116
200k
        context->get_type_mgr()->GetType(inst->type_id());
1117
1118
200k
    if (IsCooperativeMatrix(type)) {
1119
0
      return false;
1120
0
    }
1121
1122
200k
    bool uses_float = HasFloatingPoint(type);
1123
200k
    if (uses_float && !inst->IsFloatingPointFoldingAllowed()) return false;
1124
1125
200k
    uint32_t width = ElementWidth(type);
1126
200k
    if (width != 32 && width != 64) return false;
1127
1128
200k
    const analysis::Constant* const_input1 = ConstInput(constants);
1129
200k
    if (!const_input1) return false;
1130
50.0k
    Instruction* other_inst = NonConstInput(context, constants[0], inst);
1131
50.0k
    if (uses_float && !other_inst->IsFloatingPointFoldingAllowed())
1132
5
      return false;
1133
1134
50.0k
    if (other_inst->opcode() == spv::Op::OpSNegate ||
1135
50.0k
        other_inst->opcode() == spv::Op::OpFNegate) {
1136
13
      uint32_t op1 = 0;
1137
13
      uint32_t op2 = 0;
1138
13
      spv::Op opcode = inst->opcode();
1139
13
      if (constants[0] != nullptr) {
1140
9
        op1 = other_inst->GetSingleWordInOperand(0u);
1141
9
        op2 = inst->GetSingleWordInOperand(0u);
1142
9
        opcode = HasFloatingPoint(type) ? spv::Op::OpFAdd : spv::Op::OpIAdd;
1143
9
      } else {
1144
4
        op1 = NegateConstant(const_mgr, const_input1);
1145
4
        op2 = other_inst->GetSingleWordInOperand(0u);
1146
4
      }
1147
1148
13
      inst->SetOpcode(opcode);
1149
13
      inst->SetInOperands(
1150
13
          {{SPV_OPERAND_TYPE_ID, {op1}}, {SPV_OPERAND_TYPE_ID, {op2}}});
1151
13
      return true;
1152
13
    }
1153
50.0k
    return false;
1154
50.0k
  };
1155
16.3k
}
1156
1157
// Folds addition of an addition where each operation has a constant operand.
1158
// Cases:
1159
// (x + 2) + 2 = x + 4
1160
// (2 + x) + 2 = x + 4
1161
// 2 + (x + 2) = x + 4
1162
// 2 + (2 + x) = x + 4
1163
16.3k
FoldingRule MergeAddAddArithmetic() {
1164
16.3k
  return [](IRContext* context, Instruction* inst,
1165
1.81M
            const std::vector<const analysis::Constant*>& constants) {
1166
1.81M
    assert(inst->opcode() == spv::Op::OpFAdd ||
1167
1.81M
           inst->opcode() == spv::Op::OpIAdd);
1168
1.81M
    const analysis::Type* type =
1169
1.81M
        context->get_type_mgr()->GetType(inst->type_id());
1170
1171
1.81M
    if (IsCooperativeMatrix(type)) {
1172
0
      return false;
1173
0
    }
1174
1175
1.81M
    analysis::ConstantManager* const_mgr = context->get_constant_mgr();
1176
1.81M
    bool uses_float = HasFloatingPoint(type);
1177
1.81M
    if (uses_float && !inst->IsFloatingPointFoldingAllowed()) return false;
1178
1179
1.81M
    uint32_t width = ElementWidth(type);
1180
1.81M
    if (width != 32 && width != 64) return false;
1181
1182
1.81M
    const analysis::Constant* const_input1 = ConstInput(constants);
1183
1.81M
    if (!const_input1) return false;
1184
428k
    Instruction* other_inst = NonConstInput(context, constants[0], inst);
1185
428k
    if (uses_float && !other_inst->IsFloatingPointFoldingAllowed())
1186
13.8k
      return false;
1187
1188
415k
    if (other_inst->opcode() == spv::Op::OpFAdd ||
1189
415k
        other_inst->opcode() == spv::Op::OpIAdd) {
1190
5.83k
      std::vector<const analysis::Constant*> other_constants =
1191
5.83k
          const_mgr->GetOperandConstants(other_inst);
1192
5.83k
      const analysis::Constant* const_input2 = ConstInput(other_constants);
1193
5.83k
      if (!const_input2) return false;
1194
1195
3.89k
      Instruction* non_const_input =
1196
3.89k
          NonConstInput(context, other_constants[0], other_inst);
1197
3.89k
      uint32_t merged_id = PerformOperation(const_mgr, inst->opcode(),
1198
3.89k
                                            const_input1, const_input2);
1199
3.89k
      if (merged_id == 0) return false;
1200
1201
2.70k
      inst->SetInOperands(
1202
2.70k
          {{SPV_OPERAND_TYPE_ID, {non_const_input->result_id()}},
1203
2.70k
           {SPV_OPERAND_TYPE_ID, {merged_id}}});
1204
2.70k
      return true;
1205
3.89k
    }
1206
409k
    return false;
1207
415k
  };
1208
16.3k
}
1209
1210
// Folds addition of a subtraction where each operation has a constant operand.
1211
// Cases:
1212
// (x - 2) + 2 = x + 0
1213
// (2 - x) + 2 = 4 - x
1214
// 2 + (x - 2) = x + 0
1215
// 2 + (2 - x) = 4 - x
1216
16.3k
FoldingRule MergeAddSubArithmetic() {
1217
16.3k
  return [](IRContext* context, Instruction* inst,
1218
1.81M
            const std::vector<const analysis::Constant*>& constants) {
1219
1.81M
    assert(inst->opcode() == spv::Op::OpFAdd ||
1220
1.81M
           inst->opcode() == spv::Op::OpIAdd);
1221
1.81M
    const analysis::Type* type =
1222
1.81M
        context->get_type_mgr()->GetType(inst->type_id());
1223
1224
1.81M
    if (IsCooperativeMatrix(type)) {
1225
0
      return false;
1226
0
    }
1227
1228
1.81M
    analysis::ConstantManager* const_mgr = context->get_constant_mgr();
1229
1.81M
    bool uses_float = HasFloatingPoint(type);
1230
1.81M
    if (uses_float && !inst->IsFloatingPointFoldingAllowed()) return false;
1231
1232
1.81M
    uint32_t width = ElementWidth(type);
1233
1.81M
    if (width != 32 && width != 64) return false;
1234
1235
1.81M
    const analysis::Constant* const_input1 = ConstInput(constants);
1236
1.81M
    if (!const_input1) return false;
1237
426k
    Instruction* other_inst = NonConstInput(context, constants[0], inst);
1238
426k
    if (uses_float && !other_inst->IsFloatingPointFoldingAllowed())
1239
13.8k
      return false;
1240
1241
412k
    if (other_inst->opcode() == spv::Op::OpFSub ||
1242
412k
        other_inst->opcode() == spv::Op::OpISub) {
1243
2.35k
      std::vector<const analysis::Constant*> other_constants =
1244
2.35k
          const_mgr->GetOperandConstants(other_inst);
1245
2.35k
      const analysis::Constant* const_input2 = ConstInput(other_constants);
1246
2.35k
      if (!const_input2) return false;
1247
1248
2.26k
      bool first_is_variable = other_constants[0] == nullptr;
1249
2.26k
      spv::Op op = inst->opcode();
1250
2.26k
      uint32_t op1 = 0;
1251
2.26k
      uint32_t op2 = 0;
1252
2.26k
      if (first_is_variable) {
1253
        // Subtract constants. Non-constant operand is first.
1254
2.25k
        op1 = other_inst->GetSingleWordInOperand(0u);
1255
2.25k
        op2 = PerformOperation(const_mgr, other_inst->opcode(), const_input1,
1256
2.25k
                               const_input2);
1257
2.25k
      } else {
1258
        // Add constants. Constant operand is first. Change the opcode.
1259
12
        op1 = PerformOperation(const_mgr, inst->opcode(), const_input1,
1260
12
                               const_input2);
1261
12
        op2 = other_inst->GetSingleWordInOperand(1u);
1262
12
        op = other_inst->opcode();
1263
12
      }
1264
2.26k
      if (op1 == 0 || op2 == 0) return false;
1265
1266
2.23k
      inst->SetOpcode(op);
1267
2.23k
      inst->SetInOperands(
1268
2.23k
          {{SPV_OPERAND_TYPE_ID, {op1}}, {SPV_OPERAND_TYPE_ID, {op2}}});
1269
2.23k
      return true;
1270
2.26k
    }
1271
410k
    return false;
1272
412k
  };
1273
16.3k
}
1274
1275
// Folds subtraction of an addition where each operand has a constant operand.
1276
// Cases:
1277
// (x + 2) - 2 = x + 0
1278
// (2 + x) - 2 = x + 0
1279
// 2 - (x + 2) = 0 - x
1280
// 2 - (2 + x) = 0 - x
1281
16.3k
FoldingRule MergeSubAddArithmetic() {
1282
16.3k
  return [](IRContext* context, Instruction* inst,
1283
200k
            const std::vector<const analysis::Constant*>& constants) {
1284
200k
    assert(inst->opcode() == spv::Op::OpFSub ||
1285
200k
           inst->opcode() == spv::Op::OpISub);
1286
200k
    const analysis::Type* type =
1287
200k
        context->get_type_mgr()->GetType(inst->type_id());
1288
1289
200k
    if (IsCooperativeMatrix(type)) {
1290
0
      return false;
1291
0
    }
1292
1293
200k
    analysis::ConstantManager* const_mgr = context->get_constant_mgr();
1294
200k
    bool uses_float = HasFloatingPoint(type);
1295
200k
    if (uses_float && !inst->IsFloatingPointFoldingAllowed()) return false;
1296
1297
200k
    uint32_t width = ElementWidth(type);
1298
200k
    if (width != 32 && width != 64) return false;
1299
1300
200k
    const analysis::Constant* const_input1 = ConstInput(constants);
1301
200k
    if (!const_input1) return false;
1302
50.0k
    Instruction* other_inst = NonConstInput(context, constants[0], inst);
1303
50.0k
    if (uses_float && !other_inst->IsFloatingPointFoldingAllowed())
1304
5
      return false;
1305
1306
50.0k
    if (other_inst->opcode() == spv::Op::OpFAdd ||
1307
50.0k
        other_inst->opcode() == spv::Op::OpIAdd) {
1308
6.57k
      std::vector<const analysis::Constant*> other_constants =
1309
6.57k
          const_mgr->GetOperandConstants(other_inst);
1310
6.57k
      const analysis::Constant* const_input2 = ConstInput(other_constants);
1311
6.57k
      if (!const_input2) return false;
1312
1313
1.33k
      Instruction* non_const_input =
1314
1.33k
          NonConstInput(context, other_constants[0], other_inst);
1315
1316
      // If the first operand of the sub is not a constant, swap the constants
1317
      // so the subtraction has the correct operands.
1318
1.33k
      if (constants[0] == nullptr) std::swap(const_input1, const_input2);
1319
      // Subtract the constants.
1320
1.33k
      uint32_t merged_id = PerformOperation(const_mgr, inst->opcode(),
1321
1.33k
                                            const_input1, const_input2);
1322
1.33k
      spv::Op op = inst->opcode();
1323
1.33k
      uint32_t op1 = 0;
1324
1.33k
      uint32_t op2 = 0;
1325
1.33k
      if (constants[0] == nullptr) {
1326
        // Non-constant operand is first. Change the opcode.
1327
1.30k
        op1 = non_const_input->result_id();
1328
1.30k
        op2 = merged_id;
1329
1.30k
        op = other_inst->opcode();
1330
1.30k
      } else {
1331
        // Constant operand is first.
1332
38
        op1 = merged_id;
1333
38
        op2 = non_const_input->result_id();
1334
38
      }
1335
1.33k
      if (op1 == 0 || op2 == 0) return false;
1336
1337
1.22k
      inst->SetOpcode(op);
1338
1.22k
      inst->SetInOperands(
1339
1.22k
          {{SPV_OPERAND_TYPE_ID, {op1}}, {SPV_OPERAND_TYPE_ID, {op2}}});
1340
1.22k
      return true;
1341
1.33k
    }
1342
43.4k
    return false;
1343
50.0k
  };
1344
16.3k
}
1345
1346
// Folds subtraction of a subtraction where each operand has a constant operand.
1347
// Cases:
1348
// (x - 2) - 2 = x - 4
1349
// (2 - x) - 2 = 0 - x
1350
// 2 - (x - 2) = 4 - x
1351
// 2 - (2 - x) = x + 0
1352
16.3k
FoldingRule MergeSubSubArithmetic() {
1353
16.3k
  return [](IRContext* context, Instruction* inst,
1354
199k
            const std::vector<const analysis::Constant*>& constants) {
1355
199k
    assert(inst->opcode() == spv::Op::OpFSub ||
1356
199k
           inst->opcode() == spv::Op::OpISub);
1357
199k
    const analysis::Type* type =
1358
199k
        context->get_type_mgr()->GetType(inst->type_id());
1359
1360
199k
    if (IsCooperativeMatrix(type)) {
1361
0
      return false;
1362
0
    }
1363
1364
199k
    analysis::ConstantManager* const_mgr = context->get_constant_mgr();
1365
199k
    bool uses_float = HasFloatingPoint(type);
1366
199k
    if (uses_float && !inst->IsFloatingPointFoldingAllowed()) return false;
1367
1368
199k
    uint32_t width = ElementWidth(type);
1369
199k
    if (width != 32 && width != 64) return false;
1370
1371
199k
    const analysis::Constant* const_input1 = ConstInput(constants);
1372
199k
    if (!const_input1) return false;
1373
48.8k
    Instruction* other_inst = NonConstInput(context, constants[0], inst);
1374
48.8k
    if (uses_float && !other_inst->IsFloatingPointFoldingAllowed())
1375
5
      return false;
1376
1377
48.7k
    if (other_inst->opcode() == spv::Op::OpFSub ||
1378
48.7k
        other_inst->opcode() == spv::Op::OpISub) {
1379
3.16k
      std::vector<const analysis::Constant*> other_constants =
1380
3.16k
          const_mgr->GetOperandConstants(other_inst);
1381
3.16k
      const analysis::Constant* const_input2 = ConstInput(other_constants);
1382
3.16k
      if (!const_input2) return false;
1383
1384
2.96k
      Instruction* non_const_input =
1385
2.96k
          NonConstInput(context, other_constants[0], other_inst);
1386
1387
      // Merge the constants.
1388
2.96k
      uint32_t merged_id = 0;
1389
2.96k
      spv::Op merge_op = inst->opcode();
1390
2.96k
      if (other_constants[0] == nullptr) {
1391
2.81k
        merge_op = uses_float ? spv::Op::OpFAdd : spv::Op::OpIAdd;
1392
2.81k
      } else if (constants[0] == nullptr) {
1393
31
        std::swap(const_input1, const_input2);
1394
31
      }
1395
2.96k
      merged_id =
1396
2.96k
          PerformOperation(const_mgr, merge_op, const_input1, const_input2);
1397
2.96k
      if (merged_id == 0) return false;
1398
1399
2.09k
      spv::Op op = inst->opcode();
1400
2.09k
      if (constants[0] != nullptr && other_constants[0] != nullptr) {
1401
        // Change the operation.
1402
120
        op = uses_float ? spv::Op::OpFAdd : spv::Op::OpIAdd;
1403
120
      }
1404
1405
2.09k
      uint32_t op1 = 0;
1406
2.09k
      uint32_t op2 = 0;
1407
2.09k
      if ((constants[0] == nullptr) ^ (other_constants[0] == nullptr)) {
1408
22
        op1 = merged_id;
1409
22
        op2 = non_const_input->result_id();
1410
2.07k
      } else {
1411
2.07k
        op1 = non_const_input->result_id();
1412
2.07k
        op2 = merged_id;
1413
2.07k
      }
1414
1415
2.09k
      inst->SetOpcode(op);
1416
2.09k
      inst->SetInOperands(
1417
2.09k
          {{SPV_OPERAND_TYPE_ID, {op1}}, {SPV_OPERAND_TYPE_ID, {op2}}});
1418
2.09k
      return true;
1419
2.96k
    }
1420
45.6k
    return false;
1421
48.7k
  };
1422
16.3k
}
1423
1424
// Helper function for MergeGenericAddSubArithmetic. If |addend| and
1425
// subtrahend of |sub| is the same, merge to copy of minuend of |sub|.
1426
3.62M
bool MergeGenericAddendSub(uint32_t addend, uint32_t sub, Instruction* inst) {
1427
3.62M
  IRContext* context = inst->context();
1428
3.62M
  analysis::DefUseManager* def_use_mgr = context->get_def_use_mgr();
1429
3.62M
  Instruction* sub_inst = def_use_mgr->GetDef(sub);
1430
3.62M
  if (sub_inst->opcode() != spv::Op::OpFSub &&
1431
3.62M
      sub_inst->opcode() != spv::Op::OpISub)
1432
3.62M
    return false;
1433
2.66k
  if (sub_inst->opcode() == spv::Op::OpFSub &&
1434
2.66k
      !sub_inst->IsFloatingPointFoldingAllowed())
1435
0
    return false;
1436
2.66k
  if (addend != sub_inst->GetSingleWordInOperand(1)) return false;
1437
5
  inst->SetOpcode(spv::Op::OpCopyObject);
1438
5
  inst->SetInOperands(
1439
5
      {{SPV_OPERAND_TYPE_ID, {sub_inst->GetSingleWordInOperand(0)}}});
1440
5
  context->UpdateDefUse(inst);
1441
5
  return true;
1442
2.66k
}
1443
1444
// Folds addition of a subtraction where the subtrahend is equal to the
1445
// other addend. Return a copy of the minuend. Accepts generic (const and
1446
// non-const) operands.
1447
// Cases:
1448
// (a - b) + b = a
1449
// b + (a - b) = a
1450
16.3k
FoldingRule MergeGenericAddSubArithmetic() {
1451
16.3k
  return [](IRContext* context, Instruction* inst,
1452
1.81M
            const std::vector<const analysis::Constant*>&) {
1453
1.81M
    assert(inst->opcode() == spv::Op::OpFAdd ||
1454
1.81M
           inst->opcode() == spv::Op::OpIAdd);
1455
1.81M
    const analysis::Type* type =
1456
1.81M
        context->get_type_mgr()->GetType(inst->type_id());
1457
1458
1.81M
    if (IsCooperativeMatrix(type)) {
1459
0
      return false;
1460
0
    }
1461
1462
1.81M
    bool uses_float = HasFloatingPoint(type);
1463
1.81M
    if (uses_float && !inst->IsFloatingPointFoldingAllowed()) return false;
1464
1465
1.81M
    uint32_t width = ElementWidth(type);
1466
1.81M
    if (width != 32 && width != 64) return false;
1467
1468
1.81M
    uint32_t add_op0 = inst->GetSingleWordInOperand(0);
1469
1.81M
    uint32_t add_op1 = inst->GetSingleWordInOperand(1);
1470
1.81M
    if (MergeGenericAddendSub(add_op0, add_op1, inst)) return true;
1471
1.81M
    return MergeGenericAddendSub(add_op1, add_op0, inst);
1472
1.81M
  };
1473
16.3k
}
1474
1475
// Helper function for FactorAddMuls. If |factor0_0| is the same as |factor1_0|,
1476
// generate |factor0_0| * (|factor0_1| + |factor1_1|).
1477
bool FactorAddMulsOpnds(uint32_t factor0_0, uint32_t factor0_1,
1478
                        uint32_t factor1_0, uint32_t factor1_1,
1479
1.32k
                        Instruction* inst) {
1480
1.32k
  IRContext* context = inst->context();
1481
1.32k
  if (factor0_0 != factor1_0) return false;
1482
15
  InstructionBuilder ir_builder(
1483
15
      context, inst,
1484
15
      IRContext::kAnalysisDefUse | IRContext::kAnalysisInstrToBlockMapping);
1485
15
  Instruction* new_add_inst = ir_builder.AddBinaryOp(
1486
15
      inst->type_id(), inst->opcode(), factor0_1, factor1_1);
1487
15
  inst->SetOpcode(inst->opcode() == spv::Op::OpFAdd ? spv::Op::OpFMul
1488
15
                                                    : spv::Op::OpIMul);
1489
15
  inst->SetInOperands({{SPV_OPERAND_TYPE_ID, {factor0_0}},
1490
15
                       {SPV_OPERAND_TYPE_ID, {new_add_inst->result_id()}}});
1491
15
  context->UpdateDefUse(inst);
1492
15
  return true;
1493
1.32k
}
1494
1495
// Perform the following factoring identity, handling all operand order
1496
// combinations: (a * b) + (a * c) = a * (b + c)
1497
16.3k
FoldingRule FactorAddMuls() {
1498
16.3k
  return [](IRContext* context, Instruction* inst,
1499
1.81M
            const std::vector<const analysis::Constant*>&) {
1500
1.81M
    assert(inst->opcode() == spv::Op::OpFAdd ||
1501
1.81M
           inst->opcode() == spv::Op::OpIAdd);
1502
1.81M
    const analysis::Type* type =
1503
1.81M
        context->get_type_mgr()->GetType(inst->type_id());
1504
1.81M
    bool uses_float = HasFloatingPoint(type);
1505
1.81M
    if (uses_float && !inst->IsFloatingPointFoldingAllowed()) return false;
1506
1507
1.81M
    analysis::DefUseManager* def_use_mgr = context->get_def_use_mgr();
1508
1.81M
    uint32_t add_op0 = inst->GetSingleWordInOperand(0);
1509
1.81M
    Instruction* add_op0_inst = def_use_mgr->GetDef(add_op0);
1510
1.81M
    if (add_op0_inst->opcode() != spv::Op::OpFMul &&
1511
1.81M
        add_op0_inst->opcode() != spv::Op::OpIMul)
1512
1.80M
      return false;
1513
4.60k
    uint32_t add_op1 = inst->GetSingleWordInOperand(1);
1514
4.60k
    Instruction* add_op1_inst = def_use_mgr->GetDef(add_op1);
1515
4.60k
    if (add_op1_inst->opcode() != spv::Op::OpFMul &&
1516
4.60k
        add_op1_inst->opcode() != spv::Op::OpIMul)
1517
3.63k
      return false;
1518
1519
    // Only perform this optimization if both of the muls only have one use.
1520
    // Otherwise this is a deoptimization in size and performance.
1521
972
    if (def_use_mgr->NumUses(add_op0_inst) > 1) return false;
1522
363
    if (def_use_mgr->NumUses(add_op1_inst) > 1) return false;
1523
1524
342
    if (add_op0_inst->opcode() == spv::Op::OpFMul &&
1525
342
        (!add_op0_inst->IsFloatingPointFoldingAllowed() ||
1526
305
         !add_op1_inst->IsFloatingPointFoldingAllowed()))
1527
0
      return false;
1528
1529
996
    for (int i = 0; i < 2; i++) {
1530
1.97k
      for (int j = 0; j < 2; j++) {
1531
        // Check if operand i in add_op0_inst matches operand j in add_op1_inst.
1532
1.32k
        if (FactorAddMulsOpnds(add_op0_inst->GetSingleWordInOperand(i),
1533
1.32k
                               add_op0_inst->GetSingleWordInOperand(1 - i),
1534
1.32k
                               add_op1_inst->GetSingleWordInOperand(j),
1535
1.32k
                               add_op1_inst->GetSingleWordInOperand(1 - j),
1536
1.32k
                               inst))
1537
15
          return true;
1538
1.32k
      }
1539
669
    }
1540
327
    return false;
1541
342
  };
1542
16.3k
}
1543
1544
8.16k
FoldingRule IntMultipleBy1() {
1545
8.16k
  return [](IRContext*, Instruction* inst,
1546
8.99k
            const std::vector<const analysis::Constant*>& constants) {
1547
8.99k
    assert(inst->opcode() == spv::Op::OpIMul &&
1548
8.99k
           "Wrong opcode.  Should be OpIMul.");
1549
26.7k
    for (uint32_t i = 0; i < 2; i++) {
1550
17.9k
      if (constants[i] == nullptr) {
1551
10.2k
        continue;
1552
10.2k
      }
1553
7.77k
      const analysis::IntConstant* int_constant = constants[i]->AsIntConstant();
1554
7.77k
      if (int_constant) {
1555
7.77k
        uint32_t width = ElementWidth(int_constant->type());
1556
7.77k
        if (width != 32 && width != 64) return false;
1557
7.77k
        bool is_one = (width == 32) ? int_constant->GetU32BitValue() == 1u
1558
7.77k
                                    : int_constant->GetU64BitValue() == 1ull;
1559
7.77k
        if (is_one) {
1560
243
          inst->SetOpcode(spv::Op::OpCopyObject);
1561
243
          inst->SetInOperands(
1562
243
              {{SPV_OPERAND_TYPE_ID, {inst->GetSingleWordInOperand(1 - i)}}});
1563
243
          return true;
1564
243
        }
1565
7.77k
      }
1566
7.77k
    }
1567
8.74k
    return false;
1568
8.99k
  };
1569
8.16k
}
1570
1571
// Returns the number of elements that the |index|th in operand in |inst|
1572
// contributes to the result of |inst|.  |inst| must be an
1573
// OpCompositeConstructInstruction.
1574
uint32_t GetNumOfElementsContributedByOperand(IRContext* context,
1575
                                              const Instruction* inst,
1576
6.48k
                                              uint32_t index) {
1577
6.48k
  assert(inst->opcode() == spv::Op::OpCompositeConstruct);
1578
6.48k
  analysis::DefUseManager* def_use_mgr = context->get_def_use_mgr();
1579
6.48k
  analysis::TypeManager* type_mgr = context->get_type_mgr();
1580
1581
6.48k
  analysis::Vector* result_type =
1582
6.48k
      type_mgr->GetType(inst->type_id())->AsVector();
1583
6.48k
  if (result_type == nullptr) {
1584
    // If the result of the OpCompositeConstruct is not a vector then every
1585
    // operands corresponds to a single element in the result.
1586
0
    return 1;
1587
0
  }
1588
1589
  // If the result type is a vector then the operands are either scalars or
1590
  // vectors. If it is a scalar, then it corresponds to a single element.  If it
1591
  // is a vector, then each element in the vector will be an element in the
1592
  // result.
1593
6.48k
  uint32_t id = inst->GetSingleWordInOperand(index);
1594
6.48k
  Instruction* def = def_use_mgr->GetDef(id);
1595
6.48k
  analysis::Vector* type = type_mgr->GetType(def->type_id())->AsVector();
1596
6.48k
  if (type == nullptr) {
1597
6.48k
    return 1;
1598
6.48k
  }
1599
0
  return type->element_count();
1600
6.48k
}
1601
1602
// Returns the in-operands for an OpCompositeExtract instruction that are needed
1603
// to extract the |result_index|th element in the result of |inst| without using
1604
// the result of |inst|. Returns the empty vector if |result_index| is
1605
// out-of-bounds. |inst| must be an |OpCompositeConstruct| instruction.
1606
std::vector<Operand> GetExtractOperandsForElementOfCompositeConstruct(
1607
27.3k
    IRContext* context, const Instruction* inst, uint32_t result_index) {
1608
27.3k
  assert(inst->opcode() == spv::Op::OpCompositeConstruct);
1609
27.3k
  analysis::DefUseManager* def_use_mgr = context->get_def_use_mgr();
1610
27.3k
  analysis::TypeManager* type_mgr = context->get_type_mgr();
1611
1612
27.3k
  analysis::Type* result_type = type_mgr->GetType(inst->type_id());
1613
27.3k
  if (result_type->AsVector() == nullptr) {
1614
23.7k
    if (result_index < inst->NumInOperands()) {
1615
23.7k
      uint32_t id = inst->GetSingleWordInOperand(result_index);
1616
23.7k
      return {Operand(SPV_OPERAND_TYPE_ID, {id})};
1617
23.7k
    }
1618
5
    return {};
1619
23.7k
  }
1620
1621
  // If the result type is a vector, then vector operands are concatenated.
1622
3.57k
  uint32_t total_element_count = 0;
1623
6.48k
  for (uint32_t idx = 0; idx < inst->NumInOperands(); ++idx) {
1624
6.48k
    uint32_t element_count =
1625
6.48k
        GetNumOfElementsContributedByOperand(context, inst, idx);
1626
6.48k
    total_element_count += element_count;
1627
6.48k
    if (result_index < total_element_count) {
1628
3.57k
      std::vector<Operand> operands;
1629
3.57k
      uint32_t id = inst->GetSingleWordInOperand(idx);
1630
3.57k
      Instruction* operand_def = def_use_mgr->GetDef(id);
1631
3.57k
      analysis::Type* operand_type = type_mgr->GetType(operand_def->type_id());
1632
1633
3.57k
      operands.push_back({SPV_OPERAND_TYPE_ID, {id}});
1634
3.57k
      if (operand_type->AsVector()) {
1635
0
        uint32_t start_index_of_id = total_element_count - element_count;
1636
0
        uint32_t index_into_id = result_index - start_index_of_id;
1637
0
        operands.push_back({SPV_OPERAND_TYPE_LITERAL_INTEGER, {index_into_id}});
1638
0
      }
1639
3.57k
      return operands;
1640
3.57k
    }
1641
6.48k
  }
1642
0
  return {};
1643
3.57k
}
1644
1645
bool CompositeConstructFeedingExtract(
1646
    IRContext* context, Instruction* inst,
1647
150k
    const std::vector<const analysis::Constant*>&) {
1648
  // If the input to an OpCompositeExtract is an OpCompositeConstruct,
1649
  // then we can simply use the appropriate element in the construction.
1650
150k
  assert(inst->opcode() == spv::Op::OpCompositeExtract &&
1651
150k
         "Wrong opcode.  Should be OpCompositeExtract.");
1652
150k
  analysis::DefUseManager* def_use_mgr = context->get_def_use_mgr();
1653
1654
  // If there are no index operands, then this rule cannot do anything.
1655
150k
  if (inst->NumInOperands() <= 1) {
1656
0
    return false;
1657
0
  }
1658
1659
150k
  uint32_t cid = inst->GetSingleWordInOperand(kExtractCompositeIdInIdx);
1660
150k
  Instruction* cinst = def_use_mgr->GetDef(cid);
1661
1662
150k
  if (cinst->opcode() != spv::Op::OpCompositeConstruct) {
1663
122k
    return false;
1664
122k
  }
1665
1666
27.3k
  uint32_t index_into_result = inst->GetSingleWordInOperand(1);
1667
27.3k
  std::vector<Operand> operands =
1668
27.3k
      GetExtractOperandsForElementOfCompositeConstruct(context, cinst,
1669
27.3k
                                                       index_into_result);
1670
1671
27.3k
  if (operands.empty()) {
1672
5
    return false;
1673
5
  }
1674
1675
  // Add the remaining indices for extraction.
1676
27.2k
  for (uint32_t i = 2; i < inst->NumInOperands(); ++i) {
1677
0
    operands.push_back(
1678
0
        {SPV_OPERAND_TYPE_LITERAL_INTEGER, {inst->GetSingleWordInOperand(i)}});
1679
0
  }
1680
1681
27.2k
  if (operands.size() == 1) {
1682
    // If there were no extra indices, then we have the final object.  No need
1683
    // to extract any more.
1684
27.2k
    inst->SetOpcode(spv::Op::OpCopyObject);
1685
27.2k
  }
1686
1687
27.2k
  inst->SetInOperands(std::move(operands));
1688
27.2k
  return true;
1689
27.3k
}
1690
1691
// Walks the indexes chain from |start| to |end| of an OpCompositeInsert or
1692
// OpCompositeExtract instruction, and returns the type id of the final element
1693
// being accessed. Returns 0 if a valid type could not be found.
1694
uint32_t GetElementType(uint32_t type_id, Instruction::iterator start,
1695
                        Instruction::iterator end,
1696
186k
                        const analysis::DefUseManager* def_use_manager) {
1697
186k
  for (auto index : make_range(std::move(start), std::move(end))) {
1698
72
    const Instruction* type_inst = def_use_manager->GetDef(type_id);
1699
72
    assert(index.type == SPV_OPERAND_TYPE_LITERAL_INTEGER &&
1700
72
           index.words.size() == 1);
1701
72
    if (type_inst->opcode() == spv::Op::OpTypeArray) {
1702
72
      type_id = type_inst->GetSingleWordInOperand(0);
1703
72
    } else if (type_inst->opcode() == spv::Op::OpTypeMatrix) {
1704
0
      type_id = type_inst->GetSingleWordInOperand(0);
1705
0
    } else if (type_inst->opcode() == spv::Op::OpTypeStruct) {
1706
0
      type_id = type_inst->GetSingleWordInOperand(index.words[0]);
1707
0
    } else {
1708
0
      return 0;
1709
0
    }
1710
72
  }
1711
186k
  return type_id;
1712
186k
}
1713
1714
// Returns true of |inst_1| and |inst_2| have the same indexes that will be used
1715
// to index into a composite object, excluding the last index.  The two
1716
// instructions must have the same opcode, and be either OpCompositeExtract or
1717
// OpCompositeInsert instructions.
1718
104M
bool HaveSameIndexesExceptForLast(Instruction* inst_1, Instruction* inst_2) {
1719
104M
  assert(inst_1->opcode() == inst_2->opcode() &&
1720
104M
         "Expecting the opcodes to be the same.");
1721
104M
  assert((inst_1->opcode() == spv::Op::OpCompositeInsert ||
1722
104M
          inst_1->opcode() == spv::Op::OpCompositeExtract) &&
1723
104M
         "Instructions must be OpCompositeInsert or OpCompositeExtract.");
1724
1725
104M
  if (inst_1->NumInOperands() != inst_2->NumInOperands()) {
1726
108
    return false;
1727
108
  }
1728
1729
104M
  uint32_t first_index_position =
1730
104M
      (inst_1->opcode() == spv::Op::OpCompositeInsert ? 2 : 1);
1731
104M
  for (uint32_t i = first_index_position; i < inst_1->NumInOperands() - 1;
1732
104M
       i++) {
1733
165
    if (inst_1->GetSingleWordInOperand(i) !=
1734
165
        inst_2->GetSingleWordInOperand(i)) {
1735
0
      return false;
1736
0
    }
1737
165
  }
1738
104M
  return true;
1739
104M
}
1740
1741
// If the OpCompositeConstruct is simply putting back together elements that
1742
// where extracted from the same source, we can simply reuse the source.
1743
//
1744
// This is a common code pattern because of the way that scalar replacement
1745
// works.
1746
bool CompositeExtractFeedingConstruct(
1747
    IRContext* context, Instruction* inst,
1748
135k
    const std::vector<const analysis::Constant*>&) {
1749
135k
  assert(inst->opcode() == spv::Op::OpCompositeConstruct &&
1750
135k
         "Wrong opcode.  Should be OpCompositeConstruct.");
1751
135k
  analysis::DefUseManager* def_use_mgr = context->get_def_use_mgr();
1752
135k
  uint32_t original_id = 0;
1753
1754
135k
  if (inst->NumInOperands() == 0) {
1755
    // The struct being constructed has no members.
1756
0
    return false;
1757
0
  }
1758
1759
  // Check each element to make sure they are:
1760
  // - extractions
1761
  // - extracting the same position they are inserting
1762
  // - all extract from the same id.
1763
135k
  Instruction* first_element_inst = nullptr;
1764
147k
  for (uint32_t i = 0; i < inst->NumInOperands(); ++i) {
1765
147k
    const uint32_t element_id = inst->GetSingleWordInOperand(i);
1766
147k
    Instruction* element_inst = def_use_mgr->GetDef(element_id);
1767
147k
    if (first_element_inst == nullptr) {
1768
135k
      first_element_inst = element_inst;
1769
135k
    }
1770
1771
147k
    if (element_inst->opcode() != spv::Op::OpCompositeExtract) {
1772
133k
      return false;
1773
133k
    }
1774
1775
13.4k
    if (!HaveSameIndexesExceptForLast(element_inst, first_element_inst)) {
1776
0
      return false;
1777
0
    }
1778
1779
13.4k
    if (element_inst->GetSingleWordInOperand(element_inst->NumInOperands() -
1780
13.4k
                                             1) != i) {
1781
1.73k
      return false;
1782
1.73k
    }
1783
1784
11.6k
    if (i == 0) {
1785
4.08k
      original_id =
1786
4.08k
          element_inst->GetSingleWordInOperand(kExtractCompositeIdInIdx);
1787
7.60k
    } else if (original_id !=
1788
7.60k
               element_inst->GetSingleWordInOperand(kExtractCompositeIdInIdx)) {
1789
248
      return false;
1790
248
    }
1791
11.6k
  }
1792
1793
  // The last check it to see that the object being extracted from is the
1794
  // correct type.
1795
162
  Instruction* original_inst = def_use_mgr->GetDef(original_id);
1796
162
  uint32_t original_type_id =
1797
162
      GetElementType(original_inst->type_id(), first_element_inst->begin() + 3,
1798
162
                     first_element_inst->end() - 1, def_use_mgr);
1799
1800
162
  if (inst->type_id() != original_type_id) {
1801
36
    return false;
1802
36
  }
1803
1804
126
  if (first_element_inst->NumInOperands() == 2) {
1805
    // Simplify by using the original object.
1806
126
    inst->SetOpcode(spv::Op::OpCopyObject);
1807
126
    inst->SetInOperands({{SPV_OPERAND_TYPE_ID, {original_id}}});
1808
126
    return true;
1809
126
  }
1810
1811
  // Copies the original id and all indexes except for the last to the new
1812
  // extract instruction.
1813
0
  inst->SetOpcode(spv::Op::OpCompositeExtract);
1814
0
  inst->SetInOperands(std::vector<Operand>(first_element_inst->begin() + 2,
1815
0
                                           first_element_inst->end() - 1));
1816
0
  return true;
1817
126
}
1818
1819
8.16k
FoldingRule InsertFeedingExtract() {
1820
8.16k
  return [](IRContext* context, Instruction* inst,
1821
415k
            const std::vector<const analysis::Constant*>&) {
1822
415k
    assert(inst->opcode() == spv::Op::OpCompositeExtract &&
1823
415k
           "Wrong opcode.  Should be OpCompositeExtract.");
1824
415k
    analysis::DefUseManager* def_use_mgr = context->get_def_use_mgr();
1825
415k
    uint32_t cid = inst->GetSingleWordInOperand(kExtractCompositeIdInIdx);
1826
415k
    Instruction* cinst = def_use_mgr->GetDef(cid);
1827
1828
415k
    if (cinst->opcode() != spv::Op::OpCompositeInsert) {
1829
150k
      return false;
1830
150k
    }
1831
1832
    // Find the first position where the list of insert and extract indicies
1833
    // differ, if at all.
1834
265k
    uint32_t i;
1835
415k
    for (i = 1; i < inst->NumInOperands(); ++i) {
1836
265k
      if (i + 1 >= cinst->NumInOperands()) {
1837
0
        break;
1838
0
      }
1839
1840
265k
      if (inst->GetSingleWordInOperand(i) !=
1841
265k
          cinst->GetSingleWordInOperand(i + 1)) {
1842
114k
        break;
1843
114k
      }
1844
265k
    }
1845
1846
    // We are extracting the element that was inserted.
1847
265k
    if (i == inst->NumInOperands() && i + 1 == cinst->NumInOperands()) {
1848
150k
      inst->SetOpcode(spv::Op::OpCopyObject);
1849
150k
      inst->SetInOperands(
1850
150k
          {{SPV_OPERAND_TYPE_ID,
1851
150k
            {cinst->GetSingleWordInOperand(kInsertObjectIdInIdx)}}});
1852
150k
      return true;
1853
150k
    }
1854
1855
    // Extracting the value that was inserted along with values for the base
1856
    // composite.  Cannot do anything.
1857
114k
    if (i == inst->NumInOperands()) {
1858
0
      return false;
1859
0
    }
1860
1861
    // Extracting an element of the value that was inserted.  Extract from
1862
    // that value directly.
1863
114k
    if (i + 1 == cinst->NumInOperands()) {
1864
0
      std::vector<Operand> operands;
1865
0
      operands.push_back(
1866
0
          {SPV_OPERAND_TYPE_ID,
1867
0
           {cinst->GetSingleWordInOperand(kInsertObjectIdInIdx)}});
1868
0
      for (; i < inst->NumInOperands(); ++i) {
1869
0
        operands.push_back({SPV_OPERAND_TYPE_LITERAL_INTEGER,
1870
0
                            {inst->GetSingleWordInOperand(i)}});
1871
0
      }
1872
0
      inst->SetInOperands(std::move(operands));
1873
0
      return true;
1874
0
    }
1875
1876
    // Extracting a value that is disjoint from the element being inserted.
1877
    // Rewrite the extract to use the composite input to the insert.
1878
114k
    std::vector<Operand> operands;
1879
114k
    operands.push_back(
1880
114k
        {SPV_OPERAND_TYPE_ID,
1881
114k
         {cinst->GetSingleWordInOperand(kInsertCompositeIdInIdx)}});
1882
229k
    for (i = 1; i < inst->NumInOperands(); ++i) {
1883
114k
      operands.push_back({SPV_OPERAND_TYPE_LITERAL_INTEGER,
1884
114k
                          {inst->GetSingleWordInOperand(i)}});
1885
114k
    }
1886
114k
    inst->SetInOperands(std::move(operands));
1887
114k
    return true;
1888
114k
  };
1889
8.16k
}
1890
1891
// When a VectorShuffle is feeding an Extract, we can extract from one of the
1892
// operands of the VectorShuffle.  We just need to adjust the index in the
1893
// extract instruction.
1894
8.16k
FoldingRule VectorShuffleFeedingExtract() {
1895
8.16k
  return [](IRContext* context, Instruction* inst,
1896
123k
            const std::vector<const analysis::Constant*>&) {
1897
123k
    assert(inst->opcode() == spv::Op::OpCompositeExtract &&
1898
123k
           "Wrong opcode.  Should be OpCompositeExtract.");
1899
123k
    analysis::DefUseManager* def_use_mgr = context->get_def_use_mgr();
1900
123k
    analysis::TypeManager* type_mgr = context->get_type_mgr();
1901
123k
    uint32_t cid = inst->GetSingleWordInOperand(kExtractCompositeIdInIdx);
1902
123k
    Instruction* cinst = def_use_mgr->GetDef(cid);
1903
1904
123k
    if (cinst->opcode() != spv::Op::OpVectorShuffle) {
1905
121k
      return false;
1906
121k
    }
1907
1908
    // Find the size of the first vector operand of the VectorShuffle
1909
1.01k
    Instruction* first_input =
1910
1.01k
        def_use_mgr->GetDef(cinst->GetSingleWordInOperand(0));
1911
1.01k
    analysis::Type* first_input_type =
1912
1.01k
        type_mgr->GetType(first_input->type_id());
1913
1.01k
    assert(first_input_type->AsVector() &&
1914
1.01k
           "Input to vector shuffle should be vectors.");
1915
1.01k
    uint32_t first_input_size = first_input_type->AsVector()->element_count();
1916
1917
    // Get index of the element the vector shuffle is placing in the position
1918
    // being extracted.
1919
1.01k
    uint32_t new_index =
1920
1.01k
        cinst->GetSingleWordInOperand(2 + inst->GetSingleWordInOperand(1));
1921
1922
    // Extracting an undefined value so fold this extract into an undef.
1923
1.01k
    const uint32_t undef_literal_value = 0xffffffff;
1924
1.01k
    if (new_index == undef_literal_value) {
1925
8
      inst->SetOpcode(spv::Op::OpUndef);
1926
8
      inst->SetInOperands({});
1927
8
      return true;
1928
8
    }
1929
1930
    // Get the id of the of the vector the elemtent comes from, and update the
1931
    // index if needed.
1932
1.00k
    uint32_t new_vector = 0;
1933
1.00k
    if (new_index < first_input_size) {
1934
700
      new_vector = cinst->GetSingleWordInOperand(0);
1935
700
    } else {
1936
308
      new_vector = cinst->GetSingleWordInOperand(1);
1937
308
      new_index -= first_input_size;
1938
308
    }
1939
1940
    // Update the extract instruction.
1941
1.00k
    inst->SetInOperand(kExtractCompositeIdInIdx, {new_vector});
1942
1.00k
    inst->SetInOperand(1, {new_index});
1943
1.00k
    return true;
1944
1.01k
  };
1945
8.16k
}
1946
1947
// When an FMix with is feeding an Extract that extracts an element whose
1948
// corresponding |a| in the FMix is 0 or 1, we can extract from one of the
1949
// operands of the FMix.
1950
8.16k
FoldingRule FMixFeedingExtract() {
1951
8.16k
  return [](IRContext* context, Instruction* inst,
1952
121k
            const std::vector<const analysis::Constant*>&) {
1953
121k
    assert(inst->opcode() == spv::Op::OpCompositeExtract &&
1954
121k
           "Wrong opcode.  Should be OpCompositeExtract.");
1955
121k
    analysis::DefUseManager* def_use_mgr = context->get_def_use_mgr();
1956
121k
    analysis::ConstantManager* const_mgr = context->get_constant_mgr();
1957
1958
121k
    uint32_t composite_id =
1959
121k
        inst->GetSingleWordInOperand(kExtractCompositeIdInIdx);
1960
121k
    Instruction* composite_inst = def_use_mgr->GetDef(composite_id);
1961
1962
121k
    if (composite_inst->opcode() != spv::Op::OpExtInst) {
1963
110k
      return false;
1964
110k
    }
1965
1966
11.2k
    uint32_t inst_set_id =
1967
11.2k
        context->get_feature_mgr()->GetExtInstImportId_GLSLstd450();
1968
1969
11.2k
    if (composite_inst->GetSingleWordInOperand(kExtInstSetIdInIdx) !=
1970
11.2k
            inst_set_id ||
1971
11.2k
        composite_inst->GetSingleWordInOperand(kExtInstInstructionInIdx) !=
1972
11.2k
            GLSLstd450FMix) {
1973
10.0k
      return false;
1974
10.0k
    }
1975
1976
    // Get the |a| for the FMix instruction.
1977
1.11k
    uint32_t a_id = composite_inst->GetSingleWordInOperand(kFMixAIdInIdx);
1978
1.11k
    std::unique_ptr<Instruction> a(inst->Clone(context));
1979
1.11k
    a->SetInOperand(kExtractCompositeIdInIdx, {a_id});
1980
1.11k
    context->get_instruction_folder().FoldInstruction(a.get());
1981
1982
1.11k
    if (a->opcode() != spv::Op::OpCopyObject) {
1983
0
      return false;
1984
0
    }
1985
1986
1.11k
    const analysis::Constant* a_const =
1987
1.11k
        const_mgr->FindDeclaredConstant(a->GetSingleWordInOperand(0));
1988
1989
1.11k
    if (!a_const) {
1990
805
      return false;
1991
805
    }
1992
1993
314
    bool use_x = false;
1994
1995
314
    assert(a_const->type()->AsFloat());
1996
314
    double element_value = a_const->GetValueAsDouble();
1997
314
    if (element_value == 0.0) {
1998
0
      use_x = true;
1999
314
    } else if (element_value == 1.0) {
2000
0
      use_x = false;
2001
314
    } else {
2002
314
      return false;
2003
314
    }
2004
2005
    // Get the id of the of the vector the element comes from.
2006
0
    uint32_t new_vector = 0;
2007
0
    if (use_x) {
2008
0
      new_vector = composite_inst->GetSingleWordInOperand(kFMixXIdInIdx);
2009
0
    } else {
2010
0
      new_vector = composite_inst->GetSingleWordInOperand(kFMixYIdInIdx);
2011
0
    }
2012
2013
    // Update the extract instruction.
2014
0
    inst->SetInOperand(kExtractCompositeIdInIdx, {new_vector});
2015
0
    return true;
2016
314
  };
2017
8.16k
}
2018
2019
// Returns the number of elements in the composite type |type|.  Returns 0 if
2020
// |type| is a scalar value. Return UINT32_MAX when the size is unknown at
2021
// compile time.
2022
186k
uint32_t GetNumberOfElements(const analysis::Type* type) {
2023
186k
  if (auto* vector_type = type->AsVector()) {
2024
40.7k
    return vector_type->element_count();
2025
40.7k
  }
2026
145k
  if (auto* matrix_type = type->AsMatrix()) {
2027
0
    return matrix_type->element_count();
2028
0
  }
2029
145k
  if (auto* struct_type = type->AsStruct()) {
2030
750
    return static_cast<uint32_t>(struct_type->element_types().size());
2031
750
  }
2032
145k
  if (auto* array_type = type->AsArray()) {
2033
145k
    if (array_type->length_info().words[0] ==
2034
145k
            analysis::Array::LengthInfo::kConstant &&
2035
145k
        array_type->length_info().words.size() == 2) {
2036
145k
      return array_type->length_info().words[1];
2037
145k
    }
2038
0
    return UINT32_MAX;
2039
145k
  }
2040
0
  return 0;
2041
145k
}
2042
2043
// Returns a map with the set of values that were inserted into an object by
2044
// the chain of OpCompositeInsertInstruction starting with |inst|.
2045
// The map will map the index to the value inserted at that index. An empty map
2046
// will be returned if the map could not be properly generated.
2047
186k
std::map<uint32_t, uint32_t> GetInsertedValues(Instruction* inst) {
2048
186k
  analysis::DefUseManager* def_use_mgr = inst->context()->get_def_use_mgr();
2049
186k
  std::map<uint32_t, uint32_t> values_inserted;
2050
186k
  Instruction* current_inst = inst;
2051
104M
  while (current_inst->opcode() == spv::Op::OpCompositeInsert) {
2052
104M
    if (current_inst->NumInOperands() > inst->NumInOperands()) {
2053
      // This is to catch the case
2054
      //   %2 = OpCompositeInsert %m2x2int %v2int_1_0 %m2x2int_undef 0
2055
      //   %3 = OpCompositeInsert %m2x2int %int_4 %2 0 0
2056
      //   %4 = OpCompositeInsert %m2x2int %v2int_2_3 %3 1
2057
      // In this case we cannot do a single construct to get the matrix.
2058
24
      uint32_t partially_inserted_element_index =
2059
24
          current_inst->GetSingleWordInOperand(inst->NumInOperands() - 1);
2060
24
      if (values_inserted.count(partially_inserted_element_index) == 0)
2061
24
        return {};
2062
24
    }
2063
104M
    if (HaveSameIndexesExceptForLast(inst, current_inst)) {
2064
104M
      values_inserted.insert(
2065
104M
          {current_inst->GetSingleWordInOperand(current_inst->NumInOperands() -
2066
104M
                                                1),
2067
104M
           current_inst->GetSingleWordInOperand(kInsertObjectIdInIdx)});
2068
104M
    }
2069
104M
    current_inst = def_use_mgr->GetDef(
2070
104M
        current_inst->GetSingleWordInOperand(kInsertCompositeIdInIdx));
2071
104M
  }
2072
186k
  return values_inserted;
2073
186k
}
2074
2075
// Returns true of there is an entry in |values_inserted| for every element of
2076
// |Type|.
2077
bool DoInsertedValuesCoverEntireObject(
2078
186k
    const analysis::Type* type, std::map<uint32_t, uint32_t>& values_inserted) {
2079
186k
  uint32_t container_size = GetNumberOfElements(type);
2080
186k
  if (container_size != values_inserted.size()) {
2081
183k
    return false;
2082
183k
  }
2083
2084
3.44k
  if (values_inserted.rbegin()->first >= container_size) {
2085
0
    return false;
2086
0
  }
2087
3.44k
  return true;
2088
3.44k
}
2089
2090
// Returns id of the type of the element that immediately contains the element
2091
// being inserted by the OpCompositeInsert instruction |inst|. Returns 0 if it
2092
// could not be found.
2093
186k
uint32_t GetContainerTypeId(Instruction* inst) {
2094
186k
  assert(inst->opcode() == spv::Op::OpCompositeInsert);
2095
186k
  analysis::DefUseManager* def_use_manager = inst->context()->get_def_use_mgr();
2096
186k
  uint32_t container_type_id = GetElementType(
2097
186k
      inst->type_id(), inst->begin() + 4, inst->end() - 1, def_use_manager);
2098
186k
  return container_type_id;
2099
186k
}
2100
2101
// Returns an OpCompositeConstruct instruction that build an object with
2102
// |type_id| out of the values in |values_inserted|.  Each value will be
2103
// placed at the index corresponding to the value.  The new instruction will
2104
// be placed before |insert_before|.
2105
Instruction* BuildCompositeConstruct(
2106
    uint32_t type_id, const std::map<uint32_t, uint32_t>& values_inserted,
2107
3.44k
    Instruction* insert_before) {
2108
3.44k
  InstructionBuilder ir_builder(
2109
3.44k
      insert_before->context(), insert_before,
2110
3.44k
      IRContext::kAnalysisDefUse | IRContext::kAnalysisInstrToBlockMapping);
2111
2112
3.44k
  std::vector<uint32_t> ids_in_order;
2113
6.89k
  for (auto it : values_inserted) {
2114
6.89k
    ids_in_order.push_back(it.second);
2115
6.89k
  }
2116
3.44k
  Instruction* construct =
2117
3.44k
      ir_builder.AddCompositeConstruct(type_id, ids_in_order);
2118
3.44k
  return construct;
2119
3.44k
}
2120
2121
// Replaces the OpCompositeInsert |inst| that inserts |construct| into the same
2122
// object as |inst| with final index removed.  If the resulting
2123
// OpCompositeInsert instruction would have no remaining indexes, the
2124
// instruction is replaced with an OpCopyObject instead.
2125
3.44k
void InsertConstructedObject(Instruction* inst, const Instruction* construct) {
2126
3.44k
  if (inst->NumInOperands() == 3) {
2127
3.44k
    inst->SetOpcode(spv::Op::OpCopyObject);
2128
3.44k
    inst->SetInOperands({{SPV_OPERAND_TYPE_ID, {construct->result_id()}}});
2129
3.44k
  } else {
2130
0
    inst->SetInOperand(kInsertObjectIdInIdx, {construct->result_id()});
2131
0
    inst->RemoveOperand(inst->NumOperands() - 1);
2132
0
  }
2133
3.44k
}
2134
2135
// Replaces a series of |OpCompositeInsert| instruction that cover the entire
2136
// object with an |OpCompositeConstruct|.
2137
bool CompositeInsertToCompositeConstruct(
2138
    IRContext* context, Instruction* inst,
2139
186k
    const std::vector<const analysis::Constant*>&) {
2140
186k
  assert(inst->opcode() == spv::Op::OpCompositeInsert &&
2141
186k
         "Wrong opcode.  Should be OpCompositeInsert.");
2142
186k
  if (inst->NumInOperands() < 3) return false;
2143
2144
186k
  std::map<uint32_t, uint32_t> values_inserted = GetInsertedValues(inst);
2145
186k
  uint32_t container_type_id = GetContainerTypeId(inst);
2146
186k
  if (container_type_id == 0) {
2147
0
    return false;
2148
0
  }
2149
2150
186k
  analysis::TypeManager* type_mgr = context->get_type_mgr();
2151
186k
  const analysis::Type* container_type = type_mgr->GetType(container_type_id);
2152
186k
  assert(container_type && "GetContainerTypeId returned a bad id.");
2153
186k
  if (!DoInsertedValuesCoverEntireObject(container_type, values_inserted)) {
2154
183k
    return false;
2155
183k
  }
2156
2157
3.44k
  Instruction* construct =
2158
3.44k
      BuildCompositeConstruct(container_type_id, values_inserted, inst);
2159
3.44k
  InsertConstructedObject(inst, construct);
2160
3.44k
  return true;
2161
186k
}
2162
2163
8.16k
FoldingRule RedundantPhi() {
2164
  // An OpPhi instruction where all values are the same or the result of the phi
2165
  // itself, can be replaced by the value itself.
2166
8.16k
  return [](IRContext*, Instruction* inst,
2167
223k
            const std::vector<const analysis::Constant*>&) {
2168
223k
    assert(inst->opcode() == spv::Op::OpPhi &&
2169
223k
           "Wrong opcode.  Should be OpPhi.");
2170
2171
223k
    uint32_t incoming_value = 0;
2172
2173
526k
    for (uint32_t i = 0; i < inst->NumInOperands(); i += 2) {
2174
451k
      uint32_t op_id = inst->GetSingleWordInOperand(i);
2175
451k
      if (op_id == inst->result_id()) {
2176
64.1k
        continue;
2177
64.1k
      }
2178
2179
387k
      if (incoming_value == 0) {
2180
223k
        incoming_value = op_id;
2181
223k
      } else if (op_id != incoming_value) {
2182
        // Found two possible value.  Can't simplify.
2183
148k
        return false;
2184
148k
      }
2185
387k
    }
2186
2187
74.7k
    if (incoming_value == 0) {
2188
      // Code looks invalid.  Don't do anything.
2189
0
      return false;
2190
0
    }
2191
2192
    // We have a single incoming value.  Simplify using that value.
2193
74.7k
    inst->SetOpcode(spv::Op::OpCopyObject);
2194
74.7k
    inst->SetInOperands({{SPV_OPERAND_TYPE_ID, {incoming_value}}});
2195
74.7k
    return true;
2196
74.7k
  };
2197
8.16k
}
2198
2199
8.16k
FoldingRule BitCastScalarOrVector() {
2200
8.16k
  return [](IRContext* context, Instruction* inst,
2201
8.16k
            const std::vector<const analysis::Constant*>& constants) {
2202
5.16k
    assert(inst->opcode() == spv::Op::OpBitcast && constants.size() == 1);
2203
5.16k
    if (constants[0] == nullptr) return false;
2204
2205
4.51k
    const analysis::Type* type =
2206
4.51k
        context->get_type_mgr()->GetType(inst->type_id());
2207
4.51k
    if (HasFloatingPoint(type) && !inst->IsFloatingPointFoldingAllowed())
2208
0
      return false;
2209
2210
4.51k
    analysis::ConstantManager* const_mgr = context->get_constant_mgr();
2211
4.51k
    std::vector<uint32_t> words =
2212
4.51k
        GetWordsFromNumericScalarOrVectorConstant(const_mgr, constants[0]);
2213
4.51k
    if (words.size() == 0) return false;
2214
2215
4.51k
    const analysis::Constant* bitcasted_constant =
2216
4.51k
        ConvertWordsToNumericScalarOrVectorConstant(const_mgr, words, type);
2217
4.51k
    if (!bitcasted_constant) return false;
2218
2219
1.26k
    auto new_feeder_id =
2220
1.26k
        const_mgr->GetDefiningInstruction(bitcasted_constant, inst->type_id())
2221
1.26k
            ->result_id();
2222
1.26k
    inst->SetOpcode(spv::Op::OpCopyObject);
2223
1.26k
    inst->SetInOperands({{SPV_OPERAND_TYPE_ID, {new_feeder_id}}});
2224
1.26k
    return true;
2225
4.51k
  };
2226
8.16k
}
2227
2228
8.16k
FoldingRule RedundantSelect() {
2229
  // An OpSelect instruction where both values are the same or the condition is
2230
  // constant can be replaced by one of the values
2231
8.16k
  return [](IRContext*, Instruction* inst,
2232
8.75k
            const std::vector<const analysis::Constant*>& constants) {
2233
8.75k
    assert(inst->opcode() == spv::Op::OpSelect &&
2234
8.75k
           "Wrong opcode.  Should be OpSelect.");
2235
8.75k
    assert(inst->NumInOperands() == 3);
2236
8.75k
    assert(constants.size() == 3);
2237
2238
8.75k
    uint32_t true_id = inst->GetSingleWordInOperand(1);
2239
8.75k
    uint32_t false_id = inst->GetSingleWordInOperand(2);
2240
2241
8.75k
    if (true_id == false_id) {
2242
      // Both results are the same, condition doesn't matter
2243
23
      inst->SetOpcode(spv::Op::OpCopyObject);
2244
23
      inst->SetInOperands({{SPV_OPERAND_TYPE_ID, {true_id}}});
2245
23
      return true;
2246
8.73k
    } else if (constants[0]) {
2247
756
      const analysis::Type* type = constants[0]->type();
2248
756
      if (type->AsBool()) {
2249
        // Scalar constant value, select the corresponding value.
2250
701
        inst->SetOpcode(spv::Op::OpCopyObject);
2251
701
        if (constants[0]->AsNullConstant() ||
2252
701
            !constants[0]->AsBoolConstant()->value()) {
2253
657
          inst->SetInOperands({{SPV_OPERAND_TYPE_ID, {false_id}}});
2254
657
        } else {
2255
44
          inst->SetInOperands({{SPV_OPERAND_TYPE_ID, {true_id}}});
2256
44
        }
2257
701
        return true;
2258
701
      } else {
2259
55
        assert(type->AsVector());
2260
55
        if (constants[0]->AsNullConstant()) {
2261
          // All values come from false id.
2262
0
          inst->SetOpcode(spv::Op::OpCopyObject);
2263
0
          inst->SetInOperands({{SPV_OPERAND_TYPE_ID, {false_id}}});
2264
0
          return true;
2265
55
        } else {
2266
          // Convert to a vector shuffle.
2267
55
          std::vector<Operand> ops;
2268
55
          ops.push_back({SPV_OPERAND_TYPE_ID, {true_id}});
2269
55
          ops.push_back({SPV_OPERAND_TYPE_ID, {false_id}});
2270
55
          const analysis::VectorConstant* vector_const =
2271
55
              constants[0]->AsVectorConstant();
2272
55
          uint32_t size =
2273
55
              static_cast<uint32_t>(vector_const->GetComponents().size());
2274
165
          for (uint32_t i = 0; i != size; ++i) {
2275
110
            const analysis::Constant* component =
2276
110
                vector_const->GetComponents()[i];
2277
110
            if (component->AsNullConstant() ||
2278
110
                !component->AsBoolConstant()->value()) {
2279
              // Selecting from the false vector which is the second input
2280
              // vector to the shuffle. Offset the index by |size|.
2281
0
              ops.push_back({SPV_OPERAND_TYPE_LITERAL_INTEGER, {i + size}});
2282
110
            } else {
2283
              // Selecting from true vector which is the first input vector to
2284
              // the shuffle.
2285
110
              ops.push_back({SPV_OPERAND_TYPE_LITERAL_INTEGER, {i}});
2286
110
            }
2287
110
          }
2288
2289
55
          inst->SetOpcode(spv::Op::OpVectorShuffle);
2290
55
          inst->SetInOperands(std::move(ops));
2291
55
          return true;
2292
55
        }
2293
55
      }
2294
756
    }
2295
2296
7.97k
    return false;
2297
8.75k
  };
2298
8.16k
}
2299
2300
enum class FloatConstantKind { Unknown, Zero, One };
2301
2302
5.84M
FloatConstantKind getFloatConstantKind(const analysis::Constant* constant) {
2303
5.84M
  if (constant == nullptr) {
2304
3.83M
    return FloatConstantKind::Unknown;
2305
3.83M
  }
2306
2307
2.01M
  assert(HasFloatingPoint(constant->type()) && "Unexpected constant type");
2308
2309
2.01M
  if (constant->AsNullConstant()) {
2310
928
    return FloatConstantKind::Zero;
2311
2.01M
  } else if (const analysis::VectorConstant* vc =
2312
2.01M
                 constant->AsVectorConstant()) {
2313
622k
    const std::vector<const analysis::Constant*>& components =
2314
622k
        vc->GetComponents();
2315
622k
    assert(!components.empty());
2316
2317
622k
    FloatConstantKind kind = getFloatConstantKind(components[0]);
2318
2319
1.19M
    for (size_t i = 1; i < components.size(); ++i) {
2320
722k
      if (getFloatConstantKind(components[i]) != kind) {
2321
151k
        return FloatConstantKind::Unknown;
2322
151k
      }
2323
722k
    }
2324
2325
470k
    return kind;
2326
1.39M
  } else if (const analysis::FloatConstant* fc = constant->AsFloatConstant()) {
2327
1.39M
    if (fc->IsZero()) return FloatConstantKind::Zero;
2328
2329
1.17M
    uint32_t width = fc->type()->AsFloat()->width();
2330
1.17M
    if (width != 32 && width != 64) return FloatConstantKind::Unknown;
2331
2332
1.17M
    double value = (width == 64) ? fc->GetDoubleValue() : fc->GetFloatValue();
2333
2334
1.17M
    if (value == 0.0) {
2335
15.8k
      return FloatConstantKind::Zero;
2336
1.15M
    } else if (value == 1.0) {
2337
9.08k
      return FloatConstantKind::One;
2338
1.14M
    } else {
2339
1.14M
      return FloatConstantKind::Unknown;
2340
1.14M
    }
2341
1.17M
  } else {
2342
0
    return FloatConstantKind::Unknown;
2343
0
  }
2344
2.01M
}
2345
2346
8.16k
FoldingRule RedundantFAdd() {
2347
8.16k
  return [](IRContext*, Instruction* inst,
2348
1.78M
            const std::vector<const analysis::Constant*>& constants) {
2349
1.78M
    assert(inst->opcode() == spv::Op::OpFAdd &&
2350
1.78M
           "Wrong opcode.  Should be OpFAdd.");
2351
1.78M
    assert(constants.size() == 2);
2352
2353
1.78M
    if (!inst->IsFloatingPointFoldingAllowed()) {
2354
45
      return false;
2355
45
    }
2356
2357
1.78M
    FloatConstantKind kind0 = getFloatConstantKind(constants[0]);
2358
1.78M
    FloatConstantKind kind1 = getFloatConstantKind(constants[1]);
2359
2360
1.78M
    if (kind0 == FloatConstantKind::Zero || kind1 == FloatConstantKind::Zero) {
2361
11.2k
      inst->SetOpcode(spv::Op::OpCopyObject);
2362
11.2k
      inst->SetInOperands({{SPV_OPERAND_TYPE_ID,
2363
11.2k
                            {inst->GetSingleWordInOperand(
2364
11.2k
                                kind0 == FloatConstantKind::Zero ? 1 : 0)}}});
2365
11.2k
      return true;
2366
11.2k
    }
2367
2368
1.77M
    return false;
2369
1.78M
  };
2370
8.16k
}
2371
2372
8.16k
FoldingRule RedundantFSub() {
2373
8.16k
  return [](IRContext*, Instruction* inst,
2374
183k
            const std::vector<const analysis::Constant*>& constants) {
2375
183k
    assert(inst->opcode() == spv::Op::OpFSub &&
2376
183k
           "Wrong opcode.  Should be OpFSub.");
2377
183k
    assert(constants.size() == 2);
2378
2379
183k
    if (!inst->IsFloatingPointFoldingAllowed()) {
2380
11
      return false;
2381
11
    }
2382
2383
183k
    FloatConstantKind kind0 = getFloatConstantKind(constants[0]);
2384
183k
    FloatConstantKind kind1 = getFloatConstantKind(constants[1]);
2385
2386
183k
    if (kind0 == FloatConstantKind::Zero) {
2387
11.1k
      inst->SetOpcode(spv::Op::OpFNegate);
2388
11.1k
      inst->SetInOperands(
2389
11.1k
          {{SPV_OPERAND_TYPE_ID, {inst->GetSingleWordInOperand(1)}}});
2390
11.1k
      return true;
2391
11.1k
    }
2392
2393
172k
    if (kind1 == FloatConstantKind::Zero) {
2394
586
      inst->SetOpcode(spv::Op::OpCopyObject);
2395
586
      inst->SetInOperands(
2396
586
          {{SPV_OPERAND_TYPE_ID, {inst->GetSingleWordInOperand(0)}}});
2397
586
      return true;
2398
586
    }
2399
2400
171k
    return false;
2401
172k
  };
2402
8.16k
}
2403
2404
8.16k
FoldingRule RedundantFMul() {
2405
8.16k
  return [](IRContext*, Instruction* inst,
2406
173k
            const std::vector<const analysis::Constant*>& constants) {
2407
173k
    assert(inst->opcode() == spv::Op::OpFMul &&
2408
173k
           "Wrong opcode.  Should be OpFMul.");
2409
173k
    assert(constants.size() == 2);
2410
2411
173k
    if (!inst->IsFloatingPointFoldingAllowed()) {
2412
1
      return false;
2413
1
    }
2414
2415
173k
    FloatConstantKind kind0 = getFloatConstantKind(constants[0]);
2416
173k
    FloatConstantKind kind1 = getFloatConstantKind(constants[1]);
2417
2418
173k
    if (kind0 == FloatConstantKind::Zero || kind1 == FloatConstantKind::Zero) {
2419
1.29k
      inst->SetOpcode(spv::Op::OpCopyObject);
2420
1.29k
      inst->SetInOperands({{SPV_OPERAND_TYPE_ID,
2421
1.29k
                            {inst->GetSingleWordInOperand(
2422
1.29k
                                kind0 == FloatConstantKind::Zero ? 0 : 1)}}});
2423
1.29k
      return true;
2424
1.29k
    }
2425
2426
172k
    if (kind0 == FloatConstantKind::One || kind1 == FloatConstantKind::One) {
2427
896
      inst->SetOpcode(spv::Op::OpCopyObject);
2428
896
      inst->SetInOperands({{SPV_OPERAND_TYPE_ID,
2429
896
                            {inst->GetSingleWordInOperand(
2430
896
                                kind0 == FloatConstantKind::One ? 1 : 0)}}});
2431
896
      return true;
2432
896
    }
2433
2434
171k
    return false;
2435
172k
  };
2436
8.16k
}
2437
2438
8.16k
FoldingRule RedundantFDiv() {
2439
8.16k
  return [](IRContext*, Instruction* inst,
2440
109k
            const std::vector<const analysis::Constant*>& constants) {
2441
109k
    assert(inst->opcode() == spv::Op::OpFDiv &&
2442
109k
           "Wrong opcode.  Should be OpFDiv.");
2443
109k
    assert(constants.size() == 2);
2444
2445
109k
    if (!inst->IsFloatingPointFoldingAllowed()) {
2446
52
      return false;
2447
52
    }
2448
2449
108k
    FloatConstantKind kind0 = getFloatConstantKind(constants[0]);
2450
108k
    FloatConstantKind kind1 = getFloatConstantKind(constants[1]);
2451
2452
108k
    if (kind0 == FloatConstantKind::Zero) {
2453
37
      inst->SetOpcode(spv::Op::OpCopyObject);
2454
37
      inst->SetInOperands(
2455
37
          {{SPV_OPERAND_TYPE_ID, {inst->GetSingleWordInOperand(0)}}});
2456
37
      return true;
2457
37
    }
2458
2459
108k
    if (kind1 == FloatConstantKind::One) {
2460
37
      inst->SetOpcode(spv::Op::OpCopyObject);
2461
37
      inst->SetInOperands(
2462
37
          {{SPV_OPERAND_TYPE_ID, {inst->GetSingleWordInOperand(0)}}});
2463
37
      return true;
2464
37
    }
2465
2466
108k
    return false;
2467
108k
  };
2468
8.16k
}
2469
2470
4.69k
FoldingRule RedundantFMix() {
2471
4.69k
  return [](IRContext* context, Instruction* inst,
2472
4.94k
            const std::vector<const analysis::Constant*>& constants) {
2473
4.94k
    assert(inst->opcode() == spv::Op::OpExtInst &&
2474
4.94k
           "Wrong opcode.  Should be OpExtInst.");
2475
2476
4.94k
    if (!inst->IsFloatingPointFoldingAllowed()) {
2477
0
      return false;
2478
0
    }
2479
2480
4.94k
    uint32_t instSetId =
2481
4.94k
        context->get_feature_mgr()->GetExtInstImportId_GLSLstd450();
2482
2483
4.94k
    if (inst->GetSingleWordInOperand(kExtInstSetIdInIdx) == instSetId &&
2484
4.94k
        inst->GetSingleWordInOperand(kExtInstInstructionInIdx) ==
2485
4.94k
            GLSLstd450FMix) {
2486
4.94k
      assert(constants.size() == 5);
2487
2488
4.94k
      FloatConstantKind kind4 = getFloatConstantKind(constants[4]);
2489
2490
4.94k
      if (kind4 == FloatConstantKind::Zero || kind4 == FloatConstantKind::One) {
2491
1
        inst->SetOpcode(spv::Op::OpCopyObject);
2492
1
        inst->SetInOperands(
2493
1
            {{SPV_OPERAND_TYPE_ID,
2494
1
              {inst->GetSingleWordInOperand(kind4 == FloatConstantKind::Zero
2495
1
                                                ? kFMixXIdInIdx
2496
1
                                                : kFMixYIdInIdx)}}});
2497
1
        return true;
2498
1
      }
2499
4.94k
    }
2500
2501
4.94k
    return false;
2502
4.94k
  };
2503
4.69k
}
2504
2505
// This rule handles addition of zero for integers.
2506
8.16k
FoldingRule RedundantIAdd() {
2507
8.16k
  return [](IRContext* context, Instruction* inst,
2508
47.9k
            const std::vector<const analysis::Constant*>& constants) {
2509
47.9k
    assert(inst->opcode() == spv::Op::OpIAdd &&
2510
47.9k
           "Wrong opcode. Should be OpIAdd.");
2511
2512
47.9k
    uint32_t operand = std::numeric_limits<uint32_t>::max();
2513
47.9k
    const analysis::Type* operand_type = nullptr;
2514
47.9k
    if (constants[0] && constants[0]->IsZero()) {
2515
271
      operand = inst->GetSingleWordInOperand(1);
2516
271
      operand_type = constants[0]->type();
2517
47.6k
    } else if (constants[1] && constants[1]->IsZero()) {
2518
1.60k
      operand = inst->GetSingleWordInOperand(0);
2519
1.60k
      operand_type = constants[1]->type();
2520
1.60k
    }
2521
2522
47.9k
    if (operand != std::numeric_limits<uint32_t>::max()) {
2523
1.87k
      const analysis::Type* inst_type =
2524
1.87k
          context->get_type_mgr()->GetType(inst->type_id());
2525
1.87k
      if (inst_type->IsSame(operand_type)) {
2526
1.86k
        inst->SetOpcode(spv::Op::OpCopyObject);
2527
1.86k
      } else {
2528
15
        inst->SetOpcode(spv::Op::OpBitcast);
2529
15
      }
2530
1.87k
      inst->SetInOperands({{SPV_OPERAND_TYPE_ID, {operand}}});
2531
1.87k
      return true;
2532
1.87k
    }
2533
46.0k
    return false;
2534
47.9k
  };
2535
8.16k
}
2536
2537
// This rule look for a dot with a constant vector containing a single 1 and
2538
// the rest 0s.  This is the same as doing an extract.
2539
8.16k
FoldingRule DotProductDoingExtract() {
2540
8.16k
  return [](IRContext* context, Instruction* inst,
2541
8.16k
            const std::vector<const analysis::Constant*>& constants) {
2542
73
    assert(inst->opcode() == spv::Op::OpDot &&
2543
73
           "Wrong opcode.  Should be OpDot.");
2544
2545
73
    analysis::ConstantManager* const_mgr = context->get_constant_mgr();
2546
2547
73
    if (!inst->IsFloatingPointFoldingAllowed()) {
2548
0
      return false;
2549
0
    }
2550
2551
219
    for (int i = 0; i < 2; ++i) {
2552
146
      if (!constants[i]) {
2553
109
        continue;
2554
109
      }
2555
2556
37
      const analysis::Vector* vector_type = constants[i]->type()->AsVector();
2557
37
      assert(vector_type && "Inputs to OpDot must be vectors.");
2558
37
      const analysis::Float* element_type =
2559
37
          vector_type->element_type()->AsFloat();
2560
37
      assert(element_type && "Inputs to OpDot must be vectors of floats.");
2561
37
      uint32_t element_width = element_type->width();
2562
37
      if (element_width != 32 && element_width != 64) {
2563
0
        return false;
2564
0
      }
2565
2566
37
      std::vector<const analysis::Constant*> components;
2567
37
      components = constants[i]->GetVectorComponents(const_mgr);
2568
2569
37
      constexpr uint32_t kNotFound = std::numeric_limits<uint32_t>::max();
2570
2571
37
      uint32_t component_with_one = kNotFound;
2572
37
      bool all_others_zero = true;
2573
37
      for (uint32_t j = 0; j < components.size(); ++j) {
2574
37
        const analysis::Constant* element = components[j];
2575
37
        double value =
2576
37
            (element_width == 32 ? element->GetFloat() : element->GetDouble());
2577
37
        if (value == 0.0) {
2578
0
          continue;
2579
37
        } else if (value == 1.0) {
2580
0
          if (component_with_one == kNotFound) {
2581
0
            component_with_one = j;
2582
0
          } else {
2583
0
            component_with_one = kNotFound;
2584
0
            break;
2585
0
          }
2586
37
        } else {
2587
37
          all_others_zero = false;
2588
37
          break;
2589
37
        }
2590
37
      }
2591
2592
37
      if (!all_others_zero || component_with_one == kNotFound) {
2593
37
        continue;
2594
37
      }
2595
2596
0
      std::vector<Operand> operands;
2597
0
      operands.push_back(
2598
0
          {SPV_OPERAND_TYPE_ID, {inst->GetSingleWordInOperand(1u - i)}});
2599
0
      operands.push_back(
2600
0
          {SPV_OPERAND_TYPE_LITERAL_INTEGER, {component_with_one}});
2601
2602
0
      inst->SetOpcode(spv::Op::OpCompositeExtract);
2603
0
      inst->SetInOperands(std::move(operands));
2604
0
      return true;
2605
37
    }
2606
73
    return false;
2607
73
  };
2608
8.16k
}
2609
2610
// If we are storing an undef, then we can remove the store.
2611
//
2612
// TODO: We can do something similar for OpImageWrite, but checking for volatile
2613
// is complicated.  Waiting to see if it is needed.
2614
8.16k
FoldingRule StoringUndef() {
2615
8.16k
  return [](IRContext* context, Instruction* inst,
2616
2.64M
            const std::vector<const analysis::Constant*>&) {
2617
2.64M
    assert(inst->opcode() == spv::Op::OpStore &&
2618
2.64M
           "Wrong opcode.  Should be OpStore.");
2619
2620
2.64M
    analysis::DefUseManager* def_use_mgr = context->get_def_use_mgr();
2621
2622
    // If this is a volatile store, the store cannot be removed.
2623
2.64M
    if (inst->NumInOperands() == 3) {
2624
5.89k
      if (inst->GetSingleWordInOperand(2) &
2625
5.89k
          uint32_t(spv::MemoryAccessMask::Volatile)) {
2626
4.12k
        return false;
2627
4.12k
      }
2628
5.89k
    }
2629
2630
2.63M
    uint32_t object_id = inst->GetSingleWordInOperand(kStoreObjectInIdx);
2631
2.63M
    Instruction* object_inst = def_use_mgr->GetDef(object_id);
2632
2.63M
    if (object_inst->opcode() == spv::Op::OpUndef) {
2633
16.0k
      inst->ToNop();
2634
16.0k
      return true;
2635
16.0k
    }
2636
2.62M
    return false;
2637
2.63M
  };
2638
8.16k
}
2639
2640
8.16k
FoldingRule VectorShuffleFeedingShuffle() {
2641
8.16k
  return [](IRContext* context, Instruction* inst,
2642
12.3k
            const std::vector<const analysis::Constant*>&) {
2643
12.3k
    assert(inst->opcode() == spv::Op::OpVectorShuffle &&
2644
12.3k
           "Wrong opcode.  Should be OpVectorShuffle.");
2645
2646
12.3k
    analysis::DefUseManager* def_use_mgr = context->get_def_use_mgr();
2647
12.3k
    analysis::TypeManager* type_mgr = context->get_type_mgr();
2648
2649
12.3k
    Instruction* feeding_shuffle_inst =
2650
12.3k
        def_use_mgr->GetDef(inst->GetSingleWordInOperand(0));
2651
12.3k
    analysis::Vector* op0_type =
2652
12.3k
        type_mgr->GetType(feeding_shuffle_inst->type_id())->AsVector();
2653
12.3k
    uint32_t op0_length = op0_type->element_count();
2654
2655
12.3k
    bool feeder_is_op0 = true;
2656
12.3k
    if (feeding_shuffle_inst->opcode() != spv::Op::OpVectorShuffle) {
2657
12.3k
      feeding_shuffle_inst =
2658
12.3k
          def_use_mgr->GetDef(inst->GetSingleWordInOperand(1));
2659
12.3k
      feeder_is_op0 = false;
2660
12.3k
    }
2661
2662
12.3k
    if (feeding_shuffle_inst->opcode() != spv::Op::OpVectorShuffle) {
2663
12.1k
      return false;
2664
12.1k
    }
2665
2666
241
    Instruction* feeder2 =
2667
241
        def_use_mgr->GetDef(feeding_shuffle_inst->GetSingleWordInOperand(0));
2668
241
    analysis::Vector* feeder_op0_type =
2669
241
        type_mgr->GetType(feeder2->type_id())->AsVector();
2670
241
    uint32_t feeder_op0_length = feeder_op0_type->element_count();
2671
2672
241
    uint32_t new_feeder_id = 0;
2673
241
    std::vector<Operand> new_operands;
2674
241
    new_operands.resize(
2675
241
        2, {SPV_OPERAND_TYPE_ID, {0}});  // Place holders for vector operands.
2676
241
    const uint32_t undef_literal = 0xffffffff;
2677
857
    for (uint32_t op = 2; op < inst->NumInOperands(); ++op) {
2678
635
      uint32_t component_index = inst->GetSingleWordInOperand(op);
2679
2680
      // Do not interpret the undefined value literal as coming from operand 1.
2681
635
      if (component_index != undef_literal &&
2682
635
          feeder_is_op0 == (component_index < op0_length)) {
2683
        // This component comes from the feeding_shuffle_inst.  Update
2684
        // |component_index| to be the index into the operand of the feeder.
2685
2686
        // Adjust component_index to get the index into the operands of the
2687
        // feeding_shuffle_inst.
2688
306
        if (component_index >= op0_length) {
2689
172
          component_index -= op0_length;
2690
172
        }
2691
306
        component_index =
2692
306
            feeding_shuffle_inst->GetSingleWordInOperand(component_index + 2);
2693
2694
        // Check if we are using a component from the first or second operand of
2695
        // the feeding instruction.
2696
306
        if (component_index < feeder_op0_length) {
2697
238
          if (new_feeder_id == 0) {
2698
            // First time through, save the id of the operand the element comes
2699
            // from.
2700
148
            new_feeder_id = feeding_shuffle_inst->GetSingleWordInOperand(0);
2701
148
          } else if (new_feeder_id !=
2702
90
                     feeding_shuffle_inst->GetSingleWordInOperand(0)) {
2703
            // We need both elements of the feeding_shuffle_inst, so we cannot
2704
            // fold.
2705
9
            return false;
2706
9
          }
2707
238
        } else if (component_index != undef_literal) {
2708
53
          if (new_feeder_id == 0) {
2709
            // First time through, save the id of the operand the element comes
2710
            // from.
2711
34
            new_feeder_id = feeding_shuffle_inst->GetSingleWordInOperand(1);
2712
34
          } else if (new_feeder_id !=
2713
19
                     feeding_shuffle_inst->GetSingleWordInOperand(1)) {
2714
            // We need both elements of the feeding_shuffle_inst, so we cannot
2715
            // fold.
2716
10
            return false;
2717
10
          }
2718
43
          component_index -= feeder_op0_length;
2719
43
        }
2720
2721
287
        if (!feeder_is_op0 && component_index != undef_literal) {
2722
170
          component_index += op0_length;
2723
170
        }
2724
287
      }
2725
616
      new_operands.push_back(
2726
616
          {SPV_OPERAND_TYPE_LITERAL_INTEGER, {component_index}});
2727
616
    }
2728
2729
222
    if (new_feeder_id == 0) {
2730
59
      analysis::ConstantManager* const_mgr = context->get_constant_mgr();
2731
59
      const analysis::Type* type =
2732
59
          type_mgr->GetType(feeding_shuffle_inst->type_id());
2733
59
      const analysis::Constant* null_const = const_mgr->GetConstant(type, {});
2734
59
      new_feeder_id =
2735
59
          const_mgr->GetDefiningInstruction(null_const, 0)->result_id();
2736
59
    }
2737
2738
222
    if (feeder_is_op0) {
2739
      // If the size of the first vector operand changed then the indices
2740
      // referring to the second operand need to be adjusted.
2741
64
      Instruction* new_feeder_inst = def_use_mgr->GetDef(new_feeder_id);
2742
64
      analysis::Type* new_feeder_type =
2743
64
          type_mgr->GetType(new_feeder_inst->type_id());
2744
64
      uint32_t new_op0_size = new_feeder_type->AsVector()->element_count();
2745
64
      int32_t adjustment = op0_length - new_op0_size;
2746
2747
64
      if (adjustment != 0) {
2748
144
        for (uint32_t i = 2; i < new_operands.size(); i++) {
2749
101
          uint32_t operand = inst->GetSingleWordInOperand(i);
2750
101
          if (operand >= op0_length && operand != undef_literal) {
2751
29
            new_operands[i].words[0] -= adjustment;
2752
29
          }
2753
101
        }
2754
43
      }
2755
2756
64
      new_operands[0].words[0] = new_feeder_id;
2757
64
      new_operands[1] = inst->GetInOperand(1);
2758
158
    } else {
2759
158
      new_operands[1].words[0] = new_feeder_id;
2760
158
      new_operands[0] = inst->GetInOperand(0);
2761
158
    }
2762
2763
222
    inst->SetInOperands(std::move(new_operands));
2764
222
    return true;
2765
241
  };
2766
8.16k
}
2767
2768
// Removes duplicate ids from the interface list of an OpEntryPoint
2769
// instruction.
2770
8.16k
FoldingRule RemoveRedundantOperands() {
2771
8.16k
  return [](IRContext*, Instruction* inst,
2772
8.16k
            const std::vector<const analysis::Constant*>&) {
2773
0
    assert(inst->opcode() == spv::Op::OpEntryPoint &&
2774
0
           "Wrong opcode.  Should be OpEntryPoint.");
2775
0
    bool has_redundant_operand = false;
2776
0
    std::unordered_set<uint32_t> seen_operands;
2777
0
    std::vector<Operand> new_operands;
2778
2779
0
    new_operands.emplace_back(inst->GetOperand(0));
2780
0
    new_operands.emplace_back(inst->GetOperand(1));
2781
0
    new_operands.emplace_back(inst->GetOperand(2));
2782
0
    for (uint32_t i = 3; i < inst->NumOperands(); ++i) {
2783
0
      if (seen_operands.insert(inst->GetSingleWordOperand(i)).second) {
2784
0
        new_operands.emplace_back(inst->GetOperand(i));
2785
0
      } else {
2786
0
        has_redundant_operand = true;
2787
0
      }
2788
0
    }
2789
2790
0
    if (!has_redundant_operand) {
2791
0
      return false;
2792
0
    }
2793
2794
0
    inst->SetInOperands(std::move(new_operands));
2795
0
    return true;
2796
0
  };
2797
8.16k
}
2798
2799
// If an image instruction's operand is a constant, updates the image operand
2800
// flag from Offset to ConstOffset.
2801
204k
FoldingRule UpdateImageOperands() {
2802
204k
  return [](IRContext*, Instruction* inst,
2803
638k
            const std::vector<const analysis::Constant*>& constants) {
2804
638k
    const auto opcode = inst->opcode();
2805
638k
    (void)opcode;
2806
638k
    assert((opcode == spv::Op::OpImageSampleImplicitLod ||
2807
638k
            opcode == spv::Op::OpImageSampleExplicitLod ||
2808
638k
            opcode == spv::Op::OpImageSampleDrefImplicitLod ||
2809
638k
            opcode == spv::Op::OpImageSampleDrefExplicitLod ||
2810
638k
            opcode == spv::Op::OpImageSampleProjImplicitLod ||
2811
638k
            opcode == spv::Op::OpImageSampleProjExplicitLod ||
2812
638k
            opcode == spv::Op::OpImageSampleProjDrefImplicitLod ||
2813
638k
            opcode == spv::Op::OpImageSampleProjDrefExplicitLod ||
2814
638k
            opcode == spv::Op::OpImageFetch ||
2815
638k
            opcode == spv::Op::OpImageGather ||
2816
638k
            opcode == spv::Op::OpImageDrefGather ||
2817
638k
            opcode == spv::Op::OpImageRead || opcode == spv::Op::OpImageWrite ||
2818
638k
            opcode == spv::Op::OpImageSparseSampleImplicitLod ||
2819
638k
            opcode == spv::Op::OpImageSparseSampleExplicitLod ||
2820
638k
            opcode == spv::Op::OpImageSparseSampleDrefImplicitLod ||
2821
638k
            opcode == spv::Op::OpImageSparseSampleDrefExplicitLod ||
2822
638k
            opcode == spv::Op::OpImageSparseSampleProjImplicitLod ||
2823
638k
            opcode == spv::Op::OpImageSparseSampleProjExplicitLod ||
2824
638k
            opcode == spv::Op::OpImageSparseSampleProjDrefImplicitLod ||
2825
638k
            opcode == spv::Op::OpImageSparseSampleProjDrefExplicitLod ||
2826
638k
            opcode == spv::Op::OpImageSparseFetch ||
2827
638k
            opcode == spv::Op::OpImageSparseGather ||
2828
638k
            opcode == spv::Op::OpImageSparseDrefGather ||
2829
638k
            opcode == spv::Op::OpImageSparseRead) &&
2830
638k
           "Wrong opcode.  Should be an image instruction.");
2831
2832
638k
    int32_t operand_index = ImageOperandsMaskInOperandIndex(inst);
2833
638k
    if (operand_index >= 0) {
2834
4
      auto image_operands = inst->GetSingleWordInOperand(operand_index);
2835
4
      if (image_operands & uint32_t(spv::ImageOperandsMask::Offset)) {
2836
0
        uint32_t offset_operand_index = operand_index + 1;
2837
0
        if (image_operands & uint32_t(spv::ImageOperandsMask::Bias))
2838
0
          offset_operand_index++;
2839
0
        if (image_operands & uint32_t(spv::ImageOperandsMask::Lod))
2840
0
          offset_operand_index++;
2841
0
        if (image_operands & uint32_t(spv::ImageOperandsMask::Grad))
2842
0
          offset_operand_index += 2;
2843
0
        assert(((image_operands &
2844
0
                 uint32_t(spv::ImageOperandsMask::ConstOffset)) == 0) &&
2845
0
               "Offset and ConstOffset may not be used together");
2846
0
        if (offset_operand_index < inst->NumOperands()) {
2847
0
          if (constants[offset_operand_index]) {
2848
0
            if (constants[offset_operand_index]->IsZero()) {
2849
0
              inst->RemoveInOperand(offset_operand_index);
2850
0
            } else {
2851
0
              image_operands = image_operands |
2852
0
                               uint32_t(spv::ImageOperandsMask::ConstOffset);
2853
0
            }
2854
0
            image_operands =
2855
0
                image_operands & ~uint32_t(spv::ImageOperandsMask::Offset);
2856
0
            inst->SetInOperand(operand_index, {image_operands});
2857
0
            return true;
2858
0
          }
2859
0
        }
2860
0
      }
2861
4
    }
2862
2863
638k
    return false;
2864
638k
  };
2865
204k
}
2866
2867
}  // namespace
2868
2869
8.16k
void FoldingRules::AddFoldingRules() {
2870
  // Add all folding rules to the list for the opcodes to which they apply.
2871
  // Note that the order in which rules are added to the list matters. If a rule
2872
  // applies to the instruction, the rest of the rules will not be attempted.
2873
  // Take that into consideration.
2874
8.16k
  rules_[spv::Op::OpBitcast].push_back(BitCastScalarOrVector());
2875
2876
8.16k
  rules_[spv::Op::OpCompositeConstruct].push_back(
2877
8.16k
      CompositeExtractFeedingConstruct);
2878
2879
8.16k
  rules_[spv::Op::OpCompositeExtract].push_back(InsertFeedingExtract());
2880
8.16k
  rules_[spv::Op::OpCompositeExtract].push_back(
2881
8.16k
      CompositeConstructFeedingExtract);
2882
8.16k
  rules_[spv::Op::OpCompositeExtract].push_back(VectorShuffleFeedingExtract());
2883
8.16k
  rules_[spv::Op::OpCompositeExtract].push_back(FMixFeedingExtract());
2884
2885
8.16k
  rules_[spv::Op::OpCompositeInsert].push_back(
2886
8.16k
      CompositeInsertToCompositeConstruct);
2887
2888
8.16k
  rules_[spv::Op::OpDot].push_back(DotProductDoingExtract());
2889
2890
8.16k
  rules_[spv::Op::OpEntryPoint].push_back(RemoveRedundantOperands());
2891
2892
8.16k
  rules_[spv::Op::OpFAdd].push_back(RedundantFAdd());
2893
8.16k
  rules_[spv::Op::OpFAdd].push_back(MergeAddNegateArithmetic());
2894
8.16k
  rules_[spv::Op::OpFAdd].push_back(MergeAddAddArithmetic());
2895
8.16k
  rules_[spv::Op::OpFAdd].push_back(MergeAddSubArithmetic());
2896
8.16k
  rules_[spv::Op::OpFAdd].push_back(MergeGenericAddSubArithmetic());
2897
8.16k
  rules_[spv::Op::OpFAdd].push_back(FactorAddMuls());
2898
2899
8.16k
  rules_[spv::Op::OpFDiv].push_back(RedundantFDiv());
2900
8.16k
  rules_[spv::Op::OpFDiv].push_back(ReciprocalFDiv());
2901
8.16k
  rules_[spv::Op::OpFDiv].push_back(MergeDivDivArithmetic());
2902
8.16k
  rules_[spv::Op::OpFDiv].push_back(MergeDivMulArithmetic());
2903
8.16k
  rules_[spv::Op::OpFDiv].push_back(MergeDivNegateArithmetic());
2904
2905
8.16k
  rules_[spv::Op::OpFMul].push_back(RedundantFMul());
2906
8.16k
  rules_[spv::Op::OpFMul].push_back(MergeMulMulArithmetic());
2907
8.16k
  rules_[spv::Op::OpFMul].push_back(MergeMulDivArithmetic());
2908
8.16k
  rules_[spv::Op::OpFMul].push_back(MergeMulNegateArithmetic());
2909
2910
8.16k
  rules_[spv::Op::OpFNegate].push_back(MergeNegateArithmetic());
2911
8.16k
  rules_[spv::Op::OpFNegate].push_back(MergeNegateAddSubArithmetic());
2912
8.16k
  rules_[spv::Op::OpFNegate].push_back(MergeNegateMulDivArithmetic());
2913
2914
8.16k
  rules_[spv::Op::OpFSub].push_back(RedundantFSub());
2915
8.16k
  rules_[spv::Op::OpFSub].push_back(MergeSubNegateArithmetic());
2916
8.16k
  rules_[spv::Op::OpFSub].push_back(MergeSubAddArithmetic());
2917
8.16k
  rules_[spv::Op::OpFSub].push_back(MergeSubSubArithmetic());
2918
2919
8.16k
  rules_[spv::Op::OpIAdd].push_back(RedundantIAdd());
2920
8.16k
  rules_[spv::Op::OpIAdd].push_back(MergeAddNegateArithmetic());
2921
8.16k
  rules_[spv::Op::OpIAdd].push_back(MergeAddAddArithmetic());
2922
8.16k
  rules_[spv::Op::OpIAdd].push_back(MergeAddSubArithmetic());
2923
8.16k
  rules_[spv::Op::OpIAdd].push_back(MergeGenericAddSubArithmetic());
2924
8.16k
  rules_[spv::Op::OpIAdd].push_back(FactorAddMuls());
2925
2926
8.16k
  rules_[spv::Op::OpIMul].push_back(IntMultipleBy1());
2927
8.16k
  rules_[spv::Op::OpIMul].push_back(MergeMulMulArithmetic());
2928
8.16k
  rules_[spv::Op::OpIMul].push_back(MergeMulNegateArithmetic());
2929
2930
8.16k
  rules_[spv::Op::OpISub].push_back(MergeSubNegateArithmetic());
2931
8.16k
  rules_[spv::Op::OpISub].push_back(MergeSubAddArithmetic());
2932
8.16k
  rules_[spv::Op::OpISub].push_back(MergeSubSubArithmetic());
2933
2934
8.16k
  rules_[spv::Op::OpPhi].push_back(RedundantPhi());
2935
2936
8.16k
  rules_[spv::Op::OpSNegate].push_back(MergeNegateArithmetic());
2937
8.16k
  rules_[spv::Op::OpSNegate].push_back(MergeNegateMulDivArithmetic());
2938
8.16k
  rules_[spv::Op::OpSNegate].push_back(MergeNegateAddSubArithmetic());
2939
2940
8.16k
  rules_[spv::Op::OpSelect].push_back(RedundantSelect());
2941
2942
8.16k
  rules_[spv::Op::OpStore].push_back(StoringUndef());
2943
2944
8.16k
  rules_[spv::Op::OpVectorShuffle].push_back(VectorShuffleFeedingShuffle());
2945
2946
8.16k
  rules_[spv::Op::OpImageSampleImplicitLod].push_back(UpdateImageOperands());
2947
8.16k
  rules_[spv::Op::OpImageSampleExplicitLod].push_back(UpdateImageOperands());
2948
8.16k
  rules_[spv::Op::OpImageSampleDrefImplicitLod].push_back(
2949
8.16k
      UpdateImageOperands());
2950
8.16k
  rules_[spv::Op::OpImageSampleDrefExplicitLod].push_back(
2951
8.16k
      UpdateImageOperands());
2952
8.16k
  rules_[spv::Op::OpImageSampleProjImplicitLod].push_back(
2953
8.16k
      UpdateImageOperands());
2954
8.16k
  rules_[spv::Op::OpImageSampleProjExplicitLod].push_back(
2955
8.16k
      UpdateImageOperands());
2956
8.16k
  rules_[spv::Op::OpImageSampleProjDrefImplicitLod].push_back(
2957
8.16k
      UpdateImageOperands());
2958
8.16k
  rules_[spv::Op::OpImageSampleProjDrefExplicitLod].push_back(
2959
8.16k
      UpdateImageOperands());
2960
8.16k
  rules_[spv::Op::OpImageFetch].push_back(UpdateImageOperands());
2961
8.16k
  rules_[spv::Op::OpImageGather].push_back(UpdateImageOperands());
2962
8.16k
  rules_[spv::Op::OpImageDrefGather].push_back(UpdateImageOperands());
2963
8.16k
  rules_[spv::Op::OpImageRead].push_back(UpdateImageOperands());
2964
8.16k
  rules_[spv::Op::OpImageWrite].push_back(UpdateImageOperands());
2965
8.16k
  rules_[spv::Op::OpImageSparseSampleImplicitLod].push_back(
2966
8.16k
      UpdateImageOperands());
2967
8.16k
  rules_[spv::Op::OpImageSparseSampleExplicitLod].push_back(
2968
8.16k
      UpdateImageOperands());
2969
8.16k
  rules_[spv::Op::OpImageSparseSampleDrefImplicitLod].push_back(
2970
8.16k
      UpdateImageOperands());
2971
8.16k
  rules_[spv::Op::OpImageSparseSampleDrefExplicitLod].push_back(
2972
8.16k
      UpdateImageOperands());
2973
8.16k
  rules_[spv::Op::OpImageSparseSampleProjImplicitLod].push_back(
2974
8.16k
      UpdateImageOperands());
2975
8.16k
  rules_[spv::Op::OpImageSparseSampleProjExplicitLod].push_back(
2976
8.16k
      UpdateImageOperands());
2977
8.16k
  rules_[spv::Op::OpImageSparseSampleProjDrefImplicitLod].push_back(
2978
8.16k
      UpdateImageOperands());
2979
8.16k
  rules_[spv::Op::OpImageSparseSampleProjDrefExplicitLod].push_back(
2980
8.16k
      UpdateImageOperands());
2981
8.16k
  rules_[spv::Op::OpImageSparseFetch].push_back(UpdateImageOperands());
2982
8.16k
  rules_[spv::Op::OpImageSparseGather].push_back(UpdateImageOperands());
2983
8.16k
  rules_[spv::Op::OpImageSparseDrefGather].push_back(UpdateImageOperands());
2984
8.16k
  rules_[spv::Op::OpImageSparseRead].push_back(UpdateImageOperands());
2985
2986
8.16k
  FeatureManager* feature_manager = context_->get_feature_mgr();
2987
  // Add rules for GLSLstd450
2988
8.16k
  uint32_t ext_inst_glslstd450_id =
2989
8.16k
      feature_manager->GetExtInstImportId_GLSLstd450();
2990
8.16k
  if (ext_inst_glslstd450_id != 0) {
2991
4.69k
    ext_rules_[{ext_inst_glslstd450_id, GLSLstd450FMix}].push_back(
2992
4.69k
        RedundantFMix());
2993
4.69k
  }
2994
8.16k
}
2995
}  // namespace opt
2996
}  // namespace spvtools