/src/spirv-tools/source/opt/folding_rules.cpp
Line | Count | Source (jump to first uncovered line) |
1 | | // Copyright (c) 2018 Google LLC |
2 | | // |
3 | | // Licensed under the Apache License, Version 2.0 (the "License"); |
4 | | // you may not use this file except in compliance with the License. |
5 | | // You may obtain a copy of the License at |
6 | | // |
7 | | // http://www.apache.org/licenses/LICENSE-2.0 |
8 | | // |
9 | | // Unless required by applicable law or agreed to in writing, software |
10 | | // distributed under the License is distributed on an "AS IS" BASIS, |
11 | | // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. |
12 | | // See the License for the specific language governing permissions and |
13 | | // limitations under the License. |
14 | | |
15 | | #include "source/opt/folding_rules.h" |
16 | | |
17 | | #include <limits> |
18 | | #include <memory> |
19 | | #include <utility> |
20 | | |
21 | | #include "ir_builder.h" |
22 | | #include "source/latest_version_glsl_std_450_header.h" |
23 | | #include "source/opt/ir_context.h" |
24 | | |
25 | | namespace spvtools { |
26 | | namespace opt { |
27 | | namespace { |
28 | | |
29 | | constexpr uint32_t kExtractCompositeIdInIdx = 0; |
30 | | constexpr uint32_t kInsertObjectIdInIdx = 0; |
31 | | constexpr uint32_t kInsertCompositeIdInIdx = 1; |
32 | | constexpr uint32_t kExtInstSetIdInIdx = 0; |
33 | | constexpr uint32_t kExtInstInstructionInIdx = 1; |
34 | | constexpr uint32_t kFMixXIdInIdx = 2; |
35 | | constexpr uint32_t kFMixYIdInIdx = 3; |
36 | | constexpr uint32_t kFMixAIdInIdx = 4; |
37 | | constexpr uint32_t kStoreObjectInIdx = 1; |
38 | | |
39 | | // Some image instructions may contain an "image operands" argument. |
40 | | // Returns the operand index for the "image operands". |
41 | | // Returns -1 if the instruction does not have image operands. |
42 | 638k | int32_t ImageOperandsMaskInOperandIndex(Instruction* inst) { |
43 | 638k | const auto opcode = inst->opcode(); |
44 | 638k | switch (opcode) { |
45 | 638k | case spv::Op::OpImageSampleImplicitLod: |
46 | 638k | case spv::Op::OpImageSampleExplicitLod: |
47 | 638k | case spv::Op::OpImageSampleProjImplicitLod: |
48 | 638k | case spv::Op::OpImageSampleProjExplicitLod: |
49 | 638k | case spv::Op::OpImageFetch: |
50 | 638k | case spv::Op::OpImageRead: |
51 | 638k | case spv::Op::OpImageSparseSampleImplicitLod: |
52 | 638k | case spv::Op::OpImageSparseSampleExplicitLod: |
53 | 638k | case spv::Op::OpImageSparseSampleProjImplicitLod: |
54 | 638k | case spv::Op::OpImageSparseSampleProjExplicitLod: |
55 | 638k | case spv::Op::OpImageSparseFetch: |
56 | 638k | case spv::Op::OpImageSparseRead: |
57 | 638k | return inst->NumOperands() > 4 ? 2 : -1; |
58 | 0 | case spv::Op::OpImageSampleDrefImplicitLod: |
59 | 0 | case spv::Op::OpImageSampleDrefExplicitLod: |
60 | 0 | case spv::Op::OpImageSampleProjDrefImplicitLod: |
61 | 0 | case spv::Op::OpImageSampleProjDrefExplicitLod: |
62 | 0 | case spv::Op::OpImageGather: |
63 | 0 | case spv::Op::OpImageDrefGather: |
64 | 0 | case spv::Op::OpImageSparseSampleDrefImplicitLod: |
65 | 0 | case spv::Op::OpImageSparseSampleDrefExplicitLod: |
66 | 0 | case spv::Op::OpImageSparseSampleProjDrefImplicitLod: |
67 | 0 | case spv::Op::OpImageSparseSampleProjDrefExplicitLod: |
68 | 0 | case spv::Op::OpImageSparseGather: |
69 | 0 | case spv::Op::OpImageSparseDrefGather: |
70 | 0 | return inst->NumOperands() > 5 ? 3 : -1; |
71 | 0 | case spv::Op::OpImageWrite: |
72 | 0 | return inst->NumOperands() > 3 ? 3 : -1; |
73 | 0 | default: |
74 | 0 | return -1; |
75 | 638k | } |
76 | 638k | } |
77 | | |
78 | | // Returns the element width of |type|. |
79 | 10.7M | uint32_t ElementWidth(const analysis::Type* type) { |
80 | 10.7M | if (const analysis::Vector* vec_type = type->AsVector()) { |
81 | 3.85M | return ElementWidth(vec_type->element_type()); |
82 | 6.89M | } else if (const analysis::Float* float_type = type->AsFloat()) { |
83 | 6.64M | return float_type->width(); |
84 | 6.64M | } else { |
85 | 251k | assert(type->AsInteger()); |
86 | 251k | return type->AsInteger()->width(); |
87 | 251k | } |
88 | 10.7M | } |
89 | | |
90 | | // Returns true if |type| is Float or a vector of Float. |
91 | 12.3M | bool HasFloatingPoint(const analysis::Type* type) { |
92 | 12.3M | if (type->AsFloat()) { |
93 | 5.84M | return true; |
94 | 6.49M | } else if (const analysis::Vector* vec_type = type->AsVector()) { |
95 | 6.13M | return vec_type->element_type()->AsFloat() != nullptr; |
96 | 6.13M | } |
97 | | |
98 | 354k | return false; |
99 | 12.3M | } |
100 | | |
101 | | // Returns false if |val| is NaN, infinite or subnormal. |
102 | | template <typename T> |
103 | 133k | bool IsValidResult(T val) { |
104 | 133k | int classified = std::fpclassify(val); |
105 | 133k | switch (classified) { |
106 | 18.8k | case FP_NAN: |
107 | 37.2k | case FP_INFINITE: |
108 | 41.4k | case FP_SUBNORMAL: |
109 | 41.4k | return false; |
110 | 92.5k | default: |
111 | 92.5k | return true; |
112 | 133k | } |
113 | 133k | } Unexecuted instantiation: folding_rules.cpp:bool spvtools::opt::(anonymous namespace)::IsValidResult<double>(double) folding_rules.cpp:bool spvtools::opt::(anonymous namespace)::IsValidResult<float>(float) Line | Count | Source | 103 | 133k | bool IsValidResult(T val) { | 104 | 133k | int classified = std::fpclassify(val); | 105 | 133k | switch (classified) { | 106 | 18.8k | case FP_NAN: | 107 | 37.2k | case FP_INFINITE: | 108 | 41.4k | case FP_SUBNORMAL: | 109 | 41.4k | return false; | 110 | 92.5k | default: | 111 | 92.5k | return true; | 112 | 133k | } | 113 | 133k | } |
|
114 | | |
115 | | // Returns true if `type` is a cooperative matrix. |
116 | 6.88M | bool IsCooperativeMatrix(const analysis::Type* type) { |
117 | 6.88M | return type->kind() == analysis::Type::kCooperativeMatrixKHR || |
118 | 6.88M | type->kind() == analysis::Type::kCooperativeMatrixNV; |
119 | 6.88M | } |
120 | | |
121 | | const analysis::Constant* ConstInput( |
122 | 6.83M | const std::vector<const analysis::Constant*>& constants) { |
123 | 6.83M | return constants[0] ? constants[0] : constants[1]; |
124 | 6.83M | } |
125 | | |
126 | | Instruction* NonConstInput(IRContext* context, const analysis::Constant* c, |
127 | 1.97M | Instruction* inst) { |
128 | 1.97M | uint32_t in_op = c ? 1u : 0u; |
129 | 1.97M | return context->get_def_use_mgr()->GetDef( |
130 | 1.97M | inst->GetSingleWordInOperand(in_op)); |
131 | 1.97M | } |
132 | | |
133 | 0 | std::vector<uint32_t> ExtractInts(uint64_t val) { |
134 | 0 | std::vector<uint32_t> words; |
135 | 0 | words.push_back(static_cast<uint32_t>(val)); |
136 | 0 | words.push_back(static_cast<uint32_t>(val >> 32)); |
137 | 0 | return words; |
138 | 0 | } |
139 | | |
140 | | std::vector<uint32_t> GetWordsFromScalarIntConstant( |
141 | 4.50k | const analysis::IntConstant* c) { |
142 | 4.50k | assert(c != nullptr); |
143 | 4.50k | uint32_t width = c->type()->AsInteger()->width(); |
144 | 4.50k | assert(width == 8 || width == 16 || width == 32 || width == 64); |
145 | 4.50k | if (width == 64) { |
146 | 0 | uint64_t uval = static_cast<uint64_t>(c->GetU64()); |
147 | 0 | return ExtractInts(uval); |
148 | 0 | } |
149 | | // Section 2.2.1 of the SPIR-V spec guarantees that all integer types |
150 | | // smaller than 32-bits are automatically zero or sign extended to 32-bits. |
151 | 4.50k | return {c->GetU32BitValue()}; |
152 | 4.50k | } |
153 | | |
154 | | std::vector<uint32_t> GetWordsFromScalarFloatConstant( |
155 | 12 | const analysis::FloatConstant* c) { |
156 | 12 | assert(c != nullptr); |
157 | 12 | uint32_t width = c->type()->AsFloat()->width(); |
158 | 12 | assert(width == 16 || width == 32 || width == 64); |
159 | 12 | if (width == 64) { |
160 | 0 | utils::FloatProxy<double> result(c->GetDouble()); |
161 | 0 | return result.GetWords(); |
162 | 0 | } |
163 | | // Section 2.2.1 of the SPIR-V spec guarantees that all floating-point types |
164 | | // smaller than 32-bits are automatically zero extended to 32-bits. |
165 | 12 | return {c->GetU32BitValue()}; |
166 | 12 | } |
167 | | |
168 | | std::vector<uint32_t> GetWordsFromNumericScalarOrVectorConstant( |
169 | 4.52k | analysis::ConstantManager* const_mgr, const analysis::Constant* c) { |
170 | 4.52k | if (const auto* float_constant = c->AsFloatConstant()) { |
171 | 12 | return GetWordsFromScalarFloatConstant(float_constant); |
172 | 4.51k | } else if (const auto* int_constant = c->AsIntConstant()) { |
173 | 4.50k | return GetWordsFromScalarIntConstant(int_constant); |
174 | 4.50k | } else if (const auto* vec_constant = c->AsVectorConstant()) { |
175 | 3 | std::vector<uint32_t> words; |
176 | 12 | for (const auto* comp : vec_constant->GetComponents()) { |
177 | 12 | auto comp_in_words = |
178 | 12 | GetWordsFromNumericScalarOrVectorConstant(const_mgr, comp); |
179 | 12 | words.insert(words.end(), comp_in_words.begin(), comp_in_words.end()); |
180 | 12 | } |
181 | 3 | return words; |
182 | 3 | } |
183 | 8 | return {}; |
184 | 4.52k | } |
185 | | |
186 | | const analysis::Constant* ConvertWordsToNumericScalarOrVectorConstant( |
187 | | analysis::ConstantManager* const_mgr, const std::vector<uint32_t>& words, |
188 | 4.51k | const analysis::Type* type) { |
189 | 4.51k | const spvtools::opt::analysis::Integer* int_type = type->AsInteger(); |
190 | | |
191 | 4.51k | if (int_type && int_type->width() <= 32) { |
192 | 225 | assert(words.size() == 1); |
193 | 225 | return const_mgr->GenerateIntegerConstant(int_type, words[0]); |
194 | 225 | } |
195 | | |
196 | 4.29k | if (int_type || type->AsFloat()) return const_mgr->GetConstant(type, words); |
197 | 3.25k | if (const auto* vec_type = type->AsVector()) |
198 | 3 | return const_mgr->GetNumericVectorConstantWithWords(vec_type, words); |
199 | 3.24k | return nullptr; |
200 | 3.25k | } |
201 | | |
202 | | // Returns the negation of |c|. |c| must be a 32 or 64 bit floating point |
203 | | // constant. |
204 | | uint32_t NegateFloatingPointConstant(analysis::ConstantManager* const_mgr, |
205 | 45 | const analysis::Constant* c) { |
206 | 45 | assert(c); |
207 | 45 | assert(c->type()->AsFloat()); |
208 | 45 | uint32_t width = c->type()->AsFloat()->width(); |
209 | 45 | assert(width == 32 || width == 64); |
210 | 45 | std::vector<uint32_t> words; |
211 | 45 | if (width == 64) { |
212 | 0 | utils::FloatProxy<double> result(c->GetDouble() * -1.0); |
213 | 0 | words = result.GetWords(); |
214 | 45 | } else { |
215 | 45 | utils::FloatProxy<float> result(c->GetFloat() * -1.0f); |
216 | 45 | words = result.GetWords(); |
217 | 45 | } |
218 | | |
219 | 45 | const analysis::Constant* negated_const = |
220 | 45 | const_mgr->GetConstant(c->type(), std::move(words)); |
221 | 45 | return const_mgr->GetDefiningInstruction(negated_const)->result_id(); |
222 | 45 | } |
223 | | |
224 | | // Negates the integer constant |c|. Returns the id of the defining instruction. |
225 | | uint32_t NegateIntegerConstant(analysis::ConstantManager* const_mgr, |
226 | 31 | const analysis::Constant* c) { |
227 | 31 | assert(c); |
228 | 31 | assert(c->type()->AsInteger()); |
229 | 31 | uint32_t width = c->type()->AsInteger()->width(); |
230 | 31 | assert(width == 32 || width == 64); |
231 | 31 | std::vector<uint32_t> words; |
232 | 31 | if (width == 64) { |
233 | 0 | uint64_t uval = static_cast<uint64_t>(0 - c->GetU64()); |
234 | 0 | words = ExtractInts(uval); |
235 | 31 | } else { |
236 | 31 | words.push_back(static_cast<uint32_t>(0 - c->GetU32())); |
237 | 31 | } |
238 | | |
239 | 31 | const analysis::Constant* negated_const = |
240 | 31 | const_mgr->GetConstant(c->type(), std::move(words)); |
241 | 31 | return const_mgr->GetDefiningInstruction(negated_const)->result_id(); |
242 | 31 | } |
243 | | |
244 | | // Negates the vector constant |c|. Returns the id of the defining instruction. |
245 | | uint32_t NegateVectorConstant(analysis::ConstantManager* const_mgr, |
246 | 0 | const analysis::Constant* c) { |
247 | 0 | assert(const_mgr && c); |
248 | 0 | assert(c->type()->AsVector()); |
249 | 0 | if (c->AsNullConstant()) { |
250 | | // 0.0 vs -0.0 shouldn't matter. |
251 | 0 | return const_mgr->GetDefiningInstruction(c)->result_id(); |
252 | 0 | } else { |
253 | 0 | const analysis::Type* component_type = |
254 | 0 | c->AsVectorConstant()->component_type(); |
255 | 0 | std::vector<uint32_t> words; |
256 | 0 | for (auto& comp : c->AsVectorConstant()->GetComponents()) { |
257 | 0 | if (component_type->AsFloat()) { |
258 | 0 | words.push_back(NegateFloatingPointConstant(const_mgr, comp)); |
259 | 0 | } else { |
260 | 0 | assert(component_type->AsInteger()); |
261 | 0 | words.push_back(NegateIntegerConstant(const_mgr, comp)); |
262 | 0 | } |
263 | 0 | } |
264 | | |
265 | 0 | const analysis::Constant* negated_const = |
266 | 0 | const_mgr->GetConstant(c->type(), std::move(words)); |
267 | 0 | return const_mgr->GetDefiningInstruction(negated_const)->result_id(); |
268 | 0 | } |
269 | 0 | } |
270 | | |
271 | | // Negates |c|. Returns the id of the defining instruction. |
272 | | uint32_t NegateConstant(analysis::ConstantManager* const_mgr, |
273 | 76 | const analysis::Constant* c) { |
274 | 76 | if (c->type()->AsVector()) { |
275 | 0 | return NegateVectorConstant(const_mgr, c); |
276 | 76 | } else if (c->type()->AsFloat()) { |
277 | 45 | return NegateFloatingPointConstant(const_mgr, c); |
278 | 45 | } else { |
279 | 31 | assert(c->type()->AsInteger()); |
280 | 31 | return NegateIntegerConstant(const_mgr, c); |
281 | 31 | } |
282 | 76 | } |
283 | | |
284 | | // Takes the reciprocal of |c|. |c|'s type must be Float or a vector of Float. |
285 | | // Returns 0 if the reciprocal is NaN, infinite or subnormal. |
286 | | uint32_t Reciprocal(analysis::ConstantManager* const_mgr, |
287 | 137k | const analysis::Constant* c) { |
288 | 137k | assert(const_mgr && c); |
289 | 137k | assert(c->type()->AsFloat()); |
290 | | |
291 | 137k | uint32_t width = c->type()->AsFloat()->width(); |
292 | 137k | assert(width == 32 || width == 64); |
293 | 137k | std::vector<uint32_t> words; |
294 | | |
295 | 137k | if (c->IsZero()) { |
296 | 16.8k | return 0; |
297 | 16.8k | } |
298 | | |
299 | 120k | if (width == 64) { |
300 | 0 | spvtools::utils::FloatProxy<double> result(1.0 / c->GetDouble()); |
301 | 0 | if (!IsValidResult(result.getAsFloat())) return 0; |
302 | 0 | words = result.GetWords(); |
303 | 120k | } else { |
304 | 120k | spvtools::utils::FloatProxy<float> result(1.0f / c->GetFloat()); |
305 | 120k | if (!IsValidResult(result.getAsFloat())) return 0; |
306 | 83.0k | words = result.GetWords(); |
307 | 83.0k | } |
308 | | |
309 | 83.0k | const analysis::Constant* negated_const = |
310 | 83.0k | const_mgr->GetConstant(c->type(), std::move(words)); |
311 | 83.0k | return const_mgr->GetDefiningInstruction(negated_const)->result_id(); |
312 | 120k | } |
313 | | |
314 | | // Replaces fdiv where second operand is constant with fmul. |
315 | 8.16k | FoldingRule ReciprocalFDiv() { |
316 | 8.16k | return [](IRContext* context, Instruction* inst, |
317 | 108k | const std::vector<const analysis::Constant*>& constants) { |
318 | 108k | assert(inst->opcode() == spv::Op::OpFDiv); |
319 | 108k | analysis::ConstantManager* const_mgr = context->get_constant_mgr(); |
320 | 108k | const analysis::Type* type = |
321 | 108k | context->get_type_mgr()->GetType(inst->type_id()); |
322 | | |
323 | 108k | if (IsCooperativeMatrix(type)) { |
324 | 0 | return false; |
325 | 0 | } |
326 | | |
327 | 108k | if (!inst->IsFloatingPointFoldingAllowed()) return false; |
328 | | |
329 | 108k | uint32_t width = ElementWidth(type); |
330 | 108k | if (width != 32 && width != 64) return false; |
331 | | |
332 | 108k | if (constants[1] != nullptr) { |
333 | 87.3k | uint32_t id = 0; |
334 | 87.3k | if (const analysis::VectorConstant* vector_const = |
335 | 87.3k | constants[1]->AsVectorConstant()) { |
336 | 83.6k | std::vector<uint32_t> neg_ids; |
337 | 133k | for (auto& comp : vector_const->GetComponents()) { |
338 | 133k | id = Reciprocal(const_mgr, comp); |
339 | 133k | if (id == 0) return false; |
340 | 81.5k | neg_ids.push_back(id); |
341 | 81.5k | } |
342 | 31.2k | const analysis::Constant* negated_const = |
343 | 31.2k | const_mgr->GetConstant(constants[1]->type(), std::move(neg_ids)); |
344 | 31.2k | id = const_mgr->GetDefiningInstruction(negated_const)->result_id(); |
345 | 31.2k | } else if (constants[1]->AsFloatConstant()) { |
346 | 3.74k | id = Reciprocal(const_mgr, constants[1]); |
347 | 3.74k | if (id == 0) return false; |
348 | 3.74k | } else { |
349 | | // Don't fold a null constant. |
350 | 0 | return false; |
351 | 0 | } |
352 | 32.7k | inst->SetOpcode(spv::Op::OpFMul); |
353 | 32.7k | inst->SetInOperands( |
354 | 32.7k | {{SPV_OPERAND_TYPE_ID, {inst->GetSingleWordInOperand(0u)}}, |
355 | 32.7k | {SPV_OPERAND_TYPE_ID, {id}}}); |
356 | 32.7k | return true; |
357 | 87.3k | } |
358 | | |
359 | 21.5k | return false; |
360 | 108k | }; |
361 | 8.16k | } |
362 | | |
363 | | // Elides consecutive negate instructions. |
364 | 16.3k | FoldingRule MergeNegateArithmetic() { |
365 | 16.3k | return [](IRContext* context, Instruction* inst, |
366 | 24.7k | const std::vector<const analysis::Constant*>& constants) { |
367 | 24.7k | assert(inst->opcode() == spv::Op::OpFNegate || |
368 | 24.7k | inst->opcode() == spv::Op::OpSNegate); |
369 | 24.7k | (void)constants; |
370 | 24.7k | const analysis::Type* type = |
371 | 24.7k | context->get_type_mgr()->GetType(inst->type_id()); |
372 | 24.7k | if (HasFloatingPoint(type) && !inst->IsFloatingPointFoldingAllowed()) |
373 | 0 | return false; |
374 | | |
375 | 24.7k | Instruction* op_inst = |
376 | 24.7k | context->get_def_use_mgr()->GetDef(inst->GetSingleWordInOperand(0u)); |
377 | 24.7k | if (HasFloatingPoint(type) && !op_inst->IsFloatingPointFoldingAllowed()) |
378 | 0 | return false; |
379 | | |
380 | 24.7k | if (op_inst->opcode() == inst->opcode()) { |
381 | | // Elide negates. |
382 | 0 | inst->SetOpcode(spv::Op::OpCopyObject); |
383 | 0 | inst->SetInOperands( |
384 | 0 | {{SPV_OPERAND_TYPE_ID, {op_inst->GetSingleWordInOperand(0u)}}}); |
385 | 0 | return true; |
386 | 0 | } |
387 | | |
388 | 24.7k | return false; |
389 | 24.7k | }; |
390 | 16.3k | } |
391 | | |
392 | | // Merges negate into a mul or div operation if that operation contains a |
393 | | // constant operand. |
394 | | // Cases: |
395 | | // -(x * 2) = x * -2 |
396 | | // -(2 * x) = x * -2 |
397 | | // -(x / 2) = x / -2 |
398 | | // -(2 / x) = -2 / x |
399 | 16.3k | FoldingRule MergeNegateMulDivArithmetic() { |
400 | 16.3k | return [](IRContext* context, Instruction* inst, |
401 | 24.6k | const std::vector<const analysis::Constant*>& constants) { |
402 | 24.6k | assert(inst->opcode() == spv::Op::OpFNegate || |
403 | 24.6k | inst->opcode() == spv::Op::OpSNegate); |
404 | 24.6k | (void)constants; |
405 | 24.6k | analysis::ConstantManager* const_mgr = context->get_constant_mgr(); |
406 | 24.6k | const analysis::Type* type = |
407 | 24.6k | context->get_type_mgr()->GetType(inst->type_id()); |
408 | | |
409 | 24.6k | if (IsCooperativeMatrix(type)) { |
410 | 0 | return false; |
411 | 0 | } |
412 | | |
413 | 24.6k | if (HasFloatingPoint(type) && !inst->IsFloatingPointFoldingAllowed()) |
414 | 0 | return false; |
415 | | |
416 | 24.6k | Instruction* op_inst = |
417 | 24.6k | context->get_def_use_mgr()->GetDef(inst->GetSingleWordInOperand(0u)); |
418 | 24.6k | if (HasFloatingPoint(type) && !op_inst->IsFloatingPointFoldingAllowed()) |
419 | 0 | return false; |
420 | | |
421 | 24.6k | uint32_t width = ElementWidth(type); |
422 | 24.6k | if (width != 32 && width != 64) return false; |
423 | | |
424 | 24.6k | spv::Op opcode = op_inst->opcode(); |
425 | 24.6k | if (opcode == spv::Op::OpFMul || opcode == spv::Op::OpFDiv || |
426 | 24.6k | opcode == spv::Op::OpIMul || opcode == spv::Op::OpSDiv || |
427 | 24.6k | opcode == spv::Op::OpUDiv) { |
428 | 92 | std::vector<const analysis::Constant*> op_constants = |
429 | 92 | const_mgr->GetOperandConstants(op_inst); |
430 | | // Merge negate into mul or div if one operand is constant. |
431 | 92 | if (op_constants[0] || op_constants[1]) { |
432 | 53 | bool zero_is_variable = op_constants[0] == nullptr; |
433 | 53 | const analysis::Constant* c = ConstInput(op_constants); |
434 | 53 | uint32_t neg_id = NegateConstant(const_mgr, c); |
435 | 53 | uint32_t non_const_id = zero_is_variable |
436 | 53 | ? op_inst->GetSingleWordInOperand(0u) |
437 | 53 | : op_inst->GetSingleWordInOperand(1u); |
438 | | // Change this instruction to a mul/div. |
439 | 53 | inst->SetOpcode(op_inst->opcode()); |
440 | 53 | if (opcode == spv::Op::OpFDiv || opcode == spv::Op::OpUDiv || |
441 | 53 | opcode == spv::Op::OpSDiv) { |
442 | 4 | uint32_t op0 = zero_is_variable ? non_const_id : neg_id; |
443 | 4 | uint32_t op1 = zero_is_variable ? neg_id : non_const_id; |
444 | 4 | inst->SetInOperands( |
445 | 4 | {{SPV_OPERAND_TYPE_ID, {op0}}, {SPV_OPERAND_TYPE_ID, {op1}}}); |
446 | 49 | } else { |
447 | 49 | inst->SetInOperands({{SPV_OPERAND_TYPE_ID, {non_const_id}}, |
448 | 49 | {SPV_OPERAND_TYPE_ID, {neg_id}}}); |
449 | 49 | } |
450 | 53 | return true; |
451 | 53 | } |
452 | 92 | } |
453 | | |
454 | 24.6k | return false; |
455 | 24.6k | }; |
456 | 16.3k | } |
457 | | |
458 | | // Merges negate into a add or sub operation if that operation contains a |
459 | | // constant operand. |
460 | | // Cases: |
461 | | // -(x + 2) = -2 - x |
462 | | // -(2 + x) = -2 - x |
463 | | // -(x - 2) = 2 - x |
464 | | // -(2 - x) = x - 2 |
465 | 16.3k | FoldingRule MergeNegateAddSubArithmetic() { |
466 | 16.3k | return [](IRContext* context, Instruction* inst, |
467 | 24.7k | const std::vector<const analysis::Constant*>& constants) { |
468 | 24.7k | assert(inst->opcode() == spv::Op::OpFNegate || |
469 | 24.7k | inst->opcode() == spv::Op::OpSNegate); |
470 | 24.7k | (void)constants; |
471 | 24.7k | analysis::ConstantManager* const_mgr = context->get_constant_mgr(); |
472 | 24.7k | const analysis::Type* type = |
473 | 24.7k | context->get_type_mgr()->GetType(inst->type_id()); |
474 | | |
475 | 24.7k | if (IsCooperativeMatrix(type)) { |
476 | 0 | return false; |
477 | 0 | } |
478 | | |
479 | 24.7k | if (HasFloatingPoint(type) && !inst->IsFloatingPointFoldingAllowed()) |
480 | 0 | return false; |
481 | | |
482 | 24.7k | Instruction* op_inst = |
483 | 24.7k | context->get_def_use_mgr()->GetDef(inst->GetSingleWordInOperand(0u)); |
484 | 24.7k | if (HasFloatingPoint(type) && !op_inst->IsFloatingPointFoldingAllowed()) |
485 | 0 | return false; |
486 | | |
487 | 24.7k | uint32_t width = ElementWidth(type); |
488 | 24.7k | if (width != 32 && width != 64) return false; |
489 | | |
490 | 24.7k | if (op_inst->opcode() == spv::Op::OpFAdd || |
491 | 24.7k | op_inst->opcode() == spv::Op::OpFSub || |
492 | 24.7k | op_inst->opcode() == spv::Op::OpIAdd || |
493 | 24.7k | op_inst->opcode() == spv::Op::OpISub) { |
494 | 115 | std::vector<const analysis::Constant*> op_constants = |
495 | 115 | const_mgr->GetOperandConstants(op_inst); |
496 | 115 | if (op_constants[0] || op_constants[1]) { |
497 | 47 | bool zero_is_variable = op_constants[0] == nullptr; |
498 | 47 | bool is_add = (op_inst->opcode() == spv::Op::OpFAdd) || |
499 | 47 | (op_inst->opcode() == spv::Op::OpIAdd); |
500 | 47 | bool swap_operands = !is_add || zero_is_variable; |
501 | 47 | bool negate_const = is_add; |
502 | 47 | const analysis::Constant* c = ConstInput(op_constants); |
503 | 47 | uint32_t const_id = 0; |
504 | 47 | if (negate_const) { |
505 | 18 | const_id = NegateConstant(const_mgr, c); |
506 | 29 | } else { |
507 | 29 | const_id = zero_is_variable ? op_inst->GetSingleWordInOperand(1u) |
508 | 29 | : op_inst->GetSingleWordInOperand(0u); |
509 | 29 | } |
510 | | |
511 | | // Swap operands if necessary and make the instruction a subtraction. |
512 | 47 | uint32_t op0 = |
513 | 47 | zero_is_variable ? op_inst->GetSingleWordInOperand(0u) : const_id; |
514 | 47 | uint32_t op1 = |
515 | 47 | zero_is_variable ? const_id : op_inst->GetSingleWordInOperand(1u); |
516 | 47 | if (swap_operands) std::swap(op0, op1); |
517 | 47 | inst->SetOpcode(HasFloatingPoint(type) ? spv::Op::OpFSub |
518 | 47 | : spv::Op::OpISub); |
519 | 47 | inst->SetInOperands( |
520 | 47 | {{SPV_OPERAND_TYPE_ID, {op0}}, {SPV_OPERAND_TYPE_ID, {op1}}}); |
521 | 47 | return true; |
522 | 47 | } |
523 | 115 | } |
524 | | |
525 | 24.6k | return false; |
526 | 24.7k | }; |
527 | 16.3k | } |
528 | | |
529 | | // Returns true if |c| has a zero element. |
530 | 290k | bool HasZero(const analysis::Constant* c) { |
531 | 290k | if (c->AsNullConstant()) { |
532 | 0 | return true; |
533 | 0 | } |
534 | 290k | if (const analysis::VectorConstant* vec_const = c->AsVectorConstant()) { |
535 | 105k | for (auto& comp : vec_const->GetComponents()) |
536 | 178k | if (HasZero(comp)) return true; |
537 | 185k | } else { |
538 | 185k | assert(c->AsScalarConstant()); |
539 | 185k | return c->AsScalarConstant()->IsZero(); |
540 | 185k | } |
541 | | |
542 | 69.7k | return false; |
543 | 290k | } |
544 | | |
545 | | // Performs |input1| |opcode| |input2| and returns the merged constant result |
546 | | // id. Returns 0 if the result is not a valid value. The input types must be |
547 | | // Float. |
548 | | uint32_t PerformFloatingPointOperation(analysis::ConstantManager* const_mgr, |
549 | | spv::Op opcode, |
550 | | const analysis::Constant* input1, |
551 | 13.1k | const analysis::Constant* input2) { |
552 | 13.1k | const analysis::Type* type = input1->type(); |
553 | 13.1k | assert(type->AsFloat()); |
554 | 13.1k | uint32_t width = type->AsFloat()->width(); |
555 | 13.1k | assert(width == 32 || width == 64); |
556 | 13.1k | std::vector<uint32_t> words; |
557 | 13.1k | #define FOLD_OP(op) \ |
558 | 13.1k | if (width == 64) { \ |
559 | 0 | utils::FloatProxy<double> val = \ |
560 | 0 | input1->GetDouble() op input2->GetDouble(); \ |
561 | 0 | double dval = val.getAsFloat(); \ |
562 | 0 | if (!IsValidResult(dval)) return 0; \ |
563 | 0 | words = val.GetWords(); \ |
564 | 13.1k | } else { \ |
565 | 13.1k | utils::FloatProxy<float> val = input1->GetFloat() op input2->GetFloat(); \ |
566 | 13.1k | float fval = val.getAsFloat(); \ |
567 | 13.1k | if (!IsValidResult(fval)) return 0; \ |
568 | 13.1k | words = val.GetWords(); \ |
569 | 9.51k | } \ |
570 | 13.1k | static_assert(true, "require extra semicolon") |
571 | 13.1k | switch (opcode) { |
572 | 2.58k | case spv::Op::OpFMul: |
573 | 2.58k | FOLD_OP(*); |
574 | 1.18k | break; |
575 | 69 | case spv::Op::OpFDiv: |
576 | 69 | if (HasZero(input2)) return 0; |
577 | 69 | FOLD_OP(/); |
578 | 36 | break; |
579 | 8.07k | case spv::Op::OpFAdd: |
580 | 8.07k | FOLD_OP(+); |
581 | 6.02k | break; |
582 | 2.42k | case spv::Op::OpFSub: |
583 | 2.42k | FOLD_OP(-); |
584 | 2.27k | break; |
585 | 0 | default: |
586 | 0 | assert(false && "Unexpected operation"); |
587 | 0 | break; |
588 | 13.1k | } |
589 | 9.51k | #undef FOLD_OP |
590 | 9.51k | const analysis::Constant* merged_const = const_mgr->GetConstant(type, words); |
591 | 9.51k | return const_mgr->GetDefiningInstruction(merged_const)->result_id(); |
592 | 13.1k | } |
593 | | |
594 | | // Performs |input1| |opcode| |input2| and returns the merged constant result |
595 | | // id. Returns 0 if the result is not a valid value. The input types must be |
596 | | // Integers. |
597 | | uint32_t PerformIntegerOperation(analysis::ConstantManager* const_mgr, |
598 | | spv::Op opcode, |
599 | | const analysis::Constant* input1, |
600 | 3.03k | const analysis::Constant* input2) { |
601 | 3.03k | assert(input1->type()->AsInteger()); |
602 | 3.03k | const analysis::Integer* type = input1->type()->AsInteger(); |
603 | 3.03k | uint32_t width = type->AsInteger()->width(); |
604 | 3.03k | assert(width == 32 || width == 64); |
605 | 3.03k | std::vector<uint32_t> words; |
606 | | // Regardless of the sign of the constant, folding is performed on an unsigned |
607 | | // interpretation of the constant data. This avoids signed integer overflow |
608 | | // while folding, and works because sign is irrelevant for the IAdd, ISub and |
609 | | // IMul instructions. |
610 | 3.03k | #define FOLD_OP(op) \ |
611 | 3.03k | if (width == 64) { \ |
612 | 0 | uint64_t val = input1->GetU64() op input2->GetU64(); \ |
613 | 0 | words = ExtractInts(val); \ |
614 | 3.03k | } else { \ |
615 | 3.03k | uint32_t val = input1->GetU32() op input2->GetU32(); \ |
616 | 3.03k | words.push_back(val); \ |
617 | 3.03k | } \ |
618 | 3.03k | static_assert(true, "require extra semicolon") |
619 | 3.03k | switch (opcode) { |
620 | 110 | case spv::Op::OpIMul: |
621 | 110 | FOLD_OP(*); |
622 | 110 | break; |
623 | 0 | case spv::Op::OpSDiv: |
624 | 0 | case spv::Op::OpUDiv: |
625 | 0 | assert(false && "Should not merge integer division"); |
626 | 0 | break; |
627 | 733 | case spv::Op::OpIAdd: |
628 | 733 | FOLD_OP(+); |
629 | 733 | break; |
630 | 2.19k | case spv::Op::OpISub: |
631 | 2.19k | FOLD_OP(-); |
632 | 2.19k | break; |
633 | 0 | default: |
634 | 0 | assert(false && "Unexpected operation"); |
635 | 0 | break; |
636 | 3.03k | } |
637 | 3.03k | #undef FOLD_OP |
638 | 3.03k | const analysis::Constant* merged_const = const_mgr->GetConstant(type, words); |
639 | 3.03k | return const_mgr->GetDefiningInstruction(merged_const)->result_id(); |
640 | 3.03k | } |
641 | | |
642 | | // Performs |input1| |opcode| |input2| and returns the merged constant result |
643 | | // id. Returns 0 if the result is not a valid value. The input types must be |
644 | | // Integers, Floats or Vectors of such. |
645 | | uint32_t PerformOperation(analysis::ConstantManager* const_mgr, spv::Op opcode, |
646 | | const analysis::Constant* input1, |
647 | 13.0k | const analysis::Constant* input2) { |
648 | 13.0k | assert(input1 && input2); |
649 | 13.0k | const analysis::Type* type = input1->type(); |
650 | 13.0k | std::vector<uint32_t> words; |
651 | 13.0k | if (const analysis::Vector* vector_type = type->AsVector()) { |
652 | 4.56k | const analysis::Type* ele_type = vector_type->element_type(); |
653 | 10.4k | for (uint32_t i = 0; i != vector_type->element_count(); ++i) { |
654 | 7.71k | uint32_t id = 0; |
655 | | |
656 | 7.71k | const analysis::Constant* input1_comp = nullptr; |
657 | 7.71k | if (const analysis::VectorConstant* input1_vector = |
658 | 7.71k | input1->AsVectorConstant()) { |
659 | 7.71k | input1_comp = input1_vector->GetComponents()[i]; |
660 | 7.71k | } else { |
661 | 0 | assert(input1->AsNullConstant()); |
662 | 0 | input1_comp = const_mgr->GetConstant(ele_type, {}); |
663 | 0 | } |
664 | | |
665 | 7.71k | const analysis::Constant* input2_comp = nullptr; |
666 | 7.71k | if (const analysis::VectorConstant* input2_vector = |
667 | 7.71k | input2->AsVectorConstant()) { |
668 | 7.71k | input2_comp = input2_vector->GetComponents()[i]; |
669 | 7.71k | } else { |
670 | 0 | assert(input2->AsNullConstant()); |
671 | 0 | input2_comp = const_mgr->GetConstant(ele_type, {}); |
672 | 0 | } |
673 | | |
674 | 7.71k | if (ele_type->AsFloat()) { |
675 | 7.71k | id = PerformFloatingPointOperation(const_mgr, opcode, input1_comp, |
676 | 7.71k | input2_comp); |
677 | 7.71k | } else { |
678 | 0 | assert(ele_type->AsInteger()); |
679 | 0 | id = PerformIntegerOperation(const_mgr, opcode, input1_comp, |
680 | 0 | input2_comp); |
681 | 0 | } |
682 | 7.71k | if (id == 0) return 0; |
683 | 5.89k | words.push_back(id); |
684 | 5.89k | } |
685 | 2.73k | const analysis::Constant* merged_const = |
686 | 2.73k | const_mgr->GetConstant(type, words); |
687 | 2.73k | return const_mgr->GetDefiningInstruction(merged_const)->result_id(); |
688 | 8.47k | } else if (type->AsFloat()) { |
689 | 5.43k | return PerformFloatingPointOperation(const_mgr, opcode, input1, input2); |
690 | 5.43k | } else { |
691 | 3.03k | assert(type->AsInteger()); |
692 | 3.03k | return PerformIntegerOperation(const_mgr, opcode, input1, input2); |
693 | 3.03k | } |
694 | 13.0k | } |
695 | | |
696 | | // Merges consecutive multiplies where each contains one constant operand. |
697 | | // Cases: |
698 | | // 2 * (x * 2) = x * 4 |
699 | | // 2 * (2 * x) = x * 4 |
700 | | // (x * 2) * 2 = x * 4 |
701 | | // (2 * x) * 2 = x * 4 |
702 | 16.3k | FoldingRule MergeMulMulArithmetic() { |
703 | 16.3k | return [](IRContext* context, Instruction* inst, |
704 | 180k | const std::vector<const analysis::Constant*>& constants) { |
705 | 180k | assert(inst->opcode() == spv::Op::OpFMul || |
706 | 180k | inst->opcode() == spv::Op::OpIMul); |
707 | 180k | analysis::ConstantManager* const_mgr = context->get_constant_mgr(); |
708 | 180k | const analysis::Type* type = |
709 | 180k | context->get_type_mgr()->GetType(inst->type_id()); |
710 | | |
711 | 180k | if (IsCooperativeMatrix(type)) { |
712 | 0 | return false; |
713 | 0 | } |
714 | | |
715 | 180k | if (HasFloatingPoint(type) && !inst->IsFloatingPointFoldingAllowed()) |
716 | 1 | return false; |
717 | | |
718 | 180k | uint32_t width = ElementWidth(type); |
719 | 180k | if (width != 32 && width != 64) return false; |
720 | | |
721 | | // Determine the constant input and the variable input in |inst|. |
722 | 180k | const analysis::Constant* const_input1 = ConstInput(constants); |
723 | 180k | if (!const_input1) return false; |
724 | 138k | Instruction* other_inst = NonConstInput(context, constants[0], inst); |
725 | 138k | if (HasFloatingPoint(type) && !other_inst->IsFloatingPointFoldingAllowed()) |
726 | 6 | return false; |
727 | | |
728 | 137k | if (other_inst->opcode() == inst->opcode()) { |
729 | 2.89k | std::vector<const analysis::Constant*> other_constants = |
730 | 2.89k | const_mgr->GetOperandConstants(other_inst); |
731 | 2.89k | const analysis::Constant* const_input2 = ConstInput(other_constants); |
732 | 2.89k | if (!const_input2) return false; |
733 | | |
734 | 2.25k | bool other_first_is_variable = other_constants[0] == nullptr; |
735 | 2.25k | uint32_t merged_id = PerformOperation(const_mgr, inst->opcode(), |
736 | 2.25k | const_input1, const_input2); |
737 | 2.25k | if (merged_id == 0) return false; |
738 | | |
739 | 1.07k | uint32_t non_const_id = other_first_is_variable |
740 | 1.07k | ? other_inst->GetSingleWordInOperand(0u) |
741 | 1.07k | : other_inst->GetSingleWordInOperand(1u); |
742 | 1.07k | inst->SetInOperands({{SPV_OPERAND_TYPE_ID, {non_const_id}}, |
743 | 1.07k | {SPV_OPERAND_TYPE_ID, {merged_id}}}); |
744 | 1.07k | return true; |
745 | 2.25k | } |
746 | | |
747 | 135k | return false; |
748 | 137k | }; |
749 | 16.3k | } |
750 | | |
751 | | // Merges divides into subsequent multiplies if each instruction contains one |
752 | | // constant operand. Does not support integer operations. |
753 | | // Cases: |
754 | | // 2 * (x / 2) = x * 1 |
755 | | // 2 * (2 / x) = 4 / x |
756 | | // (x / 2) * 2 = x * 1 |
757 | | // (2 / x) * 2 = 4 / x |
758 | | // (y / x) * x = y |
759 | | // x * (y / x) = y |
760 | 8.16k | FoldingRule MergeMulDivArithmetic() { |
761 | 8.16k | return [](IRContext* context, Instruction* inst, |
762 | 170k | const std::vector<const analysis::Constant*>& constants) { |
763 | 170k | assert(inst->opcode() == spv::Op::OpFMul); |
764 | 170k | analysis::ConstantManager* const_mgr = context->get_constant_mgr(); |
765 | 170k | analysis::DefUseManager* def_use_mgr = context->get_def_use_mgr(); |
766 | | |
767 | 170k | const analysis::Type* type = |
768 | 170k | context->get_type_mgr()->GetType(inst->type_id()); |
769 | | |
770 | 170k | if (IsCooperativeMatrix(type)) { |
771 | 0 | return false; |
772 | 0 | } |
773 | | |
774 | 170k | if (!inst->IsFloatingPointFoldingAllowed()) return false; |
775 | | |
776 | 170k | uint32_t width = ElementWidth(type); |
777 | 170k | if (width != 32 && width != 64) return false; |
778 | | |
779 | 511k | for (uint32_t i = 0; i < 2; i++) { |
780 | 341k | uint32_t op_id = inst->GetSingleWordInOperand(i); |
781 | 341k | Instruction* op_inst = def_use_mgr->GetDef(op_id); |
782 | 341k | if (op_inst->opcode() == spv::Op::OpFDiv) { |
783 | 3.47k | if (op_inst->GetSingleWordInOperand(1) == |
784 | 3.47k | inst->GetSingleWordInOperand(1 - i)) { |
785 | 0 | inst->SetOpcode(spv::Op::OpCopyObject); |
786 | 0 | inst->SetInOperands( |
787 | 0 | {{SPV_OPERAND_TYPE_ID, {op_inst->GetSingleWordInOperand(0)}}}); |
788 | 0 | return true; |
789 | 0 | } |
790 | 3.47k | } |
791 | 341k | } |
792 | | |
793 | 170k | const analysis::Constant* const_input1 = ConstInput(constants); |
794 | 170k | if (!const_input1) return false; |
795 | 129k | Instruction* other_inst = NonConstInput(context, constants[0], inst); |
796 | 129k | if (!other_inst->IsFloatingPointFoldingAllowed()) return false; |
797 | | |
798 | 129k | if (other_inst->opcode() == spv::Op::OpFDiv) { |
799 | 209 | std::vector<const analysis::Constant*> other_constants = |
800 | 209 | const_mgr->GetOperandConstants(other_inst); |
801 | 209 | const analysis::Constant* const_input2 = ConstInput(other_constants); |
802 | 209 | if (!const_input2 || HasZero(const_input2)) return false; |
803 | | |
804 | 18 | bool other_first_is_variable = other_constants[0] == nullptr; |
805 | | // If the variable value is the second operand of the divide, multiply |
806 | | // the constants together. Otherwise divide the constants. |
807 | 18 | uint32_t merged_id = PerformOperation( |
808 | 18 | const_mgr, |
809 | 18 | other_first_is_variable ? other_inst->opcode() : inst->opcode(), |
810 | 18 | const_input1, const_input2); |
811 | 18 | if (merged_id == 0) return false; |
812 | | |
813 | 5 | uint32_t non_const_id = other_first_is_variable |
814 | 5 | ? other_inst->GetSingleWordInOperand(0u) |
815 | 5 | : other_inst->GetSingleWordInOperand(1u); |
816 | | |
817 | | // If the variable value is on the second operand of the div, then this |
818 | | // operation is a div. Otherwise it should be a multiply. |
819 | 5 | inst->SetOpcode(other_first_is_variable ? inst->opcode() |
820 | 5 | : other_inst->opcode()); |
821 | 5 | if (other_first_is_variable) { |
822 | 5 | inst->SetInOperands({{SPV_OPERAND_TYPE_ID, {non_const_id}}, |
823 | 5 | {SPV_OPERAND_TYPE_ID, {merged_id}}}); |
824 | 5 | } else { |
825 | 0 | inst->SetInOperands({{SPV_OPERAND_TYPE_ID, {merged_id}}, |
826 | 0 | {SPV_OPERAND_TYPE_ID, {non_const_id}}}); |
827 | 0 | } |
828 | 5 | return true; |
829 | 18 | } |
830 | | |
831 | 129k | return false; |
832 | 129k | }; |
833 | 8.16k | } |
834 | | |
835 | | // Merges multiply of constant and negation. |
836 | | // Cases: |
837 | | // (-x) * 2 = x * -2 |
838 | | // 2 * (-x) = x * -2 |
839 | 16.3k | FoldingRule MergeMulNegateArithmetic() { |
840 | 16.3k | return [](IRContext* context, Instruction* inst, |
841 | 179k | const std::vector<const analysis::Constant*>& constants) { |
842 | 179k | assert(inst->opcode() == spv::Op::OpFMul || |
843 | 179k | inst->opcode() == spv::Op::OpIMul); |
844 | 179k | analysis::ConstantManager* const_mgr = context->get_constant_mgr(); |
845 | 179k | const analysis::Type* type = |
846 | 179k | context->get_type_mgr()->GetType(inst->type_id()); |
847 | | |
848 | 179k | if (IsCooperativeMatrix(type)) { |
849 | 0 | return false; |
850 | 0 | } |
851 | | |
852 | 179k | bool uses_float = HasFloatingPoint(type); |
853 | 179k | if (uses_float && !inst->IsFloatingPointFoldingAllowed()) return false; |
854 | | |
855 | 179k | uint32_t width = ElementWidth(type); |
856 | 179k | if (width != 32 && width != 64) return false; |
857 | | |
858 | 179k | const analysis::Constant* const_input1 = ConstInput(constants); |
859 | 179k | if (!const_input1) return false; |
860 | 136k | Instruction* other_inst = NonConstInput(context, constants[0], inst); |
861 | 136k | if (uses_float && !other_inst->IsFloatingPointFoldingAllowed()) |
862 | 6 | return false; |
863 | | |
864 | 136k | if (other_inst->opcode() == spv::Op::OpFNegate || |
865 | 136k | other_inst->opcode() == spv::Op::OpSNegate) { |
866 | 1 | uint32_t neg_id = NegateConstant(const_mgr, const_input1); |
867 | | |
868 | 1 | inst->SetInOperands( |
869 | 1 | {{SPV_OPERAND_TYPE_ID, {other_inst->GetSingleWordInOperand(0u)}}, |
870 | 1 | {SPV_OPERAND_TYPE_ID, {neg_id}}}); |
871 | 1 | return true; |
872 | 1 | } |
873 | | |
874 | 136k | return false; |
875 | 136k | }; |
876 | 16.3k | } |
877 | | |
878 | | // Merges consecutive divides if each instruction contains one constant operand. |
879 | | // Does not support integer division. |
880 | | // Cases: |
881 | | // 2 / (x / 2) = 4 / x |
882 | | // 4 / (2 / x) = 2 * x |
883 | | // (4 / x) / 2 = 2 / x |
884 | | // (x / 2) / 2 = x / 4 |
885 | 8.16k | FoldingRule MergeDivDivArithmetic() { |
886 | 8.16k | return [](IRContext* context, Instruction* inst, |
887 | 76.1k | const std::vector<const analysis::Constant*>& constants) { |
888 | 76.1k | assert(inst->opcode() == spv::Op::OpFDiv); |
889 | 76.1k | analysis::ConstantManager* const_mgr = context->get_constant_mgr(); |
890 | 76.1k | const analysis::Type* type = |
891 | 76.1k | context->get_type_mgr()->GetType(inst->type_id()); |
892 | | |
893 | 76.1k | if (IsCooperativeMatrix(type)) { |
894 | 0 | return false; |
895 | 0 | } |
896 | | |
897 | 76.1k | if (!inst->IsFloatingPointFoldingAllowed()) return false; |
898 | | |
899 | 76.1k | uint32_t width = ElementWidth(type); |
900 | 76.1k | if (width != 32 && width != 64) return false; |
901 | | |
902 | 76.1k | const analysis::Constant* const_input1 = ConstInput(constants); |
903 | 76.1k | if (!const_input1 || HasZero(const_input1)) return false; |
904 | 37.8k | Instruction* other_inst = NonConstInput(context, constants[0], inst); |
905 | 37.8k | if (!other_inst->IsFloatingPointFoldingAllowed()) return false; |
906 | | |
907 | 37.8k | bool first_is_variable = constants[0] == nullptr; |
908 | 37.8k | if (other_inst->opcode() == inst->opcode()) { |
909 | 334 | std::vector<const analysis::Constant*> other_constants = |
910 | 334 | const_mgr->GetOperandConstants(other_inst); |
911 | 334 | const analysis::Constant* const_input2 = ConstInput(other_constants); |
912 | 334 | if (!const_input2 || HasZero(const_input2)) return false; |
913 | | |
914 | 262 | bool other_first_is_variable = other_constants[0] == nullptr; |
915 | | |
916 | 262 | spv::Op merge_op = inst->opcode(); |
917 | 262 | if (other_first_is_variable) { |
918 | | // Constants magnify. |
919 | 249 | merge_op = spv::Op::OpFMul; |
920 | 249 | } |
921 | | |
922 | | // This is an x / (*) case. Swap the inputs. Doesn't harm multiply |
923 | | // because it is commutative. |
924 | 262 | if (first_is_variable) std::swap(const_input1, const_input2); |
925 | 262 | uint32_t merged_id = |
926 | 262 | PerformOperation(const_mgr, merge_op, const_input1, const_input2); |
927 | 262 | if (merged_id == 0) return false; |
928 | | |
929 | 44 | uint32_t non_const_id = other_first_is_variable |
930 | 44 | ? other_inst->GetSingleWordInOperand(0u) |
931 | 44 | : other_inst->GetSingleWordInOperand(1u); |
932 | | |
933 | 44 | spv::Op op = inst->opcode(); |
934 | 44 | if (!first_is_variable && !other_first_is_variable) { |
935 | | // Effectively div of 1/x, so change to multiply. |
936 | 7 | op = spv::Op::OpFMul; |
937 | 7 | } |
938 | | |
939 | 44 | uint32_t op1 = merged_id; |
940 | 44 | uint32_t op2 = non_const_id; |
941 | 44 | if (first_is_variable && other_first_is_variable) std::swap(op1, op2); |
942 | 44 | inst->SetOpcode(op); |
943 | 44 | inst->SetInOperands( |
944 | 44 | {{SPV_OPERAND_TYPE_ID, {op1}}, {SPV_OPERAND_TYPE_ID, {op2}}}); |
945 | 44 | return true; |
946 | 262 | } |
947 | | |
948 | 37.4k | return false; |
949 | 37.8k | }; |
950 | 8.16k | } |
951 | | |
952 | | // Fold multiplies succeeded by divides where each instruction contains a |
953 | | // constant operand. Does not support integer divide. |
954 | | // Cases: |
955 | | // 4 / (x * 2) = 2 / x |
956 | | // 4 / (2 * x) = 2 / x |
957 | | // (x * 4) / 2 = x * 2 |
958 | | // (4 * x) / 2 = x * 2 |
959 | | // (x * y) / x = y |
960 | | // (y * x) / x = y |
961 | 8.16k | FoldingRule MergeDivMulArithmetic() { |
962 | 8.16k | return [](IRContext* context, Instruction* inst, |
963 | 76.1k | const std::vector<const analysis::Constant*>& constants) { |
964 | 76.1k | assert(inst->opcode() == spv::Op::OpFDiv); |
965 | 76.1k | analysis::DefUseManager* def_use_mgr = context->get_def_use_mgr(); |
966 | 76.1k | analysis::ConstantManager* const_mgr = context->get_constant_mgr(); |
967 | | |
968 | 76.1k | const analysis::Type* type = |
969 | 76.1k | context->get_type_mgr()->GetType(inst->type_id()); |
970 | | |
971 | 76.1k | if (IsCooperativeMatrix(type)) { |
972 | 0 | return false; |
973 | 0 | } |
974 | | |
975 | 76.1k | if (!inst->IsFloatingPointFoldingAllowed()) return false; |
976 | | |
977 | 76.0k | uint32_t width = ElementWidth(type); |
978 | 76.0k | if (width != 32 && width != 64) return false; |
979 | | |
980 | 76.0k | uint32_t op_id = inst->GetSingleWordInOperand(0); |
981 | 76.0k | Instruction* op_inst = def_use_mgr->GetDef(op_id); |
982 | | |
983 | 76.0k | if (op_inst->opcode() == spv::Op::OpFMul) { |
984 | 565 | for (uint32_t i = 0; i < 2; i++) { |
985 | 377 | if (op_inst->GetSingleWordInOperand(i) == |
986 | 377 | inst->GetSingleWordInOperand(1)) { |
987 | 1 | inst->SetOpcode(spv::Op::OpCopyObject); |
988 | 1 | inst->SetInOperands({{SPV_OPERAND_TYPE_ID, |
989 | 1 | {op_inst->GetSingleWordInOperand(1 - i)}}}); |
990 | 1 | return true; |
991 | 1 | } |
992 | 377 | } |
993 | 189 | } |
994 | | |
995 | 76.0k | const analysis::Constant* const_input1 = ConstInput(constants); |
996 | 76.0k | if (!const_input1 || HasZero(const_input1)) return false; |
997 | 37.7k | Instruction* other_inst = NonConstInput(context, constants[0], inst); |
998 | 37.7k | if (!other_inst->IsFloatingPointFoldingAllowed()) return false; |
999 | | |
1000 | 37.7k | bool first_is_variable = constants[0] == nullptr; |
1001 | 37.7k | if (other_inst->opcode() == spv::Op::OpFMul) { |
1002 | 38 | std::vector<const analysis::Constant*> other_constants = |
1003 | 38 | const_mgr->GetOperandConstants(other_inst); |
1004 | 38 | const analysis::Constant* const_input2 = ConstInput(other_constants); |
1005 | 38 | if (!const_input2) return false; |
1006 | | |
1007 | 36 | bool other_first_is_variable = other_constants[0] == nullptr; |
1008 | | |
1009 | | // This is an x / (*) case. Swap the inputs. |
1010 | 36 | if (first_is_variable) std::swap(const_input1, const_input2); |
1011 | 36 | uint32_t merged_id = PerformOperation(const_mgr, inst->opcode(), |
1012 | 36 | const_input1, const_input2); |
1013 | 36 | if (merged_id == 0) return false; |
1014 | | |
1015 | 17 | uint32_t non_const_id = other_first_is_variable |
1016 | 17 | ? other_inst->GetSingleWordInOperand(0u) |
1017 | 17 | : other_inst->GetSingleWordInOperand(1u); |
1018 | | |
1019 | 17 | uint32_t op1 = merged_id; |
1020 | 17 | uint32_t op2 = non_const_id; |
1021 | 17 | if (first_is_variable) std::swap(op1, op2); |
1022 | | |
1023 | | // Convert to multiply |
1024 | 17 | if (first_is_variable) inst->SetOpcode(other_inst->opcode()); |
1025 | 17 | inst->SetInOperands( |
1026 | 17 | {{SPV_OPERAND_TYPE_ID, {op1}}, {SPV_OPERAND_TYPE_ID, {op2}}}); |
1027 | 17 | return true; |
1028 | 36 | } |
1029 | | |
1030 | 37.7k | return false; |
1031 | 37.7k | }; |
1032 | 8.16k | } |
1033 | | |
1034 | | // Fold divides of a constant and a negation. |
1035 | | // Cases: |
1036 | | // (-x) / 2 = x / -2 |
1037 | | // 2 / (-x) = -2 / x |
1038 | 8.16k | FoldingRule MergeDivNegateArithmetic() { |
1039 | 8.16k | return [](IRContext* context, Instruction* inst, |
1040 | 76.1k | const std::vector<const analysis::Constant*>& constants) { |
1041 | 76.1k | assert(inst->opcode() == spv::Op::OpFDiv); |
1042 | 76.1k | analysis::ConstantManager* const_mgr = context->get_constant_mgr(); |
1043 | 76.1k | if (!inst->IsFloatingPointFoldingAllowed()) return false; |
1044 | | |
1045 | 76.0k | const analysis::Constant* const_input1 = ConstInput(constants); |
1046 | 76.0k | if (!const_input1) return false; |
1047 | 55.4k | Instruction* other_inst = NonConstInput(context, constants[0], inst); |
1048 | 55.4k | if (!other_inst->IsFloatingPointFoldingAllowed()) return false; |
1049 | | |
1050 | 55.4k | bool first_is_variable = constants[0] == nullptr; |
1051 | 55.4k | if (other_inst->opcode() == spv::Op::OpFNegate) { |
1052 | 0 | uint32_t neg_id = NegateConstant(const_mgr, const_input1); |
1053 | |
|
1054 | 0 | if (first_is_variable) { |
1055 | 0 | inst->SetInOperands( |
1056 | 0 | {{SPV_OPERAND_TYPE_ID, {other_inst->GetSingleWordInOperand(0u)}}, |
1057 | 0 | {SPV_OPERAND_TYPE_ID, {neg_id}}}); |
1058 | 0 | } else { |
1059 | 0 | inst->SetInOperands( |
1060 | 0 | {{SPV_OPERAND_TYPE_ID, {neg_id}}, |
1061 | 0 | {SPV_OPERAND_TYPE_ID, {other_inst->GetSingleWordInOperand(0u)}}}); |
1062 | 0 | } |
1063 | 0 | return true; |
1064 | 0 | } |
1065 | | |
1066 | 55.4k | return false; |
1067 | 55.4k | }; |
1068 | 8.16k | } |
1069 | | |
1070 | | // Folds addition of a constant and a negation. |
1071 | | // Cases: |
1072 | | // (-x) + 2 = 2 - x |
1073 | | // 2 + (-x) = 2 - x |
1074 | 16.3k | FoldingRule MergeAddNegateArithmetic() { |
1075 | 16.3k | return [](IRContext* context, Instruction* inst, |
1076 | 1.81M | const std::vector<const analysis::Constant*>& constants) { |
1077 | 1.81M | assert(inst->opcode() == spv::Op::OpFAdd || |
1078 | 1.81M | inst->opcode() == spv::Op::OpIAdd); |
1079 | 1.81M | const analysis::Type* type = |
1080 | 1.81M | context->get_type_mgr()->GetType(inst->type_id()); |
1081 | 1.81M | bool uses_float = HasFloatingPoint(type); |
1082 | 1.81M | if (uses_float && !inst->IsFloatingPointFoldingAllowed()) return false; |
1083 | | |
1084 | 1.81M | const analysis::Constant* const_input1 = ConstInput(constants); |
1085 | 1.81M | if (!const_input1) return false; |
1086 | 428k | Instruction* other_inst = NonConstInput(context, constants[0], inst); |
1087 | 428k | if (uses_float && !other_inst->IsFloatingPointFoldingAllowed()) |
1088 | 13.8k | return false; |
1089 | | |
1090 | 415k | if (other_inst->opcode() == spv::Op::OpSNegate || |
1091 | 415k | other_inst->opcode() == spv::Op::OpFNegate) { |
1092 | 1 | inst->SetOpcode(HasFloatingPoint(type) ? spv::Op::OpFSub |
1093 | 1 | : spv::Op::OpISub); |
1094 | 1 | uint32_t const_id = constants[0] ? inst->GetSingleWordInOperand(0u) |
1095 | 1 | : inst->GetSingleWordInOperand(1u); |
1096 | 1 | inst->SetInOperands( |
1097 | 1 | {{SPV_OPERAND_TYPE_ID, {const_id}}, |
1098 | 1 | {SPV_OPERAND_TYPE_ID, {other_inst->GetSingleWordInOperand(0u)}}}); |
1099 | 1 | return true; |
1100 | 1 | } |
1101 | 415k | return false; |
1102 | 415k | }; |
1103 | 16.3k | } |
1104 | | |
1105 | | // Folds subtraction of a constant and a negation. |
1106 | | // Cases: |
1107 | | // (-x) - 2 = -2 - x |
1108 | | // 2 - (-x) = x + 2 |
1109 | 16.3k | FoldingRule MergeSubNegateArithmetic() { |
1110 | 16.3k | return [](IRContext* context, Instruction* inst, |
1111 | 200k | const std::vector<const analysis::Constant*>& constants) { |
1112 | 200k | assert(inst->opcode() == spv::Op::OpFSub || |
1113 | 200k | inst->opcode() == spv::Op::OpISub); |
1114 | 200k | analysis::ConstantManager* const_mgr = context->get_constant_mgr(); |
1115 | 200k | const analysis::Type* type = |
1116 | 200k | context->get_type_mgr()->GetType(inst->type_id()); |
1117 | | |
1118 | 200k | if (IsCooperativeMatrix(type)) { |
1119 | 0 | return false; |
1120 | 0 | } |
1121 | | |
1122 | 200k | bool uses_float = HasFloatingPoint(type); |
1123 | 200k | if (uses_float && !inst->IsFloatingPointFoldingAllowed()) return false; |
1124 | | |
1125 | 200k | uint32_t width = ElementWidth(type); |
1126 | 200k | if (width != 32 && width != 64) return false; |
1127 | | |
1128 | 200k | const analysis::Constant* const_input1 = ConstInput(constants); |
1129 | 200k | if (!const_input1) return false; |
1130 | 50.0k | Instruction* other_inst = NonConstInput(context, constants[0], inst); |
1131 | 50.0k | if (uses_float && !other_inst->IsFloatingPointFoldingAllowed()) |
1132 | 5 | return false; |
1133 | | |
1134 | 50.0k | if (other_inst->opcode() == spv::Op::OpSNegate || |
1135 | 50.0k | other_inst->opcode() == spv::Op::OpFNegate) { |
1136 | 13 | uint32_t op1 = 0; |
1137 | 13 | uint32_t op2 = 0; |
1138 | 13 | spv::Op opcode = inst->opcode(); |
1139 | 13 | if (constants[0] != nullptr) { |
1140 | 9 | op1 = other_inst->GetSingleWordInOperand(0u); |
1141 | 9 | op2 = inst->GetSingleWordInOperand(0u); |
1142 | 9 | opcode = HasFloatingPoint(type) ? spv::Op::OpFAdd : spv::Op::OpIAdd; |
1143 | 9 | } else { |
1144 | 4 | op1 = NegateConstant(const_mgr, const_input1); |
1145 | 4 | op2 = other_inst->GetSingleWordInOperand(0u); |
1146 | 4 | } |
1147 | | |
1148 | 13 | inst->SetOpcode(opcode); |
1149 | 13 | inst->SetInOperands( |
1150 | 13 | {{SPV_OPERAND_TYPE_ID, {op1}}, {SPV_OPERAND_TYPE_ID, {op2}}}); |
1151 | 13 | return true; |
1152 | 13 | } |
1153 | 50.0k | return false; |
1154 | 50.0k | }; |
1155 | 16.3k | } |
1156 | | |
1157 | | // Folds addition of an addition where each operation has a constant operand. |
1158 | | // Cases: |
1159 | | // (x + 2) + 2 = x + 4 |
1160 | | // (2 + x) + 2 = x + 4 |
1161 | | // 2 + (x + 2) = x + 4 |
1162 | | // 2 + (2 + x) = x + 4 |
1163 | 16.3k | FoldingRule MergeAddAddArithmetic() { |
1164 | 16.3k | return [](IRContext* context, Instruction* inst, |
1165 | 1.81M | const std::vector<const analysis::Constant*>& constants) { |
1166 | 1.81M | assert(inst->opcode() == spv::Op::OpFAdd || |
1167 | 1.81M | inst->opcode() == spv::Op::OpIAdd); |
1168 | 1.81M | const analysis::Type* type = |
1169 | 1.81M | context->get_type_mgr()->GetType(inst->type_id()); |
1170 | | |
1171 | 1.81M | if (IsCooperativeMatrix(type)) { |
1172 | 0 | return false; |
1173 | 0 | } |
1174 | | |
1175 | 1.81M | analysis::ConstantManager* const_mgr = context->get_constant_mgr(); |
1176 | 1.81M | bool uses_float = HasFloatingPoint(type); |
1177 | 1.81M | if (uses_float && !inst->IsFloatingPointFoldingAllowed()) return false; |
1178 | | |
1179 | 1.81M | uint32_t width = ElementWidth(type); |
1180 | 1.81M | if (width != 32 && width != 64) return false; |
1181 | | |
1182 | 1.81M | const analysis::Constant* const_input1 = ConstInput(constants); |
1183 | 1.81M | if (!const_input1) return false; |
1184 | 428k | Instruction* other_inst = NonConstInput(context, constants[0], inst); |
1185 | 428k | if (uses_float && !other_inst->IsFloatingPointFoldingAllowed()) |
1186 | 13.8k | return false; |
1187 | | |
1188 | 415k | if (other_inst->opcode() == spv::Op::OpFAdd || |
1189 | 415k | other_inst->opcode() == spv::Op::OpIAdd) { |
1190 | 5.83k | std::vector<const analysis::Constant*> other_constants = |
1191 | 5.83k | const_mgr->GetOperandConstants(other_inst); |
1192 | 5.83k | const analysis::Constant* const_input2 = ConstInput(other_constants); |
1193 | 5.83k | if (!const_input2) return false; |
1194 | | |
1195 | 3.89k | Instruction* non_const_input = |
1196 | 3.89k | NonConstInput(context, other_constants[0], other_inst); |
1197 | 3.89k | uint32_t merged_id = PerformOperation(const_mgr, inst->opcode(), |
1198 | 3.89k | const_input1, const_input2); |
1199 | 3.89k | if (merged_id == 0) return false; |
1200 | | |
1201 | 2.70k | inst->SetInOperands( |
1202 | 2.70k | {{SPV_OPERAND_TYPE_ID, {non_const_input->result_id()}}, |
1203 | 2.70k | {SPV_OPERAND_TYPE_ID, {merged_id}}}); |
1204 | 2.70k | return true; |
1205 | 3.89k | } |
1206 | 409k | return false; |
1207 | 415k | }; |
1208 | 16.3k | } |
1209 | | |
1210 | | // Folds addition of a subtraction where each operation has a constant operand. |
1211 | | // Cases: |
1212 | | // (x - 2) + 2 = x + 0 |
1213 | | // (2 - x) + 2 = 4 - x |
1214 | | // 2 + (x - 2) = x + 0 |
1215 | | // 2 + (2 - x) = 4 - x |
1216 | 16.3k | FoldingRule MergeAddSubArithmetic() { |
1217 | 16.3k | return [](IRContext* context, Instruction* inst, |
1218 | 1.81M | const std::vector<const analysis::Constant*>& constants) { |
1219 | 1.81M | assert(inst->opcode() == spv::Op::OpFAdd || |
1220 | 1.81M | inst->opcode() == spv::Op::OpIAdd); |
1221 | 1.81M | const analysis::Type* type = |
1222 | 1.81M | context->get_type_mgr()->GetType(inst->type_id()); |
1223 | | |
1224 | 1.81M | if (IsCooperativeMatrix(type)) { |
1225 | 0 | return false; |
1226 | 0 | } |
1227 | | |
1228 | 1.81M | analysis::ConstantManager* const_mgr = context->get_constant_mgr(); |
1229 | 1.81M | bool uses_float = HasFloatingPoint(type); |
1230 | 1.81M | if (uses_float && !inst->IsFloatingPointFoldingAllowed()) return false; |
1231 | | |
1232 | 1.81M | uint32_t width = ElementWidth(type); |
1233 | 1.81M | if (width != 32 && width != 64) return false; |
1234 | | |
1235 | 1.81M | const analysis::Constant* const_input1 = ConstInput(constants); |
1236 | 1.81M | if (!const_input1) return false; |
1237 | 426k | Instruction* other_inst = NonConstInput(context, constants[0], inst); |
1238 | 426k | if (uses_float && !other_inst->IsFloatingPointFoldingAllowed()) |
1239 | 13.8k | return false; |
1240 | | |
1241 | 412k | if (other_inst->opcode() == spv::Op::OpFSub || |
1242 | 412k | other_inst->opcode() == spv::Op::OpISub) { |
1243 | 2.35k | std::vector<const analysis::Constant*> other_constants = |
1244 | 2.35k | const_mgr->GetOperandConstants(other_inst); |
1245 | 2.35k | const analysis::Constant* const_input2 = ConstInput(other_constants); |
1246 | 2.35k | if (!const_input2) return false; |
1247 | | |
1248 | 2.26k | bool first_is_variable = other_constants[0] == nullptr; |
1249 | 2.26k | spv::Op op = inst->opcode(); |
1250 | 2.26k | uint32_t op1 = 0; |
1251 | 2.26k | uint32_t op2 = 0; |
1252 | 2.26k | if (first_is_variable) { |
1253 | | // Subtract constants. Non-constant operand is first. |
1254 | 2.25k | op1 = other_inst->GetSingleWordInOperand(0u); |
1255 | 2.25k | op2 = PerformOperation(const_mgr, other_inst->opcode(), const_input1, |
1256 | 2.25k | const_input2); |
1257 | 2.25k | } else { |
1258 | | // Add constants. Constant operand is first. Change the opcode. |
1259 | 12 | op1 = PerformOperation(const_mgr, inst->opcode(), const_input1, |
1260 | 12 | const_input2); |
1261 | 12 | op2 = other_inst->GetSingleWordInOperand(1u); |
1262 | 12 | op = other_inst->opcode(); |
1263 | 12 | } |
1264 | 2.26k | if (op1 == 0 || op2 == 0) return false; |
1265 | | |
1266 | 2.23k | inst->SetOpcode(op); |
1267 | 2.23k | inst->SetInOperands( |
1268 | 2.23k | {{SPV_OPERAND_TYPE_ID, {op1}}, {SPV_OPERAND_TYPE_ID, {op2}}}); |
1269 | 2.23k | return true; |
1270 | 2.26k | } |
1271 | 410k | return false; |
1272 | 412k | }; |
1273 | 16.3k | } |
1274 | | |
1275 | | // Folds subtraction of an addition where each operand has a constant operand. |
1276 | | // Cases: |
1277 | | // (x + 2) - 2 = x + 0 |
1278 | | // (2 + x) - 2 = x + 0 |
1279 | | // 2 - (x + 2) = 0 - x |
1280 | | // 2 - (2 + x) = 0 - x |
1281 | 16.3k | FoldingRule MergeSubAddArithmetic() { |
1282 | 16.3k | return [](IRContext* context, Instruction* inst, |
1283 | 200k | const std::vector<const analysis::Constant*>& constants) { |
1284 | 200k | assert(inst->opcode() == spv::Op::OpFSub || |
1285 | 200k | inst->opcode() == spv::Op::OpISub); |
1286 | 200k | const analysis::Type* type = |
1287 | 200k | context->get_type_mgr()->GetType(inst->type_id()); |
1288 | | |
1289 | 200k | if (IsCooperativeMatrix(type)) { |
1290 | 0 | return false; |
1291 | 0 | } |
1292 | | |
1293 | 200k | analysis::ConstantManager* const_mgr = context->get_constant_mgr(); |
1294 | 200k | bool uses_float = HasFloatingPoint(type); |
1295 | 200k | if (uses_float && !inst->IsFloatingPointFoldingAllowed()) return false; |
1296 | | |
1297 | 200k | uint32_t width = ElementWidth(type); |
1298 | 200k | if (width != 32 && width != 64) return false; |
1299 | | |
1300 | 200k | const analysis::Constant* const_input1 = ConstInput(constants); |
1301 | 200k | if (!const_input1) return false; |
1302 | 50.0k | Instruction* other_inst = NonConstInput(context, constants[0], inst); |
1303 | 50.0k | if (uses_float && !other_inst->IsFloatingPointFoldingAllowed()) |
1304 | 5 | return false; |
1305 | | |
1306 | 50.0k | if (other_inst->opcode() == spv::Op::OpFAdd || |
1307 | 50.0k | other_inst->opcode() == spv::Op::OpIAdd) { |
1308 | 6.57k | std::vector<const analysis::Constant*> other_constants = |
1309 | 6.57k | const_mgr->GetOperandConstants(other_inst); |
1310 | 6.57k | const analysis::Constant* const_input2 = ConstInput(other_constants); |
1311 | 6.57k | if (!const_input2) return false; |
1312 | | |
1313 | 1.33k | Instruction* non_const_input = |
1314 | 1.33k | NonConstInput(context, other_constants[0], other_inst); |
1315 | | |
1316 | | // If the first operand of the sub is not a constant, swap the constants |
1317 | | // so the subtraction has the correct operands. |
1318 | 1.33k | if (constants[0] == nullptr) std::swap(const_input1, const_input2); |
1319 | | // Subtract the constants. |
1320 | 1.33k | uint32_t merged_id = PerformOperation(const_mgr, inst->opcode(), |
1321 | 1.33k | const_input1, const_input2); |
1322 | 1.33k | spv::Op op = inst->opcode(); |
1323 | 1.33k | uint32_t op1 = 0; |
1324 | 1.33k | uint32_t op2 = 0; |
1325 | 1.33k | if (constants[0] == nullptr) { |
1326 | | // Non-constant operand is first. Change the opcode. |
1327 | 1.30k | op1 = non_const_input->result_id(); |
1328 | 1.30k | op2 = merged_id; |
1329 | 1.30k | op = other_inst->opcode(); |
1330 | 1.30k | } else { |
1331 | | // Constant operand is first. |
1332 | 38 | op1 = merged_id; |
1333 | 38 | op2 = non_const_input->result_id(); |
1334 | 38 | } |
1335 | 1.33k | if (op1 == 0 || op2 == 0) return false; |
1336 | | |
1337 | 1.22k | inst->SetOpcode(op); |
1338 | 1.22k | inst->SetInOperands( |
1339 | 1.22k | {{SPV_OPERAND_TYPE_ID, {op1}}, {SPV_OPERAND_TYPE_ID, {op2}}}); |
1340 | 1.22k | return true; |
1341 | 1.33k | } |
1342 | 43.4k | return false; |
1343 | 50.0k | }; |
1344 | 16.3k | } |
1345 | | |
1346 | | // Folds subtraction of a subtraction where each operand has a constant operand. |
1347 | | // Cases: |
1348 | | // (x - 2) - 2 = x - 4 |
1349 | | // (2 - x) - 2 = 0 - x |
1350 | | // 2 - (x - 2) = 4 - x |
1351 | | // 2 - (2 - x) = x + 0 |
1352 | 16.3k | FoldingRule MergeSubSubArithmetic() { |
1353 | 16.3k | return [](IRContext* context, Instruction* inst, |
1354 | 199k | const std::vector<const analysis::Constant*>& constants) { |
1355 | 199k | assert(inst->opcode() == spv::Op::OpFSub || |
1356 | 199k | inst->opcode() == spv::Op::OpISub); |
1357 | 199k | const analysis::Type* type = |
1358 | 199k | context->get_type_mgr()->GetType(inst->type_id()); |
1359 | | |
1360 | 199k | if (IsCooperativeMatrix(type)) { |
1361 | 0 | return false; |
1362 | 0 | } |
1363 | | |
1364 | 199k | analysis::ConstantManager* const_mgr = context->get_constant_mgr(); |
1365 | 199k | bool uses_float = HasFloatingPoint(type); |
1366 | 199k | if (uses_float && !inst->IsFloatingPointFoldingAllowed()) return false; |
1367 | | |
1368 | 199k | uint32_t width = ElementWidth(type); |
1369 | 199k | if (width != 32 && width != 64) return false; |
1370 | | |
1371 | 199k | const analysis::Constant* const_input1 = ConstInput(constants); |
1372 | 199k | if (!const_input1) return false; |
1373 | 48.8k | Instruction* other_inst = NonConstInput(context, constants[0], inst); |
1374 | 48.8k | if (uses_float && !other_inst->IsFloatingPointFoldingAllowed()) |
1375 | 5 | return false; |
1376 | | |
1377 | 48.7k | if (other_inst->opcode() == spv::Op::OpFSub || |
1378 | 48.7k | other_inst->opcode() == spv::Op::OpISub) { |
1379 | 3.16k | std::vector<const analysis::Constant*> other_constants = |
1380 | 3.16k | const_mgr->GetOperandConstants(other_inst); |
1381 | 3.16k | const analysis::Constant* const_input2 = ConstInput(other_constants); |
1382 | 3.16k | if (!const_input2) return false; |
1383 | | |
1384 | 2.96k | Instruction* non_const_input = |
1385 | 2.96k | NonConstInput(context, other_constants[0], other_inst); |
1386 | | |
1387 | | // Merge the constants. |
1388 | 2.96k | uint32_t merged_id = 0; |
1389 | 2.96k | spv::Op merge_op = inst->opcode(); |
1390 | 2.96k | if (other_constants[0] == nullptr) { |
1391 | 2.81k | merge_op = uses_float ? spv::Op::OpFAdd : spv::Op::OpIAdd; |
1392 | 2.81k | } else if (constants[0] == nullptr) { |
1393 | 31 | std::swap(const_input1, const_input2); |
1394 | 31 | } |
1395 | 2.96k | merged_id = |
1396 | 2.96k | PerformOperation(const_mgr, merge_op, const_input1, const_input2); |
1397 | 2.96k | if (merged_id == 0) return false; |
1398 | | |
1399 | 2.09k | spv::Op op = inst->opcode(); |
1400 | 2.09k | if (constants[0] != nullptr && other_constants[0] != nullptr) { |
1401 | | // Change the operation. |
1402 | 120 | op = uses_float ? spv::Op::OpFAdd : spv::Op::OpIAdd; |
1403 | 120 | } |
1404 | | |
1405 | 2.09k | uint32_t op1 = 0; |
1406 | 2.09k | uint32_t op2 = 0; |
1407 | 2.09k | if ((constants[0] == nullptr) ^ (other_constants[0] == nullptr)) { |
1408 | 22 | op1 = merged_id; |
1409 | 22 | op2 = non_const_input->result_id(); |
1410 | 2.07k | } else { |
1411 | 2.07k | op1 = non_const_input->result_id(); |
1412 | 2.07k | op2 = merged_id; |
1413 | 2.07k | } |
1414 | | |
1415 | 2.09k | inst->SetOpcode(op); |
1416 | 2.09k | inst->SetInOperands( |
1417 | 2.09k | {{SPV_OPERAND_TYPE_ID, {op1}}, {SPV_OPERAND_TYPE_ID, {op2}}}); |
1418 | 2.09k | return true; |
1419 | 2.96k | } |
1420 | 45.6k | return false; |
1421 | 48.7k | }; |
1422 | 16.3k | } |
1423 | | |
1424 | | // Helper function for MergeGenericAddSubArithmetic. If |addend| and |
1425 | | // subtrahend of |sub| is the same, merge to copy of minuend of |sub|. |
1426 | 3.62M | bool MergeGenericAddendSub(uint32_t addend, uint32_t sub, Instruction* inst) { |
1427 | 3.62M | IRContext* context = inst->context(); |
1428 | 3.62M | analysis::DefUseManager* def_use_mgr = context->get_def_use_mgr(); |
1429 | 3.62M | Instruction* sub_inst = def_use_mgr->GetDef(sub); |
1430 | 3.62M | if (sub_inst->opcode() != spv::Op::OpFSub && |
1431 | 3.62M | sub_inst->opcode() != spv::Op::OpISub) |
1432 | 3.62M | return false; |
1433 | 2.66k | if (sub_inst->opcode() == spv::Op::OpFSub && |
1434 | 2.66k | !sub_inst->IsFloatingPointFoldingAllowed()) |
1435 | 0 | return false; |
1436 | 2.66k | if (addend != sub_inst->GetSingleWordInOperand(1)) return false; |
1437 | 5 | inst->SetOpcode(spv::Op::OpCopyObject); |
1438 | 5 | inst->SetInOperands( |
1439 | 5 | {{SPV_OPERAND_TYPE_ID, {sub_inst->GetSingleWordInOperand(0)}}}); |
1440 | 5 | context->UpdateDefUse(inst); |
1441 | 5 | return true; |
1442 | 2.66k | } |
1443 | | |
1444 | | // Folds addition of a subtraction where the subtrahend is equal to the |
1445 | | // other addend. Return a copy of the minuend. Accepts generic (const and |
1446 | | // non-const) operands. |
1447 | | // Cases: |
1448 | | // (a - b) + b = a |
1449 | | // b + (a - b) = a |
1450 | 16.3k | FoldingRule MergeGenericAddSubArithmetic() { |
1451 | 16.3k | return [](IRContext* context, Instruction* inst, |
1452 | 1.81M | const std::vector<const analysis::Constant*>&) { |
1453 | 1.81M | assert(inst->opcode() == spv::Op::OpFAdd || |
1454 | 1.81M | inst->opcode() == spv::Op::OpIAdd); |
1455 | 1.81M | const analysis::Type* type = |
1456 | 1.81M | context->get_type_mgr()->GetType(inst->type_id()); |
1457 | | |
1458 | 1.81M | if (IsCooperativeMatrix(type)) { |
1459 | 0 | return false; |
1460 | 0 | } |
1461 | | |
1462 | 1.81M | bool uses_float = HasFloatingPoint(type); |
1463 | 1.81M | if (uses_float && !inst->IsFloatingPointFoldingAllowed()) return false; |
1464 | | |
1465 | 1.81M | uint32_t width = ElementWidth(type); |
1466 | 1.81M | if (width != 32 && width != 64) return false; |
1467 | | |
1468 | 1.81M | uint32_t add_op0 = inst->GetSingleWordInOperand(0); |
1469 | 1.81M | uint32_t add_op1 = inst->GetSingleWordInOperand(1); |
1470 | 1.81M | if (MergeGenericAddendSub(add_op0, add_op1, inst)) return true; |
1471 | 1.81M | return MergeGenericAddendSub(add_op1, add_op0, inst); |
1472 | 1.81M | }; |
1473 | 16.3k | } |
1474 | | |
1475 | | // Helper function for FactorAddMuls. If |factor0_0| is the same as |factor1_0|, |
1476 | | // generate |factor0_0| * (|factor0_1| + |factor1_1|). |
1477 | | bool FactorAddMulsOpnds(uint32_t factor0_0, uint32_t factor0_1, |
1478 | | uint32_t factor1_0, uint32_t factor1_1, |
1479 | 1.32k | Instruction* inst) { |
1480 | 1.32k | IRContext* context = inst->context(); |
1481 | 1.32k | if (factor0_0 != factor1_0) return false; |
1482 | 15 | InstructionBuilder ir_builder( |
1483 | 15 | context, inst, |
1484 | 15 | IRContext::kAnalysisDefUse | IRContext::kAnalysisInstrToBlockMapping); |
1485 | 15 | Instruction* new_add_inst = ir_builder.AddBinaryOp( |
1486 | 15 | inst->type_id(), inst->opcode(), factor0_1, factor1_1); |
1487 | 15 | inst->SetOpcode(inst->opcode() == spv::Op::OpFAdd ? spv::Op::OpFMul |
1488 | 15 | : spv::Op::OpIMul); |
1489 | 15 | inst->SetInOperands({{SPV_OPERAND_TYPE_ID, {factor0_0}}, |
1490 | 15 | {SPV_OPERAND_TYPE_ID, {new_add_inst->result_id()}}}); |
1491 | 15 | context->UpdateDefUse(inst); |
1492 | 15 | return true; |
1493 | 1.32k | } |
1494 | | |
1495 | | // Perform the following factoring identity, handling all operand order |
1496 | | // combinations: (a * b) + (a * c) = a * (b + c) |
1497 | 16.3k | FoldingRule FactorAddMuls() { |
1498 | 16.3k | return [](IRContext* context, Instruction* inst, |
1499 | 1.81M | const std::vector<const analysis::Constant*>&) { |
1500 | 1.81M | assert(inst->opcode() == spv::Op::OpFAdd || |
1501 | 1.81M | inst->opcode() == spv::Op::OpIAdd); |
1502 | 1.81M | const analysis::Type* type = |
1503 | 1.81M | context->get_type_mgr()->GetType(inst->type_id()); |
1504 | 1.81M | bool uses_float = HasFloatingPoint(type); |
1505 | 1.81M | if (uses_float && !inst->IsFloatingPointFoldingAllowed()) return false; |
1506 | | |
1507 | 1.81M | analysis::DefUseManager* def_use_mgr = context->get_def_use_mgr(); |
1508 | 1.81M | uint32_t add_op0 = inst->GetSingleWordInOperand(0); |
1509 | 1.81M | Instruction* add_op0_inst = def_use_mgr->GetDef(add_op0); |
1510 | 1.81M | if (add_op0_inst->opcode() != spv::Op::OpFMul && |
1511 | 1.81M | add_op0_inst->opcode() != spv::Op::OpIMul) |
1512 | 1.80M | return false; |
1513 | 4.60k | uint32_t add_op1 = inst->GetSingleWordInOperand(1); |
1514 | 4.60k | Instruction* add_op1_inst = def_use_mgr->GetDef(add_op1); |
1515 | 4.60k | if (add_op1_inst->opcode() != spv::Op::OpFMul && |
1516 | 4.60k | add_op1_inst->opcode() != spv::Op::OpIMul) |
1517 | 3.63k | return false; |
1518 | | |
1519 | | // Only perform this optimization if both of the muls only have one use. |
1520 | | // Otherwise this is a deoptimization in size and performance. |
1521 | 972 | if (def_use_mgr->NumUses(add_op0_inst) > 1) return false; |
1522 | 363 | if (def_use_mgr->NumUses(add_op1_inst) > 1) return false; |
1523 | | |
1524 | 342 | if (add_op0_inst->opcode() == spv::Op::OpFMul && |
1525 | 342 | (!add_op0_inst->IsFloatingPointFoldingAllowed() || |
1526 | 305 | !add_op1_inst->IsFloatingPointFoldingAllowed())) |
1527 | 0 | return false; |
1528 | | |
1529 | 996 | for (int i = 0; i < 2; i++) { |
1530 | 1.97k | for (int j = 0; j < 2; j++) { |
1531 | | // Check if operand i in add_op0_inst matches operand j in add_op1_inst. |
1532 | 1.32k | if (FactorAddMulsOpnds(add_op0_inst->GetSingleWordInOperand(i), |
1533 | 1.32k | add_op0_inst->GetSingleWordInOperand(1 - i), |
1534 | 1.32k | add_op1_inst->GetSingleWordInOperand(j), |
1535 | 1.32k | add_op1_inst->GetSingleWordInOperand(1 - j), |
1536 | 1.32k | inst)) |
1537 | 15 | return true; |
1538 | 1.32k | } |
1539 | 669 | } |
1540 | 327 | return false; |
1541 | 342 | }; |
1542 | 16.3k | } |
1543 | | |
1544 | 8.16k | FoldingRule IntMultipleBy1() { |
1545 | 8.16k | return [](IRContext*, Instruction* inst, |
1546 | 8.99k | const std::vector<const analysis::Constant*>& constants) { |
1547 | 8.99k | assert(inst->opcode() == spv::Op::OpIMul && |
1548 | 8.99k | "Wrong opcode. Should be OpIMul."); |
1549 | 26.7k | for (uint32_t i = 0; i < 2; i++) { |
1550 | 17.9k | if (constants[i] == nullptr) { |
1551 | 10.2k | continue; |
1552 | 10.2k | } |
1553 | 7.77k | const analysis::IntConstant* int_constant = constants[i]->AsIntConstant(); |
1554 | 7.77k | if (int_constant) { |
1555 | 7.77k | uint32_t width = ElementWidth(int_constant->type()); |
1556 | 7.77k | if (width != 32 && width != 64) return false; |
1557 | 7.77k | bool is_one = (width == 32) ? int_constant->GetU32BitValue() == 1u |
1558 | 7.77k | : int_constant->GetU64BitValue() == 1ull; |
1559 | 7.77k | if (is_one) { |
1560 | 243 | inst->SetOpcode(spv::Op::OpCopyObject); |
1561 | 243 | inst->SetInOperands( |
1562 | 243 | {{SPV_OPERAND_TYPE_ID, {inst->GetSingleWordInOperand(1 - i)}}}); |
1563 | 243 | return true; |
1564 | 243 | } |
1565 | 7.77k | } |
1566 | 7.77k | } |
1567 | 8.74k | return false; |
1568 | 8.99k | }; |
1569 | 8.16k | } |
1570 | | |
1571 | | // Returns the number of elements that the |index|th in operand in |inst| |
1572 | | // contributes to the result of |inst|. |inst| must be an |
1573 | | // OpCompositeConstructInstruction. |
1574 | | uint32_t GetNumOfElementsContributedByOperand(IRContext* context, |
1575 | | const Instruction* inst, |
1576 | 6.48k | uint32_t index) { |
1577 | 6.48k | assert(inst->opcode() == spv::Op::OpCompositeConstruct); |
1578 | 6.48k | analysis::DefUseManager* def_use_mgr = context->get_def_use_mgr(); |
1579 | 6.48k | analysis::TypeManager* type_mgr = context->get_type_mgr(); |
1580 | | |
1581 | 6.48k | analysis::Vector* result_type = |
1582 | 6.48k | type_mgr->GetType(inst->type_id())->AsVector(); |
1583 | 6.48k | if (result_type == nullptr) { |
1584 | | // If the result of the OpCompositeConstruct is not a vector then every |
1585 | | // operands corresponds to a single element in the result. |
1586 | 0 | return 1; |
1587 | 0 | } |
1588 | | |
1589 | | // If the result type is a vector then the operands are either scalars or |
1590 | | // vectors. If it is a scalar, then it corresponds to a single element. If it |
1591 | | // is a vector, then each element in the vector will be an element in the |
1592 | | // result. |
1593 | 6.48k | uint32_t id = inst->GetSingleWordInOperand(index); |
1594 | 6.48k | Instruction* def = def_use_mgr->GetDef(id); |
1595 | 6.48k | analysis::Vector* type = type_mgr->GetType(def->type_id())->AsVector(); |
1596 | 6.48k | if (type == nullptr) { |
1597 | 6.48k | return 1; |
1598 | 6.48k | } |
1599 | 0 | return type->element_count(); |
1600 | 6.48k | } |
1601 | | |
1602 | | // Returns the in-operands for an OpCompositeExtract instruction that are needed |
1603 | | // to extract the |result_index|th element in the result of |inst| without using |
1604 | | // the result of |inst|. Returns the empty vector if |result_index| is |
1605 | | // out-of-bounds. |inst| must be an |OpCompositeConstruct| instruction. |
1606 | | std::vector<Operand> GetExtractOperandsForElementOfCompositeConstruct( |
1607 | 27.3k | IRContext* context, const Instruction* inst, uint32_t result_index) { |
1608 | 27.3k | assert(inst->opcode() == spv::Op::OpCompositeConstruct); |
1609 | 27.3k | analysis::DefUseManager* def_use_mgr = context->get_def_use_mgr(); |
1610 | 27.3k | analysis::TypeManager* type_mgr = context->get_type_mgr(); |
1611 | | |
1612 | 27.3k | analysis::Type* result_type = type_mgr->GetType(inst->type_id()); |
1613 | 27.3k | if (result_type->AsVector() == nullptr) { |
1614 | 23.7k | if (result_index < inst->NumInOperands()) { |
1615 | 23.7k | uint32_t id = inst->GetSingleWordInOperand(result_index); |
1616 | 23.7k | return {Operand(SPV_OPERAND_TYPE_ID, {id})}; |
1617 | 23.7k | } |
1618 | 5 | return {}; |
1619 | 23.7k | } |
1620 | | |
1621 | | // If the result type is a vector, then vector operands are concatenated. |
1622 | 3.57k | uint32_t total_element_count = 0; |
1623 | 6.48k | for (uint32_t idx = 0; idx < inst->NumInOperands(); ++idx) { |
1624 | 6.48k | uint32_t element_count = |
1625 | 6.48k | GetNumOfElementsContributedByOperand(context, inst, idx); |
1626 | 6.48k | total_element_count += element_count; |
1627 | 6.48k | if (result_index < total_element_count) { |
1628 | 3.57k | std::vector<Operand> operands; |
1629 | 3.57k | uint32_t id = inst->GetSingleWordInOperand(idx); |
1630 | 3.57k | Instruction* operand_def = def_use_mgr->GetDef(id); |
1631 | 3.57k | analysis::Type* operand_type = type_mgr->GetType(operand_def->type_id()); |
1632 | | |
1633 | 3.57k | operands.push_back({SPV_OPERAND_TYPE_ID, {id}}); |
1634 | 3.57k | if (operand_type->AsVector()) { |
1635 | 0 | uint32_t start_index_of_id = total_element_count - element_count; |
1636 | 0 | uint32_t index_into_id = result_index - start_index_of_id; |
1637 | 0 | operands.push_back({SPV_OPERAND_TYPE_LITERAL_INTEGER, {index_into_id}}); |
1638 | 0 | } |
1639 | 3.57k | return operands; |
1640 | 3.57k | } |
1641 | 6.48k | } |
1642 | 0 | return {}; |
1643 | 3.57k | } |
1644 | | |
1645 | | bool CompositeConstructFeedingExtract( |
1646 | | IRContext* context, Instruction* inst, |
1647 | 150k | const std::vector<const analysis::Constant*>&) { |
1648 | | // If the input to an OpCompositeExtract is an OpCompositeConstruct, |
1649 | | // then we can simply use the appropriate element in the construction. |
1650 | 150k | assert(inst->opcode() == spv::Op::OpCompositeExtract && |
1651 | 150k | "Wrong opcode. Should be OpCompositeExtract."); |
1652 | 150k | analysis::DefUseManager* def_use_mgr = context->get_def_use_mgr(); |
1653 | | |
1654 | | // If there are no index operands, then this rule cannot do anything. |
1655 | 150k | if (inst->NumInOperands() <= 1) { |
1656 | 0 | return false; |
1657 | 0 | } |
1658 | | |
1659 | 150k | uint32_t cid = inst->GetSingleWordInOperand(kExtractCompositeIdInIdx); |
1660 | 150k | Instruction* cinst = def_use_mgr->GetDef(cid); |
1661 | | |
1662 | 150k | if (cinst->opcode() != spv::Op::OpCompositeConstruct) { |
1663 | 122k | return false; |
1664 | 122k | } |
1665 | | |
1666 | 27.3k | uint32_t index_into_result = inst->GetSingleWordInOperand(1); |
1667 | 27.3k | std::vector<Operand> operands = |
1668 | 27.3k | GetExtractOperandsForElementOfCompositeConstruct(context, cinst, |
1669 | 27.3k | index_into_result); |
1670 | | |
1671 | 27.3k | if (operands.empty()) { |
1672 | 5 | return false; |
1673 | 5 | } |
1674 | | |
1675 | | // Add the remaining indices for extraction. |
1676 | 27.2k | for (uint32_t i = 2; i < inst->NumInOperands(); ++i) { |
1677 | 0 | operands.push_back( |
1678 | 0 | {SPV_OPERAND_TYPE_LITERAL_INTEGER, {inst->GetSingleWordInOperand(i)}}); |
1679 | 0 | } |
1680 | | |
1681 | 27.2k | if (operands.size() == 1) { |
1682 | | // If there were no extra indices, then we have the final object. No need |
1683 | | // to extract any more. |
1684 | 27.2k | inst->SetOpcode(spv::Op::OpCopyObject); |
1685 | 27.2k | } |
1686 | | |
1687 | 27.2k | inst->SetInOperands(std::move(operands)); |
1688 | 27.2k | return true; |
1689 | 27.3k | } |
1690 | | |
1691 | | // Walks the indexes chain from |start| to |end| of an OpCompositeInsert or |
1692 | | // OpCompositeExtract instruction, and returns the type id of the final element |
1693 | | // being accessed. Returns 0 if a valid type could not be found. |
1694 | | uint32_t GetElementType(uint32_t type_id, Instruction::iterator start, |
1695 | | Instruction::iterator end, |
1696 | 186k | const analysis::DefUseManager* def_use_manager) { |
1697 | 186k | for (auto index : make_range(std::move(start), std::move(end))) { |
1698 | 72 | const Instruction* type_inst = def_use_manager->GetDef(type_id); |
1699 | 72 | assert(index.type == SPV_OPERAND_TYPE_LITERAL_INTEGER && |
1700 | 72 | index.words.size() == 1); |
1701 | 72 | if (type_inst->opcode() == spv::Op::OpTypeArray) { |
1702 | 72 | type_id = type_inst->GetSingleWordInOperand(0); |
1703 | 72 | } else if (type_inst->opcode() == spv::Op::OpTypeMatrix) { |
1704 | 0 | type_id = type_inst->GetSingleWordInOperand(0); |
1705 | 0 | } else if (type_inst->opcode() == spv::Op::OpTypeStruct) { |
1706 | 0 | type_id = type_inst->GetSingleWordInOperand(index.words[0]); |
1707 | 0 | } else { |
1708 | 0 | return 0; |
1709 | 0 | } |
1710 | 72 | } |
1711 | 186k | return type_id; |
1712 | 186k | } |
1713 | | |
1714 | | // Returns true of |inst_1| and |inst_2| have the same indexes that will be used |
1715 | | // to index into a composite object, excluding the last index. The two |
1716 | | // instructions must have the same opcode, and be either OpCompositeExtract or |
1717 | | // OpCompositeInsert instructions. |
1718 | 104M | bool HaveSameIndexesExceptForLast(Instruction* inst_1, Instruction* inst_2) { |
1719 | 104M | assert(inst_1->opcode() == inst_2->opcode() && |
1720 | 104M | "Expecting the opcodes to be the same."); |
1721 | 104M | assert((inst_1->opcode() == spv::Op::OpCompositeInsert || |
1722 | 104M | inst_1->opcode() == spv::Op::OpCompositeExtract) && |
1723 | 104M | "Instructions must be OpCompositeInsert or OpCompositeExtract."); |
1724 | | |
1725 | 104M | if (inst_1->NumInOperands() != inst_2->NumInOperands()) { |
1726 | 108 | return false; |
1727 | 108 | } |
1728 | | |
1729 | 104M | uint32_t first_index_position = |
1730 | 104M | (inst_1->opcode() == spv::Op::OpCompositeInsert ? 2 : 1); |
1731 | 104M | for (uint32_t i = first_index_position; i < inst_1->NumInOperands() - 1; |
1732 | 104M | i++) { |
1733 | 165 | if (inst_1->GetSingleWordInOperand(i) != |
1734 | 165 | inst_2->GetSingleWordInOperand(i)) { |
1735 | 0 | return false; |
1736 | 0 | } |
1737 | 165 | } |
1738 | 104M | return true; |
1739 | 104M | } |
1740 | | |
1741 | | // If the OpCompositeConstruct is simply putting back together elements that |
1742 | | // where extracted from the same source, we can simply reuse the source. |
1743 | | // |
1744 | | // This is a common code pattern because of the way that scalar replacement |
1745 | | // works. |
1746 | | bool CompositeExtractFeedingConstruct( |
1747 | | IRContext* context, Instruction* inst, |
1748 | 135k | const std::vector<const analysis::Constant*>&) { |
1749 | 135k | assert(inst->opcode() == spv::Op::OpCompositeConstruct && |
1750 | 135k | "Wrong opcode. Should be OpCompositeConstruct."); |
1751 | 135k | analysis::DefUseManager* def_use_mgr = context->get_def_use_mgr(); |
1752 | 135k | uint32_t original_id = 0; |
1753 | | |
1754 | 135k | if (inst->NumInOperands() == 0) { |
1755 | | // The struct being constructed has no members. |
1756 | 0 | return false; |
1757 | 0 | } |
1758 | | |
1759 | | // Check each element to make sure they are: |
1760 | | // - extractions |
1761 | | // - extracting the same position they are inserting |
1762 | | // - all extract from the same id. |
1763 | 135k | Instruction* first_element_inst = nullptr; |
1764 | 147k | for (uint32_t i = 0; i < inst->NumInOperands(); ++i) { |
1765 | 147k | const uint32_t element_id = inst->GetSingleWordInOperand(i); |
1766 | 147k | Instruction* element_inst = def_use_mgr->GetDef(element_id); |
1767 | 147k | if (first_element_inst == nullptr) { |
1768 | 135k | first_element_inst = element_inst; |
1769 | 135k | } |
1770 | | |
1771 | 147k | if (element_inst->opcode() != spv::Op::OpCompositeExtract) { |
1772 | 133k | return false; |
1773 | 133k | } |
1774 | | |
1775 | 13.4k | if (!HaveSameIndexesExceptForLast(element_inst, first_element_inst)) { |
1776 | 0 | return false; |
1777 | 0 | } |
1778 | | |
1779 | 13.4k | if (element_inst->GetSingleWordInOperand(element_inst->NumInOperands() - |
1780 | 13.4k | 1) != i) { |
1781 | 1.73k | return false; |
1782 | 1.73k | } |
1783 | | |
1784 | 11.6k | if (i == 0) { |
1785 | 4.08k | original_id = |
1786 | 4.08k | element_inst->GetSingleWordInOperand(kExtractCompositeIdInIdx); |
1787 | 7.60k | } else if (original_id != |
1788 | 7.60k | element_inst->GetSingleWordInOperand(kExtractCompositeIdInIdx)) { |
1789 | 248 | return false; |
1790 | 248 | } |
1791 | 11.6k | } |
1792 | | |
1793 | | // The last check it to see that the object being extracted from is the |
1794 | | // correct type. |
1795 | 162 | Instruction* original_inst = def_use_mgr->GetDef(original_id); |
1796 | 162 | uint32_t original_type_id = |
1797 | 162 | GetElementType(original_inst->type_id(), first_element_inst->begin() + 3, |
1798 | 162 | first_element_inst->end() - 1, def_use_mgr); |
1799 | | |
1800 | 162 | if (inst->type_id() != original_type_id) { |
1801 | 36 | return false; |
1802 | 36 | } |
1803 | | |
1804 | 126 | if (first_element_inst->NumInOperands() == 2) { |
1805 | | // Simplify by using the original object. |
1806 | 126 | inst->SetOpcode(spv::Op::OpCopyObject); |
1807 | 126 | inst->SetInOperands({{SPV_OPERAND_TYPE_ID, {original_id}}}); |
1808 | 126 | return true; |
1809 | 126 | } |
1810 | | |
1811 | | // Copies the original id and all indexes except for the last to the new |
1812 | | // extract instruction. |
1813 | 0 | inst->SetOpcode(spv::Op::OpCompositeExtract); |
1814 | 0 | inst->SetInOperands(std::vector<Operand>(first_element_inst->begin() + 2, |
1815 | 0 | first_element_inst->end() - 1)); |
1816 | 0 | return true; |
1817 | 126 | } |
1818 | | |
1819 | 8.16k | FoldingRule InsertFeedingExtract() { |
1820 | 8.16k | return [](IRContext* context, Instruction* inst, |
1821 | 415k | const std::vector<const analysis::Constant*>&) { |
1822 | 415k | assert(inst->opcode() == spv::Op::OpCompositeExtract && |
1823 | 415k | "Wrong opcode. Should be OpCompositeExtract."); |
1824 | 415k | analysis::DefUseManager* def_use_mgr = context->get_def_use_mgr(); |
1825 | 415k | uint32_t cid = inst->GetSingleWordInOperand(kExtractCompositeIdInIdx); |
1826 | 415k | Instruction* cinst = def_use_mgr->GetDef(cid); |
1827 | | |
1828 | 415k | if (cinst->opcode() != spv::Op::OpCompositeInsert) { |
1829 | 150k | return false; |
1830 | 150k | } |
1831 | | |
1832 | | // Find the first position where the list of insert and extract indicies |
1833 | | // differ, if at all. |
1834 | 265k | uint32_t i; |
1835 | 415k | for (i = 1; i < inst->NumInOperands(); ++i) { |
1836 | 265k | if (i + 1 >= cinst->NumInOperands()) { |
1837 | 0 | break; |
1838 | 0 | } |
1839 | | |
1840 | 265k | if (inst->GetSingleWordInOperand(i) != |
1841 | 265k | cinst->GetSingleWordInOperand(i + 1)) { |
1842 | 114k | break; |
1843 | 114k | } |
1844 | 265k | } |
1845 | | |
1846 | | // We are extracting the element that was inserted. |
1847 | 265k | if (i == inst->NumInOperands() && i + 1 == cinst->NumInOperands()) { |
1848 | 150k | inst->SetOpcode(spv::Op::OpCopyObject); |
1849 | 150k | inst->SetInOperands( |
1850 | 150k | {{SPV_OPERAND_TYPE_ID, |
1851 | 150k | {cinst->GetSingleWordInOperand(kInsertObjectIdInIdx)}}}); |
1852 | 150k | return true; |
1853 | 150k | } |
1854 | | |
1855 | | // Extracting the value that was inserted along with values for the base |
1856 | | // composite. Cannot do anything. |
1857 | 114k | if (i == inst->NumInOperands()) { |
1858 | 0 | return false; |
1859 | 0 | } |
1860 | | |
1861 | | // Extracting an element of the value that was inserted. Extract from |
1862 | | // that value directly. |
1863 | 114k | if (i + 1 == cinst->NumInOperands()) { |
1864 | 0 | std::vector<Operand> operands; |
1865 | 0 | operands.push_back( |
1866 | 0 | {SPV_OPERAND_TYPE_ID, |
1867 | 0 | {cinst->GetSingleWordInOperand(kInsertObjectIdInIdx)}}); |
1868 | 0 | for (; i < inst->NumInOperands(); ++i) { |
1869 | 0 | operands.push_back({SPV_OPERAND_TYPE_LITERAL_INTEGER, |
1870 | 0 | {inst->GetSingleWordInOperand(i)}}); |
1871 | 0 | } |
1872 | 0 | inst->SetInOperands(std::move(operands)); |
1873 | 0 | return true; |
1874 | 0 | } |
1875 | | |
1876 | | // Extracting a value that is disjoint from the element being inserted. |
1877 | | // Rewrite the extract to use the composite input to the insert. |
1878 | 114k | std::vector<Operand> operands; |
1879 | 114k | operands.push_back( |
1880 | 114k | {SPV_OPERAND_TYPE_ID, |
1881 | 114k | {cinst->GetSingleWordInOperand(kInsertCompositeIdInIdx)}}); |
1882 | 229k | for (i = 1; i < inst->NumInOperands(); ++i) { |
1883 | 114k | operands.push_back({SPV_OPERAND_TYPE_LITERAL_INTEGER, |
1884 | 114k | {inst->GetSingleWordInOperand(i)}}); |
1885 | 114k | } |
1886 | 114k | inst->SetInOperands(std::move(operands)); |
1887 | 114k | return true; |
1888 | 114k | }; |
1889 | 8.16k | } |
1890 | | |
1891 | | // When a VectorShuffle is feeding an Extract, we can extract from one of the |
1892 | | // operands of the VectorShuffle. We just need to adjust the index in the |
1893 | | // extract instruction. |
1894 | 8.16k | FoldingRule VectorShuffleFeedingExtract() { |
1895 | 8.16k | return [](IRContext* context, Instruction* inst, |
1896 | 123k | const std::vector<const analysis::Constant*>&) { |
1897 | 123k | assert(inst->opcode() == spv::Op::OpCompositeExtract && |
1898 | 123k | "Wrong opcode. Should be OpCompositeExtract."); |
1899 | 123k | analysis::DefUseManager* def_use_mgr = context->get_def_use_mgr(); |
1900 | 123k | analysis::TypeManager* type_mgr = context->get_type_mgr(); |
1901 | 123k | uint32_t cid = inst->GetSingleWordInOperand(kExtractCompositeIdInIdx); |
1902 | 123k | Instruction* cinst = def_use_mgr->GetDef(cid); |
1903 | | |
1904 | 123k | if (cinst->opcode() != spv::Op::OpVectorShuffle) { |
1905 | 121k | return false; |
1906 | 121k | } |
1907 | | |
1908 | | // Find the size of the first vector operand of the VectorShuffle |
1909 | 1.01k | Instruction* first_input = |
1910 | 1.01k | def_use_mgr->GetDef(cinst->GetSingleWordInOperand(0)); |
1911 | 1.01k | analysis::Type* first_input_type = |
1912 | 1.01k | type_mgr->GetType(first_input->type_id()); |
1913 | 1.01k | assert(first_input_type->AsVector() && |
1914 | 1.01k | "Input to vector shuffle should be vectors."); |
1915 | 1.01k | uint32_t first_input_size = first_input_type->AsVector()->element_count(); |
1916 | | |
1917 | | // Get index of the element the vector shuffle is placing in the position |
1918 | | // being extracted. |
1919 | 1.01k | uint32_t new_index = |
1920 | 1.01k | cinst->GetSingleWordInOperand(2 + inst->GetSingleWordInOperand(1)); |
1921 | | |
1922 | | // Extracting an undefined value so fold this extract into an undef. |
1923 | 1.01k | const uint32_t undef_literal_value = 0xffffffff; |
1924 | 1.01k | if (new_index == undef_literal_value) { |
1925 | 8 | inst->SetOpcode(spv::Op::OpUndef); |
1926 | 8 | inst->SetInOperands({}); |
1927 | 8 | return true; |
1928 | 8 | } |
1929 | | |
1930 | | // Get the id of the of the vector the elemtent comes from, and update the |
1931 | | // index if needed. |
1932 | 1.00k | uint32_t new_vector = 0; |
1933 | 1.00k | if (new_index < first_input_size) { |
1934 | 700 | new_vector = cinst->GetSingleWordInOperand(0); |
1935 | 700 | } else { |
1936 | 308 | new_vector = cinst->GetSingleWordInOperand(1); |
1937 | 308 | new_index -= first_input_size; |
1938 | 308 | } |
1939 | | |
1940 | | // Update the extract instruction. |
1941 | 1.00k | inst->SetInOperand(kExtractCompositeIdInIdx, {new_vector}); |
1942 | 1.00k | inst->SetInOperand(1, {new_index}); |
1943 | 1.00k | return true; |
1944 | 1.01k | }; |
1945 | 8.16k | } |
1946 | | |
1947 | | // When an FMix with is feeding an Extract that extracts an element whose |
1948 | | // corresponding |a| in the FMix is 0 or 1, we can extract from one of the |
1949 | | // operands of the FMix. |
1950 | 8.16k | FoldingRule FMixFeedingExtract() { |
1951 | 8.16k | return [](IRContext* context, Instruction* inst, |
1952 | 121k | const std::vector<const analysis::Constant*>&) { |
1953 | 121k | assert(inst->opcode() == spv::Op::OpCompositeExtract && |
1954 | 121k | "Wrong opcode. Should be OpCompositeExtract."); |
1955 | 121k | analysis::DefUseManager* def_use_mgr = context->get_def_use_mgr(); |
1956 | 121k | analysis::ConstantManager* const_mgr = context->get_constant_mgr(); |
1957 | | |
1958 | 121k | uint32_t composite_id = |
1959 | 121k | inst->GetSingleWordInOperand(kExtractCompositeIdInIdx); |
1960 | 121k | Instruction* composite_inst = def_use_mgr->GetDef(composite_id); |
1961 | | |
1962 | 121k | if (composite_inst->opcode() != spv::Op::OpExtInst) { |
1963 | 110k | return false; |
1964 | 110k | } |
1965 | | |
1966 | 11.2k | uint32_t inst_set_id = |
1967 | 11.2k | context->get_feature_mgr()->GetExtInstImportId_GLSLstd450(); |
1968 | | |
1969 | 11.2k | if (composite_inst->GetSingleWordInOperand(kExtInstSetIdInIdx) != |
1970 | 11.2k | inst_set_id || |
1971 | 11.2k | composite_inst->GetSingleWordInOperand(kExtInstInstructionInIdx) != |
1972 | 11.2k | GLSLstd450FMix) { |
1973 | 10.0k | return false; |
1974 | 10.0k | } |
1975 | | |
1976 | | // Get the |a| for the FMix instruction. |
1977 | 1.11k | uint32_t a_id = composite_inst->GetSingleWordInOperand(kFMixAIdInIdx); |
1978 | 1.11k | std::unique_ptr<Instruction> a(inst->Clone(context)); |
1979 | 1.11k | a->SetInOperand(kExtractCompositeIdInIdx, {a_id}); |
1980 | 1.11k | context->get_instruction_folder().FoldInstruction(a.get()); |
1981 | | |
1982 | 1.11k | if (a->opcode() != spv::Op::OpCopyObject) { |
1983 | 0 | return false; |
1984 | 0 | } |
1985 | | |
1986 | 1.11k | const analysis::Constant* a_const = |
1987 | 1.11k | const_mgr->FindDeclaredConstant(a->GetSingleWordInOperand(0)); |
1988 | | |
1989 | 1.11k | if (!a_const) { |
1990 | 805 | return false; |
1991 | 805 | } |
1992 | | |
1993 | 314 | bool use_x = false; |
1994 | | |
1995 | 314 | assert(a_const->type()->AsFloat()); |
1996 | 314 | double element_value = a_const->GetValueAsDouble(); |
1997 | 314 | if (element_value == 0.0) { |
1998 | 0 | use_x = true; |
1999 | 314 | } else if (element_value == 1.0) { |
2000 | 0 | use_x = false; |
2001 | 314 | } else { |
2002 | 314 | return false; |
2003 | 314 | } |
2004 | | |
2005 | | // Get the id of the of the vector the element comes from. |
2006 | 0 | uint32_t new_vector = 0; |
2007 | 0 | if (use_x) { |
2008 | 0 | new_vector = composite_inst->GetSingleWordInOperand(kFMixXIdInIdx); |
2009 | 0 | } else { |
2010 | 0 | new_vector = composite_inst->GetSingleWordInOperand(kFMixYIdInIdx); |
2011 | 0 | } |
2012 | | |
2013 | | // Update the extract instruction. |
2014 | 0 | inst->SetInOperand(kExtractCompositeIdInIdx, {new_vector}); |
2015 | 0 | return true; |
2016 | 314 | }; |
2017 | 8.16k | } |
2018 | | |
2019 | | // Returns the number of elements in the composite type |type|. Returns 0 if |
2020 | | // |type| is a scalar value. Return UINT32_MAX when the size is unknown at |
2021 | | // compile time. |
2022 | 186k | uint32_t GetNumberOfElements(const analysis::Type* type) { |
2023 | 186k | if (auto* vector_type = type->AsVector()) { |
2024 | 40.7k | return vector_type->element_count(); |
2025 | 40.7k | } |
2026 | 145k | if (auto* matrix_type = type->AsMatrix()) { |
2027 | 0 | return matrix_type->element_count(); |
2028 | 0 | } |
2029 | 145k | if (auto* struct_type = type->AsStruct()) { |
2030 | 750 | return static_cast<uint32_t>(struct_type->element_types().size()); |
2031 | 750 | } |
2032 | 145k | if (auto* array_type = type->AsArray()) { |
2033 | 145k | if (array_type->length_info().words[0] == |
2034 | 145k | analysis::Array::LengthInfo::kConstant && |
2035 | 145k | array_type->length_info().words.size() == 2) { |
2036 | 145k | return array_type->length_info().words[1]; |
2037 | 145k | } |
2038 | 0 | return UINT32_MAX; |
2039 | 145k | } |
2040 | 0 | return 0; |
2041 | 145k | } |
2042 | | |
2043 | | // Returns a map with the set of values that were inserted into an object by |
2044 | | // the chain of OpCompositeInsertInstruction starting with |inst|. |
2045 | | // The map will map the index to the value inserted at that index. An empty map |
2046 | | // will be returned if the map could not be properly generated. |
2047 | 186k | std::map<uint32_t, uint32_t> GetInsertedValues(Instruction* inst) { |
2048 | 186k | analysis::DefUseManager* def_use_mgr = inst->context()->get_def_use_mgr(); |
2049 | 186k | std::map<uint32_t, uint32_t> values_inserted; |
2050 | 186k | Instruction* current_inst = inst; |
2051 | 104M | while (current_inst->opcode() == spv::Op::OpCompositeInsert) { |
2052 | 104M | if (current_inst->NumInOperands() > inst->NumInOperands()) { |
2053 | | // This is to catch the case |
2054 | | // %2 = OpCompositeInsert %m2x2int %v2int_1_0 %m2x2int_undef 0 |
2055 | | // %3 = OpCompositeInsert %m2x2int %int_4 %2 0 0 |
2056 | | // %4 = OpCompositeInsert %m2x2int %v2int_2_3 %3 1 |
2057 | | // In this case we cannot do a single construct to get the matrix. |
2058 | 24 | uint32_t partially_inserted_element_index = |
2059 | 24 | current_inst->GetSingleWordInOperand(inst->NumInOperands() - 1); |
2060 | 24 | if (values_inserted.count(partially_inserted_element_index) == 0) |
2061 | 24 | return {}; |
2062 | 24 | } |
2063 | 104M | if (HaveSameIndexesExceptForLast(inst, current_inst)) { |
2064 | 104M | values_inserted.insert( |
2065 | 104M | {current_inst->GetSingleWordInOperand(current_inst->NumInOperands() - |
2066 | 104M | 1), |
2067 | 104M | current_inst->GetSingleWordInOperand(kInsertObjectIdInIdx)}); |
2068 | 104M | } |
2069 | 104M | current_inst = def_use_mgr->GetDef( |
2070 | 104M | current_inst->GetSingleWordInOperand(kInsertCompositeIdInIdx)); |
2071 | 104M | } |
2072 | 186k | return values_inserted; |
2073 | 186k | } |
2074 | | |
2075 | | // Returns true of there is an entry in |values_inserted| for every element of |
2076 | | // |Type|. |
2077 | | bool DoInsertedValuesCoverEntireObject( |
2078 | 186k | const analysis::Type* type, std::map<uint32_t, uint32_t>& values_inserted) { |
2079 | 186k | uint32_t container_size = GetNumberOfElements(type); |
2080 | 186k | if (container_size != values_inserted.size()) { |
2081 | 183k | return false; |
2082 | 183k | } |
2083 | | |
2084 | 3.44k | if (values_inserted.rbegin()->first >= container_size) { |
2085 | 0 | return false; |
2086 | 0 | } |
2087 | 3.44k | return true; |
2088 | 3.44k | } |
2089 | | |
2090 | | // Returns id of the type of the element that immediately contains the element |
2091 | | // being inserted by the OpCompositeInsert instruction |inst|. Returns 0 if it |
2092 | | // could not be found. |
2093 | 186k | uint32_t GetContainerTypeId(Instruction* inst) { |
2094 | 186k | assert(inst->opcode() == spv::Op::OpCompositeInsert); |
2095 | 186k | analysis::DefUseManager* def_use_manager = inst->context()->get_def_use_mgr(); |
2096 | 186k | uint32_t container_type_id = GetElementType( |
2097 | 186k | inst->type_id(), inst->begin() + 4, inst->end() - 1, def_use_manager); |
2098 | 186k | return container_type_id; |
2099 | 186k | } |
2100 | | |
2101 | | // Returns an OpCompositeConstruct instruction that build an object with |
2102 | | // |type_id| out of the values in |values_inserted|. Each value will be |
2103 | | // placed at the index corresponding to the value. The new instruction will |
2104 | | // be placed before |insert_before|. |
2105 | | Instruction* BuildCompositeConstruct( |
2106 | | uint32_t type_id, const std::map<uint32_t, uint32_t>& values_inserted, |
2107 | 3.44k | Instruction* insert_before) { |
2108 | 3.44k | InstructionBuilder ir_builder( |
2109 | 3.44k | insert_before->context(), insert_before, |
2110 | 3.44k | IRContext::kAnalysisDefUse | IRContext::kAnalysisInstrToBlockMapping); |
2111 | | |
2112 | 3.44k | std::vector<uint32_t> ids_in_order; |
2113 | 6.89k | for (auto it : values_inserted) { |
2114 | 6.89k | ids_in_order.push_back(it.second); |
2115 | 6.89k | } |
2116 | 3.44k | Instruction* construct = |
2117 | 3.44k | ir_builder.AddCompositeConstruct(type_id, ids_in_order); |
2118 | 3.44k | return construct; |
2119 | 3.44k | } |
2120 | | |
2121 | | // Replaces the OpCompositeInsert |inst| that inserts |construct| into the same |
2122 | | // object as |inst| with final index removed. If the resulting |
2123 | | // OpCompositeInsert instruction would have no remaining indexes, the |
2124 | | // instruction is replaced with an OpCopyObject instead. |
2125 | 3.44k | void InsertConstructedObject(Instruction* inst, const Instruction* construct) { |
2126 | 3.44k | if (inst->NumInOperands() == 3) { |
2127 | 3.44k | inst->SetOpcode(spv::Op::OpCopyObject); |
2128 | 3.44k | inst->SetInOperands({{SPV_OPERAND_TYPE_ID, {construct->result_id()}}}); |
2129 | 3.44k | } else { |
2130 | 0 | inst->SetInOperand(kInsertObjectIdInIdx, {construct->result_id()}); |
2131 | 0 | inst->RemoveOperand(inst->NumOperands() - 1); |
2132 | 0 | } |
2133 | 3.44k | } |
2134 | | |
2135 | | // Replaces a series of |OpCompositeInsert| instruction that cover the entire |
2136 | | // object with an |OpCompositeConstruct|. |
2137 | | bool CompositeInsertToCompositeConstruct( |
2138 | | IRContext* context, Instruction* inst, |
2139 | 186k | const std::vector<const analysis::Constant*>&) { |
2140 | 186k | assert(inst->opcode() == spv::Op::OpCompositeInsert && |
2141 | 186k | "Wrong opcode. Should be OpCompositeInsert."); |
2142 | 186k | if (inst->NumInOperands() < 3) return false; |
2143 | | |
2144 | 186k | std::map<uint32_t, uint32_t> values_inserted = GetInsertedValues(inst); |
2145 | 186k | uint32_t container_type_id = GetContainerTypeId(inst); |
2146 | 186k | if (container_type_id == 0) { |
2147 | 0 | return false; |
2148 | 0 | } |
2149 | | |
2150 | 186k | analysis::TypeManager* type_mgr = context->get_type_mgr(); |
2151 | 186k | const analysis::Type* container_type = type_mgr->GetType(container_type_id); |
2152 | 186k | assert(container_type && "GetContainerTypeId returned a bad id."); |
2153 | 186k | if (!DoInsertedValuesCoverEntireObject(container_type, values_inserted)) { |
2154 | 183k | return false; |
2155 | 183k | } |
2156 | | |
2157 | 3.44k | Instruction* construct = |
2158 | 3.44k | BuildCompositeConstruct(container_type_id, values_inserted, inst); |
2159 | 3.44k | InsertConstructedObject(inst, construct); |
2160 | 3.44k | return true; |
2161 | 186k | } |
2162 | | |
2163 | 8.16k | FoldingRule RedundantPhi() { |
2164 | | // An OpPhi instruction where all values are the same or the result of the phi |
2165 | | // itself, can be replaced by the value itself. |
2166 | 8.16k | return [](IRContext*, Instruction* inst, |
2167 | 223k | const std::vector<const analysis::Constant*>&) { |
2168 | 223k | assert(inst->opcode() == spv::Op::OpPhi && |
2169 | 223k | "Wrong opcode. Should be OpPhi."); |
2170 | | |
2171 | 223k | uint32_t incoming_value = 0; |
2172 | | |
2173 | 526k | for (uint32_t i = 0; i < inst->NumInOperands(); i += 2) { |
2174 | 451k | uint32_t op_id = inst->GetSingleWordInOperand(i); |
2175 | 451k | if (op_id == inst->result_id()) { |
2176 | 64.1k | continue; |
2177 | 64.1k | } |
2178 | | |
2179 | 387k | if (incoming_value == 0) { |
2180 | 223k | incoming_value = op_id; |
2181 | 223k | } else if (op_id != incoming_value) { |
2182 | | // Found two possible value. Can't simplify. |
2183 | 148k | return false; |
2184 | 148k | } |
2185 | 387k | } |
2186 | | |
2187 | 74.7k | if (incoming_value == 0) { |
2188 | | // Code looks invalid. Don't do anything. |
2189 | 0 | return false; |
2190 | 0 | } |
2191 | | |
2192 | | // We have a single incoming value. Simplify using that value. |
2193 | 74.7k | inst->SetOpcode(spv::Op::OpCopyObject); |
2194 | 74.7k | inst->SetInOperands({{SPV_OPERAND_TYPE_ID, {incoming_value}}}); |
2195 | 74.7k | return true; |
2196 | 74.7k | }; |
2197 | 8.16k | } |
2198 | | |
2199 | 8.16k | FoldingRule BitCastScalarOrVector() { |
2200 | 8.16k | return [](IRContext* context, Instruction* inst, |
2201 | 8.16k | const std::vector<const analysis::Constant*>& constants) { |
2202 | 5.16k | assert(inst->opcode() == spv::Op::OpBitcast && constants.size() == 1); |
2203 | 5.16k | if (constants[0] == nullptr) return false; |
2204 | | |
2205 | 4.51k | const analysis::Type* type = |
2206 | 4.51k | context->get_type_mgr()->GetType(inst->type_id()); |
2207 | 4.51k | if (HasFloatingPoint(type) && !inst->IsFloatingPointFoldingAllowed()) |
2208 | 0 | return false; |
2209 | | |
2210 | 4.51k | analysis::ConstantManager* const_mgr = context->get_constant_mgr(); |
2211 | 4.51k | std::vector<uint32_t> words = |
2212 | 4.51k | GetWordsFromNumericScalarOrVectorConstant(const_mgr, constants[0]); |
2213 | 4.51k | if (words.size() == 0) return false; |
2214 | | |
2215 | 4.51k | const analysis::Constant* bitcasted_constant = |
2216 | 4.51k | ConvertWordsToNumericScalarOrVectorConstant(const_mgr, words, type); |
2217 | 4.51k | if (!bitcasted_constant) return false; |
2218 | | |
2219 | 1.26k | auto new_feeder_id = |
2220 | 1.26k | const_mgr->GetDefiningInstruction(bitcasted_constant, inst->type_id()) |
2221 | 1.26k | ->result_id(); |
2222 | 1.26k | inst->SetOpcode(spv::Op::OpCopyObject); |
2223 | 1.26k | inst->SetInOperands({{SPV_OPERAND_TYPE_ID, {new_feeder_id}}}); |
2224 | 1.26k | return true; |
2225 | 4.51k | }; |
2226 | 8.16k | } |
2227 | | |
2228 | 8.16k | FoldingRule RedundantSelect() { |
2229 | | // An OpSelect instruction where both values are the same or the condition is |
2230 | | // constant can be replaced by one of the values |
2231 | 8.16k | return [](IRContext*, Instruction* inst, |
2232 | 8.75k | const std::vector<const analysis::Constant*>& constants) { |
2233 | 8.75k | assert(inst->opcode() == spv::Op::OpSelect && |
2234 | 8.75k | "Wrong opcode. Should be OpSelect."); |
2235 | 8.75k | assert(inst->NumInOperands() == 3); |
2236 | 8.75k | assert(constants.size() == 3); |
2237 | | |
2238 | 8.75k | uint32_t true_id = inst->GetSingleWordInOperand(1); |
2239 | 8.75k | uint32_t false_id = inst->GetSingleWordInOperand(2); |
2240 | | |
2241 | 8.75k | if (true_id == false_id) { |
2242 | | // Both results are the same, condition doesn't matter |
2243 | 23 | inst->SetOpcode(spv::Op::OpCopyObject); |
2244 | 23 | inst->SetInOperands({{SPV_OPERAND_TYPE_ID, {true_id}}}); |
2245 | 23 | return true; |
2246 | 8.73k | } else if (constants[0]) { |
2247 | 756 | const analysis::Type* type = constants[0]->type(); |
2248 | 756 | if (type->AsBool()) { |
2249 | | // Scalar constant value, select the corresponding value. |
2250 | 701 | inst->SetOpcode(spv::Op::OpCopyObject); |
2251 | 701 | if (constants[0]->AsNullConstant() || |
2252 | 701 | !constants[0]->AsBoolConstant()->value()) { |
2253 | 657 | inst->SetInOperands({{SPV_OPERAND_TYPE_ID, {false_id}}}); |
2254 | 657 | } else { |
2255 | 44 | inst->SetInOperands({{SPV_OPERAND_TYPE_ID, {true_id}}}); |
2256 | 44 | } |
2257 | 701 | return true; |
2258 | 701 | } else { |
2259 | 55 | assert(type->AsVector()); |
2260 | 55 | if (constants[0]->AsNullConstant()) { |
2261 | | // All values come from false id. |
2262 | 0 | inst->SetOpcode(spv::Op::OpCopyObject); |
2263 | 0 | inst->SetInOperands({{SPV_OPERAND_TYPE_ID, {false_id}}}); |
2264 | 0 | return true; |
2265 | 55 | } else { |
2266 | | // Convert to a vector shuffle. |
2267 | 55 | std::vector<Operand> ops; |
2268 | 55 | ops.push_back({SPV_OPERAND_TYPE_ID, {true_id}}); |
2269 | 55 | ops.push_back({SPV_OPERAND_TYPE_ID, {false_id}}); |
2270 | 55 | const analysis::VectorConstant* vector_const = |
2271 | 55 | constants[0]->AsVectorConstant(); |
2272 | 55 | uint32_t size = |
2273 | 55 | static_cast<uint32_t>(vector_const->GetComponents().size()); |
2274 | 165 | for (uint32_t i = 0; i != size; ++i) { |
2275 | 110 | const analysis::Constant* component = |
2276 | 110 | vector_const->GetComponents()[i]; |
2277 | 110 | if (component->AsNullConstant() || |
2278 | 110 | !component->AsBoolConstant()->value()) { |
2279 | | // Selecting from the false vector which is the second input |
2280 | | // vector to the shuffle. Offset the index by |size|. |
2281 | 0 | ops.push_back({SPV_OPERAND_TYPE_LITERAL_INTEGER, {i + size}}); |
2282 | 110 | } else { |
2283 | | // Selecting from true vector which is the first input vector to |
2284 | | // the shuffle. |
2285 | 110 | ops.push_back({SPV_OPERAND_TYPE_LITERAL_INTEGER, {i}}); |
2286 | 110 | } |
2287 | 110 | } |
2288 | | |
2289 | 55 | inst->SetOpcode(spv::Op::OpVectorShuffle); |
2290 | 55 | inst->SetInOperands(std::move(ops)); |
2291 | 55 | return true; |
2292 | 55 | } |
2293 | 55 | } |
2294 | 756 | } |
2295 | | |
2296 | 7.97k | return false; |
2297 | 8.75k | }; |
2298 | 8.16k | } |
2299 | | |
2300 | | enum class FloatConstantKind { Unknown, Zero, One }; |
2301 | | |
2302 | 5.84M | FloatConstantKind getFloatConstantKind(const analysis::Constant* constant) { |
2303 | 5.84M | if (constant == nullptr) { |
2304 | 3.83M | return FloatConstantKind::Unknown; |
2305 | 3.83M | } |
2306 | | |
2307 | 2.01M | assert(HasFloatingPoint(constant->type()) && "Unexpected constant type"); |
2308 | | |
2309 | 2.01M | if (constant->AsNullConstant()) { |
2310 | 928 | return FloatConstantKind::Zero; |
2311 | 2.01M | } else if (const analysis::VectorConstant* vc = |
2312 | 2.01M | constant->AsVectorConstant()) { |
2313 | 622k | const std::vector<const analysis::Constant*>& components = |
2314 | 622k | vc->GetComponents(); |
2315 | 622k | assert(!components.empty()); |
2316 | | |
2317 | 622k | FloatConstantKind kind = getFloatConstantKind(components[0]); |
2318 | | |
2319 | 1.19M | for (size_t i = 1; i < components.size(); ++i) { |
2320 | 722k | if (getFloatConstantKind(components[i]) != kind) { |
2321 | 151k | return FloatConstantKind::Unknown; |
2322 | 151k | } |
2323 | 722k | } |
2324 | | |
2325 | 470k | return kind; |
2326 | 1.39M | } else if (const analysis::FloatConstant* fc = constant->AsFloatConstant()) { |
2327 | 1.39M | if (fc->IsZero()) return FloatConstantKind::Zero; |
2328 | | |
2329 | 1.17M | uint32_t width = fc->type()->AsFloat()->width(); |
2330 | 1.17M | if (width != 32 && width != 64) return FloatConstantKind::Unknown; |
2331 | | |
2332 | 1.17M | double value = (width == 64) ? fc->GetDoubleValue() : fc->GetFloatValue(); |
2333 | | |
2334 | 1.17M | if (value == 0.0) { |
2335 | 15.8k | return FloatConstantKind::Zero; |
2336 | 1.15M | } else if (value == 1.0) { |
2337 | 9.08k | return FloatConstantKind::One; |
2338 | 1.14M | } else { |
2339 | 1.14M | return FloatConstantKind::Unknown; |
2340 | 1.14M | } |
2341 | 1.17M | } else { |
2342 | 0 | return FloatConstantKind::Unknown; |
2343 | 0 | } |
2344 | 2.01M | } |
2345 | | |
2346 | 8.16k | FoldingRule RedundantFAdd() { |
2347 | 8.16k | return [](IRContext*, Instruction* inst, |
2348 | 1.78M | const std::vector<const analysis::Constant*>& constants) { |
2349 | 1.78M | assert(inst->opcode() == spv::Op::OpFAdd && |
2350 | 1.78M | "Wrong opcode. Should be OpFAdd."); |
2351 | 1.78M | assert(constants.size() == 2); |
2352 | | |
2353 | 1.78M | if (!inst->IsFloatingPointFoldingAllowed()) { |
2354 | 45 | return false; |
2355 | 45 | } |
2356 | | |
2357 | 1.78M | FloatConstantKind kind0 = getFloatConstantKind(constants[0]); |
2358 | 1.78M | FloatConstantKind kind1 = getFloatConstantKind(constants[1]); |
2359 | | |
2360 | 1.78M | if (kind0 == FloatConstantKind::Zero || kind1 == FloatConstantKind::Zero) { |
2361 | 11.2k | inst->SetOpcode(spv::Op::OpCopyObject); |
2362 | 11.2k | inst->SetInOperands({{SPV_OPERAND_TYPE_ID, |
2363 | 11.2k | {inst->GetSingleWordInOperand( |
2364 | 11.2k | kind0 == FloatConstantKind::Zero ? 1 : 0)}}}); |
2365 | 11.2k | return true; |
2366 | 11.2k | } |
2367 | | |
2368 | 1.77M | return false; |
2369 | 1.78M | }; |
2370 | 8.16k | } |
2371 | | |
2372 | 8.16k | FoldingRule RedundantFSub() { |
2373 | 8.16k | return [](IRContext*, Instruction* inst, |
2374 | 183k | const std::vector<const analysis::Constant*>& constants) { |
2375 | 183k | assert(inst->opcode() == spv::Op::OpFSub && |
2376 | 183k | "Wrong opcode. Should be OpFSub."); |
2377 | 183k | assert(constants.size() == 2); |
2378 | | |
2379 | 183k | if (!inst->IsFloatingPointFoldingAllowed()) { |
2380 | 11 | return false; |
2381 | 11 | } |
2382 | | |
2383 | 183k | FloatConstantKind kind0 = getFloatConstantKind(constants[0]); |
2384 | 183k | FloatConstantKind kind1 = getFloatConstantKind(constants[1]); |
2385 | | |
2386 | 183k | if (kind0 == FloatConstantKind::Zero) { |
2387 | 11.1k | inst->SetOpcode(spv::Op::OpFNegate); |
2388 | 11.1k | inst->SetInOperands( |
2389 | 11.1k | {{SPV_OPERAND_TYPE_ID, {inst->GetSingleWordInOperand(1)}}}); |
2390 | 11.1k | return true; |
2391 | 11.1k | } |
2392 | | |
2393 | 172k | if (kind1 == FloatConstantKind::Zero) { |
2394 | 586 | inst->SetOpcode(spv::Op::OpCopyObject); |
2395 | 586 | inst->SetInOperands( |
2396 | 586 | {{SPV_OPERAND_TYPE_ID, {inst->GetSingleWordInOperand(0)}}}); |
2397 | 586 | return true; |
2398 | 586 | } |
2399 | | |
2400 | 171k | return false; |
2401 | 172k | }; |
2402 | 8.16k | } |
2403 | | |
2404 | 8.16k | FoldingRule RedundantFMul() { |
2405 | 8.16k | return [](IRContext*, Instruction* inst, |
2406 | 173k | const std::vector<const analysis::Constant*>& constants) { |
2407 | 173k | assert(inst->opcode() == spv::Op::OpFMul && |
2408 | 173k | "Wrong opcode. Should be OpFMul."); |
2409 | 173k | assert(constants.size() == 2); |
2410 | | |
2411 | 173k | if (!inst->IsFloatingPointFoldingAllowed()) { |
2412 | 1 | return false; |
2413 | 1 | } |
2414 | | |
2415 | 173k | FloatConstantKind kind0 = getFloatConstantKind(constants[0]); |
2416 | 173k | FloatConstantKind kind1 = getFloatConstantKind(constants[1]); |
2417 | | |
2418 | 173k | if (kind0 == FloatConstantKind::Zero || kind1 == FloatConstantKind::Zero) { |
2419 | 1.29k | inst->SetOpcode(spv::Op::OpCopyObject); |
2420 | 1.29k | inst->SetInOperands({{SPV_OPERAND_TYPE_ID, |
2421 | 1.29k | {inst->GetSingleWordInOperand( |
2422 | 1.29k | kind0 == FloatConstantKind::Zero ? 0 : 1)}}}); |
2423 | 1.29k | return true; |
2424 | 1.29k | } |
2425 | | |
2426 | 172k | if (kind0 == FloatConstantKind::One || kind1 == FloatConstantKind::One) { |
2427 | 896 | inst->SetOpcode(spv::Op::OpCopyObject); |
2428 | 896 | inst->SetInOperands({{SPV_OPERAND_TYPE_ID, |
2429 | 896 | {inst->GetSingleWordInOperand( |
2430 | 896 | kind0 == FloatConstantKind::One ? 1 : 0)}}}); |
2431 | 896 | return true; |
2432 | 896 | } |
2433 | | |
2434 | 171k | return false; |
2435 | 172k | }; |
2436 | 8.16k | } |
2437 | | |
2438 | 8.16k | FoldingRule RedundantFDiv() { |
2439 | 8.16k | return [](IRContext*, Instruction* inst, |
2440 | 109k | const std::vector<const analysis::Constant*>& constants) { |
2441 | 109k | assert(inst->opcode() == spv::Op::OpFDiv && |
2442 | 109k | "Wrong opcode. Should be OpFDiv."); |
2443 | 109k | assert(constants.size() == 2); |
2444 | | |
2445 | 109k | if (!inst->IsFloatingPointFoldingAllowed()) { |
2446 | 52 | return false; |
2447 | 52 | } |
2448 | | |
2449 | 108k | FloatConstantKind kind0 = getFloatConstantKind(constants[0]); |
2450 | 108k | FloatConstantKind kind1 = getFloatConstantKind(constants[1]); |
2451 | | |
2452 | 108k | if (kind0 == FloatConstantKind::Zero) { |
2453 | 37 | inst->SetOpcode(spv::Op::OpCopyObject); |
2454 | 37 | inst->SetInOperands( |
2455 | 37 | {{SPV_OPERAND_TYPE_ID, {inst->GetSingleWordInOperand(0)}}}); |
2456 | 37 | return true; |
2457 | 37 | } |
2458 | | |
2459 | 108k | if (kind1 == FloatConstantKind::One) { |
2460 | 37 | inst->SetOpcode(spv::Op::OpCopyObject); |
2461 | 37 | inst->SetInOperands( |
2462 | 37 | {{SPV_OPERAND_TYPE_ID, {inst->GetSingleWordInOperand(0)}}}); |
2463 | 37 | return true; |
2464 | 37 | } |
2465 | | |
2466 | 108k | return false; |
2467 | 108k | }; |
2468 | 8.16k | } |
2469 | | |
2470 | 4.69k | FoldingRule RedundantFMix() { |
2471 | 4.69k | return [](IRContext* context, Instruction* inst, |
2472 | 4.94k | const std::vector<const analysis::Constant*>& constants) { |
2473 | 4.94k | assert(inst->opcode() == spv::Op::OpExtInst && |
2474 | 4.94k | "Wrong opcode. Should be OpExtInst."); |
2475 | | |
2476 | 4.94k | if (!inst->IsFloatingPointFoldingAllowed()) { |
2477 | 0 | return false; |
2478 | 0 | } |
2479 | | |
2480 | 4.94k | uint32_t instSetId = |
2481 | 4.94k | context->get_feature_mgr()->GetExtInstImportId_GLSLstd450(); |
2482 | | |
2483 | 4.94k | if (inst->GetSingleWordInOperand(kExtInstSetIdInIdx) == instSetId && |
2484 | 4.94k | inst->GetSingleWordInOperand(kExtInstInstructionInIdx) == |
2485 | 4.94k | GLSLstd450FMix) { |
2486 | 4.94k | assert(constants.size() == 5); |
2487 | | |
2488 | 4.94k | FloatConstantKind kind4 = getFloatConstantKind(constants[4]); |
2489 | | |
2490 | 4.94k | if (kind4 == FloatConstantKind::Zero || kind4 == FloatConstantKind::One) { |
2491 | 1 | inst->SetOpcode(spv::Op::OpCopyObject); |
2492 | 1 | inst->SetInOperands( |
2493 | 1 | {{SPV_OPERAND_TYPE_ID, |
2494 | 1 | {inst->GetSingleWordInOperand(kind4 == FloatConstantKind::Zero |
2495 | 1 | ? kFMixXIdInIdx |
2496 | 1 | : kFMixYIdInIdx)}}}); |
2497 | 1 | return true; |
2498 | 1 | } |
2499 | 4.94k | } |
2500 | | |
2501 | 4.94k | return false; |
2502 | 4.94k | }; |
2503 | 4.69k | } |
2504 | | |
2505 | | // This rule handles addition of zero for integers. |
2506 | 8.16k | FoldingRule RedundantIAdd() { |
2507 | 8.16k | return [](IRContext* context, Instruction* inst, |
2508 | 47.9k | const std::vector<const analysis::Constant*>& constants) { |
2509 | 47.9k | assert(inst->opcode() == spv::Op::OpIAdd && |
2510 | 47.9k | "Wrong opcode. Should be OpIAdd."); |
2511 | | |
2512 | 47.9k | uint32_t operand = std::numeric_limits<uint32_t>::max(); |
2513 | 47.9k | const analysis::Type* operand_type = nullptr; |
2514 | 47.9k | if (constants[0] && constants[0]->IsZero()) { |
2515 | 271 | operand = inst->GetSingleWordInOperand(1); |
2516 | 271 | operand_type = constants[0]->type(); |
2517 | 47.6k | } else if (constants[1] && constants[1]->IsZero()) { |
2518 | 1.60k | operand = inst->GetSingleWordInOperand(0); |
2519 | 1.60k | operand_type = constants[1]->type(); |
2520 | 1.60k | } |
2521 | | |
2522 | 47.9k | if (operand != std::numeric_limits<uint32_t>::max()) { |
2523 | 1.87k | const analysis::Type* inst_type = |
2524 | 1.87k | context->get_type_mgr()->GetType(inst->type_id()); |
2525 | 1.87k | if (inst_type->IsSame(operand_type)) { |
2526 | 1.86k | inst->SetOpcode(spv::Op::OpCopyObject); |
2527 | 1.86k | } else { |
2528 | 15 | inst->SetOpcode(spv::Op::OpBitcast); |
2529 | 15 | } |
2530 | 1.87k | inst->SetInOperands({{SPV_OPERAND_TYPE_ID, {operand}}}); |
2531 | 1.87k | return true; |
2532 | 1.87k | } |
2533 | 46.0k | return false; |
2534 | 47.9k | }; |
2535 | 8.16k | } |
2536 | | |
2537 | | // This rule look for a dot with a constant vector containing a single 1 and |
2538 | | // the rest 0s. This is the same as doing an extract. |
2539 | 8.16k | FoldingRule DotProductDoingExtract() { |
2540 | 8.16k | return [](IRContext* context, Instruction* inst, |
2541 | 8.16k | const std::vector<const analysis::Constant*>& constants) { |
2542 | 73 | assert(inst->opcode() == spv::Op::OpDot && |
2543 | 73 | "Wrong opcode. Should be OpDot."); |
2544 | | |
2545 | 73 | analysis::ConstantManager* const_mgr = context->get_constant_mgr(); |
2546 | | |
2547 | 73 | if (!inst->IsFloatingPointFoldingAllowed()) { |
2548 | 0 | return false; |
2549 | 0 | } |
2550 | | |
2551 | 219 | for (int i = 0; i < 2; ++i) { |
2552 | 146 | if (!constants[i]) { |
2553 | 109 | continue; |
2554 | 109 | } |
2555 | | |
2556 | 37 | const analysis::Vector* vector_type = constants[i]->type()->AsVector(); |
2557 | 37 | assert(vector_type && "Inputs to OpDot must be vectors."); |
2558 | 37 | const analysis::Float* element_type = |
2559 | 37 | vector_type->element_type()->AsFloat(); |
2560 | 37 | assert(element_type && "Inputs to OpDot must be vectors of floats."); |
2561 | 37 | uint32_t element_width = element_type->width(); |
2562 | 37 | if (element_width != 32 && element_width != 64) { |
2563 | 0 | return false; |
2564 | 0 | } |
2565 | | |
2566 | 37 | std::vector<const analysis::Constant*> components; |
2567 | 37 | components = constants[i]->GetVectorComponents(const_mgr); |
2568 | | |
2569 | 37 | constexpr uint32_t kNotFound = std::numeric_limits<uint32_t>::max(); |
2570 | | |
2571 | 37 | uint32_t component_with_one = kNotFound; |
2572 | 37 | bool all_others_zero = true; |
2573 | 37 | for (uint32_t j = 0; j < components.size(); ++j) { |
2574 | 37 | const analysis::Constant* element = components[j]; |
2575 | 37 | double value = |
2576 | 37 | (element_width == 32 ? element->GetFloat() : element->GetDouble()); |
2577 | 37 | if (value == 0.0) { |
2578 | 0 | continue; |
2579 | 37 | } else if (value == 1.0) { |
2580 | 0 | if (component_with_one == kNotFound) { |
2581 | 0 | component_with_one = j; |
2582 | 0 | } else { |
2583 | 0 | component_with_one = kNotFound; |
2584 | 0 | break; |
2585 | 0 | } |
2586 | 37 | } else { |
2587 | 37 | all_others_zero = false; |
2588 | 37 | break; |
2589 | 37 | } |
2590 | 37 | } |
2591 | | |
2592 | 37 | if (!all_others_zero || component_with_one == kNotFound) { |
2593 | 37 | continue; |
2594 | 37 | } |
2595 | | |
2596 | 0 | std::vector<Operand> operands; |
2597 | 0 | operands.push_back( |
2598 | 0 | {SPV_OPERAND_TYPE_ID, {inst->GetSingleWordInOperand(1u - i)}}); |
2599 | 0 | operands.push_back( |
2600 | 0 | {SPV_OPERAND_TYPE_LITERAL_INTEGER, {component_with_one}}); |
2601 | |
|
2602 | 0 | inst->SetOpcode(spv::Op::OpCompositeExtract); |
2603 | 0 | inst->SetInOperands(std::move(operands)); |
2604 | 0 | return true; |
2605 | 37 | } |
2606 | 73 | return false; |
2607 | 73 | }; |
2608 | 8.16k | } |
2609 | | |
2610 | | // If we are storing an undef, then we can remove the store. |
2611 | | // |
2612 | | // TODO: We can do something similar for OpImageWrite, but checking for volatile |
2613 | | // is complicated. Waiting to see if it is needed. |
2614 | 8.16k | FoldingRule StoringUndef() { |
2615 | 8.16k | return [](IRContext* context, Instruction* inst, |
2616 | 2.64M | const std::vector<const analysis::Constant*>&) { |
2617 | 2.64M | assert(inst->opcode() == spv::Op::OpStore && |
2618 | 2.64M | "Wrong opcode. Should be OpStore."); |
2619 | | |
2620 | 2.64M | analysis::DefUseManager* def_use_mgr = context->get_def_use_mgr(); |
2621 | | |
2622 | | // If this is a volatile store, the store cannot be removed. |
2623 | 2.64M | if (inst->NumInOperands() == 3) { |
2624 | 5.89k | if (inst->GetSingleWordInOperand(2) & |
2625 | 5.89k | uint32_t(spv::MemoryAccessMask::Volatile)) { |
2626 | 4.12k | return false; |
2627 | 4.12k | } |
2628 | 5.89k | } |
2629 | | |
2630 | 2.63M | uint32_t object_id = inst->GetSingleWordInOperand(kStoreObjectInIdx); |
2631 | 2.63M | Instruction* object_inst = def_use_mgr->GetDef(object_id); |
2632 | 2.63M | if (object_inst->opcode() == spv::Op::OpUndef) { |
2633 | 16.0k | inst->ToNop(); |
2634 | 16.0k | return true; |
2635 | 16.0k | } |
2636 | 2.62M | return false; |
2637 | 2.63M | }; |
2638 | 8.16k | } |
2639 | | |
2640 | 8.16k | FoldingRule VectorShuffleFeedingShuffle() { |
2641 | 8.16k | return [](IRContext* context, Instruction* inst, |
2642 | 12.3k | const std::vector<const analysis::Constant*>&) { |
2643 | 12.3k | assert(inst->opcode() == spv::Op::OpVectorShuffle && |
2644 | 12.3k | "Wrong opcode. Should be OpVectorShuffle."); |
2645 | | |
2646 | 12.3k | analysis::DefUseManager* def_use_mgr = context->get_def_use_mgr(); |
2647 | 12.3k | analysis::TypeManager* type_mgr = context->get_type_mgr(); |
2648 | | |
2649 | 12.3k | Instruction* feeding_shuffle_inst = |
2650 | 12.3k | def_use_mgr->GetDef(inst->GetSingleWordInOperand(0)); |
2651 | 12.3k | analysis::Vector* op0_type = |
2652 | 12.3k | type_mgr->GetType(feeding_shuffle_inst->type_id())->AsVector(); |
2653 | 12.3k | uint32_t op0_length = op0_type->element_count(); |
2654 | | |
2655 | 12.3k | bool feeder_is_op0 = true; |
2656 | 12.3k | if (feeding_shuffle_inst->opcode() != spv::Op::OpVectorShuffle) { |
2657 | 12.3k | feeding_shuffle_inst = |
2658 | 12.3k | def_use_mgr->GetDef(inst->GetSingleWordInOperand(1)); |
2659 | 12.3k | feeder_is_op0 = false; |
2660 | 12.3k | } |
2661 | | |
2662 | 12.3k | if (feeding_shuffle_inst->opcode() != spv::Op::OpVectorShuffle) { |
2663 | 12.1k | return false; |
2664 | 12.1k | } |
2665 | | |
2666 | 241 | Instruction* feeder2 = |
2667 | 241 | def_use_mgr->GetDef(feeding_shuffle_inst->GetSingleWordInOperand(0)); |
2668 | 241 | analysis::Vector* feeder_op0_type = |
2669 | 241 | type_mgr->GetType(feeder2->type_id())->AsVector(); |
2670 | 241 | uint32_t feeder_op0_length = feeder_op0_type->element_count(); |
2671 | | |
2672 | 241 | uint32_t new_feeder_id = 0; |
2673 | 241 | std::vector<Operand> new_operands; |
2674 | 241 | new_operands.resize( |
2675 | 241 | 2, {SPV_OPERAND_TYPE_ID, {0}}); // Place holders for vector operands. |
2676 | 241 | const uint32_t undef_literal = 0xffffffff; |
2677 | 857 | for (uint32_t op = 2; op < inst->NumInOperands(); ++op) { |
2678 | 635 | uint32_t component_index = inst->GetSingleWordInOperand(op); |
2679 | | |
2680 | | // Do not interpret the undefined value literal as coming from operand 1. |
2681 | 635 | if (component_index != undef_literal && |
2682 | 635 | feeder_is_op0 == (component_index < op0_length)) { |
2683 | | // This component comes from the feeding_shuffle_inst. Update |
2684 | | // |component_index| to be the index into the operand of the feeder. |
2685 | | |
2686 | | // Adjust component_index to get the index into the operands of the |
2687 | | // feeding_shuffle_inst. |
2688 | 306 | if (component_index >= op0_length) { |
2689 | 172 | component_index -= op0_length; |
2690 | 172 | } |
2691 | 306 | component_index = |
2692 | 306 | feeding_shuffle_inst->GetSingleWordInOperand(component_index + 2); |
2693 | | |
2694 | | // Check if we are using a component from the first or second operand of |
2695 | | // the feeding instruction. |
2696 | 306 | if (component_index < feeder_op0_length) { |
2697 | 238 | if (new_feeder_id == 0) { |
2698 | | // First time through, save the id of the operand the element comes |
2699 | | // from. |
2700 | 148 | new_feeder_id = feeding_shuffle_inst->GetSingleWordInOperand(0); |
2701 | 148 | } else if (new_feeder_id != |
2702 | 90 | feeding_shuffle_inst->GetSingleWordInOperand(0)) { |
2703 | | // We need both elements of the feeding_shuffle_inst, so we cannot |
2704 | | // fold. |
2705 | 9 | return false; |
2706 | 9 | } |
2707 | 238 | } else if (component_index != undef_literal) { |
2708 | 53 | if (new_feeder_id == 0) { |
2709 | | // First time through, save the id of the operand the element comes |
2710 | | // from. |
2711 | 34 | new_feeder_id = feeding_shuffle_inst->GetSingleWordInOperand(1); |
2712 | 34 | } else if (new_feeder_id != |
2713 | 19 | feeding_shuffle_inst->GetSingleWordInOperand(1)) { |
2714 | | // We need both elements of the feeding_shuffle_inst, so we cannot |
2715 | | // fold. |
2716 | 10 | return false; |
2717 | 10 | } |
2718 | 43 | component_index -= feeder_op0_length; |
2719 | 43 | } |
2720 | | |
2721 | 287 | if (!feeder_is_op0 && component_index != undef_literal) { |
2722 | 170 | component_index += op0_length; |
2723 | 170 | } |
2724 | 287 | } |
2725 | 616 | new_operands.push_back( |
2726 | 616 | {SPV_OPERAND_TYPE_LITERAL_INTEGER, {component_index}}); |
2727 | 616 | } |
2728 | | |
2729 | 222 | if (new_feeder_id == 0) { |
2730 | 59 | analysis::ConstantManager* const_mgr = context->get_constant_mgr(); |
2731 | 59 | const analysis::Type* type = |
2732 | 59 | type_mgr->GetType(feeding_shuffle_inst->type_id()); |
2733 | 59 | const analysis::Constant* null_const = const_mgr->GetConstant(type, {}); |
2734 | 59 | new_feeder_id = |
2735 | 59 | const_mgr->GetDefiningInstruction(null_const, 0)->result_id(); |
2736 | 59 | } |
2737 | | |
2738 | 222 | if (feeder_is_op0) { |
2739 | | // If the size of the first vector operand changed then the indices |
2740 | | // referring to the second operand need to be adjusted. |
2741 | 64 | Instruction* new_feeder_inst = def_use_mgr->GetDef(new_feeder_id); |
2742 | 64 | analysis::Type* new_feeder_type = |
2743 | 64 | type_mgr->GetType(new_feeder_inst->type_id()); |
2744 | 64 | uint32_t new_op0_size = new_feeder_type->AsVector()->element_count(); |
2745 | 64 | int32_t adjustment = op0_length - new_op0_size; |
2746 | | |
2747 | 64 | if (adjustment != 0) { |
2748 | 144 | for (uint32_t i = 2; i < new_operands.size(); i++) { |
2749 | 101 | uint32_t operand = inst->GetSingleWordInOperand(i); |
2750 | 101 | if (operand >= op0_length && operand != undef_literal) { |
2751 | 29 | new_operands[i].words[0] -= adjustment; |
2752 | 29 | } |
2753 | 101 | } |
2754 | 43 | } |
2755 | | |
2756 | 64 | new_operands[0].words[0] = new_feeder_id; |
2757 | 64 | new_operands[1] = inst->GetInOperand(1); |
2758 | 158 | } else { |
2759 | 158 | new_operands[1].words[0] = new_feeder_id; |
2760 | 158 | new_operands[0] = inst->GetInOperand(0); |
2761 | 158 | } |
2762 | | |
2763 | 222 | inst->SetInOperands(std::move(new_operands)); |
2764 | 222 | return true; |
2765 | 241 | }; |
2766 | 8.16k | } |
2767 | | |
2768 | | // Removes duplicate ids from the interface list of an OpEntryPoint |
2769 | | // instruction. |
2770 | 8.16k | FoldingRule RemoveRedundantOperands() { |
2771 | 8.16k | return [](IRContext*, Instruction* inst, |
2772 | 8.16k | const std::vector<const analysis::Constant*>&) { |
2773 | 0 | assert(inst->opcode() == spv::Op::OpEntryPoint && |
2774 | 0 | "Wrong opcode. Should be OpEntryPoint."); |
2775 | 0 | bool has_redundant_operand = false; |
2776 | 0 | std::unordered_set<uint32_t> seen_operands; |
2777 | 0 | std::vector<Operand> new_operands; |
2778 | |
|
2779 | 0 | new_operands.emplace_back(inst->GetOperand(0)); |
2780 | 0 | new_operands.emplace_back(inst->GetOperand(1)); |
2781 | 0 | new_operands.emplace_back(inst->GetOperand(2)); |
2782 | 0 | for (uint32_t i = 3; i < inst->NumOperands(); ++i) { |
2783 | 0 | if (seen_operands.insert(inst->GetSingleWordOperand(i)).second) { |
2784 | 0 | new_operands.emplace_back(inst->GetOperand(i)); |
2785 | 0 | } else { |
2786 | 0 | has_redundant_operand = true; |
2787 | 0 | } |
2788 | 0 | } |
2789 | |
|
2790 | 0 | if (!has_redundant_operand) { |
2791 | 0 | return false; |
2792 | 0 | } |
2793 | | |
2794 | 0 | inst->SetInOperands(std::move(new_operands)); |
2795 | 0 | return true; |
2796 | 0 | }; |
2797 | 8.16k | } |
2798 | | |
2799 | | // If an image instruction's operand is a constant, updates the image operand |
2800 | | // flag from Offset to ConstOffset. |
2801 | 204k | FoldingRule UpdateImageOperands() { |
2802 | 204k | return [](IRContext*, Instruction* inst, |
2803 | 638k | const std::vector<const analysis::Constant*>& constants) { |
2804 | 638k | const auto opcode = inst->opcode(); |
2805 | 638k | (void)opcode; |
2806 | 638k | assert((opcode == spv::Op::OpImageSampleImplicitLod || |
2807 | 638k | opcode == spv::Op::OpImageSampleExplicitLod || |
2808 | 638k | opcode == spv::Op::OpImageSampleDrefImplicitLod || |
2809 | 638k | opcode == spv::Op::OpImageSampleDrefExplicitLod || |
2810 | 638k | opcode == spv::Op::OpImageSampleProjImplicitLod || |
2811 | 638k | opcode == spv::Op::OpImageSampleProjExplicitLod || |
2812 | 638k | opcode == spv::Op::OpImageSampleProjDrefImplicitLod || |
2813 | 638k | opcode == spv::Op::OpImageSampleProjDrefExplicitLod || |
2814 | 638k | opcode == spv::Op::OpImageFetch || |
2815 | 638k | opcode == spv::Op::OpImageGather || |
2816 | 638k | opcode == spv::Op::OpImageDrefGather || |
2817 | 638k | opcode == spv::Op::OpImageRead || opcode == spv::Op::OpImageWrite || |
2818 | 638k | opcode == spv::Op::OpImageSparseSampleImplicitLod || |
2819 | 638k | opcode == spv::Op::OpImageSparseSampleExplicitLod || |
2820 | 638k | opcode == spv::Op::OpImageSparseSampleDrefImplicitLod || |
2821 | 638k | opcode == spv::Op::OpImageSparseSampleDrefExplicitLod || |
2822 | 638k | opcode == spv::Op::OpImageSparseSampleProjImplicitLod || |
2823 | 638k | opcode == spv::Op::OpImageSparseSampleProjExplicitLod || |
2824 | 638k | opcode == spv::Op::OpImageSparseSampleProjDrefImplicitLod || |
2825 | 638k | opcode == spv::Op::OpImageSparseSampleProjDrefExplicitLod || |
2826 | 638k | opcode == spv::Op::OpImageSparseFetch || |
2827 | 638k | opcode == spv::Op::OpImageSparseGather || |
2828 | 638k | opcode == spv::Op::OpImageSparseDrefGather || |
2829 | 638k | opcode == spv::Op::OpImageSparseRead) && |
2830 | 638k | "Wrong opcode. Should be an image instruction."); |
2831 | | |
2832 | 638k | int32_t operand_index = ImageOperandsMaskInOperandIndex(inst); |
2833 | 638k | if (operand_index >= 0) { |
2834 | 4 | auto image_operands = inst->GetSingleWordInOperand(operand_index); |
2835 | 4 | if (image_operands & uint32_t(spv::ImageOperandsMask::Offset)) { |
2836 | 0 | uint32_t offset_operand_index = operand_index + 1; |
2837 | 0 | if (image_operands & uint32_t(spv::ImageOperandsMask::Bias)) |
2838 | 0 | offset_operand_index++; |
2839 | 0 | if (image_operands & uint32_t(spv::ImageOperandsMask::Lod)) |
2840 | 0 | offset_operand_index++; |
2841 | 0 | if (image_operands & uint32_t(spv::ImageOperandsMask::Grad)) |
2842 | 0 | offset_operand_index += 2; |
2843 | 0 | assert(((image_operands & |
2844 | 0 | uint32_t(spv::ImageOperandsMask::ConstOffset)) == 0) && |
2845 | 0 | "Offset and ConstOffset may not be used together"); |
2846 | 0 | if (offset_operand_index < inst->NumOperands()) { |
2847 | 0 | if (constants[offset_operand_index]) { |
2848 | 0 | if (constants[offset_operand_index]->IsZero()) { |
2849 | 0 | inst->RemoveInOperand(offset_operand_index); |
2850 | 0 | } else { |
2851 | 0 | image_operands = image_operands | |
2852 | 0 | uint32_t(spv::ImageOperandsMask::ConstOffset); |
2853 | 0 | } |
2854 | 0 | image_operands = |
2855 | 0 | image_operands & ~uint32_t(spv::ImageOperandsMask::Offset); |
2856 | 0 | inst->SetInOperand(operand_index, {image_operands}); |
2857 | 0 | return true; |
2858 | 0 | } |
2859 | 0 | } |
2860 | 0 | } |
2861 | 4 | } |
2862 | | |
2863 | 638k | return false; |
2864 | 638k | }; |
2865 | 204k | } |
2866 | | |
2867 | | } // namespace |
2868 | | |
2869 | 8.16k | void FoldingRules::AddFoldingRules() { |
2870 | | // Add all folding rules to the list for the opcodes to which they apply. |
2871 | | // Note that the order in which rules are added to the list matters. If a rule |
2872 | | // applies to the instruction, the rest of the rules will not be attempted. |
2873 | | // Take that into consideration. |
2874 | 8.16k | rules_[spv::Op::OpBitcast].push_back(BitCastScalarOrVector()); |
2875 | | |
2876 | 8.16k | rules_[spv::Op::OpCompositeConstruct].push_back( |
2877 | 8.16k | CompositeExtractFeedingConstruct); |
2878 | | |
2879 | 8.16k | rules_[spv::Op::OpCompositeExtract].push_back(InsertFeedingExtract()); |
2880 | 8.16k | rules_[spv::Op::OpCompositeExtract].push_back( |
2881 | 8.16k | CompositeConstructFeedingExtract); |
2882 | 8.16k | rules_[spv::Op::OpCompositeExtract].push_back(VectorShuffleFeedingExtract()); |
2883 | 8.16k | rules_[spv::Op::OpCompositeExtract].push_back(FMixFeedingExtract()); |
2884 | | |
2885 | 8.16k | rules_[spv::Op::OpCompositeInsert].push_back( |
2886 | 8.16k | CompositeInsertToCompositeConstruct); |
2887 | | |
2888 | 8.16k | rules_[spv::Op::OpDot].push_back(DotProductDoingExtract()); |
2889 | | |
2890 | 8.16k | rules_[spv::Op::OpEntryPoint].push_back(RemoveRedundantOperands()); |
2891 | | |
2892 | 8.16k | rules_[spv::Op::OpFAdd].push_back(RedundantFAdd()); |
2893 | 8.16k | rules_[spv::Op::OpFAdd].push_back(MergeAddNegateArithmetic()); |
2894 | 8.16k | rules_[spv::Op::OpFAdd].push_back(MergeAddAddArithmetic()); |
2895 | 8.16k | rules_[spv::Op::OpFAdd].push_back(MergeAddSubArithmetic()); |
2896 | 8.16k | rules_[spv::Op::OpFAdd].push_back(MergeGenericAddSubArithmetic()); |
2897 | 8.16k | rules_[spv::Op::OpFAdd].push_back(FactorAddMuls()); |
2898 | | |
2899 | 8.16k | rules_[spv::Op::OpFDiv].push_back(RedundantFDiv()); |
2900 | 8.16k | rules_[spv::Op::OpFDiv].push_back(ReciprocalFDiv()); |
2901 | 8.16k | rules_[spv::Op::OpFDiv].push_back(MergeDivDivArithmetic()); |
2902 | 8.16k | rules_[spv::Op::OpFDiv].push_back(MergeDivMulArithmetic()); |
2903 | 8.16k | rules_[spv::Op::OpFDiv].push_back(MergeDivNegateArithmetic()); |
2904 | | |
2905 | 8.16k | rules_[spv::Op::OpFMul].push_back(RedundantFMul()); |
2906 | 8.16k | rules_[spv::Op::OpFMul].push_back(MergeMulMulArithmetic()); |
2907 | 8.16k | rules_[spv::Op::OpFMul].push_back(MergeMulDivArithmetic()); |
2908 | 8.16k | rules_[spv::Op::OpFMul].push_back(MergeMulNegateArithmetic()); |
2909 | | |
2910 | 8.16k | rules_[spv::Op::OpFNegate].push_back(MergeNegateArithmetic()); |
2911 | 8.16k | rules_[spv::Op::OpFNegate].push_back(MergeNegateAddSubArithmetic()); |
2912 | 8.16k | rules_[spv::Op::OpFNegate].push_back(MergeNegateMulDivArithmetic()); |
2913 | | |
2914 | 8.16k | rules_[spv::Op::OpFSub].push_back(RedundantFSub()); |
2915 | 8.16k | rules_[spv::Op::OpFSub].push_back(MergeSubNegateArithmetic()); |
2916 | 8.16k | rules_[spv::Op::OpFSub].push_back(MergeSubAddArithmetic()); |
2917 | 8.16k | rules_[spv::Op::OpFSub].push_back(MergeSubSubArithmetic()); |
2918 | | |
2919 | 8.16k | rules_[spv::Op::OpIAdd].push_back(RedundantIAdd()); |
2920 | 8.16k | rules_[spv::Op::OpIAdd].push_back(MergeAddNegateArithmetic()); |
2921 | 8.16k | rules_[spv::Op::OpIAdd].push_back(MergeAddAddArithmetic()); |
2922 | 8.16k | rules_[spv::Op::OpIAdd].push_back(MergeAddSubArithmetic()); |
2923 | 8.16k | rules_[spv::Op::OpIAdd].push_back(MergeGenericAddSubArithmetic()); |
2924 | 8.16k | rules_[spv::Op::OpIAdd].push_back(FactorAddMuls()); |
2925 | | |
2926 | 8.16k | rules_[spv::Op::OpIMul].push_back(IntMultipleBy1()); |
2927 | 8.16k | rules_[spv::Op::OpIMul].push_back(MergeMulMulArithmetic()); |
2928 | 8.16k | rules_[spv::Op::OpIMul].push_back(MergeMulNegateArithmetic()); |
2929 | | |
2930 | 8.16k | rules_[spv::Op::OpISub].push_back(MergeSubNegateArithmetic()); |
2931 | 8.16k | rules_[spv::Op::OpISub].push_back(MergeSubAddArithmetic()); |
2932 | 8.16k | rules_[spv::Op::OpISub].push_back(MergeSubSubArithmetic()); |
2933 | | |
2934 | 8.16k | rules_[spv::Op::OpPhi].push_back(RedundantPhi()); |
2935 | | |
2936 | 8.16k | rules_[spv::Op::OpSNegate].push_back(MergeNegateArithmetic()); |
2937 | 8.16k | rules_[spv::Op::OpSNegate].push_back(MergeNegateMulDivArithmetic()); |
2938 | 8.16k | rules_[spv::Op::OpSNegate].push_back(MergeNegateAddSubArithmetic()); |
2939 | | |
2940 | 8.16k | rules_[spv::Op::OpSelect].push_back(RedundantSelect()); |
2941 | | |
2942 | 8.16k | rules_[spv::Op::OpStore].push_back(StoringUndef()); |
2943 | | |
2944 | 8.16k | rules_[spv::Op::OpVectorShuffle].push_back(VectorShuffleFeedingShuffle()); |
2945 | | |
2946 | 8.16k | rules_[spv::Op::OpImageSampleImplicitLod].push_back(UpdateImageOperands()); |
2947 | 8.16k | rules_[spv::Op::OpImageSampleExplicitLod].push_back(UpdateImageOperands()); |
2948 | 8.16k | rules_[spv::Op::OpImageSampleDrefImplicitLod].push_back( |
2949 | 8.16k | UpdateImageOperands()); |
2950 | 8.16k | rules_[spv::Op::OpImageSampleDrefExplicitLod].push_back( |
2951 | 8.16k | UpdateImageOperands()); |
2952 | 8.16k | rules_[spv::Op::OpImageSampleProjImplicitLod].push_back( |
2953 | 8.16k | UpdateImageOperands()); |
2954 | 8.16k | rules_[spv::Op::OpImageSampleProjExplicitLod].push_back( |
2955 | 8.16k | UpdateImageOperands()); |
2956 | 8.16k | rules_[spv::Op::OpImageSampleProjDrefImplicitLod].push_back( |
2957 | 8.16k | UpdateImageOperands()); |
2958 | 8.16k | rules_[spv::Op::OpImageSampleProjDrefExplicitLod].push_back( |
2959 | 8.16k | UpdateImageOperands()); |
2960 | 8.16k | rules_[spv::Op::OpImageFetch].push_back(UpdateImageOperands()); |
2961 | 8.16k | rules_[spv::Op::OpImageGather].push_back(UpdateImageOperands()); |
2962 | 8.16k | rules_[spv::Op::OpImageDrefGather].push_back(UpdateImageOperands()); |
2963 | 8.16k | rules_[spv::Op::OpImageRead].push_back(UpdateImageOperands()); |
2964 | 8.16k | rules_[spv::Op::OpImageWrite].push_back(UpdateImageOperands()); |
2965 | 8.16k | rules_[spv::Op::OpImageSparseSampleImplicitLod].push_back( |
2966 | 8.16k | UpdateImageOperands()); |
2967 | 8.16k | rules_[spv::Op::OpImageSparseSampleExplicitLod].push_back( |
2968 | 8.16k | UpdateImageOperands()); |
2969 | 8.16k | rules_[spv::Op::OpImageSparseSampleDrefImplicitLod].push_back( |
2970 | 8.16k | UpdateImageOperands()); |
2971 | 8.16k | rules_[spv::Op::OpImageSparseSampleDrefExplicitLod].push_back( |
2972 | 8.16k | UpdateImageOperands()); |
2973 | 8.16k | rules_[spv::Op::OpImageSparseSampleProjImplicitLod].push_back( |
2974 | 8.16k | UpdateImageOperands()); |
2975 | 8.16k | rules_[spv::Op::OpImageSparseSampleProjExplicitLod].push_back( |
2976 | 8.16k | UpdateImageOperands()); |
2977 | 8.16k | rules_[spv::Op::OpImageSparseSampleProjDrefImplicitLod].push_back( |
2978 | 8.16k | UpdateImageOperands()); |
2979 | 8.16k | rules_[spv::Op::OpImageSparseSampleProjDrefExplicitLod].push_back( |
2980 | 8.16k | UpdateImageOperands()); |
2981 | 8.16k | rules_[spv::Op::OpImageSparseFetch].push_back(UpdateImageOperands()); |
2982 | 8.16k | rules_[spv::Op::OpImageSparseGather].push_back(UpdateImageOperands()); |
2983 | 8.16k | rules_[spv::Op::OpImageSparseDrefGather].push_back(UpdateImageOperands()); |
2984 | 8.16k | rules_[spv::Op::OpImageSparseRead].push_back(UpdateImageOperands()); |
2985 | | |
2986 | 8.16k | FeatureManager* feature_manager = context_->get_feature_mgr(); |
2987 | | // Add rules for GLSLstd450 |
2988 | 8.16k | uint32_t ext_inst_glslstd450_id = |
2989 | 8.16k | feature_manager->GetExtInstImportId_GLSLstd450(); |
2990 | 8.16k | if (ext_inst_glslstd450_id != 0) { |
2991 | 4.69k | ext_rules_[{ext_inst_glslstd450_id, GLSLstd450FMix}].push_back( |
2992 | 4.69k | RedundantFMix()); |
2993 | 4.69k | } |
2994 | 8.16k | } |
2995 | | } // namespace opt |
2996 | | } // namespace spvtools |