/src/spirv-tools/source/opt/folding_rules.cpp
Line | Count | Source |
1 | | // Copyright (c) 2018 Google LLC |
2 | | // |
3 | | // Licensed under the Apache License, Version 2.0 (the "License"); |
4 | | // you may not use this file except in compliance with the License. |
5 | | // You may obtain a copy of the License at |
6 | | // |
7 | | // http://www.apache.org/licenses/LICENSE-2.0 |
8 | | // |
9 | | // Unless required by applicable law or agreed to in writing, software |
10 | | // distributed under the License is distributed on an "AS IS" BASIS, |
11 | | // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. |
12 | | // See the License for the specific language governing permissions and |
13 | | // limitations under the License. |
14 | | |
15 | | #include "source/opt/folding_rules.h" |
16 | | |
17 | | #include <limits> |
18 | | #include <memory> |
19 | | #include <optional> |
20 | | #include <utility> |
21 | | |
22 | | #include "ir_builder.h" |
23 | | #include "source/latest_version_glsl_std_450_header.h" |
24 | | #include "source/opt/ir_context.h" |
25 | | |
26 | | namespace spvtools { |
27 | | namespace opt { |
28 | | namespace { |
29 | | |
30 | | constexpr uint32_t kExtractCompositeIdInIdx = 0; |
31 | | constexpr uint32_t kInsertObjectIdInIdx = 0; |
32 | | constexpr uint32_t kInsertCompositeIdInIdx = 1; |
33 | | constexpr uint32_t kExtInstSetIdInIdx = 0; |
34 | | constexpr uint32_t kExtInstInstructionInIdx = 1; |
35 | | constexpr uint32_t kFMixXIdInIdx = 2; |
36 | | constexpr uint32_t kFMixYIdInIdx = 3; |
37 | | constexpr uint32_t kFMixAIdInIdx = 4; |
38 | | constexpr uint32_t kStoreObjectInIdx = 1; |
39 | | |
40 | | // Some image instructions may contain an "image operands" argument. |
41 | | // Returns the operand index for the "image operands". |
42 | | // Returns -1 if the instruction does not have image operands. |
43 | 270k | int32_t ImageOperandsMaskInOperandIndex(Instruction* inst) { |
44 | 270k | const auto opcode = inst->opcode(); |
45 | 270k | switch (opcode) { |
46 | 270k | case spv::Op::OpImageSampleImplicitLod: |
47 | 270k | case spv::Op::OpImageSampleExplicitLod: |
48 | 270k | case spv::Op::OpImageSampleProjImplicitLod: |
49 | 270k | case spv::Op::OpImageSampleProjExplicitLod: |
50 | 270k | case spv::Op::OpImageFetch: |
51 | 270k | case spv::Op::OpImageRead: |
52 | 270k | case spv::Op::OpImageSparseSampleImplicitLod: |
53 | 270k | case spv::Op::OpImageSparseSampleExplicitLod: |
54 | 270k | case spv::Op::OpImageSparseSampleProjImplicitLod: |
55 | 270k | case spv::Op::OpImageSparseSampleProjExplicitLod: |
56 | 270k | case spv::Op::OpImageSparseFetch: |
57 | 270k | case spv::Op::OpImageSparseRead: |
58 | 270k | return inst->NumOperands() > 4 ? 2 : -1; |
59 | 2 | case spv::Op::OpImageSampleDrefImplicitLod: |
60 | 2 | case spv::Op::OpImageSampleDrefExplicitLod: |
61 | 2 | case spv::Op::OpImageSampleProjDrefImplicitLod: |
62 | 2 | case spv::Op::OpImageSampleProjDrefExplicitLod: |
63 | 2 | case spv::Op::OpImageGather: |
64 | 2 | case spv::Op::OpImageDrefGather: |
65 | 2 | case spv::Op::OpImageSparseSampleDrefImplicitLod: |
66 | 2 | case spv::Op::OpImageSparseSampleDrefExplicitLod: |
67 | 2 | case spv::Op::OpImageSparseSampleProjDrefImplicitLod: |
68 | 2 | case spv::Op::OpImageSparseSampleProjDrefExplicitLod: |
69 | 2 | case spv::Op::OpImageSparseGather: |
70 | 2 | case spv::Op::OpImageSparseDrefGather: |
71 | 2 | return inst->NumOperands() > 5 ? 3 : -1; |
72 | 0 | case spv::Op::OpImageWrite: |
73 | 0 | return inst->NumOperands() > 3 ? 3 : -1; |
74 | 0 | default: |
75 | 0 | return -1; |
76 | 270k | } |
77 | 270k | } |
78 | | |
79 | | // Returns the element width of |type|. |
80 | 5.76M | uint32_t ElementWidth(const analysis::Type* type) { |
81 | 5.76M | if (const analysis::CooperativeVectorNV* coopvec_type = |
82 | 5.76M | type->AsCooperativeVectorNV()) { |
83 | 0 | return ElementWidth(coopvec_type->component_type()); |
84 | 5.76M | } else if (const analysis::Vector* vec_type = type->AsVector()) { |
85 | 2.13M | return ElementWidth(vec_type->element_type()); |
86 | 3.63M | } else if (const analysis::Float* float_type = type->AsFloat()) { |
87 | 3.03M | return float_type->width(); |
88 | 3.03M | } else { |
89 | 600k | assert(type->AsInteger()); |
90 | 600k | return type->AsInteger()->width(); |
91 | 600k | } |
92 | 5.76M | } |
93 | | |
94 | | // Returns true if |type| is Float or a vector of Float. |
95 | 5.94M | bool HasFloatingPoint(const analysis::Type* type) { |
96 | 5.94M | if (type->AsFloat()) { |
97 | 1.86M | return true; |
98 | 4.08M | } else if (const analysis::Vector* vec_type = type->AsVector()) { |
99 | 3.35M | return vec_type->element_type()->AsFloat() != nullptr; |
100 | 3.35M | } |
101 | | |
102 | 731k | return false; |
103 | 5.94M | } |
104 | | |
105 | | // Returns false if |val| is NaN, infinite or subnormal. |
106 | | template <typename T> |
107 | 176k | bool IsValidResult(T val) { |
108 | 176k | int classified = std::fpclassify(val); |
109 | 176k | switch (classified) { |
110 | 6.14k | case FP_NAN: |
111 | 50.0k | case FP_INFINITE: |
112 | 53.4k | case FP_SUBNORMAL: |
113 | 53.4k | return false; |
114 | 122k | default: |
115 | 122k | return true; |
116 | 176k | } |
117 | 176k | } Unexecuted instantiation: folding_rules.cpp:bool spvtools::opt::(anonymous namespace)::IsValidResult<double>(double) folding_rules.cpp:bool spvtools::opt::(anonymous namespace)::IsValidResult<float>(float) Line | Count | Source | 107 | 176k | bool IsValidResult(T val) { | 108 | 176k | int classified = std::fpclassify(val); | 109 | 176k | switch (classified) { | 110 | 6.14k | case FP_NAN: | 111 | 50.0k | case FP_INFINITE: | 112 | 53.4k | case FP_SUBNORMAL: | 113 | 53.4k | return false; | 114 | 122k | default: | 115 | 122k | return true; | 116 | 176k | } | 117 | 176k | } |
|
118 | | |
119 | | const analysis::Constant* ConstInput( |
120 | 4.14M | const std::vector<const analysis::Constant*>& constants) { |
121 | 4.14M | return constants[0] ? constants[0] : constants[1]; |
122 | 4.14M | } |
123 | | |
124 | | Instruction* NonConstInput(IRContext* context, const analysis::Constant* c, |
125 | 1.98M | Instruction* inst) { |
126 | 1.98M | uint32_t in_op = c ? 1u : 0u; |
127 | 1.98M | return context->get_def_use_mgr()->GetDef( |
128 | 1.98M | inst->GetSingleWordInOperand(in_op)); |
129 | 1.98M | } |
130 | | |
131 | 0 | std::vector<uint32_t> ExtractInts(uint64_t val) { |
132 | 0 | std::vector<uint32_t> words; |
133 | 0 | words.push_back(static_cast<uint32_t>(val)); |
134 | 0 | words.push_back(static_cast<uint32_t>(val >> 32)); |
135 | 0 | return words; |
136 | 0 | } |
137 | | |
138 | | std::vector<uint32_t> GetWordsFromScalarIntConstant( |
139 | 1.29k | const analysis::IntConstant* c) { |
140 | 1.29k | assert(c != nullptr); |
141 | 1.29k | uint32_t width = c->type()->AsInteger()->width(); |
142 | 1.29k | assert(width == 8 || width == 16 || width == 32 || width == 64); |
143 | 1.29k | if (width == 64) { |
144 | 0 | uint64_t uval = static_cast<uint64_t>(c->GetU64()); |
145 | 0 | return ExtractInts(uval); |
146 | 0 | } |
147 | | // Section 2.2.1 of the SPIR-V spec guarantees that all integer types |
148 | | // smaller than 32-bits are automatically zero or sign extended to 32-bits. |
149 | 1.29k | return {c->GetU32BitValue()}; |
150 | 1.29k | } |
151 | | |
152 | | std::vector<uint32_t> GetWordsFromScalarFloatConstant( |
153 | 274 | const analysis::FloatConstant* c) { |
154 | 274 | assert(c != nullptr); |
155 | 274 | uint32_t width = c->type()->AsFloat()->width(); |
156 | 274 | assert(width == 16 || width == 32 || width == 64); |
157 | 274 | if (width == 64) { |
158 | 0 | utils::FloatProxy<double> result(c->GetDouble()); |
159 | 0 | return result.GetWords(); |
160 | 0 | } |
161 | | // Section 2.2.1 of the SPIR-V spec guarantees that all floating-point types |
162 | | // smaller than 32-bits are automatically zero extended to 32-bits. |
163 | 274 | return {c->GetU32BitValue()}; |
164 | 274 | } |
165 | | |
166 | | std::vector<uint32_t> GetWordsFromNumericScalarOrVectorConstant( |
167 | 2.21k | analysis::ConstantManager* const_mgr, const analysis::Constant* c) { |
168 | 2.21k | if (const auto* float_constant = c->AsFloatConstant()) { |
169 | 274 | return GetWordsFromScalarFloatConstant(float_constant); |
170 | 1.93k | } else if (const auto* int_constant = c->AsIntConstant()) { |
171 | 1.29k | return GetWordsFromScalarIntConstant(int_constant); |
172 | 1.29k | } else if (const auto* vec_constant = c->AsVectorConstant()) { |
173 | 86 | std::vector<uint32_t> words; |
174 | | // Retrieve all the components as 32bit words. |
175 | 308 | for (const auto* comp : vec_constant->GetComponents()) { |
176 | 308 | auto comp_in_words = |
177 | 308 | GetWordsFromNumericScalarOrVectorConstant(const_mgr, comp); |
178 | 308 | words.insert(words.end(), comp_in_words.begin(), comp_in_words.end()); |
179 | 308 | } |
180 | | |
181 | 86 | if (ElementWidth(c->type()) >= 32) { |
182 | 86 | return words; |
183 | 86 | } |
184 | | // Check the element width and concactenate if the width is less than 32. |
185 | 0 | if (ElementWidth(c->type()) == 8) { |
186 | 0 | assert(words.size() <= 4); |
187 | | // Each 32-bit word will comprise 4 8-bit integers. |
188 | | // reverse the order when compacting. |
189 | 0 | uint32_t compacted_word = 0; |
190 | 0 | for (int32_t i = static_cast<int32_t>(words.size()) - 1; i >= 0; --i) { |
191 | 0 | compacted_word <<= 8; |
192 | 0 | compacted_word |= (words[i] & 0xFF); |
193 | 0 | } |
194 | 0 | return {compacted_word}; |
195 | 0 | } else if (ElementWidth(c->type()) == 16) { |
196 | 0 | assert(words.size() <= 4); |
197 | 0 | std::vector<uint32_t> compacted_words; |
198 | | // Each 32-bit word will comprise 2 16-bit integers. |
199 | | // reverse the order pair-wise when compacting. |
200 | 0 | for (uint32_t i = 0; i < words.size(); i += 2) { |
201 | 0 | uint32_t word1 = words[i]; |
202 | 0 | uint32_t word2 = (i + 1 < words.size()) ? words[i + 1] : 0; |
203 | 0 | uint32_t compacted_word = (word2 << 16) | (word1 & 0xFFFF); |
204 | 0 | compacted_words.push_back(compacted_word); |
205 | 0 | } |
206 | 0 | return compacted_words; |
207 | 0 | } |
208 | 0 | assert(false && "Unhandled element width"); |
209 | 552 | } else if (c->AsNullConstant()) { |
210 | 552 | uint32_t num_elements = 1; |
211 | | |
212 | 552 | if (const auto* vec_type = c->type()->AsVector()) { |
213 | 0 | num_elements = vec_type->element_count(); |
214 | 0 | } |
215 | | |
216 | | // We need to check the element width to determine how many 32-bit words are |
217 | | // needed. |
218 | 552 | uint32_t element_width = ElementWidth(c->type()); |
219 | 552 | if (element_width < 32) { |
220 | 0 | num_elements = (num_elements + 1) / 2; |
221 | 552 | } else if (element_width == 64) { |
222 | 0 | num_elements = num_elements * 2; |
223 | 0 | } |
224 | 552 | return std::vector<uint32_t>(num_elements, 0); |
225 | 552 | } |
226 | 0 | return {}; |
227 | 2.21k | } |
228 | | |
229 | | const analysis::Constant* ConvertWordsToNumericScalarOrVectorConstant( |
230 | | analysis::ConstantManager* const_mgr, const std::vector<uint32_t>& words, |
231 | 1.90k | const analysis::Type* type) { |
232 | 1.90k | const spvtools::opt::analysis::Integer* int_type = type->AsInteger(); |
233 | | |
234 | 1.90k | if (int_type && int_type->width() <= 32) { |
235 | 876 | assert(words.size() == 1); |
236 | 876 | return const_mgr->GenerateIntegerConstant(int_type, words[0]); |
237 | 876 | } |
238 | | |
239 | 1.02k | if (int_type || type->AsFloat()) return const_mgr->GetConstant(type, words); |
240 | 86 | if (const auto* vec_type = type->AsVector()) |
241 | 86 | return const_mgr->GetNumericVectorConstantWithWords(vec_type, words); |
242 | 0 | return nullptr; |
243 | 86 | } |
244 | | |
245 | | // Returns the negation of |c|. |c| must be a 32 or 64 bit floating point |
246 | | // constant. |
247 | | uint32_t NegateFloatingPointConstant(analysis::ConstantManager* const_mgr, |
248 | 2.50k | const analysis::Constant* c) { |
249 | 2.50k | assert(c); |
250 | 2.50k | assert(c->type()->AsFloat()); |
251 | 2.50k | uint32_t width = c->type()->AsFloat()->width(); |
252 | 2.50k | assert(width == 32 || width == 64); |
253 | 2.50k | std::vector<uint32_t> words; |
254 | 2.50k | if (width == 64) { |
255 | 0 | utils::FloatProxy<double> result(c->GetDouble() * -1.0); |
256 | 0 | words = result.GetWords(); |
257 | 2.50k | } else { |
258 | 2.50k | utils::FloatProxy<float> result(c->GetFloat() * -1.0f); |
259 | 2.50k | words = result.GetWords(); |
260 | 2.50k | } |
261 | | |
262 | 2.50k | const analysis::Constant* negated_const = |
263 | 2.50k | const_mgr->GetConstant(c->type(), std::move(words)); |
264 | 2.50k | return const_mgr->GetDefiningInstruction(negated_const)->result_id(); |
265 | 2.50k | } |
266 | | |
267 | | // Negates the integer constant |c|. Returns the id of the defining instruction. |
268 | | uint32_t NegateIntegerConstant(analysis::ConstantManager* const_mgr, |
269 | 616 | const analysis::Constant* c) { |
270 | 616 | assert(c); |
271 | 616 | assert(c->type()->AsInteger()); |
272 | 616 | uint32_t width = c->type()->AsInteger()->width(); |
273 | 616 | assert(width == 32 || width == 64); |
274 | 616 | std::vector<uint32_t> words; |
275 | 616 | if (width == 64) { |
276 | 0 | uint64_t uval = static_cast<uint64_t>(0 - c->GetU64()); |
277 | 0 | words = ExtractInts(uval); |
278 | 616 | } else { |
279 | 616 | words.push_back(static_cast<uint32_t>(0 - c->GetU32())); |
280 | 616 | } |
281 | | |
282 | 616 | const analysis::Constant* negated_const = |
283 | 616 | const_mgr->GetConstant(c->type(), std::move(words)); |
284 | 616 | return const_mgr->GetDefiningInstruction(negated_const)->result_id(); |
285 | 616 | } |
286 | | |
287 | | // Negates the vector constant |c|. Returns the id of the defining instruction. |
288 | | uint32_t NegateVectorConstant(analysis::ConstantManager* const_mgr, |
289 | 656 | const analysis::Constant* c) { |
290 | 656 | assert(const_mgr && c); |
291 | 656 | assert(c->type()->AsVector()); |
292 | 656 | if (c->AsNullConstant()) { |
293 | | // 0.0 vs -0.0 shouldn't matter. |
294 | 0 | return const_mgr->GetDefiningInstruction(c)->result_id(); |
295 | 656 | } else { |
296 | 656 | const analysis::Type* component_type = |
297 | 656 | c->AsVectorConstant()->component_type(); |
298 | 656 | std::vector<uint32_t> words; |
299 | 1.31k | for (auto& comp : c->AsVectorConstant()->GetComponents()) { |
300 | 1.31k | if (component_type->AsFloat()) { |
301 | 1.31k | words.push_back(NegateFloatingPointConstant(const_mgr, comp)); |
302 | 1.31k | } else { |
303 | 0 | assert(component_type->AsInteger()); |
304 | 0 | words.push_back(NegateIntegerConstant(const_mgr, comp)); |
305 | 0 | } |
306 | 1.31k | } |
307 | | |
308 | 656 | const analysis::Constant* negated_const = |
309 | 656 | const_mgr->GetConstant(c->type(), std::move(words)); |
310 | 656 | return const_mgr->GetDefiningInstruction(negated_const)->result_id(); |
311 | 656 | } |
312 | 656 | } |
313 | | |
314 | | // Negates |c|. Returns the id of the defining instruction. |
315 | | uint32_t NegateConstant(analysis::ConstantManager* const_mgr, |
316 | 2.46k | const analysis::Constant* c) { |
317 | 2.46k | if (c->type()->AsVector()) { |
318 | 656 | return NegateVectorConstant(const_mgr, c); |
319 | 1.80k | } else if (c->type()->AsFloat()) { |
320 | 1.19k | return NegateFloatingPointConstant(const_mgr, c); |
321 | 1.19k | } else { |
322 | 616 | assert(c->type()->AsInteger()); |
323 | 616 | return NegateIntegerConstant(const_mgr, c); |
324 | 616 | } |
325 | 2.46k | } |
326 | | |
327 | | // Takes the reciprocal of |c|. |c|'s type must be Float or a vector of Float. |
328 | | // Returns 0 if the reciprocal is NaN, infinite or subnormal. |
329 | | uint32_t Reciprocal(analysis::ConstantManager* const_mgr, |
330 | 152k | const analysis::Constant* c) { |
331 | 152k | assert(const_mgr && c); |
332 | 152k | assert(c->type()->AsFloat()); |
333 | | |
334 | 152k | uint32_t width = c->type()->AsFloat()->width(); |
335 | 152k | assert(width == 32 || width == 64); |
336 | 152k | std::vector<uint32_t> words; |
337 | | |
338 | 152k | if (c->IsZero()) { |
339 | 18.1k | return 0; |
340 | 18.1k | } |
341 | | |
342 | 133k | if (width == 64) { |
343 | 0 | spvtools::utils::FloatProxy<double> result(1.0 / c->GetDouble()); |
344 | 0 | if (!IsValidResult(result.getAsFloat())) return 0; |
345 | 0 | words = result.GetWords(); |
346 | 133k | } else { |
347 | 133k | spvtools::utils::FloatProxy<float> result(1.0f / c->GetFloat()); |
348 | 133k | if (!IsValidResult(result.getAsFloat())) return 0; |
349 | 92.4k | words = result.GetWords(); |
350 | 92.4k | } |
351 | | |
352 | 92.4k | const analysis::Constant* negated_const = |
353 | 92.4k | const_mgr->GetConstant(c->type(), std::move(words)); |
354 | 92.4k | return const_mgr->GetDefiningInstruction(negated_const)->result_id(); |
355 | 133k | } |
356 | | |
357 | | // Replaces fdiv where second operand is constant with fmul. |
358 | 16.7k | FoldingRule ReciprocalFDiv() { |
359 | 16.7k | return [](IRContext* context, Instruction* inst, |
360 | 145k | const std::vector<const analysis::Constant*>& constants) { |
361 | 145k | assert(inst->opcode() == spv::Op::OpFDiv); |
362 | 145k | analysis::ConstantManager* const_mgr = context->get_constant_mgr(); |
363 | 145k | const analysis::Type* type = |
364 | 145k | context->get_type_mgr()->GetType(inst->type_id()); |
365 | | |
366 | 145k | if (type->IsCooperativeMatrix()) { |
367 | 0 | return false; |
368 | 0 | } |
369 | | |
370 | 145k | if (!inst->IsFloatingPointFoldingAllowed()) return false; |
371 | | |
372 | 145k | uint32_t width = ElementWidth(type); |
373 | 145k | if (width != 32 && width != 64) return false; |
374 | | |
375 | 145k | if (constants[1] != nullptr) { |
376 | 105k | uint32_t id = 0; |
377 | 105k | if (const analysis::VectorConstant* vector_const = |
378 | 105k | constants[1]->AsVectorConstant()) { |
379 | 94.3k | std::vector<uint32_t> neg_ids; |
380 | 140k | for (auto& comp : vector_const->GetComponents()) { |
381 | 140k | id = Reciprocal(const_mgr, comp); |
382 | 140k | if (id == 0) return false; |
383 | 83.4k | neg_ids.push_back(id); |
384 | 83.4k | } |
385 | 37.2k | const analysis::Constant* negated_const = |
386 | 37.2k | const_mgr->GetConstant(constants[1]->type(), std::move(neg_ids)); |
387 | 37.2k | id = const_mgr->GetDefiningInstruction(negated_const)->result_id(); |
388 | 37.2k | } else if (constants[1]->AsFloatConstant()) { |
389 | 11.2k | id = Reciprocal(const_mgr, constants[1]); |
390 | 11.2k | if (id == 0) return false; |
391 | 11.2k | } else { |
392 | | // Don't fold a null constant. |
393 | 257 | return false; |
394 | 257 | } |
395 | 46.0k | inst->SetOpcode(spv::Op::OpFMul); |
396 | 46.0k | inst->SetInOperands( |
397 | 46.0k | {{SPV_OPERAND_TYPE_ID, {inst->GetSingleWordInOperand(0u)}}, |
398 | 46.0k | {SPV_OPERAND_TYPE_ID, {id}}}); |
399 | 46.0k | return true; |
400 | 105k | } |
401 | | |
402 | 40.0k | return false; |
403 | 145k | }; |
404 | 16.7k | } |
405 | | |
406 | | // Elides consecutive negate instructions. |
407 | 33.5k | FoldingRule MergeNegateArithmetic() { |
408 | 33.5k | return [](IRContext* context, Instruction* inst, |
409 | 33.5k | const std::vector<const analysis::Constant*>& constants) { |
410 | 7.29k | assert(inst->opcode() == spv::Op::OpFNegate || |
411 | 7.29k | inst->opcode() == spv::Op::OpSNegate); |
412 | 7.29k | (void)constants; |
413 | 7.29k | const analysis::Type* type = |
414 | 7.29k | context->get_type_mgr()->GetType(inst->type_id()); |
415 | 7.29k | if (HasFloatingPoint(type) && !inst->IsFloatingPointFoldingAllowed()) |
416 | 0 | return false; |
417 | | |
418 | 7.29k | Instruction* op_inst = |
419 | 7.29k | context->get_def_use_mgr()->GetDef(inst->GetSingleWordInOperand(0u)); |
420 | 7.29k | if (HasFloatingPoint(type) && !op_inst->IsFloatingPointFoldingAllowed()) |
421 | 0 | return false; |
422 | | |
423 | 7.29k | if (op_inst->opcode() == inst->opcode()) { |
424 | | // Elide negates. |
425 | 155 | inst->SetOpcode(spv::Op::OpCopyObject); |
426 | 155 | inst->SetInOperands( |
427 | 155 | {{SPV_OPERAND_TYPE_ID, {op_inst->GetSingleWordInOperand(0u)}}}); |
428 | 155 | return true; |
429 | 155 | } |
430 | | |
431 | 7.13k | return false; |
432 | 7.29k | }; |
433 | 33.5k | } |
434 | | |
435 | | // Merges negate into a mul or div operation if that operation contains a |
436 | | // constant operand. |
437 | | // Cases: |
438 | | // -(x * 2) = x * -2 |
439 | | // -(2 * x) = x * -2 |
440 | | // -(x / 2) = x / -2 |
441 | | // -(2 / x) = -2 / x |
442 | 33.5k | FoldingRule MergeNegateMulDivArithmetic() { |
443 | 33.5k | return [](IRContext* context, Instruction* inst, |
444 | 33.5k | const std::vector<const analysis::Constant*>& constants) { |
445 | 6.99k | assert(inst->opcode() == spv::Op::OpFNegate || |
446 | 6.99k | inst->opcode() == spv::Op::OpSNegate); |
447 | 6.99k | (void)constants; |
448 | 6.99k | analysis::ConstantManager* const_mgr = context->get_constant_mgr(); |
449 | 6.99k | const analysis::Type* type = |
450 | 6.99k | context->get_type_mgr()->GetType(inst->type_id()); |
451 | | |
452 | 6.99k | if (type->IsCooperativeMatrix()) { |
453 | 0 | return false; |
454 | 0 | } |
455 | | |
456 | 6.99k | if (HasFloatingPoint(type) && !inst->IsFloatingPointFoldingAllowed()) |
457 | 0 | return false; |
458 | | |
459 | 6.99k | Instruction* op_inst = |
460 | 6.99k | context->get_def_use_mgr()->GetDef(inst->GetSingleWordInOperand(0u)); |
461 | 6.99k | if (HasFloatingPoint(type) && !op_inst->IsFloatingPointFoldingAllowed()) |
462 | 0 | return false; |
463 | | |
464 | 6.99k | uint32_t width = ElementWidth(type); |
465 | 6.99k | if (width != 32 && width != 64) return false; |
466 | | |
467 | 6.99k | spv::Op opcode = op_inst->opcode(); |
468 | 6.99k | if (opcode != spv::Op::OpFMul && opcode != spv::Op::OpFDiv && |
469 | 6.27k | opcode != spv::Op::OpIMul && opcode != spv::Op::OpSDiv) { |
470 | 6.21k | return false; |
471 | 6.21k | } |
472 | | |
473 | 771 | std::vector<const analysis::Constant*> op_constants = |
474 | 771 | const_mgr->GetOperandConstants(op_inst); |
475 | | // Merge negate into mul or div if one operand is constant. |
476 | 771 | if (op_constants[0] == nullptr && op_constants[1] == nullptr) { |
477 | 248 | return false; |
478 | 248 | } |
479 | | |
480 | 523 | bool zero_is_variable = op_constants[0] == nullptr; |
481 | 523 | const analysis::Constant* c = ConstInput(op_constants); |
482 | 523 | uint32_t neg_id = NegateConstant(const_mgr, c); |
483 | 523 | uint32_t non_const_id = zero_is_variable |
484 | 523 | ? op_inst->GetSingleWordInOperand(0u) |
485 | 523 | : op_inst->GetSingleWordInOperand(1u); |
486 | | // Change this instruction to a mul/div. |
487 | 523 | inst->SetOpcode(op_inst->opcode()); |
488 | 523 | if (opcode == spv::Op::OpFDiv || opcode == spv::Op::OpUDiv || |
489 | 491 | opcode == spv::Op::OpSDiv) { |
490 | 39 | uint32_t op0 = zero_is_variable ? non_const_id : neg_id; |
491 | 39 | uint32_t op1 = zero_is_variable ? neg_id : non_const_id; |
492 | 39 | inst->SetInOperands( |
493 | 39 | {{SPV_OPERAND_TYPE_ID, {op0}}, {SPV_OPERAND_TYPE_ID, {op1}}}); |
494 | 484 | } else { |
495 | 484 | inst->SetInOperands({{SPV_OPERAND_TYPE_ID, {non_const_id}}, |
496 | 484 | {SPV_OPERAND_TYPE_ID, {neg_id}}}); |
497 | 484 | } |
498 | 523 | return true; |
499 | 771 | }; |
500 | 33.5k | } |
501 | | |
502 | | // Merges negate into a add or sub operation if that operation contains a |
503 | | // constant operand. |
504 | | // Cases: |
505 | | // -(x + 2) = -2 - x |
506 | | // -(2 + x) = -2 - x |
507 | | // -(x - 2) = 2 - x |
508 | | // -(2 - x) = x - 2 |
509 | 33.5k | FoldingRule MergeNegateAddSubArithmetic() { |
510 | 33.5k | return [](IRContext* context, Instruction* inst, |
511 | 33.5k | const std::vector<const analysis::Constant*>& constants) { |
512 | 7.11k | assert(inst->opcode() == spv::Op::OpFNegate || |
513 | 7.11k | inst->opcode() == spv::Op::OpSNegate); |
514 | 7.11k | (void)constants; |
515 | 7.11k | analysis::ConstantManager* const_mgr = context->get_constant_mgr(); |
516 | 7.11k | const analysis::Type* type = |
517 | 7.11k | context->get_type_mgr()->GetType(inst->type_id()); |
518 | | |
519 | 7.11k | if (type->IsCooperativeMatrix()) { |
520 | 0 | return false; |
521 | 0 | } |
522 | | |
523 | 7.11k | if (HasFloatingPoint(type) && !inst->IsFloatingPointFoldingAllowed()) |
524 | 0 | return false; |
525 | | |
526 | 7.11k | Instruction* op_inst = |
527 | 7.11k | context->get_def_use_mgr()->GetDef(inst->GetSingleWordInOperand(0u)); |
528 | 7.11k | if (HasFloatingPoint(type) && !op_inst->IsFloatingPointFoldingAllowed()) |
529 | 0 | return false; |
530 | | |
531 | 7.11k | uint32_t width = ElementWidth(type); |
532 | 7.11k | if (width != 32 && width != 64) return false; |
533 | | |
534 | 7.11k | if (op_inst->opcode() == spv::Op::OpFAdd || |
535 | 6.92k | op_inst->opcode() == spv::Op::OpFSub || |
536 | 6.69k | op_inst->opcode() == spv::Op::OpIAdd || |
537 | 6.44k | op_inst->opcode() == spv::Op::OpISub) { |
538 | 674 | std::vector<const analysis::Constant*> op_constants = |
539 | 674 | const_mgr->GetOperandConstants(op_inst); |
540 | 674 | if (op_constants[0] || op_constants[1]) { |
541 | 236 | bool zero_is_variable = op_constants[0] == nullptr; |
542 | 236 | bool is_add = (op_inst->opcode() == spv::Op::OpFAdd) || |
543 | 179 | (op_inst->opcode() == spv::Op::OpIAdd); |
544 | 236 | bool swap_operands = !is_add || zero_is_variable; |
545 | 236 | bool negate_const = is_add; |
546 | 236 | const analysis::Constant* c = ConstInput(op_constants); |
547 | 236 | uint32_t const_id = 0; |
548 | 236 | if (negate_const) { |
549 | 141 | const_id = NegateConstant(const_mgr, c); |
550 | 141 | } else { |
551 | 95 | const_id = zero_is_variable ? op_inst->GetSingleWordInOperand(1u) |
552 | 95 | : op_inst->GetSingleWordInOperand(0u); |
553 | 95 | } |
554 | | |
555 | | // Swap operands if necessary and make the instruction a subtraction. |
556 | 236 | uint32_t op0 = |
557 | 236 | zero_is_variable ? op_inst->GetSingleWordInOperand(0u) : const_id; |
558 | 236 | uint32_t op1 = |
559 | 236 | zero_is_variable ? const_id : op_inst->GetSingleWordInOperand(1u); |
560 | 236 | if (swap_operands) std::swap(op0, op1); |
561 | 236 | inst->SetOpcode(HasFloatingPoint(type) ? spv::Op::OpFSub |
562 | 236 | : spv::Op::OpISub); |
563 | 236 | inst->SetInOperands( |
564 | 236 | {{SPV_OPERAND_TYPE_ID, {op0}}, {SPV_OPERAND_TYPE_ID, {op1}}}); |
565 | 236 | return true; |
566 | 236 | } |
567 | 674 | } |
568 | | |
569 | 6.87k | return false; |
570 | 7.11k | }; |
571 | 33.5k | } |
572 | | |
573 | | // Returns true if |c| has a zero element. |
574 | 362k | bool HasZero(const analysis::Constant* c) { |
575 | 362k | if (c->AsNullConstant()) { |
576 | 522 | return true; |
577 | 522 | } |
578 | 362k | if (const analysis::VectorConstant* vec_const = c->AsVectorConstant()) { |
579 | 124k | for (auto& comp : vec_const->GetComponents()) |
580 | 212k | if (HasZero(comp)) return true; |
581 | 238k | } else { |
582 | 238k | assert(c->AsScalarConstant()); |
583 | 238k | return c->AsScalarConstant()->IsZero(); |
584 | 238k | } |
585 | | |
586 | 84.0k | return false; |
587 | 362k | } |
588 | | |
589 | | // Performs |input1| |opcode| |input2| and returns the merged constant result |
590 | | // id. Returns 0 if the result is not a valid value. The input types must be |
591 | | // Float. |
592 | | uint32_t PerformFloatingPointOperation(analysis::ConstantManager* const_mgr, |
593 | | spv::Op opcode, |
594 | | const analysis::Constant* input1, |
595 | 42.2k | const analysis::Constant* input2) { |
596 | 42.2k | const analysis::Type* type = input1->type(); |
597 | 42.2k | assert(type->AsFloat()); |
598 | 42.2k | uint32_t width = type->AsFloat()->width(); |
599 | 42.2k | assert(width == 32 || width == 64); |
600 | 42.2k | std::vector<uint32_t> words; |
601 | 42.2k | #define FOLD_OP(op) \ |
602 | 42.2k | if (width == 64) { \ |
603 | 0 | utils::FloatProxy<double> val = \ |
604 | 0 | input1->GetDouble() op input2->GetDouble(); \ |
605 | 0 | double dval = val.getAsFloat(); \ |
606 | 0 | if (!IsValidResult(dval)) return 0; \ |
607 | 0 | words = val.GetWords(); \ |
608 | 42.2k | } else { \ |
609 | 42.2k | utils::FloatProxy<float> val = input1->GetFloat() op input2->GetFloat(); \ |
610 | 42.2k | float fval = val.getAsFloat(); \ |
611 | 42.2k | if (!IsValidResult(fval)) return 0; \ |
612 | 42.2k | words = val.GetWords(); \ |
613 | 30.3k | } \ |
614 | 42.2k | static_assert(true, "require extra semicolon") |
615 | 42.2k | switch (opcode) { |
616 | 6.66k | case spv::Op::OpFMul: |
617 | 6.66k | FOLD_OP(*); |
618 | 3.44k | break; |
619 | 2.49k | case spv::Op::OpFDiv: |
620 | 2.49k | if (HasZero(input2)) return 0; |
621 | 2.49k | FOLD_OP(/); |
622 | 1.66k | break; |
623 | 26.7k | case spv::Op::OpFAdd: |
624 | 26.7k | FOLD_OP(+); |
625 | 20.5k | break; |
626 | 6.41k | case spv::Op::OpFSub: |
627 | 6.41k | FOLD_OP(-); |
628 | 4.61k | break; |
629 | 0 | default: |
630 | 0 | assert(false && "Unexpected operation"); |
631 | 0 | break; |
632 | 42.2k | } |
633 | 30.3k | #undef FOLD_OP |
634 | 30.3k | const analysis::Constant* merged_const = const_mgr->GetConstant(type, words); |
635 | 30.3k | return const_mgr->GetDefiningInstruction(merged_const)->result_id(); |
636 | 42.2k | } |
637 | | |
638 | | // Performs |input1| |opcode| |input2| and returns the merged constant result |
639 | | // id. Returns 0 if the result is not a valid value. The input types must be |
640 | | // Integers. |
641 | | uint32_t PerformIntegerOperation(analysis::ConstantManager* const_mgr, |
642 | | spv::Op opcode, |
643 | | const analysis::Constant* input1, |
644 | 7.69k | const analysis::Constant* input2) { |
645 | 7.69k | assert(input1->type()->AsInteger()); |
646 | 7.69k | const analysis::Integer* type = input1->type()->AsInteger(); |
647 | 7.69k | uint32_t width = type->AsInteger()->width(); |
648 | 7.69k | assert(width == 32 || width == 64); |
649 | 7.69k | std::vector<uint32_t> words; |
650 | | // Regardless of the sign of the constant, folding is performed on an unsigned |
651 | | // interpretation of the constant data. This avoids signed integer overflow |
652 | | // while folding, and works because sign is irrelevant for the IAdd, ISub and |
653 | | // IMul instructions. |
654 | 7.69k | #define FOLD_OP(op) \ |
655 | 7.69k | if (width == 64) { \ |
656 | 0 | uint64_t val = input1->GetU64() op input2->GetU64(); \ |
657 | 0 | words = ExtractInts(val); \ |
658 | 7.69k | } else { \ |
659 | 7.69k | uint32_t val = input1->GetU32() op input2->GetU32(); \ |
660 | 7.69k | words.push_back(val); \ |
661 | 7.69k | } \ |
662 | 7.69k | static_assert(true, "require extra semicolon") |
663 | 7.69k | switch (opcode) { |
664 | 349 | case spv::Op::OpIMul: |
665 | 349 | FOLD_OP(*); |
666 | 349 | break; |
667 | 0 | case spv::Op::OpSDiv: |
668 | 0 | case spv::Op::OpUDiv: |
669 | 0 | assert(false && "Should not merge integer division"); |
670 | 0 | break; |
671 | 3.23k | case spv::Op::OpIAdd: |
672 | 3.23k | FOLD_OP(+); |
673 | 3.23k | break; |
674 | 2.04k | case spv::Op::OpISub: |
675 | 2.04k | FOLD_OP(-); |
676 | 2.04k | break; |
677 | 8 | case spv::Op::OpBitwiseXor: |
678 | 8 | FOLD_OP(^); |
679 | 8 | break; |
680 | 1.02k | case spv::Op::OpBitwiseOr: |
681 | 1.02k | FOLD_OP(|); |
682 | 1.02k | break; |
683 | 1.03k | case spv::Op::OpBitwiseAnd: |
684 | 1.03k | FOLD_OP(&); |
685 | 1.03k | break; |
686 | 0 | default: |
687 | 0 | assert(false && "Unexpected operation"); |
688 | 0 | break; |
689 | 7.69k | } |
690 | 7.69k | #undef FOLD_OP |
691 | 7.69k | const analysis::Constant* merged_const = const_mgr->GetConstant(type, words); |
692 | 7.69k | return const_mgr->GetDefiningInstruction(merged_const)->result_id(); |
693 | 7.69k | } |
694 | | |
695 | | // Performs |input1| |opcode| |input2| and returns the merged constant result |
696 | | // id. Returns 0 if the result is not a valid value. The input types must be |
697 | | // Integers, Floats or Vectors of such. |
698 | | uint32_t PerformOperation(analysis::ConstantManager* const_mgr, spv::Op opcode, |
699 | | const analysis::Constant* input1, |
700 | 40.3k | const analysis::Constant* input2) { |
701 | 40.3k | assert(input1 && input2); |
702 | 40.3k | const analysis::Type* type = input1->type(); |
703 | 40.3k | std::vector<uint32_t> words; |
704 | 40.3k | if (const analysis::Vector* vector_type = type->AsVector()) { |
705 | 12.3k | const analysis::Type* ele_type = vector_type->element_type(); |
706 | 31.0k | for (uint32_t i = 0; i != vector_type->element_count(); ++i) { |
707 | 22.0k | uint32_t id = 0; |
708 | | |
709 | 22.0k | const analysis::Constant* input1_comp = nullptr; |
710 | 22.0k | if (const analysis::VectorConstant* input1_vector = |
711 | 22.0k | input1->AsVectorConstant()) { |
712 | 22.0k | input1_comp = input1_vector->GetComponents()[i]; |
713 | 22.0k | } else { |
714 | 0 | assert(input1->AsNullConstant()); |
715 | 0 | input1_comp = const_mgr->GetConstant(ele_type, {}); |
716 | 0 | } |
717 | | |
718 | 22.0k | const analysis::Constant* input2_comp = nullptr; |
719 | 22.0k | if (const analysis::VectorConstant* input2_vector = |
720 | 22.0k | input2->AsVectorConstant()) { |
721 | 22.0k | input2_comp = input2_vector->GetComponents()[i]; |
722 | 22.0k | } else { |
723 | 0 | assert(input2->AsNullConstant()); |
724 | 0 | input2_comp = const_mgr->GetConstant(ele_type, {}); |
725 | 0 | } |
726 | | |
727 | 22.0k | if (ele_type->AsFloat()) { |
728 | 22.0k | id = PerformFloatingPointOperation(const_mgr, opcode, input1_comp, |
729 | 22.0k | input2_comp); |
730 | 22.0k | } else { |
731 | 0 | assert(ele_type->AsInteger()); |
732 | 0 | id = PerformIntegerOperation(const_mgr, opcode, input1_comp, |
733 | 0 | input2_comp); |
734 | 0 | } |
735 | 22.0k | if (id == 0) return 0; |
736 | 18.6k | words.push_back(id); |
737 | 18.6k | } |
738 | 9.03k | const analysis::Constant* merged_const = |
739 | 9.03k | const_mgr->GetConstant(type, words); |
740 | 9.03k | return const_mgr->GetDefiningInstruction(merged_const)->result_id(); |
741 | 27.9k | } else if (type->AsFloat()) { |
742 | 20.2k | return PerformFloatingPointOperation(const_mgr, opcode, input1, input2); |
743 | 20.2k | } else { |
744 | 7.69k | assert(type->AsInteger()); |
745 | 7.69k | return PerformIntegerOperation(const_mgr, opcode, input1, input2); |
746 | 7.69k | } |
747 | 40.3k | } |
748 | | |
749 | | // Merges consecutive multiplies where each contains one constant operand. |
750 | | // Cases: |
751 | | // 2 * (x * 2) = x * 4 |
752 | | // 2 * (2 * x) = x * 4 |
753 | | // (x * 2) * 2 = x * 4 |
754 | | // (2 * x) * 2 = x * 4 |
755 | 33.5k | FoldingRule MergeMulMulArithmetic() { |
756 | 33.5k | return [](IRContext* context, Instruction* inst, |
757 | 214k | const std::vector<const analysis::Constant*>& constants) { |
758 | 214k | assert(inst->opcode() == spv::Op::OpFMul || |
759 | 214k | inst->opcode() == spv::Op::OpIMul); |
760 | 214k | analysis::ConstantManager* const_mgr = context->get_constant_mgr(); |
761 | 214k | const analysis::Type* type = |
762 | 214k | context->get_type_mgr()->GetType(inst->type_id()); |
763 | | |
764 | 214k | if (type->IsCooperativeMatrix()) { |
765 | 0 | return false; |
766 | 0 | } |
767 | | |
768 | 214k | if (HasFloatingPoint(type) && !inst->IsFloatingPointFoldingAllowed()) |
769 | 19 | return false; |
770 | | |
771 | 214k | uint32_t width = ElementWidth(type); |
772 | 214k | if (width != 32 && width != 64) return false; |
773 | | |
774 | | // Determine the constant input and the variable input in |inst|. |
775 | 214k | const analysis::Constant* const_input1 = ConstInput(constants); |
776 | 214k | if (!const_input1) return false; |
777 | 152k | Instruction* other_inst = NonConstInput(context, constants[0], inst); |
778 | 152k | if (HasFloatingPoint(type) && !other_inst->IsFloatingPointFoldingAllowed()) |
779 | 29 | return false; |
780 | | |
781 | 152k | if (other_inst->opcode() == inst->opcode()) { |
782 | 5.74k | std::vector<const analysis::Constant*> other_constants = |
783 | 5.74k | const_mgr->GetOperandConstants(other_inst); |
784 | 5.74k | const analysis::Constant* const_input2 = ConstInput(other_constants); |
785 | 5.74k | if (!const_input2) return false; |
786 | | |
787 | 4.07k | bool other_first_is_variable = other_constants[0] == nullptr; |
788 | 4.07k | uint32_t merged_id = PerformOperation(const_mgr, inst->opcode(), |
789 | 4.07k | const_input1, const_input2); |
790 | 4.07k | if (merged_id == 0) return false; |
791 | | |
792 | 1.99k | uint32_t non_const_id = other_first_is_variable |
793 | 1.99k | ? other_inst->GetSingleWordInOperand(0u) |
794 | 1.99k | : other_inst->GetSingleWordInOperand(1u); |
795 | 1.99k | inst->SetInOperands({{SPV_OPERAND_TYPE_ID, {non_const_id}}, |
796 | 1.99k | {SPV_OPERAND_TYPE_ID, {merged_id}}}); |
797 | 1.99k | return true; |
798 | 4.07k | } |
799 | | |
800 | 146k | return false; |
801 | 152k | }; |
802 | 33.5k | } |
803 | | |
804 | | // Merges divides into subsequent multiplies if each instruction contains one |
805 | | // constant operand. Does not support integer operations. |
806 | | // Cases: |
807 | | // 2 * (x / 2) = x * 1 |
808 | | // 2 * (2 / x) = 4 / x |
809 | | // (x / 2) * 2 = x * 1 |
810 | | // (2 / x) * 2 = 4 / x |
811 | | // (y / x) * x = y |
812 | | // x * (y / x) = y |
813 | 16.7k | FoldingRule MergeMulDivArithmetic() { |
814 | 16.7k | return [](IRContext* context, Instruction* inst, |
815 | 199k | const std::vector<const analysis::Constant*>& constants) { |
816 | 199k | assert(inst->opcode() == spv::Op::OpFMul); |
817 | 199k | analysis::ConstantManager* const_mgr = context->get_constant_mgr(); |
818 | 199k | analysis::DefUseManager* def_use_mgr = context->get_def_use_mgr(); |
819 | | |
820 | 199k | const analysis::Type* type = |
821 | 199k | context->get_type_mgr()->GetType(inst->type_id()); |
822 | | |
823 | 199k | if (type->IsCooperativeMatrix()) { |
824 | 0 | return false; |
825 | 0 | } |
826 | | |
827 | 199k | if (!inst->IsFloatingPointFoldingAllowed()) return false; |
828 | | |
829 | 199k | uint32_t width = ElementWidth(type); |
830 | 199k | if (width != 32 && width != 64) return false; |
831 | | |
832 | 596k | for (uint32_t i = 0; i < 2; i++) { |
833 | 398k | uint32_t op_id = inst->GetSingleWordInOperand(i); |
834 | 398k | Instruction* op_inst = def_use_mgr->GetDef(op_id); |
835 | 398k | if (op_inst->opcode() == spv::Op::OpFDiv) { |
836 | 15.1k | if (op_inst->GetSingleWordInOperand(1) == |
837 | 15.1k | inst->GetSingleWordInOperand(1 - i)) { |
838 | 284 | inst->SetOpcode(spv::Op::OpCopyObject); |
839 | 284 | inst->SetInOperands( |
840 | 284 | {{SPV_OPERAND_TYPE_ID, {op_inst->GetSingleWordInOperand(0)}}}); |
841 | 284 | return true; |
842 | 284 | } |
843 | 15.1k | } |
844 | 398k | } |
845 | | |
846 | 198k | const analysis::Constant* const_input1 = ConstInput(constants); |
847 | 198k | if (!const_input1) return false; |
848 | 137k | Instruction* other_inst = NonConstInput(context, constants[0], inst); |
849 | 137k | if (!other_inst->IsFloatingPointFoldingAllowed()) return false; |
850 | | |
851 | 137k | if (other_inst->opcode() == spv::Op::OpFDiv) { |
852 | 2.80k | std::vector<const analysis::Constant*> other_constants = |
853 | 2.80k | const_mgr->GetOperandConstants(other_inst); |
854 | 2.80k | const analysis::Constant* const_input2 = ConstInput(other_constants); |
855 | 2.80k | if (!const_input2 || HasZero(const_input2)) return false; |
856 | | |
857 | 776 | bool other_first_is_variable = other_constants[0] == nullptr; |
858 | | // If the variable value is the second operand of the divide, multiply |
859 | | // the constants together. Otherwise divide the constants. |
860 | 776 | uint32_t merged_id = PerformOperation( |
861 | 776 | const_mgr, |
862 | 776 | other_first_is_variable ? other_inst->opcode() : inst->opcode(), |
863 | 776 | const_input1, const_input2); |
864 | 776 | if (merged_id == 0) return false; |
865 | | |
866 | 334 | uint32_t non_const_id = other_first_is_variable |
867 | 334 | ? other_inst->GetSingleWordInOperand(0u) |
868 | 334 | : other_inst->GetSingleWordInOperand(1u); |
869 | | |
870 | | // If the variable value is on the second operand of the div, then this |
871 | | // operation is a div. Otherwise it should be a multiply. |
872 | 334 | inst->SetOpcode(other_first_is_variable ? inst->opcode() |
873 | 334 | : other_inst->opcode()); |
874 | 334 | if (other_first_is_variable) { |
875 | 27 | inst->SetInOperands({{SPV_OPERAND_TYPE_ID, {non_const_id}}, |
876 | 27 | {SPV_OPERAND_TYPE_ID, {merged_id}}}); |
877 | 307 | } else { |
878 | 307 | inst->SetInOperands({{SPV_OPERAND_TYPE_ID, {merged_id}}, |
879 | 307 | {SPV_OPERAND_TYPE_ID, {non_const_id}}}); |
880 | 307 | } |
881 | 334 | return true; |
882 | 776 | } |
883 | | |
884 | 135k | return false; |
885 | 137k | }; |
886 | 16.7k | } |
887 | | |
888 | | // Merges multiply of constant and negation. |
889 | | // Cases: |
890 | | // (-x) * 2 = x * -2 |
891 | | // 2 * (-x) = x * -2 |
892 | 33.5k | FoldingRule MergeMulNegateArithmetic() { |
893 | 33.5k | return [](IRContext* context, Instruction* inst, |
894 | 212k | const std::vector<const analysis::Constant*>& constants) { |
895 | 212k | assert(inst->opcode() == spv::Op::OpFMul || |
896 | 212k | inst->opcode() == spv::Op::OpIMul); |
897 | 212k | analysis::ConstantManager* const_mgr = context->get_constant_mgr(); |
898 | 212k | const analysis::Type* type = |
899 | 212k | context->get_type_mgr()->GetType(inst->type_id()); |
900 | | |
901 | 212k | if (type->IsCooperativeMatrix()) { |
902 | 0 | return false; |
903 | 0 | } |
904 | | |
905 | 212k | bool uses_float = HasFloatingPoint(type); |
906 | 212k | if (uses_float && !inst->IsFloatingPointFoldingAllowed()) return false; |
907 | | |
908 | 212k | uint32_t width = ElementWidth(type); |
909 | 212k | if (width != 32 && width != 64) return false; |
910 | | |
911 | 212k | const analysis::Constant* const_input1 = ConstInput(constants); |
912 | 212k | if (!const_input1) return false; |
913 | 149k | Instruction* other_inst = NonConstInput(context, constants[0], inst); |
914 | 149k | if (uses_float && !other_inst->IsFloatingPointFoldingAllowed()) |
915 | 29 | return false; |
916 | | |
917 | 149k | if (other_inst->opcode() == spv::Op::OpFNegate || |
918 | 149k | other_inst->opcode() == spv::Op::OpSNegate) { |
919 | 36 | uint32_t neg_id = NegateConstant(const_mgr, const_input1); |
920 | | |
921 | 36 | inst->SetInOperands( |
922 | 36 | {{SPV_OPERAND_TYPE_ID, {other_inst->GetSingleWordInOperand(0u)}}, |
923 | 36 | {SPV_OPERAND_TYPE_ID, {neg_id}}}); |
924 | 36 | return true; |
925 | 36 | } |
926 | | |
927 | 149k | return false; |
928 | 149k | }; |
929 | 33.5k | } |
930 | | |
931 | | // Returns true if |inst| is negation op and is safe to fold. |
932 | 1.92M | static bool IsFoldableNegation(const Instruction* inst) { |
933 | 1.92M | return (inst->opcode() == spv::Op::OpSNegate || |
934 | 1.92M | (inst->opcode() == spv::Op::OpFNegate && |
935 | 803 | inst->IsFloatingPointFoldingAllowed())); |
936 | 1.92M | } |
937 | | |
938 | | // Merges multiplies / divisions of two negations. |
939 | | // Cases: |
940 | | // (-x) * (-y) = x * y |
941 | | // (-x) / (-y) = x / y |
942 | 83.8k | FoldingRule MergeDivMulDoubleNegative() { |
943 | 83.8k | return [](IRContext* context, Instruction* inst, |
944 | 656k | const std::vector<const analysis::Constant*>&) { |
945 | 656k | assert(inst->opcode() == spv::Op::OpFMul || |
946 | 656k | inst->opcode() == spv::Op::OpVectorTimesScalar || |
947 | 656k | inst->opcode() == spv::Op::OpFDiv || |
948 | 656k | inst->opcode() == spv::Op::OpIMul || |
949 | 656k | inst->opcode() == spv::Op::OpSDiv); |
950 | | |
951 | 656k | const analysis::Type* type = |
952 | 656k | context->get_type_mgr()->GetType(inst->type_id()); |
953 | | |
954 | 656k | bool uses_float = HasFloatingPoint(type); |
955 | 656k | if (uses_float && !inst->IsFloatingPointFoldingAllowed()) return false; |
956 | | |
957 | 655k | analysis::DefUseManager* def_use_mgr = context->get_def_use_mgr(); |
958 | 655k | Instruction* lhs = def_use_mgr->GetDef(inst->GetSingleWordInOperand(0)); |
959 | 655k | Instruction* rhs = def_use_mgr->GetDef(inst->GetSingleWordInOperand(1)); |
960 | | |
961 | 655k | if (IsFoldableNegation(lhs) && IsFoldableNegation(rhs)) { |
962 | 12 | inst->SetInOperands( |
963 | 12 | {{SPV_OPERAND_TYPE_ID, {lhs->GetSingleWordInOperand(0u)}}, |
964 | 12 | {SPV_OPERAND_TYPE_ID, {rhs->GetSingleWordInOperand(0u)}}}); |
965 | 12 | return true; |
966 | 12 | } |
967 | 655k | return false; |
968 | 655k | }; |
969 | 83.8k | } |
970 | | |
971 | | // Merges consecutive divides if each instruction contains one constant operand. |
972 | | // Does not support integer division. |
973 | | // Cases: |
974 | | // 2 / (x / 2) = 4 / x |
975 | | // 4 / (2 / x) = 2 * x |
976 | | // (4 / x) / 2 = 2 / x |
977 | | // (x / 2) / 2 = x / 4 |
978 | 16.7k | FoldingRule MergeDivDivArithmetic() { |
979 | 16.7k | return [](IRContext* context, Instruction* inst, |
980 | 99.9k | const std::vector<const analysis::Constant*>& constants) { |
981 | 99.9k | assert(inst->opcode() == spv::Op::OpFDiv); |
982 | 99.9k | analysis::ConstantManager* const_mgr = context->get_constant_mgr(); |
983 | 99.9k | const analysis::Type* type = |
984 | 99.9k | context->get_type_mgr()->GetType(inst->type_id()); |
985 | | |
986 | 99.9k | if (type->IsCooperativeMatrix()) { |
987 | 0 | return false; |
988 | 0 | } |
989 | | |
990 | 99.9k | if (!inst->IsFloatingPointFoldingAllowed()) return false; |
991 | | |
992 | 99.9k | uint32_t width = ElementWidth(type); |
993 | 99.9k | if (width != 32 && width != 64) return false; |
994 | | |
995 | 99.9k | const analysis::Constant* const_input1 = ConstInput(constants); |
996 | 99.9k | if (!const_input1 || HasZero(const_input1)) return false; |
997 | 52.4k | Instruction* other_inst = NonConstInput(context, constants[0], inst); |
998 | 52.4k | if (!other_inst->IsFloatingPointFoldingAllowed()) return false; |
999 | | |
1000 | 52.4k | bool first_is_variable = constants[0] == nullptr; |
1001 | 52.4k | if (other_inst->opcode() == inst->opcode()) { |
1002 | 2.05k | std::vector<const analysis::Constant*> other_constants = |
1003 | 2.05k | const_mgr->GetOperandConstants(other_inst); |
1004 | 2.05k | const analysis::Constant* const_input2 = ConstInput(other_constants); |
1005 | 2.05k | if (!const_input2 || HasZero(const_input2)) return false; |
1006 | | |
1007 | 1.65k | bool other_first_is_variable = other_constants[0] == nullptr; |
1008 | | |
1009 | 1.65k | spv::Op merge_op = inst->opcode(); |
1010 | 1.65k | if (other_first_is_variable) { |
1011 | | // Constants magnify. |
1012 | 883 | merge_op = spv::Op::OpFMul; |
1013 | 883 | } |
1014 | | |
1015 | | // This is an x / (*) case. Swap the inputs. Doesn't harm multiply |
1016 | | // because it is commutative. |
1017 | 1.65k | if (first_is_variable) std::swap(const_input1, const_input2); |
1018 | 1.65k | uint32_t merged_id = |
1019 | 1.65k | PerformOperation(const_mgr, merge_op, const_input1, const_input2); |
1020 | 1.65k | if (merged_id == 0) return false; |
1021 | | |
1022 | 726 | uint32_t non_const_id = other_first_is_variable |
1023 | 726 | ? other_inst->GetSingleWordInOperand(0u) |
1024 | 726 | : other_inst->GetSingleWordInOperand(1u); |
1025 | | |
1026 | 726 | spv::Op op = inst->opcode(); |
1027 | 726 | if (!first_is_variable && !other_first_is_variable) { |
1028 | | // Effectively div of 1/x, so change to multiply. |
1029 | 516 | op = spv::Op::OpFMul; |
1030 | 516 | } |
1031 | | |
1032 | 726 | uint32_t op1 = merged_id; |
1033 | 726 | uint32_t op2 = non_const_id; |
1034 | 726 | if (first_is_variable && other_first_is_variable) std::swap(op1, op2); |
1035 | 726 | inst->SetOpcode(op); |
1036 | 726 | inst->SetInOperands( |
1037 | 726 | {{SPV_OPERAND_TYPE_ID, {op1}}, {SPV_OPERAND_TYPE_ID, {op2}}}); |
1038 | 726 | return true; |
1039 | 1.65k | } |
1040 | | |
1041 | 50.4k | return false; |
1042 | 52.4k | }; |
1043 | 16.7k | } |
1044 | | |
1045 | | // Fold multiplies succeeded by divides where each instruction contains a |
1046 | | // constant operand. Does not support integer divide. |
1047 | | // Cases: |
1048 | | // 4 / (x * 2) = 2 / x |
1049 | | // 4 / (2 * x) = 2 / x |
1050 | | // (x * 4) / 2 = x * 2 |
1051 | | // (4 * x) / 2 = x * 2 |
1052 | | // (x * y) / x = y |
1053 | | // (y * x) / x = y |
1054 | 16.7k | FoldingRule MergeDivMulArithmetic() { |
1055 | 16.7k | return [](IRContext* context, Instruction* inst, |
1056 | 99.2k | const std::vector<const analysis::Constant*>& constants) { |
1057 | 99.2k | assert(inst->opcode() == spv::Op::OpFDiv); |
1058 | 99.2k | analysis::DefUseManager* def_use_mgr = context->get_def_use_mgr(); |
1059 | 99.2k | analysis::ConstantManager* const_mgr = context->get_constant_mgr(); |
1060 | | |
1061 | 99.2k | const analysis::Type* type = |
1062 | 99.2k | context->get_type_mgr()->GetType(inst->type_id()); |
1063 | | |
1064 | 99.2k | if (type->IsCooperativeMatrix()) { |
1065 | 0 | return false; |
1066 | 0 | } |
1067 | | |
1068 | 99.2k | if (!inst->IsFloatingPointFoldingAllowed()) return false; |
1069 | | |
1070 | 99.1k | uint32_t width = ElementWidth(type); |
1071 | 99.1k | if (width != 32 && width != 64) return false; |
1072 | | |
1073 | 99.1k | uint32_t op_id = inst->GetSingleWordInOperand(0); |
1074 | 99.1k | Instruction* op_inst = def_use_mgr->GetDef(op_id); |
1075 | | |
1076 | 99.1k | if (op_inst->opcode() == spv::Op::OpFMul) { |
1077 | 3.31k | for (uint32_t i = 0; i < 2; i++) { |
1078 | 2.36k | if (op_inst->GetSingleWordInOperand(i) == |
1079 | 2.36k | inst->GetSingleWordInOperand(1)) { |
1080 | 242 | inst->SetOpcode(spv::Op::OpCopyObject); |
1081 | 242 | inst->SetInOperands({{SPV_OPERAND_TYPE_ID, |
1082 | 242 | {op_inst->GetSingleWordInOperand(1 - i)}}}); |
1083 | 242 | return true; |
1084 | 242 | } |
1085 | 2.36k | } |
1086 | 1.19k | } |
1087 | | |
1088 | 98.9k | const analysis::Constant* const_input1 = ConstInput(constants); |
1089 | 98.9k | if (!const_input1 || HasZero(const_input1)) return false; |
1090 | 51.5k | Instruction* other_inst = NonConstInput(context, constants[0], inst); |
1091 | 51.5k | if (!other_inst->IsFloatingPointFoldingAllowed()) return false; |
1092 | | |
1093 | 51.5k | bool first_is_variable = constants[0] == nullptr; |
1094 | 51.5k | if (other_inst->opcode() == spv::Op::OpFMul) { |
1095 | 1.37k | std::vector<const analysis::Constant*> other_constants = |
1096 | 1.37k | const_mgr->GetOperandConstants(other_inst); |
1097 | 1.37k | const analysis::Constant* const_input2 = ConstInput(other_constants); |
1098 | 1.37k | if (!const_input2) return false; |
1099 | | |
1100 | 1.01k | bool other_first_is_variable = other_constants[0] == nullptr; |
1101 | | |
1102 | | // This is an x / (*) case. Swap the inputs. |
1103 | 1.01k | if (first_is_variable) std::swap(const_input1, const_input2); |
1104 | 1.01k | uint32_t merged_id = PerformOperation(const_mgr, inst->opcode(), |
1105 | 1.01k | const_input1, const_input2); |
1106 | 1.01k | if (merged_id == 0) return false; |
1107 | | |
1108 | 670 | uint32_t non_const_id = other_first_is_variable |
1109 | 670 | ? other_inst->GetSingleWordInOperand(0u) |
1110 | 670 | : other_inst->GetSingleWordInOperand(1u); |
1111 | | |
1112 | 670 | uint32_t op1 = merged_id; |
1113 | 670 | uint32_t op2 = non_const_id; |
1114 | 670 | if (first_is_variable) std::swap(op1, op2); |
1115 | | |
1116 | | // Convert to multiply |
1117 | 670 | if (first_is_variable) inst->SetOpcode(other_inst->opcode()); |
1118 | 670 | inst->SetInOperands( |
1119 | 670 | {{SPV_OPERAND_TYPE_ID, {op1}}, {SPV_OPERAND_TYPE_ID, {op2}}}); |
1120 | 670 | return true; |
1121 | 1.01k | } |
1122 | | |
1123 | 50.1k | return false; |
1124 | 51.5k | }; |
1125 | 16.7k | } |
1126 | | |
1127 | | // Fold divides of a constant and a negation. |
1128 | | // Cases: |
1129 | | // (-x) / 2 = x / -2 |
1130 | | // 2 / (-x) = -2 / x |
1131 | 16.7k | FoldingRule MergeDivNegateArithmetic() { |
1132 | 16.7k | return [](IRContext* context, Instruction* inst, |
1133 | 98.2k | const std::vector<const analysis::Constant*>& constants) { |
1134 | 98.2k | assert(inst->opcode() == spv::Op::OpFDiv); |
1135 | 98.2k | analysis::ConstantManager* const_mgr = context->get_constant_mgr(); |
1136 | 98.2k | if (!inst->IsFloatingPointFoldingAllowed()) return false; |
1137 | | |
1138 | 98.2k | const analysis::Constant* const_input1 = ConstInput(constants); |
1139 | 98.2k | if (!const_input1) return false; |
1140 | 71.1k | Instruction* other_inst = NonConstInput(context, constants[0], inst); |
1141 | 71.1k | if (!other_inst->IsFloatingPointFoldingAllowed()) return false; |
1142 | | |
1143 | 71.1k | bool first_is_variable = constants[0] == nullptr; |
1144 | 71.1k | if (other_inst->opcode() == spv::Op::OpFNegate) { |
1145 | 40 | uint32_t neg_id = NegateConstant(const_mgr, const_input1); |
1146 | | |
1147 | 40 | if (first_is_variable) { |
1148 | 26 | inst->SetInOperands( |
1149 | 26 | {{SPV_OPERAND_TYPE_ID, {other_inst->GetSingleWordInOperand(0u)}}, |
1150 | 26 | {SPV_OPERAND_TYPE_ID, {neg_id}}}); |
1151 | 26 | } else { |
1152 | 14 | inst->SetInOperands( |
1153 | 14 | {{SPV_OPERAND_TYPE_ID, {neg_id}}, |
1154 | 14 | {SPV_OPERAND_TYPE_ID, {other_inst->GetSingleWordInOperand(0u)}}}); |
1155 | 14 | } |
1156 | 40 | return true; |
1157 | 40 | } |
1158 | | |
1159 | 71.1k | return false; |
1160 | 71.1k | }; |
1161 | 16.7k | } |
1162 | | |
1163 | | // Folds addition, where one side is a negation. |
1164 | | // (-x) + y = y - x |
1165 | | // y + (-x) = y - x |
1166 | 33.5k | FoldingRule MergeAddNegateArithmetic() { |
1167 | 33.5k | return [](IRContext* context, Instruction* inst, |
1168 | 528k | const std::vector<const analysis::Constant*>&) { |
1169 | 528k | assert(inst->opcode() == spv::Op::OpFAdd || |
1170 | 528k | inst->opcode() == spv::Op::OpIAdd); |
1171 | 528k | const analysis::Type* type = |
1172 | 528k | context->get_type_mgr()->GetType(inst->type_id()); |
1173 | 528k | bool uses_float = HasFloatingPoint(type); |
1174 | 528k | if (uses_float && !inst->IsFloatingPointFoldingAllowed()) return false; |
1175 | | |
1176 | 528k | analysis::DefUseManager* def_use_mgr = context->get_def_use_mgr(); |
1177 | 528k | Instruction* lhs = def_use_mgr->GetDef(inst->GetSingleWordInOperand(0)); |
1178 | 528k | Instruction* rhs = def_use_mgr->GetDef(inst->GetSingleWordInOperand(1)); |
1179 | | |
1180 | 528k | auto TrySubstitute = [inst, uses_float](Instruction* first, |
1181 | 1.05M | Instruction* second) { |
1182 | 1.05M | if (IsFoldableNegation(first)) { |
1183 | 55 | inst->SetOpcode(uses_float ? spv::Op::OpFSub : spv::Op::OpISub); |
1184 | 55 | inst->SetInOperands( |
1185 | 55 | {{SPV_OPERAND_TYPE_ID, {second->result_id()}}, |
1186 | 55 | {SPV_OPERAND_TYPE_ID, {first->GetSingleWordInOperand(0u)}}}); |
1187 | 55 | return true; |
1188 | 55 | } |
1189 | 1.05M | return false; |
1190 | 1.05M | }; |
1191 | | |
1192 | 528k | return TrySubstitute(lhs, rhs) || TrySubstitute(rhs, lhs); |
1193 | 528k | }; |
1194 | 33.5k | } |
1195 | | |
1196 | | // Folds subtraction, where one side is a negation. |
1197 | | // Cases: |
1198 | | // (-x) - 2 = -2 - x |
1199 | | // y - (-x) = x + y |
1200 | 33.5k | FoldingRule MergeSubNegateArithmetic() { |
1201 | 33.5k | return [](IRContext* context, Instruction* inst, |
1202 | 144k | const std::vector<const analysis::Constant*>& constants) { |
1203 | 144k | assert(inst->opcode() == spv::Op::OpFSub || |
1204 | 144k | inst->opcode() == spv::Op::OpISub); |
1205 | 144k | const analysis::Type* type = |
1206 | 144k | context->get_type_mgr()->GetType(inst->type_id()); |
1207 | | |
1208 | 144k | bool uses_float = HasFloatingPoint(type); |
1209 | 144k | if (uses_float && !inst->IsFloatingPointFoldingAllowed()) return false; |
1210 | | |
1211 | 140k | analysis::DefUseManager* def_use_mgr = context->get_def_use_mgr(); |
1212 | 140k | Instruction* lhs = def_use_mgr->GetDef(inst->GetSingleWordInOperand(0)); |
1213 | 140k | Instruction* rhs = def_use_mgr->GetDef(inst->GetSingleWordInOperand(1)); |
1214 | | |
1215 | 140k | if (IsFoldableNegation(rhs)) { |
1216 | 573 | inst->SetOpcode(uses_float ? spv::Op::OpFAdd : spv::Op::OpIAdd); |
1217 | 573 | inst->SetInOperands( |
1218 | 573 | {{SPV_OPERAND_TYPE_ID, {lhs->result_id()}}, |
1219 | 573 | {SPV_OPERAND_TYPE_ID, {rhs->GetSingleWordInOperand(0)}}}); |
1220 | 573 | return true; |
1221 | 573 | } |
1222 | | |
1223 | 139k | if (type->IsCooperativeMatrix()) { |
1224 | 0 | return false; |
1225 | 0 | } |
1226 | | |
1227 | 139k | uint32_t width = ElementWidth(type); |
1228 | 139k | if (width != 32 && width != 64) return false; |
1229 | | |
1230 | 139k | if (constants[1] && IsFoldableNegation(lhs)) { |
1231 | 35 | inst->SetInOperands( |
1232 | 35 | {{SPV_OPERAND_TYPE_ID, |
1233 | 35 | {NegateConstant(context->get_constant_mgr(), constants[1])}}, |
1234 | 35 | {SPV_OPERAND_TYPE_ID, {lhs->GetSingleWordInOperand(0)}}}); |
1235 | 35 | return true; |
1236 | 35 | } |
1237 | 139k | return false; |
1238 | 139k | }; |
1239 | 33.5k | } |
1240 | | |
1241 | | // Folds addition of an addition where each operation has a constant operand. |
1242 | | // Cases: |
1243 | | // (x + 2) + 2 = x + 4 |
1244 | | // (2 + x) + 2 = x + 4 |
1245 | | // 2 + (x + 2) = x + 4 |
1246 | | // 2 + (2 + x) = x + 4 |
1247 | 33.5k | FoldingRule MergeAddAddArithmetic() { |
1248 | 33.5k | return [](IRContext* context, Instruction* inst, |
1249 | 528k | const std::vector<const analysis::Constant*>& constants) { |
1250 | 528k | assert(inst->opcode() == spv::Op::OpFAdd || |
1251 | 528k | inst->opcode() == spv::Op::OpIAdd); |
1252 | 528k | const analysis::Type* type = |
1253 | 528k | context->get_type_mgr()->GetType(inst->type_id()); |
1254 | | |
1255 | 528k | if (type->IsCooperativeMatrix()) { |
1256 | 0 | return false; |
1257 | 0 | } |
1258 | | |
1259 | 528k | analysis::ConstantManager* const_mgr = context->get_constant_mgr(); |
1260 | 528k | bool uses_float = HasFloatingPoint(type); |
1261 | 528k | if (uses_float && !inst->IsFloatingPointFoldingAllowed()) return false; |
1262 | | |
1263 | 528k | uint32_t width = ElementWidth(type); |
1264 | 528k | if (width != 32 && width != 64) return false; |
1265 | | |
1266 | 528k | const analysis::Constant* const_input1 = ConstInput(constants); |
1267 | 528k | if (!const_input1) return false; |
1268 | 139k | Instruction* other_inst = NonConstInput(context, constants[0], inst); |
1269 | 139k | if (uses_float && !other_inst->IsFloatingPointFoldingAllowed()) |
1270 | 465 | return false; |
1271 | | |
1272 | 139k | if (other_inst->opcode() == spv::Op::OpFAdd || |
1273 | 126k | other_inst->opcode() == spv::Op::OpIAdd) { |
1274 | 16.0k | std::vector<const analysis::Constant*> other_constants = |
1275 | 16.0k | const_mgr->GetOperandConstants(other_inst); |
1276 | 16.0k | const analysis::Constant* const_input2 = ConstInput(other_constants); |
1277 | 16.0k | if (!const_input2) return false; |
1278 | | |
1279 | 13.6k | Instruction* non_const_input = |
1280 | 13.6k | NonConstInput(context, other_constants[0], other_inst); |
1281 | 13.6k | uint32_t merged_id = PerformOperation(const_mgr, inst->opcode(), |
1282 | 13.6k | const_input1, const_input2); |
1283 | 13.6k | if (merged_id == 0) return false; |
1284 | | |
1285 | 9.93k | inst->SetInOperands( |
1286 | 9.93k | {{SPV_OPERAND_TYPE_ID, {non_const_input->result_id()}}, |
1287 | 9.93k | {SPV_OPERAND_TYPE_ID, {merged_id}}}); |
1288 | 9.93k | return true; |
1289 | 13.6k | } |
1290 | 123k | return false; |
1291 | 139k | }; |
1292 | 33.5k | } |
1293 | | |
1294 | | // Folds addition of a subtraction where each operation has a constant operand. |
1295 | | // Cases: |
1296 | | // (x - 2) + 2 = x + 0 |
1297 | | // (2 - x) + 2 = 4 - x |
1298 | | // 2 + (x - 2) = x + 0 |
1299 | | // 2 + (2 - x) = 4 - x |
1300 | 33.5k | FoldingRule MergeAddSubArithmetic() { |
1301 | 33.5k | return [](IRContext* context, Instruction* inst, |
1302 | 518k | const std::vector<const analysis::Constant*>& constants) { |
1303 | 518k | assert(inst->opcode() == spv::Op::OpFAdd || |
1304 | 518k | inst->opcode() == spv::Op::OpIAdd); |
1305 | 518k | const analysis::Type* type = |
1306 | 518k | context->get_type_mgr()->GetType(inst->type_id()); |
1307 | | |
1308 | 518k | if (type->IsCooperativeMatrix()) { |
1309 | 0 | return false; |
1310 | 0 | } |
1311 | | |
1312 | 518k | analysis::ConstantManager* const_mgr = context->get_constant_mgr(); |
1313 | 518k | bool uses_float = HasFloatingPoint(type); |
1314 | 518k | if (uses_float && !inst->IsFloatingPointFoldingAllowed()) return false; |
1315 | | |
1316 | 518k | uint32_t width = ElementWidth(type); |
1317 | 518k | if (width != 32 && width != 64) return false; |
1318 | | |
1319 | 518k | const analysis::Constant* const_input1 = ConstInput(constants); |
1320 | 518k | if (!const_input1) return false; |
1321 | 129k | Instruction* other_inst = NonConstInput(context, constants[0], inst); |
1322 | 129k | if (uses_float && !other_inst->IsFloatingPointFoldingAllowed()) |
1323 | 465 | return false; |
1324 | | |
1325 | 129k | if (other_inst->opcode() == spv::Op::OpFSub || |
1326 | 127k | other_inst->opcode() == spv::Op::OpISub) { |
1327 | 3.35k | std::vector<const analysis::Constant*> other_constants = |
1328 | 3.35k | const_mgr->GetOperandConstants(other_inst); |
1329 | 3.35k | const analysis::Constant* const_input2 = ConstInput(other_constants); |
1330 | 3.35k | if (!const_input2) return false; |
1331 | | |
1332 | 3.21k | bool first_is_variable = other_constants[0] == nullptr; |
1333 | 3.21k | spv::Op op = inst->opcode(); |
1334 | 3.21k | uint32_t op1 = 0; |
1335 | 3.21k | uint32_t op2 = 0; |
1336 | 3.21k | if (first_is_variable) { |
1337 | | // Subtract constants. Non-constant operand is first. |
1338 | 2.84k | op1 = other_inst->GetSingleWordInOperand(0u); |
1339 | 2.84k | op2 = PerformOperation(const_mgr, other_inst->opcode(), const_input1, |
1340 | 2.84k | const_input2); |
1341 | 2.84k | } else { |
1342 | | // Add constants. Constant operand is first. Change the opcode. |
1343 | 373 | op1 = PerformOperation(const_mgr, inst->opcode(), const_input1, |
1344 | 373 | const_input2); |
1345 | 373 | op2 = other_inst->GetSingleWordInOperand(1u); |
1346 | 373 | op = other_inst->opcode(); |
1347 | 373 | } |
1348 | 3.21k | if (op1 == 0 || op2 == 0) return false; |
1349 | | |
1350 | 2.29k | inst->SetOpcode(op); |
1351 | 2.29k | inst->SetInOperands( |
1352 | 2.29k | {{SPV_OPERAND_TYPE_ID, {op1}}, {SPV_OPERAND_TYPE_ID, {op2}}}); |
1353 | 2.29k | return true; |
1354 | 3.21k | } |
1355 | 126k | return false; |
1356 | 129k | }; |
1357 | 33.5k | } |
1358 | | |
1359 | | // Folds subtraction of an addition where each operand has a constant operand. |
1360 | | // Cases: |
1361 | | // (x + 2) - 2 = x + 0 |
1362 | | // (2 + x) - 2 = x + 0 |
1363 | | // 2 - (x + 2) = 0 - x |
1364 | | // 2 - (2 + x) = 0 - x |
1365 | 33.5k | FoldingRule MergeSubAddArithmetic() { |
1366 | 33.5k | return [](IRContext* context, Instruction* inst, |
1367 | 143k | const std::vector<const analysis::Constant*>& constants) { |
1368 | 143k | assert(inst->opcode() == spv::Op::OpFSub || |
1369 | 143k | inst->opcode() == spv::Op::OpISub); |
1370 | 143k | const analysis::Type* type = |
1371 | 143k | context->get_type_mgr()->GetType(inst->type_id()); |
1372 | | |
1373 | 143k | if (type->IsCooperativeMatrix()) { |
1374 | 0 | return false; |
1375 | 0 | } |
1376 | | |
1377 | 143k | analysis::ConstantManager* const_mgr = context->get_constant_mgr(); |
1378 | 143k | bool uses_float = HasFloatingPoint(type); |
1379 | 143k | if (uses_float && !inst->IsFloatingPointFoldingAllowed()) return false; |
1380 | | |
1381 | 139k | uint32_t width = ElementWidth(type); |
1382 | 139k | if (width != 32 && width != 64) return false; |
1383 | | |
1384 | 139k | const analysis::Constant* const_input1 = ConstInput(constants); |
1385 | 139k | if (!const_input1) return false; |
1386 | 86.7k | Instruction* other_inst = NonConstInput(context, constants[0], inst); |
1387 | 86.7k | if (uses_float && !other_inst->IsFloatingPointFoldingAllowed()) |
1388 | 11 | return false; |
1389 | | |
1390 | 86.7k | if (other_inst->opcode() == spv::Op::OpFAdd || |
1391 | 83.2k | other_inst->opcode() == spv::Op::OpIAdd) { |
1392 | 11.4k | std::vector<const analysis::Constant*> other_constants = |
1393 | 11.4k | const_mgr->GetOperandConstants(other_inst); |
1394 | 11.4k | const analysis::Constant* const_input2 = ConstInput(other_constants); |
1395 | 11.4k | if (!const_input2) return false; |
1396 | | |
1397 | 2.98k | Instruction* non_const_input = |
1398 | 2.98k | NonConstInput(context, other_constants[0], other_inst); |
1399 | | |
1400 | | // If the first operand of the sub is not a constant, swap the constants |
1401 | | // so the subtraction has the correct operands. |
1402 | 2.98k | if (constants[0] == nullptr) std::swap(const_input1, const_input2); |
1403 | | // Subtract the constants. |
1404 | 2.98k | uint32_t merged_id = PerformOperation(const_mgr, inst->opcode(), |
1405 | 2.98k | const_input1, const_input2); |
1406 | 2.98k | spv::Op op = inst->opcode(); |
1407 | 2.98k | uint32_t op1 = 0; |
1408 | 2.98k | uint32_t op2 = 0; |
1409 | 2.98k | if (constants[0] == nullptr) { |
1410 | | // Non-constant operand is first. Change the opcode. |
1411 | 1.71k | op1 = non_const_input->result_id(); |
1412 | 1.71k | op2 = merged_id; |
1413 | 1.71k | op = other_inst->opcode(); |
1414 | 1.71k | } else { |
1415 | | // Constant operand is first. |
1416 | 1.26k | op1 = merged_id; |
1417 | 1.26k | op2 = non_const_input->result_id(); |
1418 | 1.26k | } |
1419 | 2.98k | if (op1 == 0 || op2 == 0) return false; |
1420 | | |
1421 | 2.41k | inst->SetOpcode(op); |
1422 | 2.41k | inst->SetInOperands( |
1423 | 2.41k | {{SPV_OPERAND_TYPE_ID, {op1}}, {SPV_OPERAND_TYPE_ID, {op2}}}); |
1424 | 2.41k | return true; |
1425 | 2.98k | } |
1426 | 75.2k | return false; |
1427 | 86.7k | }; |
1428 | 33.5k | } |
1429 | | |
1430 | | // Folds subtraction of a subtraction where each operand has a constant operand. |
1431 | | // Cases: |
1432 | | // (x - 2) - 2 = x - 4 |
1433 | | // (2 - x) - 2 = 0 - x |
1434 | | // 2 - (x - 2) = 4 - x |
1435 | | // 2 - (2 - x) = x + 0 |
1436 | 33.5k | FoldingRule MergeSubSubArithmetic() { |
1437 | 33.5k | return [](IRContext* context, Instruction* inst, |
1438 | 141k | const std::vector<const analysis::Constant*>& constants) { |
1439 | 141k | assert(inst->opcode() == spv::Op::OpFSub || |
1440 | 141k | inst->opcode() == spv::Op::OpISub); |
1441 | 141k | const analysis::Type* type = |
1442 | 141k | context->get_type_mgr()->GetType(inst->type_id()); |
1443 | | |
1444 | 141k | if (type->IsCooperativeMatrix()) { |
1445 | 0 | return false; |
1446 | 0 | } |
1447 | | |
1448 | 141k | analysis::ConstantManager* const_mgr = context->get_constant_mgr(); |
1449 | 141k | bool uses_float = HasFloatingPoint(type); |
1450 | 141k | if (uses_float && !inst->IsFloatingPointFoldingAllowed()) return false; |
1451 | | |
1452 | 137k | uint32_t width = ElementWidth(type); |
1453 | 137k | if (width != 32 && width != 64) return false; |
1454 | | |
1455 | 137k | const analysis::Constant* const_input1 = ConstInput(constants); |
1456 | 137k | if (!const_input1) return false; |
1457 | 84.3k | Instruction* other_inst = NonConstInput(context, constants[0], inst); |
1458 | 84.3k | if (uses_float && !other_inst->IsFloatingPointFoldingAllowed()) |
1459 | 11 | return false; |
1460 | | |
1461 | 84.2k | if (other_inst->opcode() == spv::Op::OpFSub || |
1462 | 76.6k | other_inst->opcode() == spv::Op::OpISub) { |
1463 | 9.17k | std::vector<const analysis::Constant*> other_constants = |
1464 | 9.17k | const_mgr->GetOperandConstants(other_inst); |
1465 | 9.17k | const analysis::Constant* const_input2 = ConstInput(other_constants); |
1466 | 9.17k | if (!const_input2) return false; |
1467 | | |
1468 | 8.68k | Instruction* non_const_input = |
1469 | 8.68k | NonConstInput(context, other_constants[0], other_inst); |
1470 | | |
1471 | | // Merge the constants. |
1472 | 8.68k | uint32_t merged_id = 0; |
1473 | 8.68k | spv::Op merge_op = inst->opcode(); |
1474 | 8.68k | if (other_constants[0] == nullptr) { |
1475 | 7.18k | merge_op = uses_float ? spv::Op::OpFAdd : spv::Op::OpIAdd; |
1476 | 7.18k | } else if (constants[0] == nullptr) { |
1477 | 402 | std::swap(const_input1, const_input2); |
1478 | 402 | } |
1479 | 8.68k | merged_id = |
1480 | 8.68k | PerformOperation(const_mgr, merge_op, const_input1, const_input2); |
1481 | 8.68k | if (merged_id == 0) return false; |
1482 | | |
1483 | 6.22k | spv::Op op = inst->opcode(); |
1484 | 6.22k | if (constants[0] != nullptr && other_constants[0] != nullptr) { |
1485 | | // Change the operation. |
1486 | 809 | op = uses_float ? spv::Op::OpFAdd : spv::Op::OpIAdd; |
1487 | 809 | } |
1488 | | |
1489 | 6.22k | uint32_t op1 = 0; |
1490 | 6.22k | uint32_t op2 = 0; |
1491 | 6.22k | if ((constants[0] == nullptr) ^ (other_constants[0] == nullptr)) { |
1492 | 572 | op1 = merged_id; |
1493 | 572 | op2 = non_const_input->result_id(); |
1494 | 5.65k | } else { |
1495 | 5.65k | op1 = non_const_input->result_id(); |
1496 | 5.65k | op2 = merged_id; |
1497 | 5.65k | } |
1498 | | |
1499 | 6.22k | inst->SetOpcode(op); |
1500 | 6.22k | inst->SetInOperands( |
1501 | 6.22k | {{SPV_OPERAND_TYPE_ID, {op1}}, {SPV_OPERAND_TYPE_ID, {op2}}}); |
1502 | 6.22k | return true; |
1503 | 8.68k | } |
1504 | 75.1k | return false; |
1505 | 84.2k | }; |
1506 | 33.5k | } |
1507 | | |
1508 | | // Helper function for MergeGenericAddSubArithmetic. If |addend| and |
1509 | | // subtrahend of |sub| is the same, merge to copy of minuend of |sub|. |
1510 | 1.03M | bool MergeGenericAddendSub(uint32_t addend, uint32_t sub, Instruction* inst) { |
1511 | 1.03M | IRContext* context = inst->context(); |
1512 | 1.03M | analysis::DefUseManager* def_use_mgr = context->get_def_use_mgr(); |
1513 | 1.03M | Instruction* sub_inst = def_use_mgr->GetDef(sub); |
1514 | 1.03M | if (sub_inst->opcode() != spv::Op::OpFSub && |
1515 | 1.02M | sub_inst->opcode() != spv::Op::OpISub) |
1516 | 1.02M | return false; |
1517 | 9.64k | if (sub_inst->opcode() == spv::Op::OpFSub && |
1518 | 7.78k | !sub_inst->IsFloatingPointFoldingAllowed()) |
1519 | 0 | return false; |
1520 | 9.64k | if (addend != sub_inst->GetSingleWordInOperand(1)) return false; |
1521 | 1.07k | inst->SetOpcode(spv::Op::OpCopyObject); |
1522 | 1.07k | inst->SetInOperands( |
1523 | 1.07k | {{SPV_OPERAND_TYPE_ID, {sub_inst->GetSingleWordInOperand(0)}}}); |
1524 | 1.07k | context->UpdateDefUse(inst); |
1525 | 1.07k | return true; |
1526 | 9.64k | } |
1527 | | |
1528 | | // Folds addition of a subtraction where the subtrahend is equal to the |
1529 | | // other addend. Return a copy of the minuend. Accepts generic (const and |
1530 | | // non-const) operands. |
1531 | | // Cases: |
1532 | | // (a - b) + b = a |
1533 | | // b + (a - b) = a |
1534 | 33.5k | FoldingRule MergeGenericAddSubArithmetic() { |
1535 | 33.5k | return [](IRContext* context, Instruction* inst, |
1536 | 516k | const std::vector<const analysis::Constant*>&) { |
1537 | 516k | assert(inst->opcode() == spv::Op::OpFAdd || |
1538 | 516k | inst->opcode() == spv::Op::OpIAdd); |
1539 | 516k | const analysis::Type* type = |
1540 | 516k | context->get_type_mgr()->GetType(inst->type_id()); |
1541 | | |
1542 | 516k | if (type->IsCooperativeMatrix()) { |
1543 | 0 | return false; |
1544 | 0 | } |
1545 | | |
1546 | 516k | bool uses_float = HasFloatingPoint(type); |
1547 | 516k | if (uses_float && !inst->IsFloatingPointFoldingAllowed()) return false; |
1548 | | |
1549 | 515k | uint32_t width = ElementWidth(type); |
1550 | 515k | if (width != 32 && width != 64) return false; |
1551 | | |
1552 | 515k | uint32_t add_op0 = inst->GetSingleWordInOperand(0); |
1553 | 515k | uint32_t add_op1 = inst->GetSingleWordInOperand(1); |
1554 | 515k | if (MergeGenericAddendSub(add_op0, add_op1, inst)) return true; |
1555 | 515k | return MergeGenericAddendSub(add_op1, add_op0, inst); |
1556 | 515k | }; |
1557 | 33.5k | } |
1558 | | |
1559 | | // Helper function for FactorAddSubMuls. |
1560 | | // If |factor0_0| is the same as |factor1_0|, generate: |
1561 | | // |factor0_0| * (|factor0_1| + |factor1_1|) |
1562 | | // |factor0_0| * (|factor0_1| - |factor1_1|) |
1563 | | bool FactorAddSubMulsOpnds(uint32_t factor0_0, uint32_t factor0_1, |
1564 | | uint32_t factor1_0, uint32_t factor1_1, |
1565 | 5.37k | Instruction* inst) { |
1566 | 5.37k | IRContext* context = inst->context(); |
1567 | 5.37k | if (factor0_0 != factor1_0) return false; |
1568 | 237 | InstructionBuilder ir_builder( |
1569 | 237 | context, inst, |
1570 | 237 | IRContext::kAnalysisDefUse | IRContext::kAnalysisInstrToBlockMapping); |
1571 | 237 | Instruction* new_add_inst = ir_builder.AddBinaryOp( |
1572 | 237 | inst->type_id(), inst->opcode(), factor0_1, factor1_1); |
1573 | 237 | if (!new_add_inst) { |
1574 | 0 | return false; |
1575 | 0 | } |
1576 | | |
1577 | 237 | bool is_float = |
1578 | 237 | inst->opcode() == spv::Op::OpFAdd || inst->opcode() == spv::Op::OpFSub; |
1579 | 237 | inst->SetOpcode(is_float ? spv::Op::OpFMul : spv::Op::OpIMul); |
1580 | 237 | inst->SetInOperands({{SPV_OPERAND_TYPE_ID, {factor0_0}}, |
1581 | 237 | {SPV_OPERAND_TYPE_ID, {new_add_inst->result_id()}}}); |
1582 | 237 | context->UpdateDefUse(inst); |
1583 | 237 | return true; |
1584 | 237 | } |
1585 | | |
1586 | | // Perform the following factoring identity, handling all operand order |
1587 | | // combinations: |
1588 | | // (a * b) + (a * c) = a * (b + c) |
1589 | | // (a * b) - (a * c) = a * (b - c) |
1590 | 67.1k | FoldingRule FactorAddSubMuls() { |
1591 | 67.1k | return [](IRContext* context, Instruction* inst, |
1592 | 648k | const std::vector<const analysis::Constant*>&) { |
1593 | 648k | assert(inst->opcode() == spv::Op::OpFAdd || |
1594 | 648k | inst->opcode() == spv::Op::OpFSub || |
1595 | 648k | inst->opcode() == spv::Op::OpIAdd || |
1596 | 648k | inst->opcode() == spv::Op::OpISub); |
1597 | 648k | const analysis::Type* type = |
1598 | 648k | context->get_type_mgr()->GetType(inst->type_id()); |
1599 | 648k | bool uses_float = HasFloatingPoint(type); |
1600 | 648k | if (uses_float && !inst->IsFloatingPointFoldingAllowed()) return false; |
1601 | | |
1602 | 644k | analysis::DefUseManager* def_use_mgr = context->get_def_use_mgr(); |
1603 | 644k | uint32_t add_op0 = inst->GetSingleWordInOperand(0); |
1604 | 644k | Instruction* add_op0_inst = def_use_mgr->GetDef(add_op0); |
1605 | 644k | if (add_op0_inst->opcode() != spv::Op::OpFMul && |
1606 | 626k | add_op0_inst->opcode() != spv::Op::OpIMul) |
1607 | 622k | return false; |
1608 | 21.5k | uint32_t add_op1 = inst->GetSingleWordInOperand(1); |
1609 | 21.5k | Instruction* add_op1_inst = def_use_mgr->GetDef(add_op1); |
1610 | 21.5k | if (add_op1_inst->opcode() != spv::Op::OpFMul && |
1611 | 19.3k | add_op1_inst->opcode() != spv::Op::OpIMul) |
1612 | 18.6k | return false; |
1613 | | |
1614 | | // Only perform this optimization if both of the muls only have one use. |
1615 | | // Otherwise this is a deoptimization in size and performance. |
1616 | 2.81k | if (def_use_mgr->NumUses(add_op0_inst) > 1) return false; |
1617 | 1.59k | if (def_use_mgr->NumUses(add_op1_inst) > 1) return false; |
1618 | | |
1619 | 1.46k | if (add_op0_inst->opcode() == spv::Op::OpFMul && |
1620 | 1.38k | (!add_op0_inst->IsFloatingPointFoldingAllowed() || |
1621 | 1.38k | !add_op1_inst->IsFloatingPointFoldingAllowed())) |
1622 | 0 | return false; |
1623 | | |
1624 | 3.95k | for (int i = 0; i < 2; i++) { |
1625 | 7.86k | for (int j = 0; j < 2; j++) { |
1626 | | // Check if operand i in add_op0_inst matches operand j in add_op1_inst. |
1627 | 5.37k | if (FactorAddSubMulsOpnds(add_op0_inst->GetSingleWordInOperand(i), |
1628 | 5.37k | add_op0_inst->GetSingleWordInOperand(1 - i), |
1629 | 5.37k | add_op1_inst->GetSingleWordInOperand(j), |
1630 | 5.37k | add_op1_inst->GetSingleWordInOperand(1 - j), |
1631 | 5.37k | inst)) |
1632 | 237 | return true; |
1633 | 5.37k | } |
1634 | 2.72k | } |
1635 | 1.22k | return false; |
1636 | 1.46k | }; |
1637 | 67.1k | } |
1638 | | |
1639 | | // Reassociate integer instructions where both operands share the same opcode |
1640 | | // and both source instructions contain a constant. |
1641 | | // e.g: |
1642 | | // (a * C0) * (C1 * b) = (C0 * C1) * (a * b) |
1643 | | // (a ^ C0) ^ (b ^ C1) = (C0 ^ C1) ^ (a ^ b) |
1644 | | // (C0 | a) | (b | C1) = (C0 | C1) | (a | b) |
1645 | | // (a & C0) & (b & C1) = (C0 & C1) & (a & b) |
1646 | | static const constexpr spv::Op ReassociateNestedGenericIntOps[] = { |
1647 | | spv::Op::OpIMul, spv::Op::OpBitwiseOr, spv::Op::OpBitwiseXor, |
1648 | | spv::Op::OpBitwiseAnd}; |
1649 | | |
1650 | 67.1k | FoldingRule ReassociateNestedGenericInt(spv::Op opcode) { |
1651 | 67.1k | assert(std::find(std::begin(ReassociateNestedGenericIntOps), |
1652 | 67.1k | std::end(ReassociateNestedGenericIntOps), |
1653 | 67.1k | opcode) != std::end(ReassociateNestedGenericIntOps) && |
1654 | 67.1k | "Wrong opcode."); |
1655 | | |
1656 | 67.1k | return [opcode](IRContext* context, Instruction* inst, |
1657 | 67.1k | const std::vector<const analysis::Constant*>& constants) { |
1658 | | // Handled by other folding rules. |
1659 | 63.3k | if (constants[0] || constants[1]) { |
1660 | 31.8k | return false; |
1661 | 31.8k | } |
1662 | | |
1663 | 31.4k | if (inst->opcode() != opcode) { |
1664 | 0 | return false; |
1665 | 0 | } |
1666 | | |
1667 | 31.4k | const analysis::Type* type = |
1668 | 31.4k | context->get_type_mgr()->GetType(inst->type_id()); |
1669 | | |
1670 | 31.4k | if (type->IsCooperativeMatrix()) { |
1671 | 0 | return false; |
1672 | 0 | } |
1673 | | |
1674 | 31.4k | uint32_t width = ElementWidth(type); |
1675 | 31.4k | if (width != 32 && width != 64) return false; |
1676 | | |
1677 | 31.4k | analysis::DefUseManager* def_use_mgr = context->get_def_use_mgr(); |
1678 | 31.4k | Instruction* lhs = def_use_mgr->GetDef(inst->GetSingleWordInOperand(0)); |
1679 | 31.4k | Instruction* rhs = def_use_mgr->GetDef(inst->GetSingleWordInOperand(1)); |
1680 | | |
1681 | 31.4k | if (lhs->opcode() != opcode || rhs->opcode() != opcode) { |
1682 | 30.2k | return false; |
1683 | 30.2k | } |
1684 | | |
1685 | 1.22k | analysis::ConstantManager* const_mgr = context->get_constant_mgr(); |
1686 | 1.22k | std::vector<const analysis::Constant*> lhs_constants = |
1687 | 1.22k | const_mgr->GetOperandConstants(lhs); |
1688 | 1.22k | const analysis::Constant* lhs_const = ConstInput(lhs_constants); |
1689 | 1.22k | if (!lhs_const) { |
1690 | 899 | return false; |
1691 | 899 | } |
1692 | | |
1693 | 322 | std::vector<const analysis::Constant*> rhs_constants = |
1694 | 322 | const_mgr->GetOperandConstants(rhs); |
1695 | 322 | const analysis::Constant* rhs_const = ConstInput(rhs_constants); |
1696 | 322 | if (!rhs_const) { |
1697 | 117 | return false; |
1698 | 117 | } |
1699 | | |
1700 | 205 | uint32_t merged_constant = |
1701 | 205 | PerformOperation(const_mgr, opcode, lhs_const, rhs_const); |
1702 | 205 | if (!merged_constant) { |
1703 | 0 | return false; |
1704 | 0 | } |
1705 | | |
1706 | 205 | InstructionBuilder ir_builder( |
1707 | 205 | context, inst, |
1708 | 205 | IRContext::kAnalysisDefUse | IRContext::kAnalysisInstrToBlockMapping); |
1709 | | |
1710 | 205 | Instruction* new_rhs = ir_builder.AddBinaryOp( |
1711 | 205 | inst->type_id(), opcode, |
1712 | 205 | NonConstInput(context, lhs_constants[0], lhs)->result_id(), |
1713 | 205 | NonConstInput(context, rhs_constants[0], rhs)->result_id()); |
1714 | | |
1715 | 205 | if (!new_rhs) { |
1716 | 0 | return false; |
1717 | 0 | } |
1718 | | |
1719 | 205 | inst->SetInOperands({{SPV_OPERAND_TYPE_ID, {merged_constant}}, |
1720 | 205 | {SPV_OPERAND_TYPE_ID, {new_rhs->result_id()}}}); |
1721 | 205 | return true; |
1722 | 205 | }; |
1723 | 67.1k | } |
1724 | | |
1725 | | // Reassociate floating point mul/div instructions, which have mul/div inputs, |
1726 | | // both of which contain a constant. |
1727 | | // e.g: |
1728 | | // (a * C0) / (C1 / b) = (C0 / C1) * (a * b) |
1729 | | // (C0 / a) * (b / C1) = (C0 / C1) * (b / a) |
1730 | | // (a / C0) / (b * C1) = (1 / (C0 * C1)) * (a / b) |
1731 | 33.5k | FoldingRule ReassociateNestedMulDivFloat() { |
1732 | 33.5k | return [](IRContext* context, Instruction* inst, |
1733 | 296k | const std::vector<const analysis::Constant*>& constants) { |
1734 | 296k | assert(inst->opcode() == spv::Op::OpFMul || |
1735 | 296k | inst->opcode() == spv::Op::OpFDiv); |
1736 | | |
1737 | | // Handled by other folding rules. |
1738 | 296k | if (constants[0] || constants[1]) { |
1739 | 208k | return false; |
1740 | 208k | } |
1741 | | |
1742 | 87.8k | const analysis::Type* type = |
1743 | 87.8k | context->get_type_mgr()->GetType(inst->type_id()); |
1744 | | |
1745 | 87.8k | if (type->IsCooperativeMatrix()) { |
1746 | 0 | return false; |
1747 | 0 | } |
1748 | | |
1749 | 87.8k | uint32_t width = ElementWidth(type); |
1750 | 87.8k | if (width != 32 && width != 64) return false; |
1751 | | |
1752 | 87.8k | if (!inst->IsFloatingPointFoldingAllowed()) return false; |
1753 | | |
1754 | 87.8k | analysis::DefUseManager* def_use_mgr = context->get_def_use_mgr(); |
1755 | 87.8k | Instruction* lhs = def_use_mgr->GetDef(inst->GetSingleWordInOperand(0)); |
1756 | 87.8k | Instruction* rhs = def_use_mgr->GetDef(inst->GetSingleWordInOperand(1)); |
1757 | | |
1758 | 87.8k | bool lhs_is_mul = lhs->opcode() == spv::Op::OpFMul; |
1759 | 87.8k | bool lhs_is_div = lhs->opcode() == spv::Op::OpFDiv; |
1760 | 87.8k | bool rhs_is_mul = rhs->opcode() == spv::Op::OpFMul; |
1761 | 87.8k | bool rhs_is_div = rhs->opcode() == spv::Op::OpFDiv; |
1762 | 87.8k | if (!(lhs_is_mul || lhs_is_div) || !(rhs_is_mul || rhs_is_div)) { |
1763 | 84.4k | return false; |
1764 | 84.4k | } |
1765 | | |
1766 | 3.38k | if (!lhs->IsFloatingPointFoldingAllowed() || |
1767 | 3.38k | !rhs->IsFloatingPointFoldingAllowed()) { |
1768 | 0 | return false; |
1769 | 0 | } |
1770 | | |
1771 | 3.38k | analysis::ConstantManager* const_mgr = context->get_constant_mgr(); |
1772 | 3.38k | std::vector<const analysis::Constant*> lhs_constants = |
1773 | 3.38k | const_mgr->GetOperandConstants(lhs); |
1774 | 3.38k | if (!lhs_constants[0] && !lhs_constants[1]) { |
1775 | 2.42k | return false; |
1776 | 2.42k | } |
1777 | | |
1778 | 966 | std::vector<const analysis::Constant*> rhs_constants = |
1779 | 966 | const_mgr->GetOperandConstants(rhs); |
1780 | 966 | if (!rhs_constants[0] && !rhs_constants[1]) { |
1781 | 201 | return false; |
1782 | 201 | } |
1783 | | |
1784 | 765 | const analysis::Constant* lhs_const = |
1785 | 765 | lhs_constants[0] ? lhs_constants[0] : lhs_constants[1]; |
1786 | 765 | const analysis::Constant* rhs_const = |
1787 | 765 | rhs_constants[0] ? rhs_constants[0] : rhs_constants[1]; |
1788 | 765 | if (!lhs_const || !rhs_const) return false; |
1789 | | |
1790 | 765 | bool const_lhs_rcp = lhs_constants[0] ? false : lhs_is_div; |
1791 | 765 | bool const_rhs_rcp = rhs_constants[0] ? false : rhs_is_div; |
1792 | | |
1793 | 765 | uint32_t non_const_lhs = lhs_constants[0] ? lhs->GetSingleWordInOperand(1) |
1794 | 765 | : lhs->GetSingleWordInOperand(0); |
1795 | 765 | bool non_const_lhs_rcp = lhs_constants[0] ? lhs_is_div : false; |
1796 | | |
1797 | 765 | uint32_t non_const_rhs = rhs_constants[0] ? rhs->GetSingleWordInOperand(1) |
1798 | 765 | : rhs->GetSingleWordInOperand(0); |
1799 | 765 | bool non_const_rhs_rcp = rhs_constants[0] ? rhs_is_div : false; |
1800 | | |
1801 | | // Rcp the rhs if we're actually dividing it. |
1802 | 765 | if (inst->opcode() == spv::Op::OpFDiv) { |
1803 | 132 | const_rhs_rcp = !const_rhs_rcp; |
1804 | 132 | non_const_rhs_rcp = !non_const_rhs_rcp; |
1805 | 132 | } |
1806 | | |
1807 | 765 | if (const_lhs_rcp) { |
1808 | 19 | lhs_const = |
1809 | 19 | const_mgr->FindDeclaredConstant(Reciprocal(const_mgr, lhs_const)); |
1810 | 19 | if (!lhs_const) { |
1811 | 19 | return false; |
1812 | 19 | } |
1813 | 19 | } |
1814 | 746 | if (const_rhs_rcp) { |
1815 | 126 | rhs_const = |
1816 | 126 | const_mgr->FindDeclaredConstant(Reciprocal(const_mgr, rhs_const)); |
1817 | 126 | if (!rhs_const) { |
1818 | 32 | return false; |
1819 | 32 | } |
1820 | 126 | } |
1821 | | |
1822 | 714 | uint32_t merged_constant = |
1823 | 714 | PerformOperation(const_mgr, spv::Op::OpFMul, lhs_const, rhs_const); |
1824 | | |
1825 | 714 | if (!merged_constant) { |
1826 | 251 | return false; |
1827 | 251 | } |
1828 | | |
1829 | 463 | spv::Op op = spv::Op::OpNop; |
1830 | 463 | Instruction* new_rhs = nullptr; |
1831 | | |
1832 | 463 | InstructionBuilder ir_builder( |
1833 | 463 | context, inst, |
1834 | 463 | IRContext::kAnalysisDefUse | IRContext::kAnalysisInstrToBlockMapping); |
1835 | | |
1836 | | // a * b => C * (b * a) |
1837 | 463 | if (!non_const_lhs_rcp && !non_const_rhs_rcp) { |
1838 | 279 | new_rhs = ir_builder.AddBinaryOp(inst->type_id(), spv::Op::OpFMul, |
1839 | 279 | non_const_lhs, non_const_rhs); |
1840 | 279 | op = spv::Op::OpFMul; |
1841 | 279 | } |
1842 | | // 1/a * b => C * (b / a) |
1843 | 184 | else if (non_const_lhs_rcp && !non_const_rhs_rcp) { |
1844 | 7 | new_rhs = ir_builder.AddBinaryOp(inst->type_id(), spv::Op::OpFDiv, |
1845 | 7 | non_const_rhs, non_const_lhs); |
1846 | 7 | op = spv::Op::OpFMul; |
1847 | 7 | } |
1848 | | // a * 1/b => C * (a / b) |
1849 | 177 | else if (!non_const_lhs_rcp && non_const_rhs_rcp) { |
1850 | 55 | new_rhs = ir_builder.AddBinaryOp(inst->type_id(), spv::Op::OpFDiv, |
1851 | 55 | non_const_lhs, non_const_rhs); |
1852 | 55 | op = spv::Op::OpFMul; |
1853 | 55 | } |
1854 | | // 1/a * 1/b => C / (a * b) |
1855 | 122 | else { |
1856 | 122 | new_rhs = ir_builder.AddBinaryOp(inst->type_id(), spv::Op::OpFMul, |
1857 | 122 | non_const_lhs, non_const_rhs); |
1858 | 122 | op = spv::Op::OpFDiv; |
1859 | 122 | } |
1860 | | |
1861 | 463 | if (!new_rhs) { |
1862 | 0 | return false; |
1863 | 0 | } |
1864 | | |
1865 | 463 | inst->SetOpcode(op); |
1866 | 463 | inst->SetInOperands({{SPV_OPERAND_TYPE_ID, {merged_constant}}, |
1867 | 463 | {SPV_OPERAND_TYPE_ID, {new_rhs->result_id()}}}); |
1868 | 463 | return true; |
1869 | 463 | }; |
1870 | 33.5k | } |
1871 | | |
1872 | | // Reassociate add/sub instructions, which have add/sub inputs, |
1873 | | // both of which contain a constant. |
1874 | | // e.g: |
1875 | | // (a + C0) - (C1 - b) = (C0 - C1) + (a + b) |
1876 | | // (C0 - a) + (b - C1) = (C0 - C1) + (b - a) |
1877 | | // (a - C0) - (b + C1) = (-C0 - C1) + (a - b) |
1878 | 67.1k | FoldingRule ReassociateNestedAddSub() { |
1879 | 67.1k | return [](IRContext* context, Instruction* inst, |
1880 | 650k | const std::vector<const analysis::Constant*>& constants) { |
1881 | 650k | assert(inst->opcode() == spv::Op::OpFAdd || |
1882 | 650k | inst->opcode() == spv::Op::OpIAdd || |
1883 | 650k | inst->opcode() == spv::Op::OpFSub || |
1884 | 650k | inst->opcode() == spv::Op::OpISub); |
1885 | | |
1886 | | // Handled by other folding rules. |
1887 | 650k | if (constants[0] || constants[1]) { |
1888 | 206k | return false; |
1889 | 206k | } |
1890 | | |
1891 | 443k | const analysis::Type* type = |
1892 | 443k | context->get_type_mgr()->GetType(inst->type_id()); |
1893 | | |
1894 | 443k | if (type->IsCooperativeMatrix()) { |
1895 | 0 | return false; |
1896 | 0 | } |
1897 | | |
1898 | 443k | uint32_t width = ElementWidth(type); |
1899 | 443k | if (width != 32 && width != 64) return false; |
1900 | | |
1901 | 443k | bool uses_float = HasFloatingPoint(type); |
1902 | 443k | if (uses_float && !inst->IsFloatingPointFoldingAllowed()) return false; |
1903 | | |
1904 | 440k | analysis::DefUseManager* def_use_mgr = context->get_def_use_mgr(); |
1905 | 440k | Instruction* lhs = def_use_mgr->GetDef(inst->GetSingleWordInOperand(0)); |
1906 | 440k | Instruction* rhs = def_use_mgr->GetDef(inst->GetSingleWordInOperand(1)); |
1907 | | |
1908 | 440k | spv::Op add_op = uses_float ? spv::Op::OpFAdd : spv::Op::OpIAdd; |
1909 | 440k | spv::Op sub_op = uses_float ? spv::Op::OpFSub : spv::Op::OpISub; |
1910 | | |
1911 | 440k | bool lhs_is_add = lhs->opcode() == add_op; |
1912 | 440k | bool lhs_is_sub = lhs->opcode() == sub_op; |
1913 | 440k | bool rhs_is_add = rhs->opcode() == add_op; |
1914 | 440k | bool rhs_is_sub = rhs->opcode() == sub_op; |
1915 | 440k | if (!(lhs_is_add || lhs_is_sub) || !(rhs_is_add || rhs_is_sub)) { |
1916 | 438k | return false; |
1917 | 438k | } |
1918 | | |
1919 | 2.65k | if (uses_float && (!lhs->IsFloatingPointFoldingAllowed() || |
1920 | 2.35k | !rhs->IsFloatingPointFoldingAllowed())) { |
1921 | 0 | return false; |
1922 | 0 | } |
1923 | | |
1924 | 2.65k | analysis::ConstantManager* const_mgr = context->get_constant_mgr(); |
1925 | 2.65k | std::vector<const analysis::Constant*> lhs_constants = |
1926 | 2.65k | const_mgr->GetOperandConstants(lhs); |
1927 | 2.65k | if (!lhs_constants[0] && !lhs_constants[1]) { |
1928 | 1.11k | return false; |
1929 | 1.11k | } |
1930 | | |
1931 | 1.54k | std::vector<const analysis::Constant*> rhs_constants = |
1932 | 1.54k | const_mgr->GetOperandConstants(rhs); |
1933 | 1.54k | if (!rhs_constants[0] && !rhs_constants[1]) { |
1934 | 48 | return false; |
1935 | 48 | } |
1936 | | |
1937 | 1.49k | const analysis::Constant* lhs_const = |
1938 | 1.49k | lhs_constants[0] ? lhs_constants[0] : lhs_constants[1]; |
1939 | 1.49k | const analysis::Constant* rhs_const = |
1940 | 1.49k | rhs_constants[0] ? rhs_constants[0] : rhs_constants[1]; |
1941 | 1.49k | if (!lhs_const || !rhs_const) return false; |
1942 | | |
1943 | 1.49k | bool const_lhs_neg = lhs_constants[0] ? false : lhs_is_sub; |
1944 | 1.49k | bool const_rhs_neg = rhs_constants[0] ? false : rhs_is_sub; |
1945 | | |
1946 | 1.49k | uint32_t non_const_lhs = lhs_constants[0] ? lhs->GetSingleWordInOperand(1) |
1947 | 1.49k | : lhs->GetSingleWordInOperand(0); |
1948 | 1.49k | bool non_const_lhs_neg = lhs_constants[0] ? lhs_is_sub : false; |
1949 | | |
1950 | 1.49k | uint32_t non_const_rhs = rhs_constants[0] ? rhs->GetSingleWordInOperand(1) |
1951 | 1.49k | : rhs->GetSingleWordInOperand(0); |
1952 | 1.49k | bool non_const_rhs_neg = rhs_constants[0] ? rhs_is_sub : false; |
1953 | | |
1954 | | // Negate the rhs if we're actually subtracting it. |
1955 | 1.49k | if (inst->opcode() == spv::Op::OpFSub || |
1956 | 810 | inst->opcode() == spv::Op::OpISub) { |
1957 | 687 | const_rhs_neg = !const_rhs_neg; |
1958 | 687 | non_const_rhs_neg = !non_const_rhs_neg; |
1959 | 687 | } |
1960 | | |
1961 | 1.49k | if (const_lhs_neg) { |
1962 | 489 | lhs_const = |
1963 | 489 | const_mgr->FindDeclaredConstant(NegateConstant(const_mgr, lhs_const)); |
1964 | 489 | if (!lhs_const) { |
1965 | 0 | return false; |
1966 | 0 | } |
1967 | 489 | } |
1968 | 1.49k | if (const_rhs_neg) { |
1969 | 1.19k | rhs_const = |
1970 | 1.19k | const_mgr->FindDeclaredConstant(NegateConstant(const_mgr, rhs_const)); |
1971 | 1.19k | if (!rhs_const) { |
1972 | 0 | return false; |
1973 | 0 | } |
1974 | 1.19k | } |
1975 | | |
1976 | 1.49k | uint32_t merged_constant = |
1977 | 1.49k | PerformOperation(const_mgr, add_op, lhs_const, rhs_const); |
1978 | | |
1979 | 1.49k | if (!merged_constant) { |
1980 | 260 | return false; |
1981 | 260 | } |
1982 | | |
1983 | 1.23k | spv::Op op = spv::Op::OpNop; |
1984 | 1.23k | Instruction* new_rhs = nullptr; |
1985 | | |
1986 | 1.23k | InstructionBuilder ir_builder( |
1987 | 1.23k | context, inst, |
1988 | 1.23k | IRContext::kAnalysisDefUse | IRContext::kAnalysisInstrToBlockMapping); |
1989 | | |
1990 | | // a + b => C + (b + a) |
1991 | 1.23k | if (!non_const_lhs_neg && !non_const_rhs_neg) { |
1992 | 599 | new_rhs = ir_builder.AddBinaryOp(inst->type_id(), add_op, non_const_lhs, |
1993 | 599 | non_const_rhs); |
1994 | 599 | op = add_op; |
1995 | 599 | } |
1996 | | // -a + b => C + (b - a) |
1997 | 638 | else if (non_const_lhs_neg && !non_const_rhs_neg) { |
1998 | 44 | new_rhs = ir_builder.AddBinaryOp(inst->type_id(), sub_op, non_const_rhs, |
1999 | 44 | non_const_lhs); |
2000 | 44 | op = add_op; |
2001 | 44 | } |
2002 | | // a + -b => C + (a - b) |
2003 | 594 | else if (!non_const_lhs_neg && non_const_rhs_neg) { |
2004 | 9 | new_rhs = ir_builder.AddBinaryOp(inst->type_id(), sub_op, non_const_lhs, |
2005 | 9 | non_const_rhs); |
2006 | 9 | op = add_op; |
2007 | 9 | } |
2008 | | // -a + -b => C - (a + b) |
2009 | 585 | else { |
2010 | 585 | new_rhs = ir_builder.AddBinaryOp(inst->type_id(), add_op, non_const_lhs, |
2011 | 585 | non_const_rhs); |
2012 | 585 | op = sub_op; |
2013 | 585 | } |
2014 | | |
2015 | 1.23k | if (!new_rhs) { |
2016 | 0 | return false; |
2017 | 0 | } |
2018 | | |
2019 | 1.23k | inst->SetOpcode(op); |
2020 | 1.23k | inst->SetInOperands({{SPV_OPERAND_TYPE_ID, {merged_constant}}, |
2021 | 1.23k | {SPV_OPERAND_TYPE_ID, {new_rhs->result_id()}}}); |
2022 | 1.23k | return true; |
2023 | 1.23k | }; |
2024 | 67.1k | } |
2025 | | |
2026 | 16.7k | FoldingRule IntMultipleBy1() { |
2027 | 16.7k | return [](IRContext*, Instruction* inst, |
2028 | 16.7k | const std::vector<const analysis::Constant*>& constants) { |
2029 | 14.7k | assert(inst->opcode() == spv::Op::OpIMul && |
2030 | 14.7k | "Wrong opcode. Should be OpIMul."); |
2031 | 43.6k | for (uint32_t i = 0; i < 2; i++) { |
2032 | 29.4k | if (constants[i] == nullptr) { |
2033 | 16.3k | continue; |
2034 | 16.3k | } |
2035 | 13.0k | const analysis::IntConstant* int_constant = constants[i]->AsIntConstant(); |
2036 | 13.0k | if (int_constant) { |
2037 | 12.8k | uint32_t width = ElementWidth(int_constant->type()); |
2038 | 12.8k | if (width != 32 && width != 64) return false; |
2039 | 12.8k | bool is_one = (width == 32) ? int_constant->GetU32BitValue() == 1u |
2040 | 12.8k | : int_constant->GetU64BitValue() == 1ull; |
2041 | 12.8k | if (is_one) { |
2042 | 479 | inst->SetOpcode(spv::Op::OpCopyObject); |
2043 | 479 | inst->SetInOperands( |
2044 | 479 | {{SPV_OPERAND_TYPE_ID, {inst->GetSingleWordInOperand(1 - i)}}}); |
2045 | 479 | return true; |
2046 | 479 | } |
2047 | 12.8k | } |
2048 | 13.0k | } |
2049 | 14.2k | return false; |
2050 | 14.7k | }; |
2051 | 16.7k | } |
2052 | | |
2053 | | // Returns the number of elements that the |index|th in operand in |inst| |
2054 | | // contributes to the result of |inst|. |inst| must be an |
2055 | | // OpCompositeConstructInstruction. |
2056 | | uint32_t GetNumOfElementsContributedByOperand(IRContext* context, |
2057 | | const Instruction* inst, |
2058 | 17.4k | uint32_t index) { |
2059 | 17.4k | assert(inst->opcode() == spv::Op::OpCompositeConstruct); |
2060 | 17.4k | analysis::DefUseManager* def_use_mgr = context->get_def_use_mgr(); |
2061 | 17.4k | analysis::TypeManager* type_mgr = context->get_type_mgr(); |
2062 | | |
2063 | 17.4k | analysis::Vector* result_type = |
2064 | 17.4k | type_mgr->GetType(inst->type_id())->AsVector(); |
2065 | 17.4k | if (result_type == nullptr) { |
2066 | | // If the result of the OpCompositeConstruct is not a vector then every |
2067 | | // operands corresponds to a single element in the result. |
2068 | 0 | return 1; |
2069 | 0 | } |
2070 | | |
2071 | | // If the result type is a vector then the operands are either scalars or |
2072 | | // vectors. If it is a scalar, then it corresponds to a single element. If it |
2073 | | // is a vector, then each element in the vector will be an element in the |
2074 | | // result. |
2075 | 17.4k | uint32_t id = inst->GetSingleWordInOperand(index); |
2076 | 17.4k | Instruction* def = def_use_mgr->GetDef(id); |
2077 | 17.4k | analysis::Vector* type = type_mgr->GetType(def->type_id())->AsVector(); |
2078 | 17.4k | if (type == nullptr) { |
2079 | 17.4k | return 1; |
2080 | 17.4k | } |
2081 | 0 | return type->element_count(); |
2082 | 17.4k | } |
2083 | | |
2084 | | // Returns the in-operands for an OpCompositeExtract instruction that are needed |
2085 | | // to extract the |result_index|th element in the result of |inst| without using |
2086 | | // the result of |inst|. Returns the empty vector if |result_index| is |
2087 | | // out-of-bounds. |inst| must be an |OpCompositeConstruct| instruction. |
2088 | | std::vector<Operand> GetExtractOperandsForElementOfCompositeConstruct( |
2089 | 39.6k | IRContext* context, const Instruction* inst, uint32_t result_index) { |
2090 | 39.6k | assert(inst->opcode() == spv::Op::OpCompositeConstruct); |
2091 | 39.6k | analysis::DefUseManager* def_use_mgr = context->get_def_use_mgr(); |
2092 | 39.6k | analysis::TypeManager* type_mgr = context->get_type_mgr(); |
2093 | | |
2094 | 39.6k | analysis::Type* result_type = type_mgr->GetType(inst->type_id()); |
2095 | 39.6k | if (result_type->AsVector() == nullptr) { |
2096 | 28.7k | if (result_index < inst->NumInOperands()) { |
2097 | 28.7k | uint32_t id = inst->GetSingleWordInOperand(result_index); |
2098 | 28.7k | return {Operand(SPV_OPERAND_TYPE_ID, {id})}; |
2099 | 28.7k | } |
2100 | 0 | return {}; |
2101 | 28.7k | } |
2102 | | |
2103 | | // If the result type is a vector, then vector operands are concatenated. |
2104 | 10.8k | uint32_t total_element_count = 0; |
2105 | 17.4k | for (uint32_t idx = 0; idx < inst->NumInOperands(); ++idx) { |
2106 | 17.4k | uint32_t element_count = |
2107 | 17.4k | GetNumOfElementsContributedByOperand(context, inst, idx); |
2108 | 17.4k | total_element_count += element_count; |
2109 | 17.4k | if (result_index < total_element_count) { |
2110 | 10.8k | std::vector<Operand> operands; |
2111 | 10.8k | uint32_t id = inst->GetSingleWordInOperand(idx); |
2112 | 10.8k | Instruction* operand_def = def_use_mgr->GetDef(id); |
2113 | 10.8k | analysis::Type* operand_type = type_mgr->GetType(operand_def->type_id()); |
2114 | | |
2115 | 10.8k | operands.push_back({SPV_OPERAND_TYPE_ID, {id}}); |
2116 | 10.8k | if (operand_type->AsVector()) { |
2117 | 0 | uint32_t start_index_of_id = total_element_count - element_count; |
2118 | 0 | uint32_t index_into_id = result_index - start_index_of_id; |
2119 | 0 | operands.push_back({SPV_OPERAND_TYPE_LITERAL_INTEGER, {index_into_id}}); |
2120 | 0 | } |
2121 | 10.8k | return operands; |
2122 | 10.8k | } |
2123 | 17.4k | } |
2124 | 0 | return {}; |
2125 | 10.8k | } |
2126 | | |
2127 | | bool CompositeConstructFeedingExtract( |
2128 | | IRContext* context, Instruction* inst, |
2129 | 290k | const std::vector<const analysis::Constant*>&) { |
2130 | | // If the input to an OpCompositeExtract is an OpCompositeConstruct, |
2131 | | // then we can simply use the appropriate element in the construction. |
2132 | 290k | assert(inst->opcode() == spv::Op::OpCompositeExtract && |
2133 | 290k | "Wrong opcode. Should be OpCompositeExtract."); |
2134 | 290k | analysis::DefUseManager* def_use_mgr = context->get_def_use_mgr(); |
2135 | | |
2136 | | // If there are no index operands, then this rule cannot do anything. |
2137 | 290k | if (inst->NumInOperands() <= 1) { |
2138 | 0 | return false; |
2139 | 0 | } |
2140 | | |
2141 | 290k | uint32_t cid = inst->GetSingleWordInOperand(kExtractCompositeIdInIdx); |
2142 | 290k | Instruction* cinst = def_use_mgr->GetDef(cid); |
2143 | | |
2144 | 290k | if (cinst->opcode() != spv::Op::OpCompositeConstruct) { |
2145 | 250k | return false; |
2146 | 250k | } |
2147 | | |
2148 | 39.6k | uint32_t index_into_result = inst->GetSingleWordInOperand(1); |
2149 | 39.6k | std::vector<Operand> operands = |
2150 | 39.6k | GetExtractOperandsForElementOfCompositeConstruct(context, cinst, |
2151 | 39.6k | index_into_result); |
2152 | | |
2153 | 39.6k | if (operands.empty()) { |
2154 | 0 | return false; |
2155 | 0 | } |
2156 | | |
2157 | | // Add the remaining indices for extraction. |
2158 | 39.6k | for (uint32_t i = 2; i < inst->NumInOperands(); ++i) { |
2159 | 34 | operands.push_back( |
2160 | 34 | {SPV_OPERAND_TYPE_LITERAL_INTEGER, {inst->GetSingleWordInOperand(i)}}); |
2161 | 34 | } |
2162 | | |
2163 | 39.6k | if (operands.size() == 1) { |
2164 | | // If there were no extra indices, then we have the final object. No need |
2165 | | // to extract any more. |
2166 | 39.5k | inst->SetOpcode(spv::Op::OpCopyObject); |
2167 | 39.5k | } |
2168 | | |
2169 | 39.6k | inst->SetInOperands(std::move(operands)); |
2170 | 39.6k | return true; |
2171 | 39.6k | } |
2172 | | |
2173 | | // Walks the indexes chain from |start| to |end| of an OpCompositeInsert or |
2174 | | // OpCompositeExtract instruction, and returns the type id of the final element |
2175 | | // being accessed. Returns 0 if a valid type could not be found. |
2176 | | uint32_t GetElementType(uint32_t type_id, Instruction::iterator start, |
2177 | | Instruction::iterator end, |
2178 | 96.0k | const analysis::DefUseManager* def_use_manager) { |
2179 | 96.0k | for (auto index : make_range(std::move(start), std::move(end))) { |
2180 | 1.12k | const Instruction* type_inst = def_use_manager->GetDef(type_id); |
2181 | 1.12k | assert(index.type == SPV_OPERAND_TYPE_LITERAL_INTEGER && |
2182 | 1.12k | index.words.size() == 1); |
2183 | 1.12k | switch (type_inst->opcode()) { |
2184 | 536 | case spv::Op::OpTypeArray: |
2185 | 536 | case spv::Op::OpTypeMatrix: |
2186 | 536 | case spv::Op::OpTypeVector: |
2187 | 536 | case spv::Op::OpTypeVectorIdEXT: |
2188 | 536 | type_id = type_inst->GetSingleWordInOperand(0); |
2189 | 536 | break; |
2190 | 586 | case spv::Op::OpTypeStruct: |
2191 | 586 | type_id = type_inst->GetSingleWordInOperand(index.words[0]); |
2192 | 586 | break; |
2193 | 0 | default: |
2194 | 0 | return 0; |
2195 | 1.12k | } |
2196 | 1.12k | } |
2197 | 96.0k | return type_id; |
2198 | 96.0k | } |
2199 | | |
2200 | | // If the input to an OpCompositeExtract is an OpCopyLogical, then we can |
2201 | | // hoist the extraction before the copy. |
2202 | | bool CopyLogicalFeedingExtract(IRContext* context, Instruction* inst, |
2203 | 245k | const std::vector<const analysis::Constant*>&) { |
2204 | 245k | assert(inst->opcode() == spv::Op::OpCompositeExtract && |
2205 | 245k | "Wrong opcode. Should be OpCompositeExtract."); |
2206 | | |
2207 | 245k | analysis::DefUseManager* def_use_mgr = context->get_def_use_mgr(); |
2208 | 245k | uint32_t cid = inst->GetSingleWordInOperand(kExtractCompositeIdInIdx); |
2209 | 245k | Instruction* cinst = def_use_mgr->GetDef(cid); |
2210 | | |
2211 | 245k | if (cinst->opcode() != spv::Op::OpCopyLogical) { |
2212 | 245k | return false; |
2213 | 245k | } |
2214 | | |
2215 | 0 | uint32_t original_composite_id = cinst->GetSingleWordInOperand(0); |
2216 | 0 | Instruction* original_composite_inst = |
2217 | 0 | def_use_mgr->GetDef(original_composite_id); |
2218 | |
|
2219 | 0 | std::vector<uint32_t> indices; |
2220 | 0 | for (uint32_t i = 1; i < inst->NumInOperands(); ++i) { |
2221 | 0 | indices.push_back(inst->GetSingleWordInOperand(i)); |
2222 | 0 | } |
2223 | |
|
2224 | 0 | uint32_t original_element_type_id = |
2225 | 0 | GetElementType(original_composite_inst->type_id(), inst->begin() + 3, |
2226 | 0 | inst->end(), def_use_mgr); |
2227 | 0 | assert(original_element_type_id != 0 && |
2228 | 0 | "Could not find the element type. Invalid SPIR-V."); |
2229 | | |
2230 | 0 | InstructionBuilder ir_builder( |
2231 | 0 | context, inst, |
2232 | 0 | IRContext::kAnalysisDefUse | IRContext::kAnalysisInstrToBlockMapping); |
2233 | |
|
2234 | 0 | Instruction* new_extract = ir_builder.AddCompositeExtract( |
2235 | 0 | original_element_type_id, original_composite_id, indices); |
2236 | |
|
2237 | 0 | if (original_element_type_id == inst->type_id()) |
2238 | 0 | inst->SetOpcode(spv::Op::OpCopyObject); |
2239 | 0 | else |
2240 | 0 | inst->SetOpcode(spv::Op::OpCopyLogical); |
2241 | 0 | inst->SetInOperands({{SPV_OPERAND_TYPE_ID, {new_extract->result_id()}}}); |
2242 | 0 | return true; |
2243 | 0 | } |
2244 | | |
2245 | | // If the input to an OpCompositeExtract is an OpLoad, we can change the |
2246 | | // load into a load of an OpAccessChain. |
2247 | | bool LoadFeedingExtract(IRContext* context, Instruction* inst, |
2248 | 245k | const std::vector<const analysis::Constant*>&) { |
2249 | 245k | assert(inst->opcode() == spv::Op::OpCompositeExtract && |
2250 | 245k | "Wrong opcode. Should be OpCompositeExtract."); |
2251 | | |
2252 | 245k | analysis::DefUseManager* def_use_mgr = context->get_def_use_mgr(); |
2253 | 245k | uint32_t cid = inst->GetSingleWordInOperand(kExtractCompositeIdInIdx); |
2254 | 245k | Instruction* cinst = def_use_mgr->GetDef(cid); |
2255 | | |
2256 | 245k | if (cinst->opcode() != spv::Op::OpLoad) { |
2257 | 182k | return false; |
2258 | 182k | } |
2259 | | |
2260 | 63.0k | Instruction* composite_type_inst = def_use_mgr->GetDef(cinst->type_id()); |
2261 | 63.0k | if (composite_type_inst->opcode() != spv::Op::OpTypeStruct && |
2262 | 32.9k | composite_type_inst->opcode() != spv::Op::OpTypeArray) { |
2263 | 11.6k | return false; |
2264 | 11.6k | } |
2265 | | |
2266 | | // Check the memory operands. |
2267 | 51.4k | if (cinst->NumInOperands() > 1) { |
2268 | 603 | uint32_t memory_access_mask = cinst->GetSingleWordInOperand(1); |
2269 | 603 | if (memory_access_mask & uint32_t(spv::MemoryAccessMask::Volatile)) { |
2270 | 24 | return false; |
2271 | 24 | } |
2272 | 603 | } |
2273 | | |
2274 | 51.4k | uint32_t ptr_id = cinst->GetSingleWordInOperand(0); |
2275 | 51.4k | Instruction* ptr_inst = def_use_mgr->GetDef(ptr_id); |
2276 | 51.4k | Instruction* ptr_type_inst = def_use_mgr->GetDef(ptr_inst->type_id()); |
2277 | 51.4k | assert(ptr_type_inst->opcode() == spv::Op::OpTypePointer); |
2278 | 51.4k | spv::StorageClass storage_class = |
2279 | 51.4k | static_cast<spv::StorageClass>(ptr_type_inst->GetSingleWordInOperand(0)); |
2280 | | |
2281 | | // If the storage class is Function or Private, we do not want to fold. |
2282 | | // These are the storage classes that the local-access-chain-convert pass |
2283 | | // works on. |
2284 | 51.4k | if (storage_class == spv::StorageClass::Function || |
2285 | 51.4k | storage_class == spv::StorageClass::Private) { |
2286 | 51.4k | return false; |
2287 | 51.4k | } |
2288 | | |
2289 | 0 | analysis::ConstantManager* const_mgr = context->get_constant_mgr(); |
2290 | 0 | analysis::TypeManager* type_mgr = context->get_type_mgr(); |
2291 | 0 | std::vector<uint32_t> index_ids; |
2292 | 0 | for (uint32_t i = 1; i < inst->NumInOperands(); ++i) { |
2293 | 0 | uint32_t index = inst->GetSingleWordInOperand(i); |
2294 | 0 | const analysis::Constant* index_const = |
2295 | 0 | const_mgr->GetConstant(type_mgr->GetUIntType(), {index}); |
2296 | 0 | index_ids.push_back( |
2297 | 0 | const_mgr->GetDefiningInstruction(index_const)->result_id()); |
2298 | 0 | } |
2299 | |
|
2300 | 0 | InstructionBuilder ir_builder( |
2301 | 0 | context, cinst, |
2302 | 0 | IRContext::kAnalysisDefUse | IRContext::kAnalysisInstrToBlockMapping); |
2303 | |
|
2304 | 0 | uint32_t element_ptr_type_id = |
2305 | 0 | type_mgr->FindPointerToType(inst->type_id(), storage_class); |
2306 | 0 | if (element_ptr_type_id == 0) { |
2307 | 0 | return false; |
2308 | 0 | } |
2309 | | |
2310 | 0 | Instruction* access_chain = |
2311 | 0 | ir_builder.AddAccessChain(element_ptr_type_id, ptr_id, index_ids); |
2312 | 0 | std::vector<Operand> load_operands; |
2313 | 0 | load_operands.push_back({SPV_OPERAND_TYPE_ID, {access_chain->result_id()}}); |
2314 | |
|
2315 | 0 | if (cinst->NumInOperands() > 1) { |
2316 | 0 | uint32_t memory_access_mask = cinst->GetSingleWordInOperand(1); |
2317 | 0 | load_operands.push_back( |
2318 | 0 | {SPV_OPERAND_TYPE_MEMORY_ACCESS, {memory_access_mask}}); |
2319 | |
|
2320 | 0 | uint32_t current_operand_index = 2; |
2321 | 0 | if (memory_access_mask & uint32_t(spv::MemoryAccessMask::Aligned)) { |
2322 | 0 | uint32_t original_alignment = |
2323 | 0 | cinst->GetSingleWordInOperand(current_operand_index); |
2324 | |
|
2325 | 0 | std::vector<uint32_t> extract_indices; |
2326 | 0 | for (uint32_t i = 1; i < inst->NumInOperands(); ++i) { |
2327 | 0 | extract_indices.push_back(inst->GetSingleWordInOperand(i)); |
2328 | 0 | } |
2329 | |
|
2330 | 0 | std::optional<uint32_t> offset = |
2331 | 0 | type_mgr->GetType(cinst->type_id())->GetByteOffset(extract_indices); |
2332 | 0 | if (!offset) { |
2333 | 0 | return false; |
2334 | 0 | } |
2335 | | |
2336 | 0 | uint32_t new_alignment = original_alignment; |
2337 | 0 | if (*offset != 0) { |
2338 | 0 | uint32_t offset_alignment = *offset & ~(*offset - 1); |
2339 | 0 | new_alignment = std::min(original_alignment, offset_alignment); |
2340 | 0 | } |
2341 | |
|
2342 | 0 | load_operands.push_back( |
2343 | 0 | {SPV_OPERAND_TYPE_TYPED_LITERAL_NUMBER, {new_alignment}}); |
2344 | 0 | current_operand_index++; |
2345 | 0 | } |
2346 | | |
2347 | | // Copy the remaining operands |
2348 | 0 | for (; current_operand_index < cinst->NumInOperands(); |
2349 | 0 | ++current_operand_index) { |
2350 | 0 | load_operands.push_back(cinst->GetInOperand(current_operand_index)); |
2351 | 0 | } |
2352 | 0 | } |
2353 | | |
2354 | 0 | uint32_t load_result_id = context->TakeNextId(); |
2355 | 0 | if (load_result_id == 0) return false; |
2356 | | |
2357 | 0 | std::unique_ptr<Instruction> new_load_inst( |
2358 | 0 | new Instruction(context, spv::Op::OpLoad, inst->type_id(), load_result_id, |
2359 | 0 | load_operands)); |
2360 | 0 | Instruction* new_load = ir_builder.AddInstruction(std::move(new_load_inst)); |
2361 | |
|
2362 | 0 | inst->SetOpcode(spv::Op::OpCopyObject); |
2363 | 0 | inst->SetInOperands({{SPV_OPERAND_TYPE_ID, {new_load->result_id()}}}); |
2364 | |
|
2365 | 0 | return true; |
2366 | 0 | } |
2367 | | |
2368 | | // Returns true of |inst_1| and |inst_2| have the same indexes that will be used |
2369 | | // to index into a composite object, excluding the last index. The two |
2370 | | // instructions must have the same opcode, and be either OpCompositeExtract or |
2371 | | // OpCompositeInsert instructions. |
2372 | 401k | bool HaveSameIndexesExceptForLast(Instruction* inst_1, Instruction* inst_2) { |
2373 | 401k | assert(inst_1->opcode() == inst_2->opcode() && |
2374 | 401k | "Expecting the opcodes to be the same."); |
2375 | 401k | assert((inst_1->opcode() == spv::Op::OpCompositeInsert || |
2376 | 401k | inst_1->opcode() == spv::Op::OpCompositeExtract) && |
2377 | 401k | "Instructions must be OpCompositeInsert or OpCompositeExtract."); |
2378 | | |
2379 | 401k | if (inst_1->NumInOperands() != inst_2->NumInOperands()) { |
2380 | 2.08k | return false; |
2381 | 2.08k | } |
2382 | | |
2383 | 399k | uint32_t first_index_position = |
2384 | 399k | (inst_1->opcode() == spv::Op::OpCompositeInsert ? 2 : 1); |
2385 | 409k | for (uint32_t i = first_index_position; i < inst_1->NumInOperands() - 1; |
2386 | 399k | i++) { |
2387 | 10.1k | if (inst_1->GetSingleWordInOperand(i) != |
2388 | 10.1k | inst_2->GetSingleWordInOperand(i)) { |
2389 | 492 | return false; |
2390 | 492 | } |
2391 | 10.1k | } |
2392 | 398k | return true; |
2393 | 399k | } |
2394 | | |
2395 | | // If the OpCompositeConstruct is simply putting back together elements that |
2396 | | // where extracted from the same source, we can simply reuse the source. |
2397 | | // |
2398 | | // This is a common code pattern because of the way that scalar replacement |
2399 | | // works. |
2400 | | bool CompositeExtractFeedingConstruct( |
2401 | | IRContext* context, Instruction* inst, |
2402 | 97.5k | const std::vector<const analysis::Constant*>&) { |
2403 | 97.5k | assert(inst->opcode() == spv::Op::OpCompositeConstruct && |
2404 | 97.5k | "Wrong opcode. Should be OpCompositeConstruct."); |
2405 | 97.5k | analysis::DefUseManager* def_use_mgr = context->get_def_use_mgr(); |
2406 | 97.5k | uint32_t original_id = 0; |
2407 | | |
2408 | 97.5k | if (inst->NumInOperands() == 0) { |
2409 | | // The struct being constructed has no members. |
2410 | 0 | return false; |
2411 | 0 | } |
2412 | | |
2413 | | // Check each element to make sure they are: |
2414 | | // - extractions |
2415 | | // - extracting the same position they are inserting |
2416 | | // - all extract from the same id. |
2417 | 97.5k | Instruction* first_element_inst = nullptr; |
2418 | 126k | for (uint32_t i = 0; i < inst->NumInOperands(); ++i) { |
2419 | 124k | const uint32_t element_id = inst->GetSingleWordInOperand(i); |
2420 | 124k | Instruction* element_inst = def_use_mgr->GetDef(element_id); |
2421 | 124k | if (first_element_inst == nullptr) { |
2422 | 97.5k | first_element_inst = element_inst; |
2423 | 97.5k | } |
2424 | | |
2425 | 124k | if (element_inst->opcode() != spv::Op::OpCompositeExtract) { |
2426 | 92.0k | return false; |
2427 | 92.0k | } |
2428 | | |
2429 | 32.6k | if (!HaveSameIndexesExceptForLast(element_inst, first_element_inst)) { |
2430 | 0 | return false; |
2431 | 0 | } |
2432 | | |
2433 | 32.6k | if (element_inst->GetSingleWordInOperand(element_inst->NumInOperands() - |
2434 | 32.6k | 1) != i) { |
2435 | 3.65k | return false; |
2436 | 3.65k | } |
2437 | | |
2438 | 29.0k | if (i == 0) { |
2439 | 9.86k | original_id = |
2440 | 9.86k | element_inst->GetSingleWordInOperand(kExtractCompositeIdInIdx); |
2441 | 19.1k | } else if (original_id != |
2442 | 19.1k | element_inst->GetSingleWordInOperand(kExtractCompositeIdInIdx)) { |
2443 | 438 | return false; |
2444 | 438 | } |
2445 | 29.0k | } |
2446 | 97.5k | assert(first_element_inst != nullptr); |
2447 | | |
2448 | | // The last check it to see that the object being extracted from is the |
2449 | | // correct type. |
2450 | 1.40k | Instruction* original_inst = def_use_mgr->GetDef(original_id); |
2451 | 1.40k | uint32_t original_type_id = |
2452 | 1.40k | GetElementType(original_inst->type_id(), first_element_inst->begin() + 3, |
2453 | 1.40k | first_element_inst->end() - 1, def_use_mgr); |
2454 | | |
2455 | 1.40k | if (inst->type_id() != original_type_id) { |
2456 | 191 | return false; |
2457 | 191 | } |
2458 | | |
2459 | 1.21k | if (first_element_inst->NumInOperands() == 2) { |
2460 | | // Simplify by using the original object. |
2461 | 1.21k | inst->SetOpcode(spv::Op::OpCopyObject); |
2462 | 1.21k | inst->SetInOperands({{SPV_OPERAND_TYPE_ID, {original_id}}}); |
2463 | 1.21k | return true; |
2464 | 1.21k | } |
2465 | | |
2466 | | // Copies the original id and all indexes except for the last to the new |
2467 | | // extract instruction. |
2468 | 0 | inst->SetOpcode(spv::Op::OpCompositeExtract); |
2469 | 0 | inst->SetInOperands(std::vector<Operand>(first_element_inst->begin() + 2, |
2470 | 0 | first_element_inst->end() - 1)); |
2471 | 0 | return true; |
2472 | 1.21k | } |
2473 | | |
2474 | 16.7k | FoldingRule InsertFeedingExtract() { |
2475 | 16.7k | return [](IRContext* context, Instruction* inst, |
2476 | 378k | const std::vector<const analysis::Constant*>&) { |
2477 | 378k | assert(inst->opcode() == spv::Op::OpCompositeExtract && |
2478 | 378k | "Wrong opcode. Should be OpCompositeExtract."); |
2479 | 378k | analysis::DefUseManager* def_use_mgr = context->get_def_use_mgr(); |
2480 | 378k | uint32_t cid = inst->GetSingleWordInOperand(kExtractCompositeIdInIdx); |
2481 | 378k | Instruction* cinst = def_use_mgr->GetDef(cid); |
2482 | | |
2483 | 378k | if (cinst->opcode() != spv::Op::OpCompositeInsert) { |
2484 | 289k | return false; |
2485 | 289k | } |
2486 | | |
2487 | | // Find the first position where the list of insert and extract indicies |
2488 | | // differ, if at all. |
2489 | 88.5k | uint32_t i; |
2490 | 108k | for (i = 1; i < inst->NumInOperands(); ++i) { |
2491 | 89.1k | if (i + 1 >= cinst->NumInOperands()) { |
2492 | 0 | break; |
2493 | 0 | } |
2494 | | |
2495 | 89.1k | if (inst->GetSingleWordInOperand(i) != |
2496 | 89.1k | cinst->GetSingleWordInOperand(i + 1)) { |
2497 | 69.3k | break; |
2498 | 69.3k | } |
2499 | 89.1k | } |
2500 | | |
2501 | | // We are extracting the element that was inserted. |
2502 | 88.5k | if (i == inst->NumInOperands() && i + 1 == cinst->NumInOperands()) { |
2503 | 18.8k | inst->SetOpcode(spv::Op::OpCopyObject); |
2504 | 18.8k | inst->SetInOperands( |
2505 | 18.8k | {{SPV_OPERAND_TYPE_ID, |
2506 | 18.8k | {cinst->GetSingleWordInOperand(kInsertObjectIdInIdx)}}}); |
2507 | 18.8k | return true; |
2508 | 18.8k | } |
2509 | | |
2510 | | // Extracting the value that was inserted along with values for the base |
2511 | | // composite. Cannot do anything. |
2512 | 69.6k | if (i == inst->NumInOperands()) { |
2513 | 357 | return false; |
2514 | 357 | } |
2515 | | |
2516 | | // Extracting an element of the value that was inserted. Extract from |
2517 | | // that value directly. |
2518 | 69.3k | if (i + 1 == cinst->NumInOperands()) { |
2519 | 0 | std::vector<Operand> operands; |
2520 | 0 | operands.push_back( |
2521 | 0 | {SPV_OPERAND_TYPE_ID, |
2522 | 0 | {cinst->GetSingleWordInOperand(kInsertObjectIdInIdx)}}); |
2523 | 0 | for (; i < inst->NumInOperands(); ++i) { |
2524 | 0 | operands.push_back({SPV_OPERAND_TYPE_LITERAL_INTEGER, |
2525 | 0 | {inst->GetSingleWordInOperand(i)}}); |
2526 | 0 | } |
2527 | 0 | inst->SetInOperands(std::move(operands)); |
2528 | 0 | return true; |
2529 | 0 | } |
2530 | | |
2531 | | // Extracting a value that is disjoint from the element being inserted. |
2532 | | // Rewrite the extract to use the composite input to the insert. |
2533 | 69.3k | std::vector<Operand> operands; |
2534 | 69.3k | operands.push_back( |
2535 | 69.3k | {SPV_OPERAND_TYPE_ID, |
2536 | 69.3k | {cinst->GetSingleWordInOperand(kInsertCompositeIdInIdx)}}); |
2537 | 139k | for (i = 1; i < inst->NumInOperands(); ++i) { |
2538 | 69.8k | operands.push_back({SPV_OPERAND_TYPE_LITERAL_INTEGER, |
2539 | 69.8k | {inst->GetSingleWordInOperand(i)}}); |
2540 | 69.8k | } |
2541 | 69.3k | inst->SetInOperands(std::move(operands)); |
2542 | 69.3k | return true; |
2543 | 69.3k | }; |
2544 | 16.7k | } |
2545 | | |
2546 | | // When a VectorShuffle is feeding an Extract, we can extract from one of the |
2547 | | // operands of the VectorShuffle. We just need to adjust the index in the |
2548 | | // extract instruction. |
2549 | 16.7k | FoldingRule VectorShuffleFeedingExtract() { |
2550 | 16.7k | return [](IRContext* context, Instruction* inst, |
2551 | 250k | const std::vector<const analysis::Constant*>&) { |
2552 | 250k | assert(inst->opcode() == spv::Op::OpCompositeExtract && |
2553 | 250k | "Wrong opcode. Should be OpCompositeExtract."); |
2554 | 250k | analysis::DefUseManager* def_use_mgr = context->get_def_use_mgr(); |
2555 | 250k | analysis::TypeManager* type_mgr = context->get_type_mgr(); |
2556 | 250k | uint32_t cid = inst->GetSingleWordInOperand(kExtractCompositeIdInIdx); |
2557 | 250k | Instruction* cinst = def_use_mgr->GetDef(cid); |
2558 | | |
2559 | 250k | if (cinst->opcode() != spv::Op::OpVectorShuffle) { |
2560 | 246k | return false; |
2561 | 246k | } |
2562 | | |
2563 | | // Find the size of the first vector operand of the VectorShuffle |
2564 | 4.64k | Instruction* first_input = |
2565 | 4.64k | def_use_mgr->GetDef(cinst->GetSingleWordInOperand(0)); |
2566 | 4.64k | analysis::Type* first_input_type = |
2567 | 4.64k | type_mgr->GetType(first_input->type_id()); |
2568 | 4.64k | assert(first_input_type->AsVector() && |
2569 | 4.64k | "Input to vector shuffle should be vectors."); |
2570 | 4.64k | uint32_t first_input_size = first_input_type->AsVector()->element_count(); |
2571 | | |
2572 | | // Get index of the element the vector shuffle is placing in the position |
2573 | | // being extracted. |
2574 | 4.64k | uint32_t new_index = |
2575 | 4.64k | cinst->GetSingleWordInOperand(2 + inst->GetSingleWordInOperand(1)); |
2576 | | |
2577 | | // Extracting an undefined value so fold this extract into an undef. |
2578 | 4.64k | const uint32_t undef_literal_value = 0xffffffff; |
2579 | 4.64k | if (new_index == undef_literal_value) { |
2580 | 187 | inst->SetOpcode(spv::Op::OpUndef); |
2581 | 187 | inst->SetInOperands({}); |
2582 | 187 | return true; |
2583 | 187 | } |
2584 | | |
2585 | | // Get the id of the of the vector the elemtent comes from, and update the |
2586 | | // index if needed. |
2587 | 4.46k | uint32_t new_vector = 0; |
2588 | 4.46k | if (new_index < first_input_size) { |
2589 | 2.76k | new_vector = cinst->GetSingleWordInOperand(0); |
2590 | 2.76k | } else { |
2591 | 1.69k | new_vector = cinst->GetSingleWordInOperand(1); |
2592 | 1.69k | new_index -= first_input_size; |
2593 | 1.69k | } |
2594 | | |
2595 | | // Update the extract instruction. |
2596 | 4.46k | inst->SetInOperand(kExtractCompositeIdInIdx, {new_vector}); |
2597 | 4.46k | inst->SetInOperand(1, {new_index}); |
2598 | 4.46k | return true; |
2599 | 4.64k | }; |
2600 | 16.7k | } |
2601 | | |
2602 | | // When an FMix with is feeding an Extract that extracts an element whose |
2603 | | // corresponding |a| in the FMix is 0 or 1, we can extract from one of the |
2604 | | // operands of the FMix. |
2605 | 16.7k | FoldingRule FMixFeedingExtract() { |
2606 | 16.7k | return [](IRContext* context, Instruction* inst, |
2607 | 246k | const std::vector<const analysis::Constant*>&) { |
2608 | 246k | assert(inst->opcode() == spv::Op::OpCompositeExtract && |
2609 | 246k | "Wrong opcode. Should be OpCompositeExtract."); |
2610 | 246k | analysis::DefUseManager* def_use_mgr = context->get_def_use_mgr(); |
2611 | 246k | analysis::ConstantManager* const_mgr = context->get_constant_mgr(); |
2612 | | |
2613 | 246k | uint32_t composite_id = |
2614 | 246k | inst->GetSingleWordInOperand(kExtractCompositeIdInIdx); |
2615 | 246k | Instruction* composite_inst = def_use_mgr->GetDef(composite_id); |
2616 | | |
2617 | 246k | if (composite_inst->opcode() != spv::Op::OpExtInst) { |
2618 | 225k | return false; |
2619 | 225k | } |
2620 | | |
2621 | 20.4k | uint32_t inst_set_id = |
2622 | 20.4k | context->get_feature_mgr()->GetExtInstImportId_GLSLstd450(); |
2623 | | |
2624 | 20.4k | if (composite_inst->GetSingleWordInOperand(kExtInstSetIdInIdx) != |
2625 | 20.4k | inst_set_id || |
2626 | 20.4k | composite_inst->GetSingleWordInOperand(kExtInstInstructionInIdx) != |
2627 | 20.4k | GLSLstd450FMix) { |
2628 | 15.3k | return false; |
2629 | 15.3k | } |
2630 | | |
2631 | | // Get the |a| for the FMix instruction. |
2632 | 5.03k | uint32_t a_id = composite_inst->GetSingleWordInOperand(kFMixAIdInIdx); |
2633 | 5.03k | std::unique_ptr<Instruction> a(inst->Clone(context)); |
2634 | 5.03k | a->SetInOperand(kExtractCompositeIdInIdx, {a_id}); |
2635 | 5.03k | context->get_instruction_folder().FoldInstruction(a.get()); |
2636 | | |
2637 | 5.03k | if (a->opcode() != spv::Op::OpCopyObject) { |
2638 | 1.31k | return false; |
2639 | 1.31k | } |
2640 | | |
2641 | 3.71k | const analysis::Constant* a_const = |
2642 | 3.71k | const_mgr->FindDeclaredConstant(a->GetSingleWordInOperand(0)); |
2643 | | |
2644 | 3.71k | if (!a_const) { |
2645 | 1.86k | return false; |
2646 | 1.86k | } |
2647 | | |
2648 | 1.84k | bool use_x = false; |
2649 | | |
2650 | 1.84k | assert(a_const->type()->AsFloat()); |
2651 | | |
2652 | 1.84k | const analysis::Type* type = |
2653 | 1.84k | context->get_type_mgr()->GetType(inst->type_id()); |
2654 | 1.84k | uint32_t width = ElementWidth(type); |
2655 | 1.84k | if (width != 32 && width != 64) { |
2656 | | // We won't support folding half float values. |
2657 | 0 | return false; |
2658 | 0 | } |
2659 | | |
2660 | 1.84k | double element_value = a_const->GetValueAsDouble(); |
2661 | 1.84k | if (element_value == 0.0) { |
2662 | 67 | use_x = true; |
2663 | 1.77k | } else if (element_value == 1.0) { |
2664 | 28 | use_x = false; |
2665 | 1.75k | } else { |
2666 | 1.75k | return false; |
2667 | 1.75k | } |
2668 | | |
2669 | | // Get the id of the of the vector the element comes from. |
2670 | 95 | uint32_t new_vector = 0; |
2671 | 95 | if (use_x) { |
2672 | 67 | new_vector = composite_inst->GetSingleWordInOperand(kFMixXIdInIdx); |
2673 | 67 | } else { |
2674 | 28 | new_vector = composite_inst->GetSingleWordInOperand(kFMixYIdInIdx); |
2675 | 28 | } |
2676 | | |
2677 | | // Update the extract instruction. |
2678 | 95 | inst->SetInOperand(kExtractCompositeIdInIdx, {new_vector}); |
2679 | 95 | return true; |
2680 | 1.84k | }; |
2681 | 16.7k | } |
2682 | | |
2683 | | // Returns the number of elements in the composite type |type|. Returns 0 if |
2684 | | // |type| is a scalar value. Return UINT32_MAX when the size is unknown at |
2685 | | // compile time. |
2686 | 94.6k | uint32_t GetNumberOfElements(const analysis::Type* type) { |
2687 | 94.6k | if (auto* vector_type = type->AsVector()) { |
2688 | 86.2k | return vector_type->element_count(); |
2689 | 86.2k | } |
2690 | 8.46k | if (auto* matrix_type = type->AsMatrix()) { |
2691 | 0 | return matrix_type->element_count(); |
2692 | 0 | } |
2693 | 8.46k | if (auto* struct_type = type->AsStruct()) { |
2694 | 3.61k | return static_cast<uint32_t>(struct_type->element_types().size()); |
2695 | 3.61k | } |
2696 | 4.84k | if (auto* array_type = type->AsArray()) { |
2697 | 4.84k | if (array_type->length_info().words[0] == |
2698 | 4.84k | analysis::Array::LengthInfo::kConstant && |
2699 | 4.84k | array_type->length_info().words.size() == 2) { |
2700 | 4.84k | return array_type->length_info().words[1]; |
2701 | 4.84k | } |
2702 | 0 | return UINT32_MAX; |
2703 | 4.84k | } |
2704 | 0 | return 0; |
2705 | 4.84k | } |
2706 | | |
2707 | | // Returns a map with the set of values that were inserted into an object by |
2708 | | // the chain of OpCompositeInsertInstruction starting with |inst|. |
2709 | | // The map will map the index to the value inserted at that index. An empty map |
2710 | | // will be returned if the map could not be properly generated. |
2711 | 94.6k | std::map<uint32_t, uint32_t> GetInsertedValues(Instruction* inst) { |
2712 | 94.6k | analysis::DefUseManager* def_use_mgr = inst->context()->get_def_use_mgr(); |
2713 | 94.6k | std::map<uint32_t, uint32_t> values_inserted; |
2714 | 94.6k | Instruction* current_inst = inst; |
2715 | 463k | while (current_inst->opcode() == spv::Op::OpCompositeInsert) { |
2716 | 369k | if (current_inst->NumInOperands() > inst->NumInOperands()) { |
2717 | | // This is to catch the case |
2718 | | // %2 = OpCompositeInsert %m2x2int %v2int_1_0 %m2x2int_undef 0 |
2719 | | // %3 = OpCompositeInsert %m2x2int %int_4 %2 0 0 |
2720 | | // %4 = OpCompositeInsert %m2x2int %v2int_2_3 %3 1 |
2721 | | // In this case we cannot do a single construct to get the matrix. |
2722 | 1.05k | uint32_t partially_inserted_element_index = |
2723 | 1.05k | current_inst->GetSingleWordInOperand(inst->NumInOperands() - 1); |
2724 | 1.05k | if (values_inserted.count(partially_inserted_element_index) == 0) |
2725 | 231 | return {}; |
2726 | 1.05k | } |
2727 | 368k | if (HaveSameIndexesExceptForLast(inst, current_inst)) { |
2728 | 366k | values_inserted.insert( |
2729 | 366k | {current_inst->GetSingleWordInOperand(current_inst->NumInOperands() - |
2730 | 366k | 1), |
2731 | 366k | current_inst->GetSingleWordInOperand(kInsertObjectIdInIdx)}); |
2732 | 366k | } |
2733 | 368k | current_inst = def_use_mgr->GetDef( |
2734 | 368k | current_inst->GetSingleWordInOperand(kInsertCompositeIdInIdx)); |
2735 | 368k | } |
2736 | 94.4k | return values_inserted; |
2737 | 94.6k | } |
2738 | | |
2739 | | // Returns true of there is an entry in |values_inserted| for every element of |
2740 | | // |Type|. |
2741 | | bool DoInsertedValuesCoverEntireObject( |
2742 | 94.6k | const analysis::Type* type, std::map<uint32_t, uint32_t>& values_inserted) { |
2743 | 94.6k | uint32_t container_size = GetNumberOfElements(type); |
2744 | 94.6k | if (container_size != values_inserted.size()) { |
2745 | 83.6k | return false; |
2746 | 83.6k | } |
2747 | | |
2748 | 11.0k | if (values_inserted.rbegin()->first >= container_size) { |
2749 | 0 | return false; |
2750 | 0 | } |
2751 | 11.0k | return true; |
2752 | 11.0k | } |
2753 | | |
2754 | | // Returns id of the type of the element that immediately contains the element |
2755 | | // being inserted by the OpCompositeInsert instruction |inst|. Returns 0 if it |
2756 | | // could not be found. |
2757 | 94.6k | uint32_t GetContainerTypeId(Instruction* inst) { |
2758 | 94.6k | assert(inst->opcode() == spv::Op::OpCompositeInsert); |
2759 | 94.6k | analysis::DefUseManager* def_use_manager = inst->context()->get_def_use_mgr(); |
2760 | 94.6k | uint32_t container_type_id = GetElementType( |
2761 | 94.6k | inst->type_id(), inst->begin() + 4, inst->end() - 1, def_use_manager); |
2762 | 94.6k | return container_type_id; |
2763 | 94.6k | } |
2764 | | |
2765 | | // Returns an OpCompositeConstruct instruction that build an object with |
2766 | | // |type_id| out of the values in |values_inserted|. Each value will be |
2767 | | // placed at the index corresponding to the value. The new instruction will |
2768 | | // be placed before |insert_before|. |
2769 | | Instruction* BuildCompositeConstruct( |
2770 | | uint32_t type_id, const std::map<uint32_t, uint32_t>& values_inserted, |
2771 | 11.0k | Instruction* insert_before) { |
2772 | 11.0k | InstructionBuilder ir_builder( |
2773 | 11.0k | insert_before->context(), insert_before, |
2774 | 11.0k | IRContext::kAnalysisDefUse | IRContext::kAnalysisInstrToBlockMapping); |
2775 | | |
2776 | 11.0k | std::vector<uint32_t> ids_in_order; |
2777 | 22.5k | for (auto it : values_inserted) { |
2778 | 22.5k | ids_in_order.push_back(it.second); |
2779 | 22.5k | } |
2780 | 11.0k | Instruction* construct = |
2781 | 11.0k | ir_builder.AddCompositeConstruct(type_id, ids_in_order); |
2782 | 11.0k | return construct; |
2783 | 11.0k | } |
2784 | | |
2785 | | // Replaces the OpCompositeInsert |inst| that inserts |construct| into the same |
2786 | | // object as |inst| with final index removed. If the resulting |
2787 | | // OpCompositeInsert instruction would have no remaining indexes, the |
2788 | | // instruction is replaced with an OpCopyObject instead. |
2789 | 11.0k | void InsertConstructedObject(Instruction* inst, const Instruction* construct) { |
2790 | 11.0k | if (inst->NumInOperands() == 3) { |
2791 | 10.9k | inst->SetOpcode(spv::Op::OpCopyObject); |
2792 | 10.9k | inst->SetInOperands({{SPV_OPERAND_TYPE_ID, {construct->result_id()}}}); |
2793 | 10.9k | } else { |
2794 | 54 | inst->SetInOperand(kInsertObjectIdInIdx, {construct->result_id()}); |
2795 | 54 | inst->RemoveOperand(inst->NumOperands() - 1); |
2796 | 54 | } |
2797 | 11.0k | } |
2798 | | |
2799 | | // Replaces a series of |OpCompositeInsert| instruction that cover the entire |
2800 | | // object with an |OpCompositeConstruct|. |
2801 | | bool CompositeInsertToCompositeConstruct( |
2802 | | IRContext* context, Instruction* inst, |
2803 | 94.6k | const std::vector<const analysis::Constant*>&) { |
2804 | 94.6k | assert(inst->opcode() == spv::Op::OpCompositeInsert && |
2805 | 94.6k | "Wrong opcode. Should be OpCompositeInsert."); |
2806 | 94.6k | if (inst->NumInOperands() < 3) return false; |
2807 | | |
2808 | 94.6k | std::map<uint32_t, uint32_t> values_inserted = GetInsertedValues(inst); |
2809 | 94.6k | uint32_t container_type_id = GetContainerTypeId(inst); |
2810 | 94.6k | if (container_type_id == 0) { |
2811 | 0 | return false; |
2812 | 0 | } |
2813 | | |
2814 | 94.6k | analysis::TypeManager* type_mgr = context->get_type_mgr(); |
2815 | 94.6k | const analysis::Type* container_type = type_mgr->GetType(container_type_id); |
2816 | 94.6k | assert(container_type && "GetContainerTypeId returned a bad id."); |
2817 | 94.6k | if (!DoInsertedValuesCoverEntireObject(container_type, values_inserted)) { |
2818 | 83.6k | return false; |
2819 | 83.6k | } |
2820 | | |
2821 | 11.0k | Instruction* construct = |
2822 | 11.0k | BuildCompositeConstruct(container_type_id, values_inserted, inst); |
2823 | 11.0k | InsertConstructedObject(inst, construct); |
2824 | 11.0k | return true; |
2825 | 94.6k | } |
2826 | | |
2827 | 16.7k | FoldingRule RedundantPhi() { |
2828 | | // An OpPhi instruction where all values are the same or the result of the phi |
2829 | | // itself, can be replaced by the value itself. |
2830 | 16.7k | return [](IRContext*, Instruction* inst, |
2831 | 350k | const std::vector<const analysis::Constant*>&) { |
2832 | 350k | assert(inst->opcode() == spv::Op::OpPhi && |
2833 | 350k | "Wrong opcode. Should be OpPhi."); |
2834 | | |
2835 | 350k | uint32_t incoming_value = 0; |
2836 | | |
2837 | 813k | for (uint32_t i = 0; i < inst->NumInOperands(); i += 2) { |
2838 | 721k | uint32_t op_id = inst->GetSingleWordInOperand(i); |
2839 | 721k | if (op_id == inst->result_id()) { |
2840 | 71.6k | continue; |
2841 | 71.6k | } |
2842 | | |
2843 | 650k | if (incoming_value == 0) { |
2844 | 350k | incoming_value = op_id; |
2845 | 350k | } else if (op_id != incoming_value) { |
2846 | | // Found two possible value. Can't simplify. |
2847 | 258k | return false; |
2848 | 258k | } |
2849 | 650k | } |
2850 | | |
2851 | 91.7k | if (incoming_value == 0) { |
2852 | | // Code looks invalid. Don't do anything. |
2853 | 0 | return false; |
2854 | 0 | } |
2855 | | |
2856 | | // We have a single incoming value. Simplify using that value. |
2857 | 91.7k | inst->SetOpcode(spv::Op::OpCopyObject); |
2858 | 91.7k | inst->SetInOperands({{SPV_OPERAND_TYPE_ID, {incoming_value}}}); |
2859 | 91.7k | return true; |
2860 | 91.7k | }; |
2861 | 16.7k | } |
2862 | | |
2863 | 16.7k | FoldingRule BitCastScalarOrVector() { |
2864 | 16.7k | return [](IRContext* context, Instruction* inst, |
2865 | 16.7k | const std::vector<const analysis::Constant*>& constants) { |
2866 | 2.41k | assert(inst->opcode() == spv::Op::OpBitcast && constants.size() == 1); |
2867 | 2.41k | if (constants[0] == nullptr) return false; |
2868 | | |
2869 | 1.41k | const analysis::Type* type = |
2870 | 1.41k | context->get_type_mgr()->GetType(inst->type_id()); |
2871 | 1.41k | if (HasFloatingPoint(type) && !inst->IsFloatingPointFoldingAllowed()) |
2872 | 8 | return false; |
2873 | | |
2874 | 1.40k | analysis::ConstantManager* const_mgr = context->get_constant_mgr(); |
2875 | 1.40k | std::vector<uint32_t> words = |
2876 | 1.40k | GetWordsFromNumericScalarOrVectorConstant(const_mgr, constants[0]); |
2877 | 1.40k | if (words.size() == 0) return false; |
2878 | | |
2879 | 1.40k | const analysis::Constant* bitcasted_constant = |
2880 | 1.40k | ConvertWordsToNumericScalarOrVectorConstant(const_mgr, words, type); |
2881 | 1.40k | if (!bitcasted_constant) return false; |
2882 | | |
2883 | 1.40k | auto new_feeder_id = |
2884 | 1.40k | const_mgr->GetDefiningInstruction(bitcasted_constant, inst->type_id()) |
2885 | 1.40k | ->result_id(); |
2886 | 1.40k | inst->SetOpcode(spv::Op::OpCopyObject); |
2887 | 1.40k | inst->SetInOperands({{SPV_OPERAND_TYPE_ID, {new_feeder_id}}}); |
2888 | 1.40k | return true; |
2889 | 1.40k | }; |
2890 | 16.7k | } |
2891 | | |
2892 | | // Remove indirect bitcasts which have no effect. |
2893 | | // uint32 x; asuint32(x) => x |
2894 | | // uint32 x; asuint32(asint32(x)) => x |
2895 | | // float32 x; asuint32(asint32(x)) => asuint32(x) |
2896 | 16.7k | FoldingRule RedundantBitcast() { |
2897 | 16.7k | return [](IRContext* context, Instruction* inst, |
2898 | 16.7k | const std::vector<const analysis::Constant*>&) { |
2899 | 1.00k | assert(inst->opcode() == spv::Op::OpBitcast); |
2900 | | |
2901 | 1.00k | analysis::DefUseManager* def_mgr = context->get_def_use_mgr(); |
2902 | 1.00k | Instruction* child = def_mgr->GetDef(inst->GetSingleWordInOperand(0)); |
2903 | | |
2904 | 1.00k | if (inst->type_id() == child->type_id()) { |
2905 | 165 | inst->SetOpcode(spv::Op::OpCopyObject); |
2906 | 165 | inst->SetInOperands({{SPV_OPERAND_TYPE_ID, {child->result_id()}}}); |
2907 | 165 | return true; |
2908 | 165 | } |
2909 | | |
2910 | 843 | if (child->opcode() != spv::Op::OpBitcast) { |
2911 | 843 | return false; |
2912 | 843 | } |
2913 | | |
2914 | 0 | if (def_mgr->GetDef(child->GetSingleWordInOperand(0))->type_id() == |
2915 | 0 | inst->type_id()) { |
2916 | 0 | inst->SetOpcode(spv::Op::OpCopyObject); |
2917 | 0 | } |
2918 | 0 | inst->SetInOperands( |
2919 | 0 | {{SPV_OPERAND_TYPE_ID, {child->GetSingleWordInOperand(0)}}}); |
2920 | |
|
2921 | 0 | return true; |
2922 | 843 | }; |
2923 | 16.7k | } |
2924 | | |
2925 | 16.7k | FoldingRule BitReverseScalarOrVector() { |
2926 | 16.7k | return [](IRContext* context, Instruction* inst, |
2927 | 16.7k | const std::vector<const analysis::Constant*>& constants) { |
2928 | 750 | assert(inst->opcode() == spv::Op::OpBitReverse && constants.size() == 1); |
2929 | 750 | if (constants[0] == nullptr) return false; |
2930 | | |
2931 | 501 | const analysis::Type* type = |
2932 | 501 | context->get_type_mgr()->GetType(inst->type_id()); |
2933 | 501 | assert(!HasFloatingPoint(type) && |
2934 | 501 | "BitReverse cannot be applied to floating point types."); |
2935 | 501 | assert((type->AsInteger() || type->AsVector()) && |
2936 | 501 | "BitReverse can only be applied to integer scalars or vectors."); |
2937 | 501 | assert((ElementWidth(type) == 32) && |
2938 | 501 | "BitReverse can only be applied to integer types of width 32"); |
2939 | | |
2940 | 501 | analysis::ConstantManager* const_mgr = context->get_constant_mgr(); |
2941 | 501 | std::vector<uint32_t> words = |
2942 | 501 | GetWordsFromNumericScalarOrVectorConstant(const_mgr, constants[0]); |
2943 | 501 | if (words.size() == 0) return false; |
2944 | | |
2945 | 504 | for (uint32_t& word : words) { |
2946 | | // Reverse the bits in each word. |
2947 | 504 | word = ((word & 0x55555555) << 1) | ((word >> 1) & 0x55555555); |
2948 | 504 | word = ((word & 0x33333333) << 2) | ((word >> 2) & 0x33333333); |
2949 | 504 | word = ((word & 0x0F0F0F0F) << 4) | ((word >> 4) & 0x0F0F0F0F); |
2950 | 504 | word = ((word & 0x00FF00FF) << 8) | ((word >> 8) & 0x00FF00FF); |
2951 | 504 | word = (word << 16) | (word >> 16); |
2952 | 504 | } |
2953 | | |
2954 | 501 | const analysis::Constant* bitreversed_constant = |
2955 | 501 | ConvertWordsToNumericScalarOrVectorConstant(const_mgr, words, type); |
2956 | 501 | if (!bitreversed_constant) return false; |
2957 | | |
2958 | 501 | auto new_feeder_id = |
2959 | 501 | const_mgr->GetDefiningInstruction(bitreversed_constant, inst->type_id()) |
2960 | 501 | ->result_id(); |
2961 | 501 | inst->SetOpcode(spv::Op::OpCopyObject); |
2962 | 501 | inst->SetInOperands({{SPV_OPERAND_TYPE_ID, {new_feeder_id}}}); |
2963 | 501 | return true; |
2964 | 501 | }; |
2965 | 16.7k | } |
2966 | | |
2967 | 16.7k | FoldingRule RedundantSelect() { |
2968 | | // An OpSelect instruction where both values are the same or the condition is |
2969 | | // constant can be replaced by one of the values |
2970 | 16.7k | return [](IRContext*, Instruction* inst, |
2971 | 17.1k | const std::vector<const analysis::Constant*>& constants) { |
2972 | 17.1k | assert(inst->opcode() == spv::Op::OpSelect && |
2973 | 17.1k | "Wrong opcode. Should be OpSelect."); |
2974 | 17.1k | assert(inst->NumInOperands() == 3); |
2975 | 17.1k | assert(constants.size() == 3); |
2976 | | |
2977 | 17.1k | uint32_t true_id = inst->GetSingleWordInOperand(1); |
2978 | 17.1k | uint32_t false_id = inst->GetSingleWordInOperand(2); |
2979 | | |
2980 | 17.1k | if (true_id == false_id) { |
2981 | | // Both results are the same, condition doesn't matter |
2982 | 117 | inst->SetOpcode(spv::Op::OpCopyObject); |
2983 | 117 | inst->SetInOperands({{SPV_OPERAND_TYPE_ID, {true_id}}}); |
2984 | 117 | return true; |
2985 | 17.0k | } else if (constants[0]) { |
2986 | 1.41k | const analysis::Type* type = constants[0]->type(); |
2987 | 1.41k | if (type->AsBool()) { |
2988 | | // Scalar constant value, select the corresponding value. |
2989 | 959 | inst->SetOpcode(spv::Op::OpCopyObject); |
2990 | 959 | if (constants[0]->AsNullConstant() || |
2991 | 959 | !constants[0]->AsBoolConstant()->value()) { |
2992 | 650 | inst->SetInOperands({{SPV_OPERAND_TYPE_ID, {false_id}}}); |
2993 | 650 | } else { |
2994 | 309 | inst->SetInOperands({{SPV_OPERAND_TYPE_ID, {true_id}}}); |
2995 | 309 | } |
2996 | 959 | return true; |
2997 | 959 | } else { |
2998 | 454 | assert(type->AsVector()); |
2999 | 454 | if (constants[0]->AsNullConstant()) { |
3000 | | // All values come from false id. |
3001 | 0 | inst->SetOpcode(spv::Op::OpCopyObject); |
3002 | 0 | inst->SetInOperands({{SPV_OPERAND_TYPE_ID, {false_id}}}); |
3003 | 0 | return true; |
3004 | 454 | } else { |
3005 | | // Convert to a vector shuffle. |
3006 | 454 | std::vector<Operand> ops; |
3007 | 454 | ops.push_back({SPV_OPERAND_TYPE_ID, {true_id}}); |
3008 | 454 | ops.push_back({SPV_OPERAND_TYPE_ID, {false_id}}); |
3009 | 454 | const analysis::VectorConstant* vector_const = |
3010 | 454 | constants[0]->AsVectorConstant(); |
3011 | 454 | uint32_t size = |
3012 | 454 | static_cast<uint32_t>(vector_const->GetComponents().size()); |
3013 | 1.36k | for (uint32_t i = 0; i != size; ++i) { |
3014 | 910 | const analysis::Constant* component = |
3015 | 910 | vector_const->GetComponents()[i]; |
3016 | 910 | if (component->AsNullConstant() || |
3017 | 910 | !component->AsBoolConstant()->value()) { |
3018 | | // Selecting from the false vector which is the second input |
3019 | | // vector to the shuffle. Offset the index by |size|. |
3020 | 27 | ops.push_back({SPV_OPERAND_TYPE_LITERAL_INTEGER, {i + size}}); |
3021 | 883 | } else { |
3022 | | // Selecting from true vector which is the first input vector to |
3023 | | // the shuffle. |
3024 | 883 | ops.push_back({SPV_OPERAND_TYPE_LITERAL_INTEGER, {i}}); |
3025 | 883 | } |
3026 | 910 | } |
3027 | | |
3028 | 454 | inst->SetOpcode(spv::Op::OpVectorShuffle); |
3029 | 454 | inst->SetInOperands(std::move(ops)); |
3030 | 454 | return true; |
3031 | 454 | } |
3032 | 454 | } |
3033 | 1.41k | } |
3034 | | |
3035 | 15.6k | return false; |
3036 | 17.1k | }; |
3037 | 16.7k | } |
3038 | | |
3039 | 10.8k | std::optional<bool> GetBoolConstantKind(const analysis::Constant* c) { |
3040 | 10.8k | if (!c) { |
3041 | 3.64k | return {}; |
3042 | 3.64k | } |
3043 | 7.19k | if (auto composite = c->AsCompositeConstant()) { |
3044 | 0 | auto& components = composite->GetComponents(); |
3045 | 0 | if (components.empty()) { |
3046 | 0 | return {}; |
3047 | 0 | } |
3048 | 0 | auto first = GetBoolConstantKind(components[0]); |
3049 | 0 | if (!first) { |
3050 | 0 | return {}; |
3051 | 0 | } |
3052 | 0 | if (std::all_of(std::begin(components) + 1, std::end(components), |
3053 | 0 | [first](const analysis::Constant* c2) { |
3054 | 0 | return GetBoolConstantKind(c2) == first; |
3055 | 0 | })) { |
3056 | 0 | return first; |
3057 | 0 | } |
3058 | 0 | return {}; |
3059 | 7.19k | } else if (c->AsNullConstant()) { |
3060 | 5 | return false; |
3061 | 7.19k | } else if (c->AsBoolConstant()) { |
3062 | 7.19k | return c->AsBoolConstant()->value(); |
3063 | 7.19k | } |
3064 | 0 | return {}; |
3065 | 7.19k | } |
3066 | | |
3067 | | // Fold OpSelect instructions which have constant booleans as their result. |
3068 | | // x ? true : false = x |
3069 | | // x ? false : true = !x |
3070 | 16.7k | FoldingRule FoldConstantBooleanSelect() { |
3071 | 16.7k | return [](IRContext* context, Instruction* inst, |
3072 | 16.7k | const std::vector<const analysis::Constant*>& constants) { |
3073 | 15.6k | assert(inst->opcode() == spv::Op::OpSelect); |
3074 | 15.6k | assert(inst->NumInOperands() == 3); |
3075 | 15.6k | assert(constants.size() == 3); |
3076 | | |
3077 | 15.6k | if (!constants[1] || !constants[2]) { |
3078 | 10.5k | return false; |
3079 | 10.5k | } |
3080 | | |
3081 | 5.07k | analysis::DefUseManager* def_mgr = context->get_def_use_mgr(); |
3082 | 5.07k | if (inst->type_id() != |
3083 | 5.07k | def_mgr->GetDef(inst->GetSingleWordInOperand(0))->type_id()) { |
3084 | 5.00k | return false; |
3085 | 5.00k | } |
3086 | | |
3087 | 70 | std::optional<bool> uniform_true = GetBoolConstantKind(constants[1]); |
3088 | 70 | std::optional<bool> uniform_false = GetBoolConstantKind(constants[2]); |
3089 | | |
3090 | 70 | if (!uniform_true || !uniform_false) { |
3091 | 0 | return false; |
3092 | 0 | } |
3093 | | |
3094 | 70 | if (uniform_true.value() && !uniform_false.value()) { |
3095 | 28 | inst->SetOpcode(spv::Op::OpCopyObject); |
3096 | 28 | inst->SetInOperands( |
3097 | 28 | {{SPV_OPERAND_TYPE_ID, {inst->GetSingleWordInOperand(0)}}}); |
3098 | 28 | return true; |
3099 | 42 | } else if (!uniform_true.value() && uniform_false.value()) { |
3100 | 42 | inst->SetOpcode(spv::Op::OpLogicalNot); |
3101 | 42 | inst->SetInOperands( |
3102 | 42 | {{SPV_OPERAND_TYPE_ID, {inst->GetSingleWordInOperand(0)}}}); |
3103 | 42 | return true; |
3104 | 42 | } |
3105 | 0 | return false; |
3106 | 70 | }; |
3107 | 16.7k | } |
3108 | | |
3109 | | // Fold OpLogicalAnd instructions which have a constant true on one side. |
3110 | | // x && true = x |
3111 | | // true && x = x |
3112 | 16.7k | FoldingRule RedundantLogicalAnd() { |
3113 | 16.7k | return [](IRContext* context, Instruction* inst, |
3114 | 16.7k | const std::vector<const analysis::Constant*>& constants) { |
3115 | 8.55k | assert(inst->opcode() == spv::Op::OpLogicalAnd); |
3116 | | |
3117 | 8.55k | if (GetBoolConstantKind(ConstInput(constants)) == |
3118 | 8.55k | std::optional<bool>(true)) { |
3119 | 5.28k | Instruction* other_inst = NonConstInput(context, constants[0], inst); |
3120 | 5.28k | inst->SetOpcode(spv::Op::OpCopyObject); |
3121 | 5.28k | inst->SetInOperands({{SPV_OPERAND_TYPE_ID, {other_inst->result_id()}}}); |
3122 | 5.28k | return true; |
3123 | 5.28k | } |
3124 | 3.27k | return false; |
3125 | 8.55k | }; |
3126 | 16.7k | } |
3127 | | |
3128 | | // Fold OpLogicalOr instructions which have a constant false on one side. |
3129 | | // x || false = x |
3130 | | // false || x = x |
3131 | 16.7k | FoldingRule RedundantLogicalOr() { |
3132 | 16.7k | return [](IRContext* context, Instruction* inst, |
3133 | 16.7k | const std::vector<const analysis::Constant*>& constants) { |
3134 | 490 | assert(inst->opcode() == spv::Op::OpLogicalOr); |
3135 | | |
3136 | 490 | if (GetBoolConstantKind(ConstInput(constants)) == |
3137 | 490 | std::optional<bool>(false)) { |
3138 | 115 | Instruction* other_inst = NonConstInput(context, constants[0], inst); |
3139 | 115 | inst->SetOpcode(spv::Op::OpCopyObject); |
3140 | 115 | inst->SetInOperands({{SPV_OPERAND_TYPE_ID, {other_inst->result_id()}}}); |
3141 | 115 | return true; |
3142 | 115 | } |
3143 | 375 | return false; |
3144 | 490 | }; |
3145 | 16.7k | } |
3146 | | |
3147 | | // Fold concurrent OpLogicalNot instructions: |
3148 | | // !!x = x |
3149 | 16.7k | FoldingRule RedundantLogicalNot() { |
3150 | 16.7k | return [](IRContext* context, Instruction* inst, |
3151 | 16.7k | const std::vector<const analysis::Constant*>&) { |
3152 | 15.2k | assert(inst->opcode() == spv::Op::OpLogicalNot); |
3153 | 15.2k | Instruction* child = |
3154 | 15.2k | context->get_def_use_mgr()->GetDef(inst->GetSingleWordInOperand(0)); |
3155 | 15.2k | if (child->opcode() == spv::Op::OpLogicalNot) { |
3156 | 51 | inst->SetOpcode(spv::Op::OpCopyObject); |
3157 | 51 | inst->SetInOperands( |
3158 | 51 | {{SPV_OPERAND_TYPE_ID, {child->GetSingleWordInOperand(0)}}}); |
3159 | 51 | return true; |
3160 | 51 | } |
3161 | 15.2k | return false; |
3162 | 15.2k | }; |
3163 | 16.7k | } |
3164 | | |
3165 | | // Cases handled: |
3166 | | // ((a ? C0 : C1) == C2) = ((a ? (C0 == C2) : (C1 == C2)) |
3167 | | // ((a ? C0 : C1) != C2) = ((a ? (C0 != C2) : (C1 != C2)) |
3168 | | // ((a ? C0 : C1) < C2) = ((a ? (C0 < C2) : (C1 < C2)) |
3169 | | // ((a ? C0 : C1) <= C2) = ((a ? (C0 <= C2) : (C1 <= C2)) |
3170 | | // ((a ? C0 : C1) > C2) = ((a ? (C0 > C2) : (C1 > C2)) |
3171 | | // ((a ? C0 : C1) >= C2) = ((a ? (C0 >= C2) : (C1 >= C2)) |
3172 | | // ((a ? C0 : C1) || C2) = ((a ? (C0 || C2) : (C1 || C2)) |
3173 | | // ((a ? C0 : C1) && C2) = ((a ? (C0 && C2) : (C1 && C2)) |
3174 | | // ((a ? C0 : C1) + C2) = ((a ? (C0 + C2) : (C1 + C2)) |
3175 | | // ((a ? C0 : C1) - C2) = ((a ? (C0 - C2) : (C1 - C2)) |
3176 | | // ((a ? C0 : C1) * C2) = ((a ? (C0 * C2) : (C1 * C2)) |
3177 | | // ((a ? C0 : C1) / C2) = ((a ? (C0 / C2) : (C1 / C2)) |
3178 | | // ((a ? C0 : C1) >> C2) = ((a ? (C0 >> C2) : (C1 >> C2)) |
3179 | | // ((a ? C0 : C1) << C2) = ((a ? (C0 << C2) : (C1 << C2)) |
3180 | | // ((a ? C0 : C1) ^ C2) = ((a ? (C0 ^ C2) : (C1 ^ C2)) |
3181 | | // ((a ? C0 : C1) | C2) = ((a ? (C0 | C2) : (C1 | C2)) |
3182 | | // ((a ? C0 : C1) & C2) = ((a ? (C0 & C2) : (C1 & C2)) |
3183 | | static const constexpr spv::Op MergeBinaryOpSelectOps[] = { |
3184 | | spv::Op::OpLogicalEqual, |
3185 | | spv::Op::OpLogicalNotEqual, |
3186 | | spv::Op::OpLogicalAnd, |
3187 | | spv::Op::OpLogicalOr, |
3188 | | spv::Op::OpIEqual, |
3189 | | spv::Op::OpINotEqual, |
3190 | | spv::Op::OpUGreaterThan, |
3191 | | spv::Op::OpSGreaterThan, |
3192 | | spv::Op::OpUGreaterThanEqual, |
3193 | | spv::Op::OpSGreaterThanEqual, |
3194 | | spv::Op::OpULessThan, |
3195 | | spv::Op::OpSLessThan, |
3196 | | spv::Op::OpULessThanEqual, |
3197 | | spv::Op::OpSLessThanEqual, |
3198 | | spv::Op::OpFOrdEqual, |
3199 | | spv::Op::OpFUnordEqual, |
3200 | | spv::Op::OpFOrdNotEqual, |
3201 | | spv::Op::OpFUnordNotEqual, |
3202 | | spv::Op::OpFOrdLessThan, |
3203 | | spv::Op::OpFUnordLessThan, |
3204 | | spv::Op::OpFOrdGreaterThan, |
3205 | | spv::Op::OpFUnordGreaterThan, |
3206 | | spv::Op::OpFOrdLessThanEqual, |
3207 | | spv::Op::OpFUnordLessThanEqual, |
3208 | | spv::Op::OpFOrdGreaterThanEqual, |
3209 | | spv::Op::OpFUnordGreaterThanEqual, |
3210 | | spv::Op::OpIAdd, |
3211 | | spv::Op::OpFAdd, |
3212 | | spv::Op::OpISub, |
3213 | | spv::Op::OpFSub, |
3214 | | spv::Op::OpIMul, |
3215 | | spv::Op::OpFMul, |
3216 | | spv::Op::OpUDiv, |
3217 | | spv::Op::OpSDiv, |
3218 | | spv::Op::OpFDiv, |
3219 | | spv::Op::OpVectorTimesScalar, |
3220 | | spv::Op::OpShiftRightLogical, |
3221 | | spv::Op::OpShiftRightArithmetic, |
3222 | | spv::Op::OpShiftLeftLogical, |
3223 | | spv::Op::OpBitwiseXor, |
3224 | | spv::Op::OpBitwiseOr, |
3225 | | spv::Op::OpBitwiseAnd}; |
3226 | | |
3227 | 704k | FoldingRule MergeBinaryOpSelect(spv::Op opcode) { |
3228 | 704k | assert(std::find(std::begin(MergeBinaryOpSelectOps), |
3229 | 704k | std::end(MergeBinaryOpSelectOps), |
3230 | 704k | opcode) != std::end(MergeBinaryOpSelectOps) && |
3231 | 704k | "Wrong opcode."); |
3232 | | |
3233 | 704k | return [opcode](IRContext* context, Instruction* inst, |
3234 | 1.73M | const std::vector<const analysis::Constant*>& constants) { |
3235 | 1.73M | const analysis::Constant* const_input = ConstInput(constants); |
3236 | 1.73M | if (!const_input) { |
3237 | 887k | return false; |
3238 | 887k | } |
3239 | 848k | Instruction* non_const = NonConstInput(context, constants[0], inst); |
3240 | 848k | if (non_const->opcode() != spv::Op::OpSelect) { |
3241 | 848k | return false; |
3242 | 848k | } |
3243 | 901 | std::vector<const analysis::Constant*> select_constants = |
3244 | 901 | context->get_constant_mgr()->GetOperandConstants(non_const); |
3245 | 901 | if (!select_constants[1] || !select_constants[2]) { |
3246 | 223 | return false; |
3247 | 223 | } |
3248 | | |
3249 | 678 | InstructionBuilder ir_builder( |
3250 | 678 | context, inst, |
3251 | 678 | IRContext::kAnalysisDefUse | IRContext::kAnalysisInstrToBlockMapping); |
3252 | | |
3253 | 678 | Instruction *lhs, *rhs; |
3254 | 678 | if (constants[0]) { |
3255 | 354 | lhs = ir_builder.AddBinaryOp(inst->type_id(), opcode, |
3256 | 354 | inst->GetSingleWordInOperand(0), |
3257 | 354 | non_const->GetSingleWordInOperand(1)); |
3258 | 354 | rhs = ir_builder.AddBinaryOp(inst->type_id(), opcode, |
3259 | 354 | inst->GetSingleWordInOperand(0), |
3260 | 354 | non_const->GetSingleWordInOperand(2)); |
3261 | 354 | } else { |
3262 | 324 | lhs = ir_builder.AddBinaryOp(inst->type_id(), opcode, |
3263 | 324 | non_const->GetSingleWordInOperand(1), |
3264 | 324 | inst->GetSingleWordInOperand(1)); |
3265 | 324 | rhs = ir_builder.AddBinaryOp(inst->type_id(), opcode, |
3266 | 324 | non_const->GetSingleWordInOperand(2), |
3267 | 324 | inst->GetSingleWordInOperand(1)); |
3268 | 324 | } |
3269 | | |
3270 | 678 | if (!lhs || !rhs) { |
3271 | 0 | return false; |
3272 | 0 | } |
3273 | | |
3274 | 678 | if (context->get_instruction_folder().FoldInstruction(lhs)) { |
3275 | 678 | context->AnalyzeDefUse(lhs); |
3276 | 1.35k | while (lhs->opcode() == spv::Op::OpCopyObject) { |
3277 | 678 | lhs = |
3278 | 678 | context->get_def_use_mgr()->GetDef(lhs->GetSingleWordInOperand(0)); |
3279 | 678 | } |
3280 | 678 | } |
3281 | 678 | if (context->get_instruction_folder().FoldInstruction(rhs)) { |
3282 | 678 | context->AnalyzeDefUse(rhs); |
3283 | 1.35k | while (rhs->opcode() == spv::Op::OpCopyObject) { |
3284 | 678 | rhs = |
3285 | 678 | context->get_def_use_mgr()->GetDef(rhs->GetSingleWordInOperand(0)); |
3286 | 678 | } |
3287 | 678 | } |
3288 | 678 | inst->SetOpcode(spv::Op::OpSelect); |
3289 | 678 | inst->SetInOperands( |
3290 | 678 | {{SPV_OPERAND_TYPE_ID, {non_const->GetSingleWordInOperand(0)}}, |
3291 | 678 | {SPV_OPERAND_TYPE_ID, {lhs->result_id()}}, |
3292 | 678 | {SPV_OPERAND_TYPE_ID, {rhs->result_id()}}}); |
3293 | 678 | return true; |
3294 | 678 | }; |
3295 | 704k | } |
3296 | | |
3297 | | // Fold OpLogicalNot instructions that follow a comparison, |
3298 | | // if the comparison is only used by that instruction. |
3299 | | // |
3300 | | // !(a == b) = (a != b) |
3301 | | // !(a != b) = (a == b) |
3302 | | // !(a < b) = (a >= b) |
3303 | | // !(a >= b) = (a < b) |
3304 | | // !(a > b) = (a <= b) |
3305 | | // !(a <= b) = (a > b) |
3306 | 16.7k | FoldingRule FoldLogicalNotComparison() { |
3307 | 16.7k | return [](IRContext* context, Instruction* inst, |
3308 | 16.7k | const std::vector<const analysis::Constant*>&) { |
3309 | 15.2k | assert(inst->opcode() == spv::Op::OpLogicalNot); |
3310 | 15.2k | analysis::DefUseManager* def_mgr = context->get_def_use_mgr(); |
3311 | 15.2k | Instruction* child = |
3312 | 15.2k | context->get_def_use_mgr()->GetDef(inst->GetSingleWordInOperand(0)); |
3313 | | |
3314 | 15.2k | if (def_mgr->NumUses(child) > 1) { |
3315 | 8.07k | return false; |
3316 | 8.07k | } |
3317 | | |
3318 | 7.16k | spv::Op new_opcode = spv::Op::OpNop; |
3319 | 7.16k | switch (child->opcode()) { |
3320 | | // (a == b) <=> (a != b) |
3321 | 2 | case spv::Op::OpIEqual: |
3322 | 2 | new_opcode = spv::Op::OpINotEqual; |
3323 | 2 | break; |
3324 | 29 | case spv::Op::OpINotEqual: |
3325 | 29 | new_opcode = spv::Op::OpIEqual; |
3326 | 29 | break; |
3327 | 138 | case spv::Op::OpFOrdEqual: |
3328 | 138 | new_opcode = spv::Op::OpFUnordNotEqual; |
3329 | 138 | break; |
3330 | 61 | case spv::Op::OpFOrdNotEqual: |
3331 | 61 | new_opcode = spv::Op::OpFUnordEqual; |
3332 | 61 | break; |
3333 | 112 | case spv::Op::OpFUnordEqual: |
3334 | 112 | new_opcode = spv::Op::OpFOrdNotEqual; |
3335 | 112 | break; |
3336 | 47 | case spv::Op::OpFUnordNotEqual: |
3337 | 47 | new_opcode = spv::Op::OpFOrdEqual; |
3338 | 47 | break; |
3339 | 6 | case spv::Op::OpLogicalEqual: |
3340 | 6 | new_opcode = spv::Op::OpLogicalNotEqual; |
3341 | 6 | break; |
3342 | 9 | case spv::Op::OpLogicalNotEqual: |
3343 | 9 | new_opcode = spv::Op::OpLogicalEqual; |
3344 | 9 | break; |
3345 | | |
3346 | | // (a > b) <=> (a <= b) |
3347 | 5 | case spv::Op::OpUGreaterThan: |
3348 | 5 | new_opcode = spv::Op::OpULessThanEqual; |
3349 | 5 | break; |
3350 | 4 | case spv::Op::OpULessThanEqual: |
3351 | 4 | new_opcode = spv::Op::OpUGreaterThan; |
3352 | 4 | break; |
3353 | 13 | case spv::Op::OpSGreaterThan: |
3354 | 13 | new_opcode = spv::Op::OpSLessThanEqual; |
3355 | 13 | break; |
3356 | 2 | case spv::Op::OpSLessThanEqual: |
3357 | 2 | new_opcode = spv::Op::OpSGreaterThan; |
3358 | 2 | break; |
3359 | 2.48k | case spv::Op::OpFOrdGreaterThan: |
3360 | 2.48k | new_opcode = spv::Op::OpFUnordLessThanEqual; |
3361 | 2.48k | break; |
3362 | 123 | case spv::Op::OpFOrdLessThanEqual: |
3363 | 123 | new_opcode = spv::Op::OpFUnordGreaterThan; |
3364 | 123 | break; |
3365 | 114 | case spv::Op::OpFUnordGreaterThan: |
3366 | 114 | new_opcode = spv::Op::OpFOrdLessThanEqual; |
3367 | 114 | break; |
3368 | 37 | case spv::Op::OpFUnordLessThanEqual: |
3369 | 37 | new_opcode = spv::Op::OpFOrdGreaterThan; |
3370 | 37 | break; |
3371 | | |
3372 | | // (a < b) <=> (a >= b) |
3373 | 1 | case spv::Op::OpULessThan: |
3374 | 1 | new_opcode = spv::Op::OpUGreaterThanEqual; |
3375 | 1 | break; |
3376 | 1 | case spv::Op::OpUGreaterThanEqual: |
3377 | 1 | new_opcode = spv::Op::OpULessThan; |
3378 | 1 | break; |
3379 | 2 | case spv::Op::OpSLessThan: |
3380 | 2 | new_opcode = spv::Op::OpSGreaterThanEqual; |
3381 | 2 | break; |
3382 | 2 | case spv::Op::OpSGreaterThanEqual: |
3383 | 2 | new_opcode = spv::Op::OpSLessThan; |
3384 | 2 | break; |
3385 | 2.81k | case spv::Op::OpFOrdLessThan: |
3386 | 2.81k | new_opcode = spv::Op::OpFUnordGreaterThanEqual; |
3387 | 2.81k | break; |
3388 | 67 | case spv::Op::OpFOrdGreaterThanEqual: |
3389 | 67 | new_opcode = spv::Op::OpFUnordLessThan; |
3390 | 67 | break; |
3391 | 30 | case spv::Op::OpFUnordLessThan: |
3392 | 30 | new_opcode = spv::Op::OpFOrdGreaterThanEqual; |
3393 | 30 | break; |
3394 | 39 | case spv::Op::OpFUnordGreaterThanEqual: |
3395 | 39 | new_opcode = spv::Op::OpFOrdLessThan; |
3396 | 39 | break; |
3397 | | |
3398 | 1.01k | default: |
3399 | 1.01k | break; |
3400 | 7.16k | } |
3401 | | |
3402 | 7.16k | if (new_opcode == spv::Op::OpNop) { |
3403 | 1.01k | return false; |
3404 | 1.01k | } |
3405 | | |
3406 | 6.14k | inst->SetOpcode(new_opcode); |
3407 | 6.14k | inst->SetInOperands( |
3408 | 6.14k | {{SPV_OPERAND_TYPE_ID, {child->GetSingleWordInOperand(0)}}, |
3409 | 6.14k | {SPV_OPERAND_TYPE_ID, {child->GetSingleWordInOperand(1)}}}); |
3410 | | |
3411 | 6.14k | return true; |
3412 | 7.16k | }; |
3413 | 16.7k | } |
3414 | | |
3415 | | // (a == true) = a |
3416 | | // (a == false) = !a |
3417 | | // (a != true) = !a |
3418 | | // (a != false) = a |
3419 | 33.5k | FoldingRule RedundantLogicalEqual() { |
3420 | 33.5k | return [](IRContext* context, Instruction* inst, |
3421 | 33.5k | const std::vector<const analysis::Constant*>& constants) { |
3422 | 1.96k | assert(inst->opcode() == spv::Op::OpLogicalEqual || |
3423 | 1.96k | inst->opcode() == spv::Op::OpLogicalNotEqual); |
3424 | | |
3425 | 1.96k | const analysis::Constant* const_input = ConstInput(constants); |
3426 | 1.96k | if (!const_input) { |
3427 | 310 | return false; |
3428 | 310 | } |
3429 | | |
3430 | 1.65k | analysis::DefUseManager* def_mgr = context->get_def_use_mgr(); |
3431 | 1.65k | if (inst->type_id() != |
3432 | 1.65k | def_mgr->GetDef(inst->GetSingleWordInOperand(0))->type_id()) { |
3433 | 0 | return false; |
3434 | 0 | } |
3435 | | |
3436 | 1.65k | std::optional<bool> uniform_const = GetBoolConstantKind(const_input); |
3437 | 1.65k | if (!uniform_const) { |
3438 | 0 | return false; |
3439 | 0 | } |
3440 | | |
3441 | 1.65k | bool direct_copy = inst->opcode() == spv::Op::OpLogicalEqual |
3442 | 1.65k | ? uniform_const.value() |
3443 | 1.65k | : !uniform_const.value(); |
3444 | | |
3445 | 1.65k | inst->SetOpcode(direct_copy ? spv::Op::OpCopyObject |
3446 | 1.65k | : spv::Op::OpLogicalNot); |
3447 | 1.65k | inst->SetInOperands( |
3448 | 1.65k | {{SPV_OPERAND_TYPE_ID, |
3449 | 1.65k | {NonConstInput(context, constants[0], inst)->result_id()}}}); |
3450 | 1.65k | return true; |
3451 | 1.65k | }; |
3452 | 33.5k | } |
3453 | | |
3454 | | enum class FloatConstantKind { Unknown, Zero, One }; |
3455 | | |
3456 | 2.47M | FloatConstantKind getFloatConstantKind(const analysis::Constant* constant) { |
3457 | 2.47M | if (constant == nullptr) { |
3458 | 1.42M | return FloatConstantKind::Unknown; |
3459 | 1.42M | } |
3460 | | |
3461 | 2.47M | assert(HasFloatingPoint(constant->type()) && "Unexpected constant type"); |
3462 | | |
3463 | 1.05M | if (constant->AsNullConstant()) { |
3464 | 4.72k | return FloatConstantKind::Zero; |
3465 | 1.04M | } else if (const analysis::VectorConstant* vc = |
3466 | 1.04M | constant->AsVectorConstant()) { |
3467 | 272k | const std::vector<const analysis::Constant*>& components = |
3468 | 272k | vc->GetComponents(); |
3469 | 272k | assert(!components.empty()); |
3470 | | |
3471 | 272k | FloatConstantKind kind = getFloatConstantKind(components[0]); |
3472 | | |
3473 | 518k | for (size_t i = 1; i < components.size(); ++i) { |
3474 | 309k | if (getFloatConstantKind(components[i]) != kind) { |
3475 | 63.1k | return FloatConstantKind::Unknown; |
3476 | 63.1k | } |
3477 | 309k | } |
3478 | | |
3479 | 209k | return kind; |
3480 | 776k | } else if (const analysis::FloatConstant* fc = constant->AsFloatConstant()) { |
3481 | 776k | if (fc->IsZero()) return FloatConstantKind::Zero; |
3482 | | |
3483 | 706k | uint32_t width = fc->type()->AsFloat()->width(); |
3484 | 706k | if (width != 32 && width != 64) return FloatConstantKind::Unknown; |
3485 | | |
3486 | 706k | double value = (width == 64) ? fc->GetDoubleValue() : fc->GetFloatValue(); |
3487 | | |
3488 | 706k | if (value == 0.0) { |
3489 | 34.1k | return FloatConstantKind::Zero; |
3490 | 672k | } else if (value == 1.0) { |
3491 | 31.2k | return FloatConstantKind::One; |
3492 | 641k | } else { |
3493 | 641k | return FloatConstantKind::Unknown; |
3494 | 641k | } |
3495 | 706k | } else { |
3496 | 0 | return FloatConstantKind::Unknown; |
3497 | 0 | } |
3498 | 1.05M | } |
3499 | | |
3500 | 16.7k | FoldingRule RedundantFAdd() { |
3501 | 16.7k | return [](IRContext*, Instruction* inst, |
3502 | 458k | const std::vector<const analysis::Constant*>& constants) { |
3503 | 458k | assert(inst->opcode() == spv::Op::OpFAdd && |
3504 | 458k | "Wrong opcode. Should be OpFAdd."); |
3505 | 458k | assert(constants.size() == 2); |
3506 | | |
3507 | 458k | if (!inst->IsFloatingPointFoldingAllowed()) { |
3508 | 559 | return false; |
3509 | 559 | } |
3510 | | |
3511 | 458k | FloatConstantKind kind0 = getFloatConstantKind(constants[0]); |
3512 | 458k | FloatConstantKind kind1 = getFloatConstantKind(constants[1]); |
3513 | | |
3514 | 458k | if (kind0 == FloatConstantKind::Zero || kind1 == FloatConstantKind::Zero) { |
3515 | 13.1k | inst->SetOpcode(spv::Op::OpCopyObject); |
3516 | 13.1k | inst->SetInOperands({{SPV_OPERAND_TYPE_ID, |
3517 | 13.1k | {inst->GetSingleWordInOperand( |
3518 | 13.1k | kind0 == FloatConstantKind::Zero ? 1 : 0)}}}); |
3519 | 13.1k | return true; |
3520 | 13.1k | } |
3521 | | |
3522 | 445k | return false; |
3523 | 458k | }; |
3524 | 16.7k | } |
3525 | | |
3526 | 16.7k | FoldingRule RedundantFSub() { |
3527 | 16.7k | return [](IRContext*, Instruction* inst, |
3528 | 95.7k | const std::vector<const analysis::Constant*>& constants) { |
3529 | 95.7k | assert(inst->opcode() == spv::Op::OpFSub && |
3530 | 95.7k | "Wrong opcode. Should be OpFSub."); |
3531 | 95.7k | assert(constants.size() == 2); |
3532 | | |
3533 | 95.7k | if (!inst->IsFloatingPointFoldingAllowed()) { |
3534 | 3.90k | return false; |
3535 | 3.90k | } |
3536 | | |
3537 | 91.8k | FloatConstantKind kind0 = getFloatConstantKind(constants[0]); |
3538 | 91.8k | FloatConstantKind kind1 = getFloatConstantKind(constants[1]); |
3539 | | |
3540 | 91.8k | if (kind0 == FloatConstantKind::Zero) { |
3541 | 1.86k | inst->SetOpcode(spv::Op::OpFNegate); |
3542 | 1.86k | inst->SetInOperands( |
3543 | 1.86k | {{SPV_OPERAND_TYPE_ID, {inst->GetSingleWordInOperand(1)}}}); |
3544 | 1.86k | return true; |
3545 | 1.86k | } |
3546 | | |
3547 | 89.9k | if (kind1 == FloatConstantKind::Zero) { |
3548 | 2.26k | inst->SetOpcode(spv::Op::OpCopyObject); |
3549 | 2.26k | inst->SetInOperands( |
3550 | 2.26k | {{SPV_OPERAND_TYPE_ID, {inst->GetSingleWordInOperand(0)}}}); |
3551 | 2.26k | return true; |
3552 | 2.26k | } |
3553 | | |
3554 | 87.6k | return false; |
3555 | 89.9k | }; |
3556 | 16.7k | } |
3557 | | |
3558 | 16.7k | FoldingRule RedundantFMul() { |
3559 | 16.7k | return [](IRContext*, Instruction* inst, |
3560 | 207k | const std::vector<const analysis::Constant*>& constants) { |
3561 | 207k | assert(inst->opcode() == spv::Op::OpFMul && |
3562 | 207k | "Wrong opcode. Should be OpFMul."); |
3563 | 207k | assert(constants.size() == 2); |
3564 | | |
3565 | 207k | if (!inst->IsFloatingPointFoldingAllowed()) { |
3566 | 19 | return false; |
3567 | 19 | } |
3568 | | |
3569 | 207k | FloatConstantKind kind0 = getFloatConstantKind(constants[0]); |
3570 | 207k | FloatConstantKind kind1 = getFloatConstantKind(constants[1]); |
3571 | | |
3572 | 207k | if (kind0 == FloatConstantKind::Zero || kind1 == FloatConstantKind::Zero) { |
3573 | 5.56k | inst->SetOpcode(spv::Op::OpCopyObject); |
3574 | 5.56k | inst->SetInOperands({{SPV_OPERAND_TYPE_ID, |
3575 | 5.56k | {inst->GetSingleWordInOperand( |
3576 | 5.56k | kind0 == FloatConstantKind::Zero ? 0 : 1)}}}); |
3577 | 5.56k | return true; |
3578 | 5.56k | } |
3579 | | |
3580 | 202k | if (kind0 == FloatConstantKind::One || kind1 == FloatConstantKind::One) { |
3581 | 1.66k | inst->SetOpcode(spv::Op::OpCopyObject); |
3582 | 1.66k | inst->SetInOperands({{SPV_OPERAND_TYPE_ID, |
3583 | 1.66k | {inst->GetSingleWordInOperand( |
3584 | 1.66k | kind0 == FloatConstantKind::One ? 1 : 0)}}}); |
3585 | 1.66k | return true; |
3586 | 1.66k | } |
3587 | | |
3588 | 200k | return false; |
3589 | 202k | }; |
3590 | 16.7k | } |
3591 | | |
3592 | 16.7k | FoldingRule RedundantFDiv() { |
3593 | 16.7k | return [](IRContext*, Instruction* inst, |
3594 | 147k | const std::vector<const analysis::Constant*>& constants) { |
3595 | 147k | assert(inst->opcode() == spv::Op::OpFDiv && |
3596 | 147k | "Wrong opcode. Should be OpFDiv."); |
3597 | 147k | assert(constants.size() == 2); |
3598 | | |
3599 | 147k | if (!inst->IsFloatingPointFoldingAllowed()) { |
3600 | 30 | return false; |
3601 | 30 | } |
3602 | | |
3603 | 147k | FloatConstantKind kind0 = getFloatConstantKind(constants[0]); |
3604 | 147k | FloatConstantKind kind1 = getFloatConstantKind(constants[1]); |
3605 | | |
3606 | 147k | if (kind0 == FloatConstantKind::Zero || kind1 == FloatConstantKind::One) { |
3607 | 1.24k | inst->SetOpcode(spv::Op::OpCopyObject); |
3608 | 1.24k | inst->SetInOperands( |
3609 | 1.24k | {{SPV_OPERAND_TYPE_ID, {inst->GetSingleWordInOperand(0)}}}); |
3610 | 1.24k | return true; |
3611 | 1.24k | } |
3612 | | |
3613 | 145k | return false; |
3614 | 147k | }; |
3615 | 16.7k | } |
3616 | | |
3617 | 16.7k | FoldingRule RedundantFMod() { |
3618 | 16.7k | return [](IRContext*, Instruction* inst, |
3619 | 71.2k | const std::vector<const analysis::Constant*>& constants) { |
3620 | 71.2k | assert(inst->opcode() == spv::Op::OpFMod && |
3621 | 71.2k | "Wrong opcode. Should be OpFMod."); |
3622 | 71.2k | assert(constants.size() == 2); |
3623 | | |
3624 | 71.2k | if (!inst->IsFloatingPointFoldingAllowed()) { |
3625 | 6 | return false; |
3626 | 6 | } |
3627 | | |
3628 | 71.2k | FloatConstantKind kind0 = getFloatConstantKind(constants[0]); |
3629 | | |
3630 | 71.2k | if (kind0 == FloatConstantKind::Zero) { |
3631 | 5.35k | inst->SetOpcode(spv::Op::OpCopyObject); |
3632 | 5.35k | inst->SetInOperands( |
3633 | 5.35k | {{SPV_OPERAND_TYPE_ID, {inst->GetSingleWordInOperand(0)}}}); |
3634 | 5.35k | return true; |
3635 | 5.35k | } |
3636 | | |
3637 | 65.8k | return false; |
3638 | 71.2k | }; |
3639 | 16.7k | } |
3640 | | |
3641 | 10.0k | FoldingRule RedundantFMix() { |
3642 | 10.0k | return [](IRContext* context, Instruction* inst, |
3643 | 10.0k | const std::vector<const analysis::Constant*>& constants) { |
3644 | 9.40k | assert(inst->opcode() == spv::Op::OpExtInst && |
3645 | 9.40k | "Wrong opcode. Should be OpExtInst."); |
3646 | | |
3647 | 9.40k | if (!inst->IsFloatingPointFoldingAllowed()) { |
3648 | 0 | return false; |
3649 | 0 | } |
3650 | | |
3651 | 9.40k | uint32_t instSetId = |
3652 | 9.40k | context->get_feature_mgr()->GetExtInstImportId_GLSLstd450(); |
3653 | | |
3654 | 9.40k | if (inst->GetSingleWordInOperand(kExtInstSetIdInIdx) == instSetId && |
3655 | 9.40k | inst->GetSingleWordInOperand(kExtInstInstructionInIdx) == |
3656 | 9.40k | GLSLstd450FMix) { |
3657 | 9.40k | assert(constants.size() == 5); |
3658 | | |
3659 | 9.40k | FloatConstantKind kind4 = getFloatConstantKind(constants[4]); |
3660 | | |
3661 | 9.40k | if (kind4 == FloatConstantKind::Zero || kind4 == FloatConstantKind::One) { |
3662 | 121 | inst->SetOpcode(spv::Op::OpCopyObject); |
3663 | 121 | inst->SetInOperands( |
3664 | 121 | {{SPV_OPERAND_TYPE_ID, |
3665 | 121 | {inst->GetSingleWordInOperand(kind4 == FloatConstantKind::Zero |
3666 | 121 | ? kFMixXIdInIdx |
3667 | 121 | : kFMixYIdInIdx)}}}); |
3668 | 121 | return true; |
3669 | 121 | } |
3670 | 9.40k | } |
3671 | | |
3672 | 9.28k | return false; |
3673 | 9.40k | }; |
3674 | 10.0k | } |
3675 | | |
3676 | | // Returns a folding rule that folds the instruction to operand |foldToArg| |
3677 | | // (0 or 1) if operand |arg| (0 or 1) is a zero constant. |
3678 | 285k | FoldingRule RedundantBinaryOpWithZeroOperand(uint32_t arg, uint32_t foldToArg) { |
3679 | 285k | return [arg, foldToArg]( |
3680 | 285k | IRContext* context, Instruction* inst, |
3681 | 333k | const std::vector<const analysis::Constant*>& constants) { |
3682 | 333k | assert(constants.size() == 2); |
3683 | | |
3684 | 333k | if (constants[arg] && constants[arg]->IsZero()) { |
3685 | 7.79k | auto operand = inst->GetSingleWordInOperand(foldToArg); |
3686 | 7.79k | auto operand_type = constants[arg]->type(); |
3687 | | |
3688 | 7.79k | const analysis::Type* inst_type = |
3689 | 7.79k | context->get_type_mgr()->GetType(inst->type_id()); |
3690 | 7.79k | if (inst_type->IsSame(operand_type)) { |
3691 | 7.68k | inst->SetOpcode(spv::Op::OpCopyObject); |
3692 | 7.68k | } else { |
3693 | 105 | inst->SetOpcode(spv::Op::OpBitcast); |
3694 | 105 | } |
3695 | 7.79k | inst->SetInOperands({{SPV_OPERAND_TYPE_ID, {operand}}}); |
3696 | 7.79k | return true; |
3697 | 7.79k | } |
3698 | 325k | return false; |
3699 | 333k | }; |
3700 | 285k | } |
3701 | | |
3702 | | // This rule handles any of RedundantBinaryRhs0Ops with a 0 or vector 0 on the |
3703 | | // right-hand side (a | 0 => a). |
3704 | | static const constexpr spv::Op RedundantBinaryRhs0Ops[] = { |
3705 | | spv::Op::OpBitwiseOr, |
3706 | | spv::Op::OpBitwiseXor, |
3707 | | spv::Op::OpShiftRightLogical, |
3708 | | spv::Op::OpShiftRightArithmetic, |
3709 | | spv::Op::OpShiftLeftLogical, |
3710 | | spv::Op::OpIAdd, |
3711 | | spv::Op::OpISub}; |
3712 | 117k | FoldingRule RedundantBinaryRhs0(spv::Op op) { |
3713 | 117k | assert(std::find(std::begin(RedundantBinaryRhs0Ops), |
3714 | 117k | std::end(RedundantBinaryRhs0Ops), |
3715 | 117k | op) != std::end(RedundantBinaryRhs0Ops) && |
3716 | 117k | "Wrong opcode."); |
3717 | 117k | (void)op; |
3718 | 117k | return RedundantBinaryOpWithZeroOperand(1, 0); |
3719 | 117k | } |
3720 | | |
3721 | | // This rule handles any of RedundantBinaryLhs0Ops with a 0 or vector 0 on the |
3722 | | // left-hand side (0 | a => a). |
3723 | | static const constexpr spv::Op RedundantBinaryLhs0Ops[] = { |
3724 | | spv::Op::OpBitwiseOr, spv::Op::OpBitwiseXor, spv::Op::OpIAdd}; |
3725 | 50.3k | FoldingRule RedundantBinaryLhs0(spv::Op op) { |
3726 | 50.3k | assert(std::find(std::begin(RedundantBinaryLhs0Ops), |
3727 | 50.3k | std::end(RedundantBinaryLhs0Ops), |
3728 | 50.3k | op) != std::end(RedundantBinaryLhs0Ops) && |
3729 | 50.3k | "Wrong opcode."); |
3730 | 50.3k | (void)op; |
3731 | 50.3k | return RedundantBinaryOpWithZeroOperand(0, 1); |
3732 | 50.3k | } |
3733 | | |
3734 | | // This rule handles shifts and divisions of 0 or vector 0 by any amount |
3735 | | // (0 >> a => 0). |
3736 | | static const constexpr spv::Op RedundantBinaryLhs0To0Ops[] = { |
3737 | | spv::Op::OpShiftRightLogical, |
3738 | | spv::Op::OpShiftRightArithmetic, |
3739 | | spv::Op::OpShiftLeftLogical, |
3740 | | spv::Op::OpSDiv, |
3741 | | spv::Op::OpUDiv, |
3742 | | spv::Op::OpSMod, |
3743 | | spv::Op::OpUMod}; |
3744 | 117k | FoldingRule RedundantBinaryLhs0To0(spv::Op op) { |
3745 | 117k | assert(std::find(std::begin(RedundantBinaryLhs0To0Ops), |
3746 | 117k | std::end(RedundantBinaryLhs0To0Ops), |
3747 | 117k | op) != std::end(RedundantBinaryLhs0To0Ops) && |
3748 | 117k | "Wrong opcode."); |
3749 | 117k | (void)op; |
3750 | 117k | return RedundantBinaryOpWithZeroOperand(0, 0); |
3751 | 117k | } |
3752 | | |
3753 | 50.3k | FoldingRule ReassociateCommutiveOp() { |
3754 | 50.3k | return [](IRContext* context, Instruction* inst, |
3755 | 50.4k | const std::vector<const analysis::Constant*>& constants) { |
3756 | 50.4k | const analysis::Type* type = |
3757 | 50.4k | context->get_type_mgr()->GetType(inst->type_id()); |
3758 | 50.4k | uint32_t width = ElementWidth(type); |
3759 | 50.4k | if (width != 32) return false; |
3760 | | |
3761 | 50.4k | analysis::ConstantManager* const_mgr = context->get_constant_mgr(); |
3762 | 50.4k | const analysis::Constant* const_input1 = ConstInput(constants); |
3763 | 50.4k | if (!const_input1) return false; |
3764 | 20.6k | Instruction* other_inst = NonConstInput(context, constants[0], inst); |
3765 | | |
3766 | 20.6k | if (other_inst->opcode() == inst->opcode()) { |
3767 | 7.10k | std::vector<const analysis::Constant*> other_constants = |
3768 | 7.10k | const_mgr->GetOperandConstants(other_inst); |
3769 | 7.10k | const analysis::Constant* const_input2 = ConstInput(other_constants); |
3770 | 7.10k | if (!const_input2) return false; |
3771 | | |
3772 | 1.86k | Instruction* non_const_input = |
3773 | 1.86k | NonConstInput(context, other_constants[0], other_inst); |
3774 | 1.86k | uint32_t merged_id = PerformOperation(const_mgr, inst->opcode(), |
3775 | 1.86k | const_input1, const_input2); |
3776 | | |
3777 | 1.86k | if (merged_id == 0) return false; |
3778 | 1.86k | inst->SetInOperands( |
3779 | 1.86k | {{SPV_OPERAND_TYPE_ID, {non_const_input->result_id()}}, |
3780 | 1.86k | {SPV_OPERAND_TYPE_ID, {merged_id}}}); |
3781 | 1.86k | return true; |
3782 | 1.86k | } |
3783 | | |
3784 | 13.5k | return false; |
3785 | 20.6k | }; |
3786 | 50.3k | } |
3787 | | |
3788 | | // A | (b | C) = b | (A | C) |
3789 | | // A ^ (b ^ C) = b ^ (A ^ C) |
3790 | | // A & (b & C) = b & (A & C) |
3791 | | // Where A and C are constants |
3792 | | static const constexpr spv::Op ReassociateCommutiveBitwiseOps[] = { |
3793 | | spv::Op::OpBitwiseOr, spv::Op::OpBitwiseXor, spv::Op::OpBitwiseAnd}; |
3794 | 50.3k | FoldingRule ReassociateCommutiveBitwise(spv::Op op) { |
3795 | 50.3k | assert(std::find(std::begin(ReassociateCommutiveBitwiseOps), |
3796 | 50.3k | std::end(ReassociateCommutiveBitwiseOps), |
3797 | 50.3k | op) != std::end(ReassociateCommutiveBitwiseOps) && |
3798 | 50.3k | "Wrong opcode."); |
3799 | 50.3k | (void)op; |
3800 | 50.3k | return ReassociateCommutiveOp(); |
3801 | 50.3k | } |
3802 | | |
3803 | | // Returns true if all elements in |c| are 1. |
3804 | 16.8k | bool IsAllInt1(const analysis::Constant* c) { |
3805 | 16.8k | if (auto composite = c->AsCompositeConstant()) { |
3806 | 0 | auto& components = composite->GetComponents(); |
3807 | 0 | return std::all_of(std::begin(components), std::end(components), IsAllInt1); |
3808 | 16.8k | } else if (c->AsIntConstant()) { |
3809 | 16.7k | return c->GetSignExtendedValue() == 1; |
3810 | 16.7k | } |
3811 | | |
3812 | 27 | return false; |
3813 | 16.8k | } |
3814 | | |
3815 | | // This rule handles divisions by 1 or vector 1 (a / 1 => a). |
3816 | 33.5k | FoldingRule RedundantSUDiv() { |
3817 | 33.5k | return [](IRContext* context, Instruction* inst, |
3818 | 33.5k | const std::vector<const analysis::Constant*>& constants) { |
3819 | 12.7k | assert(constants.size() == 2); |
3820 | 12.7k | assert((inst->opcode() == spv::Op::OpUDiv || |
3821 | 12.7k | inst->opcode() == spv::Op::OpSDiv) && |
3822 | 12.7k | "Wrong opcode."); |
3823 | | |
3824 | 12.7k | if (constants[1] && IsAllInt1(constants[1])) { |
3825 | 1.06k | auto operand = inst->GetSingleWordInOperand(0); |
3826 | 1.06k | auto operand_type = constants[1]->type(); |
3827 | | |
3828 | 1.06k | const analysis::Type* inst_type = |
3829 | 1.06k | context->get_type_mgr()->GetType(inst->type_id()); |
3830 | 1.06k | if (inst_type->IsSame(operand_type)) { |
3831 | 786 | inst->SetOpcode(spv::Op::OpCopyObject); |
3832 | 786 | } else { |
3833 | 274 | inst->SetOpcode(spv::Op::OpBitcast); |
3834 | 274 | } |
3835 | 1.06k | inst->SetInOperands({{SPV_OPERAND_TYPE_ID, {operand}}}); |
3836 | 1.06k | return true; |
3837 | 1.06k | } |
3838 | 11.7k | return false; |
3839 | 12.7k | }; |
3840 | 33.5k | } |
3841 | | |
3842 | | // This rule handles modulo from division by 1 or vector 1 (a % 1 => 0). |
3843 | 33.5k | FoldingRule RedundantSUMod() { |
3844 | 33.5k | return [](IRContext* context, Instruction* inst, |
3845 | 33.5k | const std::vector<const analysis::Constant*>& constants) { |
3846 | 6.48k | assert(constants.size() == 2); |
3847 | 6.48k | assert((inst->opcode() == spv::Op::OpUMod || |
3848 | 6.48k | inst->opcode() == spv::Op::OpSMod) && |
3849 | 6.48k | "Wrong opcode."); |
3850 | | |
3851 | 6.48k | if (constants[1] && IsAllInt1(constants[1])) { |
3852 | 807 | auto type = context->get_type_mgr()->GetType(inst->type_id()); |
3853 | 807 | auto zero_id = context->get_constant_mgr()->GetNullConstId(type); |
3854 | | |
3855 | 807 | inst->SetOpcode(spv::Op::OpCopyObject); |
3856 | 807 | inst->SetInOperands({{SPV_OPERAND_TYPE_ID, {zero_id}}}); |
3857 | 807 | return true; |
3858 | 807 | } |
3859 | 5.68k | return false; |
3860 | 6.48k | }; |
3861 | 33.5k | } |
3862 | | |
3863 | | // Utility function for applying |callback| to |input1| and |input2|. |
3864 | | // If they are vectors it applies element wise. |
3865 | | // The constants |input1| and |input2| must be integers or a vector of integers. |
3866 | | template <typename Callback> |
3867 | | void ForEachIntegerConstantPair(analysis::ConstantManager* const_mgr, |
3868 | | const analysis::Constant* input1, |
3869 | | const analysis::Constant* input2, |
3870 | 1.55k | Callback&& callback) { |
3871 | 1.55k | assert(input1 && input2); |
3872 | | |
3873 | 1.55k | auto Dispatch = [&callback](const analysis::Constant* lhs, |
3874 | 1.55k | const analysis::Constant* rhs) { |
3875 | 1.55k | assert(lhs->type()->AsInteger()); |
3876 | 1.55k | const analysis::Integer* type = lhs->type()->AsInteger(); |
3877 | 1.55k | uint32_t width = type->AsInteger()->width(); |
3878 | 1.55k | assert(width == 32 || width == 64); |
3879 | 1.55k | if (width == 32) { |
3880 | 1.55k | callback(lhs->GetU32(), rhs->GetU32()); |
3881 | 1.55k | } else { |
3882 | 0 | callback(lhs->GetU64(), rhs->GetU64()); |
3883 | 0 | } |
3884 | 1.55k | }; folding_rules.cpp:spvtools::opt::(anonymous namespace)::ForEachIntegerConstantPair<spvtools::opt::(anonymous namespace)::RedundantAndOrXor()::$_0::operator()(spvtools::opt::IRContext*, spvtools::opt::Instruction*, std::__1::vector<spvtools::opt::analysis::Constant const*, std::__1::allocator<spvtools::opt::analysis::Constant const*> > const&) const::{lambda(auto:1, auto:2)#1}>(spvtools::opt::analysis::ConstantManager*, spvtools::opt::analysis::Constant const*, spvtools::opt::analysis::Constant const*, spvtools::opt::(anonymous namespace)::RedundantAndOrXor()::$_0::operator()(spvtools::opt::IRContext*, spvtools::opt::Instruction*, std::__1::vector<spvtools::opt::analysis::Constant const*, std::__1::allocator<spvtools::opt::analysis::Constant const*> > const&) const::{lambda(auto:1, auto:2)#1}&&)::{lambda(spvtools::opt::analysis::Constant const*, spvtools::opt::analysis::Constant const*)#1}::operator()(spvtools::opt::analysis::Constant const*, spvtools::opt::analysis::Constant const*) constLine | Count | Source | 3874 | 812 | const analysis::Constant* rhs) { | 3875 | 812 | assert(lhs->type()->AsInteger()); | 3876 | 812 | const analysis::Integer* type = lhs->type()->AsInteger(); | 3877 | 812 | uint32_t width = type->AsInteger()->width(); | 3878 | 812 | assert(width == 32 || width == 64); | 3879 | 812 | if (width == 32) { | 3880 | 812 | callback(lhs->GetU32(), rhs->GetU32()); | 3881 | 812 | } else { | 3882 | 0 | callback(lhs->GetU64(), rhs->GetU64()); | 3883 | 0 | } | 3884 | 812 | }; |
folding_rules.cpp:spvtools::opt::(anonymous namespace)::ForEachIntegerConstantPair<spvtools::opt::(anonymous namespace)::RedundantAndAddSub()::$_0::operator()(spvtools::opt::IRContext*, spvtools::opt::Instruction*, std::__1::vector<spvtools::opt::analysis::Constant const*, std::__1::allocator<spvtools::opt::analysis::Constant const*> > const&) const::{lambda(auto:1, auto:2)#1}>(spvtools::opt::analysis::ConstantManager*, spvtools::opt::analysis::Constant const*, spvtools::opt::analysis::Constant const*, spvtools::opt::(anonymous namespace)::RedundantAndAddSub()::$_0::operator()(spvtools::opt::IRContext*, spvtools::opt::Instruction*, std::__1::vector<spvtools::opt::analysis::Constant const*, std::__1::allocator<spvtools::opt::analysis::Constant const*> > const&) const::{lambda(auto:1, auto:2)#1}&&)::{lambda(spvtools::opt::analysis::Constant const*, spvtools::opt::analysis::Constant const*)#1}::operator()(spvtools::opt::analysis::Constant const*, spvtools::opt::analysis::Constant const*) constLine | Count | Source | 3874 | 744 | const analysis::Constant* rhs) { | 3875 | 744 | assert(lhs->type()->AsInteger()); | 3876 | 744 | const analysis::Integer* type = lhs->type()->AsInteger(); | 3877 | 744 | uint32_t width = type->AsInteger()->width(); | 3878 | 744 | assert(width == 32 || width == 64); | 3879 | 744 | if (width == 32) { | 3880 | 744 | callback(lhs->GetU32(), rhs->GetU32()); | 3881 | 744 | } else { | 3882 | 0 | callback(lhs->GetU64(), rhs->GetU64()); | 3883 | 0 | } | 3884 | 744 | }; |
|
3885 | | |
3886 | 1.55k | const analysis::Type* type = input1->type(); |
3887 | 1.55k | if (const analysis::Vector* vector_type = type->AsVector()) { |
3888 | 0 | const analysis::Type* ele_type = vector_type->element_type(); |
3889 | 0 | assert(ele_type->AsInteger()); |
3890 | 0 | for (uint32_t i = 0; i != vector_type->element_count(); ++i) { |
3891 | 0 | const analysis::Constant* input1_comp = nullptr; |
3892 | 0 | if (const analysis::VectorConstant* input1_vector = |
3893 | 0 | input1->AsVectorConstant()) { |
3894 | 0 | input1_comp = input1_vector->GetComponents()[i]; |
3895 | 0 | } else { |
3896 | 0 | assert(input1->AsNullConstant()); |
3897 | 0 | input1_comp = const_mgr->GetConstant(ele_type, {}); |
3898 | 0 | } |
3899 | | |
3900 | 0 | const analysis::Constant* input2_comp = nullptr; |
3901 | 0 | if (const analysis::VectorConstant* input2_vector = |
3902 | 0 | input2->AsVectorConstant()) { |
3903 | 0 | input2_comp = input2_vector->GetComponents()[i]; |
3904 | 0 | } else { |
3905 | 0 | assert(input2->AsNullConstant()); |
3906 | 0 | input2_comp = const_mgr->GetConstant(ele_type, {}); |
3907 | 0 | } |
3908 | | |
3909 | 0 | assert(ele_type->AsInteger()); |
3910 | 0 | Dispatch(input1_comp, input2_comp); |
3911 | 0 | } |
3912 | |
|
3913 | 1.55k | } else { |
3914 | 1.55k | assert(type->AsInteger()); |
3915 | 1.55k | Dispatch(input1, input2); |
3916 | 1.55k | } |
3917 | 1.55k | } folding_rules.cpp:void spvtools::opt::(anonymous namespace)::ForEachIntegerConstantPair<spvtools::opt::(anonymous namespace)::RedundantAndOrXor()::$_0::operator()(spvtools::opt::IRContext*, spvtools::opt::Instruction*, std::__1::vector<spvtools::opt::analysis::Constant const*, std::__1::allocator<spvtools::opt::analysis::Constant const*> > const&) const::{lambda(auto:1, auto:2)#1}>(spvtools::opt::analysis::ConstantManager*, spvtools::opt::analysis::Constant const*, spvtools::opt::analysis::Constant const*, spvtools::opt::(anonymous namespace)::RedundantAndOrXor()::$_0::operator()(spvtools::opt::IRContext*, spvtools::opt::Instruction*, std::__1::vector<spvtools::opt::analysis::Constant const*, std::__1::allocator<spvtools::opt::analysis::Constant const*> > const&) const::{lambda(auto:1, auto:2)#1}&&)Line | Count | Source | 3870 | 812 | Callback&& callback) { | 3871 | 812 | assert(input1 && input2); | 3872 | | | 3873 | 812 | auto Dispatch = [&callback](const analysis::Constant* lhs, | 3874 | 812 | const analysis::Constant* rhs) { | 3875 | 812 | assert(lhs->type()->AsInteger()); | 3876 | 812 | const analysis::Integer* type = lhs->type()->AsInteger(); | 3877 | 812 | uint32_t width = type->AsInteger()->width(); | 3878 | 812 | assert(width == 32 || width == 64); | 3879 | 812 | if (width == 32) { | 3880 | 812 | callback(lhs->GetU32(), rhs->GetU32()); | 3881 | 812 | } else { | 3882 | 812 | callback(lhs->GetU64(), rhs->GetU64()); | 3883 | 812 | } | 3884 | 812 | }; | 3885 | | | 3886 | 812 | const analysis::Type* type = input1->type(); | 3887 | 812 | if (const analysis::Vector* vector_type = type->AsVector()) { | 3888 | 0 | const analysis::Type* ele_type = vector_type->element_type(); | 3889 | 0 | assert(ele_type->AsInteger()); | 3890 | 0 | for (uint32_t i = 0; i != vector_type->element_count(); ++i) { | 3891 | 0 | const analysis::Constant* input1_comp = nullptr; | 3892 | 0 | if (const analysis::VectorConstant* input1_vector = | 3893 | 0 | input1->AsVectorConstant()) { | 3894 | 0 | input1_comp = input1_vector->GetComponents()[i]; | 3895 | 0 | } else { | 3896 | 0 | assert(input1->AsNullConstant()); | 3897 | 0 | input1_comp = const_mgr->GetConstant(ele_type, {}); | 3898 | 0 | } | 3899 | | | 3900 | 0 | const analysis::Constant* input2_comp = nullptr; | 3901 | 0 | if (const analysis::VectorConstant* input2_vector = | 3902 | 0 | input2->AsVectorConstant()) { | 3903 | 0 | input2_comp = input2_vector->GetComponents()[i]; | 3904 | 0 | } else { | 3905 | 0 | assert(input2->AsNullConstant()); | 3906 | 0 | input2_comp = const_mgr->GetConstant(ele_type, {}); | 3907 | 0 | } | 3908 | | | 3909 | 0 | assert(ele_type->AsInteger()); | 3910 | 0 | Dispatch(input1_comp, input2_comp); | 3911 | 0 | } | 3912 | |
| 3913 | 812 | } else { | 3914 | 812 | assert(type->AsInteger()); | 3915 | 812 | Dispatch(input1, input2); | 3916 | 812 | } | 3917 | 812 | } |
folding_rules.cpp:void spvtools::opt::(anonymous namespace)::ForEachIntegerConstantPair<spvtools::opt::(anonymous namespace)::RedundantAndAddSub()::$_0::operator()(spvtools::opt::IRContext*, spvtools::opt::Instruction*, std::__1::vector<spvtools::opt::analysis::Constant const*, std::__1::allocator<spvtools::opt::analysis::Constant const*> > const&) const::{lambda(auto:1, auto:2)#1}>(spvtools::opt::analysis::ConstantManager*, spvtools::opt::analysis::Constant const*, spvtools::opt::analysis::Constant const*, spvtools::opt::(anonymous namespace)::RedundantAndAddSub()::$_0::operator()(spvtools::opt::IRContext*, spvtools::opt::Instruction*, std::__1::vector<spvtools::opt::analysis::Constant const*, std::__1::allocator<spvtools::opt::analysis::Constant const*> > const&) const::{lambda(auto:1, auto:2)#1}&&)Line | Count | Source | 3870 | 744 | Callback&& callback) { | 3871 | 744 | assert(input1 && input2); | 3872 | | | 3873 | 744 | auto Dispatch = [&callback](const analysis::Constant* lhs, | 3874 | 744 | const analysis::Constant* rhs) { | 3875 | 744 | assert(lhs->type()->AsInteger()); | 3876 | 744 | const analysis::Integer* type = lhs->type()->AsInteger(); | 3877 | 744 | uint32_t width = type->AsInteger()->width(); | 3878 | 744 | assert(width == 32 || width == 64); | 3879 | 744 | if (width == 32) { | 3880 | 744 | callback(lhs->GetU32(), rhs->GetU32()); | 3881 | 744 | } else { | 3882 | 744 | callback(lhs->GetU64(), rhs->GetU64()); | 3883 | 744 | } | 3884 | 744 | }; | 3885 | | | 3886 | 744 | const analysis::Type* type = input1->type(); | 3887 | 744 | if (const analysis::Vector* vector_type = type->AsVector()) { | 3888 | 0 | const analysis::Type* ele_type = vector_type->element_type(); | 3889 | 0 | assert(ele_type->AsInteger()); | 3890 | 0 | for (uint32_t i = 0; i != vector_type->element_count(); ++i) { | 3891 | 0 | const analysis::Constant* input1_comp = nullptr; | 3892 | 0 | if (const analysis::VectorConstant* input1_vector = | 3893 | 0 | input1->AsVectorConstant()) { | 3894 | 0 | input1_comp = input1_vector->GetComponents()[i]; | 3895 | 0 | } else { | 3896 | 0 | assert(input1->AsNullConstant()); | 3897 | 0 | input1_comp = const_mgr->GetConstant(ele_type, {}); | 3898 | 0 | } | 3899 | | | 3900 | 0 | const analysis::Constant* input2_comp = nullptr; | 3901 | 0 | if (const analysis::VectorConstant* input2_vector = | 3902 | 0 | input2->AsVectorConstant()) { | 3903 | 0 | input2_comp = input2_vector->GetComponents()[i]; | 3904 | 0 | } else { | 3905 | 0 | assert(input2->AsNullConstant()); | 3906 | 0 | input2_comp = const_mgr->GetConstant(ele_type, {}); | 3907 | 0 | } | 3908 | | | 3909 | 0 | assert(ele_type->AsInteger()); | 3910 | 0 | Dispatch(input1_comp, input2_comp); | 3911 | 0 | } | 3912 | |
| 3913 | 744 | } else { | 3914 | 744 | assert(type->AsInteger()); | 3915 | 744 | Dispatch(input1, input2); | 3916 | 744 | } | 3917 | 744 | } |
|
3918 | | |
3919 | | // Folds redundant xor and or ops that are part of an and. |
3920 | | // Cases handled: |
3921 | | // 0b1110 & (a | 0b0001) = a & 0b1110 |
3922 | | // 0b1110 & (a ^ 0b0001) = a & 0b1110 |
3923 | | // 0b0110 & (a | 0b1110) = 0b0110 |
3924 | 16.7k | FoldingRule RedundantAndOrXor() { |
3925 | 16.7k | return [](IRContext* context, Instruction* inst, |
3926 | 16.7k | const std::vector<const analysis::Constant*>& constants) { |
3927 | 12.8k | assert(inst->opcode() == spv::Op::OpBitwiseAnd && "Wrong opcode."); |
3928 | 12.8k | const analysis::Type* type = |
3929 | 12.8k | context->get_type_mgr()->GetType(inst->type_id()); |
3930 | 12.8k | uint32_t width = ElementWidth(type); |
3931 | 12.8k | if ((width != 32) && (width != 64)) return false; |
3932 | | |
3933 | 12.8k | analysis::ConstantManager* const_mgr = context->get_constant_mgr(); |
3934 | 12.8k | const analysis::Constant* const_input1 = ConstInput(constants); |
3935 | 12.8k | if (!const_input1) return false; |
3936 | 9.23k | Instruction* other_inst = NonConstInput(context, constants[0], inst); |
3937 | | |
3938 | 9.23k | if (other_inst->opcode() == spv::Op::OpBitwiseOr || |
3939 | 7.74k | other_inst->opcode() == spv::Op::OpBitwiseXor) { |
3940 | 1.60k | std::vector<const analysis::Constant*> other_constants = |
3941 | 1.60k | const_mgr->GetOperandConstants(other_inst); |
3942 | 1.60k | const analysis::Constant* const_input2 = ConstInput(other_constants); |
3943 | 1.60k | if (!const_input2) return false; |
3944 | | |
3945 | 812 | bool can_convert_to_const = other_inst->opcode() == spv::Op::OpBitwiseOr; |
3946 | 812 | bool can_remove_inner = true; |
3947 | | |
3948 | 812 | ForEachIntegerConstantPair( |
3949 | 812 | const_mgr, const_input1, const_input2, |
3950 | 812 | [&can_remove_inner, &can_convert_to_const](auto lhs, auto rhs) { |
3951 | | // Only convert to const if 'and' is a subset of 'or' |
3952 | 812 | can_convert_to_const = can_convert_to_const && ((lhs & rhs) == lhs); |
3953 | | // Only remove 'xor'/'or' if no bits intersect with 'and' |
3954 | 812 | can_remove_inner = can_remove_inner && ((lhs & rhs) == 0); |
3955 | 812 | }); folding_rules.cpp:auto spvtools::opt::(anonymous namespace)::RedundantAndOrXor()::$_0::operator()(spvtools::opt::IRContext*, spvtools::opt::Instruction*, std::__1::vector<spvtools::opt::analysis::Constant const*, std::__1::allocator<spvtools::opt::analysis::Constant const*> > const&) const::{lambda(auto:1, auto:2)#1}::operator()<unsigned int, unsigned int>(unsigned int, unsigned int) constLine | Count | Source | 3950 | 812 | [&can_remove_inner, &can_convert_to_const](auto lhs, auto rhs) { | 3951 | | // Only convert to const if 'and' is a subset of 'or' | 3952 | 812 | can_convert_to_const = can_convert_to_const && ((lhs & rhs) == lhs); | 3953 | | // Only remove 'xor'/'or' if no bits intersect with 'and' | 3954 | 812 | can_remove_inner = can_remove_inner && ((lhs & rhs) == 0); | 3955 | 812 | }); |
Unexecuted instantiation: folding_rules.cpp:auto spvtools::opt::(anonymous namespace)::RedundantAndOrXor()::$_0::operator()(spvtools::opt::IRContext*, spvtools::opt::Instruction*, std::__1::vector<spvtools::opt::analysis::Constant const*, std::__1::allocator<spvtools::opt::analysis::Constant const*> > const&) const::{lambda(auto:1, auto:2)#1}::operator()<unsigned long, unsigned long>(unsigned long, unsigned long) const |
3956 | | |
3957 | 812 | if (can_convert_to_const) { |
3958 | 63 | Instruction* const_inst = |
3959 | 63 | const_mgr->GetDefiningInstruction(const_input1); |
3960 | 63 | inst->SetOpcode(spv::Op::OpCopyObject); |
3961 | 63 | inst->SetInOperands({{SPV_OPERAND_TYPE_ID, {const_inst->result_id()}}}); |
3962 | 63 | return true; |
3963 | 749 | } else if (can_remove_inner) { |
3964 | 59 | Instruction* non_const_input = |
3965 | 59 | NonConstInput(context, other_constants[0], other_inst); |
3966 | 59 | Instruction* const_inst = |
3967 | 59 | const_mgr->GetDefiningInstruction(const_input1); |
3968 | 59 | inst->SetInOperands( |
3969 | 59 | {{SPV_OPERAND_TYPE_ID, {non_const_input->result_id()}}, |
3970 | 59 | {SPV_OPERAND_TYPE_ID, {const_inst->result_id()}}}); |
3971 | 59 | return true; |
3972 | 59 | } |
3973 | 812 | } |
3974 | 8.31k | return false; |
3975 | 9.23k | }; |
3976 | 16.7k | } |
3977 | | |
3978 | | // Folds redundant add and sub ops that are part of an and. |
3979 | | // Cases handled: |
3980 | | // 1 & (b + 2) = b & 1 |
3981 | | // 1 & (b - 2) = b & 1 |
3982 | 16.7k | FoldingRule RedundantAndAddSub() { |
3983 | 16.7k | return [](IRContext* context, Instruction* inst, |
3984 | 16.7k | const std::vector<const analysis::Constant*>& constants) { |
3985 | 12.7k | assert(inst->opcode() == spv::Op::OpBitwiseAnd && "Wrong opcode."); |
3986 | 12.7k | const analysis::Type* type = |
3987 | 12.7k | context->get_type_mgr()->GetType(inst->type_id()); |
3988 | 12.7k | uint32_t width = ElementWidth(type); |
3989 | 12.7k | if ((width != 32) && (width != 64)) return false; |
3990 | | |
3991 | 12.7k | analysis::ConstantManager* const_mgr = context->get_constant_mgr(); |
3992 | 12.7k | const analysis::Constant* const_input1 = ConstInput(constants); |
3993 | 12.7k | if (!const_input1) return false; |
3994 | 9.10k | Instruction* other_inst = NonConstInput(context, constants[0], inst); |
3995 | | |
3996 | 9.10k | if (other_inst->opcode() != spv::Op::OpIAdd && |
3997 | 8.35k | other_inst->opcode() != spv::Op::OpISub) { |
3998 | 8.20k | return false; |
3999 | 8.20k | } |
4000 | 902 | std::vector<const analysis::Constant*> other_constants = |
4001 | 902 | const_mgr->GetOperandConstants(other_inst); |
4002 | 902 | const analysis::Constant* const_input2 = ConstInput(other_constants); |
4003 | 902 | if (!const_input2) return false; |
4004 | | |
4005 | | // Only valid for subtraction if const is on the right |
4006 | 774 | if ((other_inst->opcode() == spv::Op::OpISub) && other_constants[0]) { |
4007 | 30 | return false; |
4008 | 30 | } |
4009 | | |
4010 | 744 | bool can_remove_inner = true; |
4011 | 744 | ForEachIntegerConstantPair(const_mgr, const_input1, const_input2, |
4012 | 744 | [&can_remove_inner](auto and_op, auto add_op) { |
4013 | 744 | if (can_remove_inner) { |
4014 | | // Only valid if no bits from the +/- could |
4015 | | // affect bits from the & operation. |
4016 | 744 | can_remove_inner = |
4017 | 744 | utils::LSB(add_op) > and_op; |
4018 | 744 | } |
4019 | 744 | }); folding_rules.cpp:auto spvtools::opt::(anonymous namespace)::RedundantAndAddSub()::$_0::operator()(spvtools::opt::IRContext*, spvtools::opt::Instruction*, std::__1::vector<spvtools::opt::analysis::Constant const*, std::__1::allocator<spvtools::opt::analysis::Constant const*> > const&) const::{lambda(auto:1, auto:2)#1}::operator()<unsigned int, unsigned int>(unsigned int, unsigned int) constLine | Count | Source | 4012 | 744 | [&can_remove_inner](auto and_op, auto add_op) { | 4013 | 744 | if (can_remove_inner) { | 4014 | | // Only valid if no bits from the +/- could | 4015 | | // affect bits from the & operation. | 4016 | 744 | can_remove_inner = | 4017 | 744 | utils::LSB(add_op) > and_op; | 4018 | 744 | } | 4019 | 744 | }); |
Unexecuted instantiation: folding_rules.cpp:auto spvtools::opt::(anonymous namespace)::RedundantAndAddSub()::$_0::operator()(spvtools::opt::IRContext*, spvtools::opt::Instruction*, std::__1::vector<spvtools::opt::analysis::Constant const*, std::__1::allocator<spvtools::opt::analysis::Constant const*> > const&) const::{lambda(auto:1, auto:2)#1}::operator()<unsigned long, unsigned long>(unsigned long, unsigned long) const |
4020 | | |
4021 | 744 | if (can_remove_inner) { |
4022 | 22 | Instruction* non_const_input = |
4023 | 22 | NonConstInput(context, other_constants[0], other_inst); |
4024 | 22 | Instruction* const_inst = const_mgr->GetDefiningInstruction(const_input1); |
4025 | 22 | inst->SetInOperands( |
4026 | 22 | {{SPV_OPERAND_TYPE_ID, {non_const_input->result_id()}}, |
4027 | 22 | {SPV_OPERAND_TYPE_ID, {const_inst->result_id()}}}); |
4028 | 22 | return true; |
4029 | 22 | } |
4030 | 722 | return false; |
4031 | 744 | }; |
4032 | 16.7k | } |
4033 | | |
4034 | | // Folds redundant shift ops that are part of an and. |
4035 | | // Cases handled: |
4036 | | // 1 & (b << 1) = 0 |
4037 | | // 0x80000000 & (b >> 1) = 0 |
4038 | 16.7k | FoldingRule RedundantAndShift() { |
4039 | 16.7k | return [](IRContext* context, Instruction* inst, |
4040 | 16.7k | const std::vector<const analysis::Constant*>& constants) { |
4041 | 12.6k | assert(inst->opcode() == spv::Op::OpBitwiseAnd && "Wrong opcode."); |
4042 | 12.6k | const analysis::Type* type = |
4043 | 12.6k | context->get_type_mgr()->GetType(inst->type_id()); |
4044 | 12.6k | uint32_t width = ElementWidth(type); |
4045 | 12.6k | if (width != 8 && width != 16 && width != 32 && width != 64) return false; |
4046 | 12.6k | const uint64_t width_mask = |
4047 | 12.6k | (width == 64) ? ~0ull : ((1ull << width) - 1ull); |
4048 | | |
4049 | 12.6k | analysis::ConstantManager* const_mgr = context->get_constant_mgr(); |
4050 | 12.6k | const analysis::Constant* const_input1 = ConstInput(constants); |
4051 | 12.6k | if (!const_input1) return false; |
4052 | 9.08k | Instruction* other_inst = NonConstInput(context, constants[0], inst); |
4053 | | |
4054 | 9.08k | spv::Op other_op = other_inst->opcode(); |
4055 | 9.08k | if (other_op != spv::Op::OpShiftLeftLogical && |
4056 | 8.90k | other_op != spv::Op::OpShiftRightLogical) { |
4057 | 8.65k | return false; |
4058 | 8.65k | } |
4059 | | |
4060 | 431 | std::vector<const analysis::Constant*> other_constants = |
4061 | 431 | const_mgr->GetOperandConstants(other_inst); |
4062 | | |
4063 | | // Only valid if const is on the right. |
4064 | 431 | if (other_constants[0]) return false; |
4065 | 352 | const analysis::Constant* const_input2 = other_constants[1]; |
4066 | 352 | if (!const_input2) return false; |
4067 | | |
4068 | 138 | auto get_value_u64 = |
4069 | 276 | [](const analysis::Constant* c) -> std::optional<uint64_t> { |
4070 | 276 | if (!c) return std::nullopt; |
4071 | 276 | const analysis::Integer* int_t = c->type()->AsInteger(); |
4072 | 276 | if (!int_t) return std::nullopt; |
4073 | 276 | return c->GetZeroExtendedValue(); |
4074 | 276 | }; |
4075 | | |
4076 | 138 | auto can_fold_component = |
4077 | 138 | [&](const analysis::Constant* mask_const, |
4078 | 138 | const analysis::Constant* shift_const) -> std::optional<bool> { |
4079 | 138 | auto lhs = get_value_u64(mask_const); |
4080 | 138 | auto rhs = get_value_u64(shift_const); |
4081 | 138 | if (!lhs || !rhs) return std::nullopt; |
4082 | 138 | if (*rhs >= width) return false; |
4083 | 138 | uint64_t lhs_masked = *lhs & width_mask; |
4084 | 138 | if (other_op == spv::Op::OpShiftRightLogical) { |
4085 | 69 | return ((lhs_masked << *rhs) & width_mask) == 0; |
4086 | 69 | } |
4087 | 69 | return ((lhs_masked >> *rhs) & width_mask) == 0; |
4088 | 138 | }; |
4089 | | |
4090 | 138 | if (const analysis::Vector* mask_vec = type->AsVector()) { |
4091 | 0 | const analysis::Vector* shift_vec = const_input2->type()->AsVector(); |
4092 | 0 | if (!shift_vec || |
4093 | 0 | shift_vec->element_count() != mask_vec->element_count()) { |
4094 | 0 | return false; |
4095 | 0 | } |
4096 | 0 | const auto mask_components = const_input1->GetVectorComponents(const_mgr); |
4097 | 0 | const auto shift_components = |
4098 | 0 | const_input2->GetVectorComponents(const_mgr); |
4099 | 0 | for (uint32_t i = 0; i != mask_vec->element_count(); ++i) { |
4100 | 0 | auto result = |
4101 | 0 | can_fold_component(mask_components[i], shift_components[i]); |
4102 | 0 | if (!result || !*result) return false; |
4103 | 0 | } |
4104 | 138 | } else { |
4105 | 138 | if (const_input2->type()->AsVector()) return false; |
4106 | 138 | auto result = can_fold_component(const_input1, const_input2); |
4107 | 138 | if (!result || !*result) return false; |
4108 | 138 | } |
4109 | | |
4110 | 33 | auto zero_id = context->get_constant_mgr()->GetNullConstId(type); |
4111 | 33 | inst->SetOpcode(spv::Op::OpCopyObject); |
4112 | 33 | inst->SetInOperands({{SPV_OPERAND_TYPE_ID, {zero_id}}}); |
4113 | 33 | return true; |
4114 | 138 | }; |
4115 | 16.7k | } |
4116 | | |
4117 | | // This rule look for a dot with a constant vector containing a single 1 and |
4118 | | // the rest 0s. This is the same as doing an extract. |
4119 | 16.7k | FoldingRule DotProductDoingExtract() { |
4120 | 16.7k | return [](IRContext* context, Instruction* inst, |
4121 | 16.7k | const std::vector<const analysis::Constant*>& constants) { |
4122 | 51 | assert(inst->opcode() == spv::Op::OpDot && |
4123 | 51 | "Wrong opcode. Should be OpDot."); |
4124 | | |
4125 | 51 | analysis::ConstantManager* const_mgr = context->get_constant_mgr(); |
4126 | | |
4127 | 51 | if (!inst->IsFloatingPointFoldingAllowed()) { |
4128 | 0 | return false; |
4129 | 0 | } |
4130 | | |
4131 | 153 | for (int i = 0; i < 2; ++i) { |
4132 | 102 | if (!constants[i]) { |
4133 | 60 | continue; |
4134 | 60 | } |
4135 | | |
4136 | 42 | const analysis::Vector* vector_type = constants[i]->type()->AsVector(); |
4137 | 42 | assert(vector_type && "Inputs to OpDot must be vectors."); |
4138 | 42 | const analysis::Float* element_type = |
4139 | 42 | vector_type->element_type()->AsFloat(); |
4140 | 42 | assert(element_type && "Inputs to OpDot must be vectors of floats."); |
4141 | 42 | uint32_t element_width = element_type->width(); |
4142 | 42 | if (element_width != 32 && element_width != 64) { |
4143 | 0 | return false; |
4144 | 0 | } |
4145 | | |
4146 | 42 | std::vector<const analysis::Constant*> components; |
4147 | 42 | components = constants[i]->GetVectorComponents(const_mgr); |
4148 | | |
4149 | 42 | constexpr uint32_t kNotFound = std::numeric_limits<uint32_t>::max(); |
4150 | | |
4151 | 42 | uint32_t component_with_one = kNotFound; |
4152 | 42 | bool all_others_zero = true; |
4153 | 43 | for (uint32_t j = 0; j < components.size(); ++j) { |
4154 | 43 | const analysis::Constant* element = components[j]; |
4155 | 43 | double value = |
4156 | 43 | (element_width == 32 ? element->GetFloat() : element->GetDouble()); |
4157 | 43 | if (value == 0.0) { |
4158 | 1 | continue; |
4159 | 42 | } else if (value == 1.0) { |
4160 | 0 | if (component_with_one == kNotFound) { |
4161 | 0 | component_with_one = j; |
4162 | 0 | } else { |
4163 | 0 | component_with_one = kNotFound; |
4164 | 0 | break; |
4165 | 0 | } |
4166 | 42 | } else { |
4167 | 42 | all_others_zero = false; |
4168 | 42 | break; |
4169 | 42 | } |
4170 | 43 | } |
4171 | | |
4172 | 42 | if (!all_others_zero || component_with_one == kNotFound) { |
4173 | 42 | continue; |
4174 | 42 | } |
4175 | | |
4176 | 0 | std::vector<Operand> operands; |
4177 | 0 | operands.push_back( |
4178 | 0 | {SPV_OPERAND_TYPE_ID, {inst->GetSingleWordInOperand(1u - i)}}); |
4179 | 0 | operands.push_back( |
4180 | 0 | {SPV_OPERAND_TYPE_LITERAL_INTEGER, {component_with_one}}); |
4181 | |
|
4182 | 0 | inst->SetOpcode(spv::Op::OpCompositeExtract); |
4183 | 0 | inst->SetInOperands(std::move(operands)); |
4184 | 0 | return true; |
4185 | 42 | } |
4186 | 51 | return false; |
4187 | 51 | }; |
4188 | 16.7k | } |
4189 | | |
4190 | | // If we are storing an undef, then we can remove the store. |
4191 | | // |
4192 | | // TODO: We can do something similar for OpImageWrite, but checking for volatile |
4193 | | // is complicated. Waiting to see if it is needed. |
4194 | 16.7k | FoldingRule StoringUndef() { |
4195 | 16.7k | return [](IRContext* context, Instruction* inst, |
4196 | 1.04M | const std::vector<const analysis::Constant*>&) { |
4197 | 1.04M | assert(inst->opcode() == spv::Op::OpStore && |
4198 | 1.04M | "Wrong opcode. Should be OpStore."); |
4199 | | |
4200 | 1.04M | analysis::DefUseManager* def_use_mgr = context->get_def_use_mgr(); |
4201 | | |
4202 | | // If this is a volatile store, the store cannot be removed. |
4203 | 1.04M | if (inst->NumInOperands() == 3) { |
4204 | 8.97k | if (inst->GetSingleWordInOperand(2) & |
4205 | 8.97k | uint32_t(spv::MemoryAccessMask::Volatile)) { |
4206 | 5.95k | return false; |
4207 | 5.95k | } |
4208 | 8.97k | } |
4209 | | |
4210 | 1.04M | uint32_t object_id = inst->GetSingleWordInOperand(kStoreObjectInIdx); |
4211 | 1.04M | Instruction* object_inst = def_use_mgr->GetDef(object_id); |
4212 | 1.04M | if (object_inst->opcode() == spv::Op::OpUndef) { |
4213 | 21.4k | inst->ToNop(); |
4214 | 21.4k | return true; |
4215 | 21.4k | } |
4216 | 1.01M | return false; |
4217 | 1.04M | }; |
4218 | 16.7k | } |
4219 | | |
4220 | 16.7k | FoldingRule VectorShuffleFeedingShuffle() { |
4221 | 16.7k | return [](IRContext* context, Instruction* inst, |
4222 | 26.0k | const std::vector<const analysis::Constant*>&) { |
4223 | 26.0k | assert(inst->opcode() == spv::Op::OpVectorShuffle && |
4224 | 26.0k | "Wrong opcode. Should be OpVectorShuffle."); |
4225 | | |
4226 | 26.0k | analysis::DefUseManager* def_use_mgr = context->get_def_use_mgr(); |
4227 | 26.0k | analysis::TypeManager* type_mgr = context->get_type_mgr(); |
4228 | | |
4229 | 26.0k | Instruction* feeding_shuffle_inst = |
4230 | 26.0k | def_use_mgr->GetDef(inst->GetSingleWordInOperand(0)); |
4231 | 26.0k | analysis::Vector* op0_type = |
4232 | 26.0k | type_mgr->GetType(feeding_shuffle_inst->type_id())->AsVector(); |
4233 | 26.0k | uint32_t op0_length = op0_type->element_count(); |
4234 | | |
4235 | 26.0k | bool feeder_is_op0 = true; |
4236 | 26.0k | if (feeding_shuffle_inst->opcode() != spv::Op::OpVectorShuffle) { |
4237 | 25.8k | feeding_shuffle_inst = |
4238 | 25.8k | def_use_mgr->GetDef(inst->GetSingleWordInOperand(1)); |
4239 | 25.8k | feeder_is_op0 = false; |
4240 | 25.8k | } |
4241 | | |
4242 | 26.0k | if (feeding_shuffle_inst->opcode() != spv::Op::OpVectorShuffle) { |
4243 | 25.3k | return false; |
4244 | 25.3k | } |
4245 | | |
4246 | 671 | Instruction* feeder2 = |
4247 | 671 | def_use_mgr->GetDef(feeding_shuffle_inst->GetSingleWordInOperand(0)); |
4248 | 671 | analysis::Vector* feeder_op0_type = |
4249 | 671 | type_mgr->GetType(feeder2->type_id())->AsVector(); |
4250 | 671 | uint32_t feeder_op0_length = feeder_op0_type->element_count(); |
4251 | | |
4252 | 671 | uint32_t new_feeder_id = 0; |
4253 | 671 | std::vector<Operand> new_operands; |
4254 | 671 | new_operands.resize( |
4255 | 671 | 2, {SPV_OPERAND_TYPE_ID, {0}}); // Place holders for vector operands. |
4256 | 671 | const uint32_t undef_literal = 0xffffffff; |
4257 | 2.22k | for (uint32_t op = 2; op < inst->NumInOperands(); ++op) { |
4258 | 1.61k | uint32_t component_index = inst->GetSingleWordInOperand(op); |
4259 | | |
4260 | | // Do not interpret the undefined value literal as coming from operand 1. |
4261 | 1.61k | if (component_index != undef_literal && |
4262 | 1.51k | feeder_is_op0 == (component_index < op0_length)) { |
4263 | | // This component comes from the feeding_shuffle_inst. Update |
4264 | | // |component_index| to be the index into the operand of the feeder. |
4265 | | |
4266 | | // Adjust component_index to get the index into the operands of the |
4267 | | // feeding_shuffle_inst. |
4268 | 828 | if (component_index >= op0_length) { |
4269 | 484 | component_index -= op0_length; |
4270 | 484 | } |
4271 | 828 | component_index = |
4272 | 828 | feeding_shuffle_inst->GetSingleWordInOperand(component_index + 2); |
4273 | | |
4274 | | // Check if we are using a component from the first or second operand of |
4275 | | // the feeding instruction. |
4276 | 828 | if (component_index < feeder_op0_length) { |
4277 | 643 | if (new_feeder_id == 0) { |
4278 | | // First time through, save the id of the operand the element comes |
4279 | | // from. |
4280 | 374 | new_feeder_id = feeding_shuffle_inst->GetSingleWordInOperand(0); |
4281 | 374 | } else if (new_feeder_id != |
4282 | 269 | feeding_shuffle_inst->GetSingleWordInOperand(0)) { |
4283 | | // We need both elements of the feeding_shuffle_inst, so we cannot |
4284 | | // fold. |
4285 | 41 | return false; |
4286 | 41 | } |
4287 | 643 | } else if (component_index != undef_literal) { |
4288 | 122 | if (new_feeder_id == 0) { |
4289 | | // First time through, save the id of the operand the element comes |
4290 | | // from. |
4291 | 83 | new_feeder_id = feeding_shuffle_inst->GetSingleWordInOperand(1); |
4292 | 83 | } else if (new_feeder_id != |
4293 | 39 | feeding_shuffle_inst->GetSingleWordInOperand(1)) { |
4294 | | // We need both elements of the feeding_shuffle_inst, so we cannot |
4295 | | // fold. |
4296 | 22 | return false; |
4297 | 22 | } |
4298 | 100 | component_index -= feeder_op0_length; |
4299 | 100 | } |
4300 | | |
4301 | 765 | if (!feeder_is_op0 && component_index != undef_literal) { |
4302 | 456 | component_index += op0_length; |
4303 | 456 | } |
4304 | 765 | } |
4305 | 1.55k | new_operands.push_back( |
4306 | 1.55k | {SPV_OPERAND_TYPE_LITERAL_INTEGER, {component_index}}); |
4307 | 1.55k | } |
4308 | | |
4309 | 608 | if (new_feeder_id == 0) { |
4310 | 214 | analysis::ConstantManager* const_mgr = context->get_constant_mgr(); |
4311 | 214 | const analysis::Type* type = |
4312 | 214 | type_mgr->GetType(feeding_shuffle_inst->type_id()); |
4313 | 214 | const analysis::Constant* null_const = const_mgr->GetConstant(type, {}); |
4314 | 214 | new_feeder_id = |
4315 | 214 | const_mgr->GetDefiningInstruction(null_const, 0)->result_id(); |
4316 | 214 | } |
4317 | | |
4318 | 608 | if (feeder_is_op0) { |
4319 | | // If the size of the first vector operand changed then the indices |
4320 | | // referring to the second operand need to be adjusted. |
4321 | 169 | Instruction* new_feeder_inst = def_use_mgr->GetDef(new_feeder_id); |
4322 | 169 | analysis::Type* new_feeder_type = |
4323 | 169 | type_mgr->GetType(new_feeder_inst->type_id()); |
4324 | 169 | uint32_t new_op0_size = new_feeder_type->AsVector()->element_count(); |
4325 | 169 | int32_t adjustment = op0_length - new_op0_size; |
4326 | | |
4327 | 169 | if (adjustment != 0) { |
4328 | 250 | for (uint32_t i = 2; i < new_operands.size(); i++) { |
4329 | 168 | uint32_t operand = inst->GetSingleWordInOperand(i); |
4330 | 168 | if (operand >= op0_length && operand != undef_literal) { |
4331 | 47 | new_operands[i].words[0] -= adjustment; |
4332 | 47 | } |
4333 | 168 | } |
4334 | 82 | } |
4335 | | |
4336 | 169 | new_operands[0].words[0] = new_feeder_id; |
4337 | 169 | new_operands[1] = inst->GetInOperand(1); |
4338 | 439 | } else { |
4339 | 439 | new_operands[1].words[0] = new_feeder_id; |
4340 | 439 | new_operands[0] = inst->GetInOperand(0); |
4341 | 439 | } |
4342 | | |
4343 | 608 | inst->SetInOperands(std::move(new_operands)); |
4344 | 608 | return true; |
4345 | 671 | }; |
4346 | 16.7k | } |
4347 | | |
4348 | | // Removes duplicate ids from the interface list of an OpEntryPoint |
4349 | | // instruction. |
4350 | 16.7k | FoldingRule RemoveRedundantOperands() { |
4351 | 16.7k | return [](IRContext*, Instruction* inst, |
4352 | 16.7k | const std::vector<const analysis::Constant*>&) { |
4353 | 0 | assert(inst->opcode() == spv::Op::OpEntryPoint && |
4354 | 0 | "Wrong opcode. Should be OpEntryPoint."); |
4355 | 0 | bool has_redundant_operand = false; |
4356 | 0 | std::unordered_set<uint32_t> seen_operands; |
4357 | 0 | std::vector<Operand> new_operands; |
4358 | |
|
4359 | 0 | new_operands.emplace_back(inst->GetOperand(0)); |
4360 | 0 | new_operands.emplace_back(inst->GetOperand(1)); |
4361 | 0 | new_operands.emplace_back(inst->GetOperand(2)); |
4362 | 0 | for (uint32_t i = 3; i < inst->NumOperands(); ++i) { |
4363 | 0 | if (seen_operands.insert(inst->GetSingleWordOperand(i)).second) { |
4364 | 0 | new_operands.emplace_back(inst->GetOperand(i)); |
4365 | 0 | } else { |
4366 | 0 | has_redundant_operand = true; |
4367 | 0 | } |
4368 | 0 | } |
4369 | |
|
4370 | 0 | if (!has_redundant_operand) { |
4371 | 0 | return false; |
4372 | 0 | } |
4373 | | |
4374 | 0 | inst->SetInOperands(std::move(new_operands)); |
4375 | 0 | return true; |
4376 | 0 | }; |
4377 | 16.7k | } |
4378 | | |
4379 | | // If an image instruction's operand is a constant, updates the image operand |
4380 | | // flag from Offset to ConstOffset. |
4381 | 419k | FoldingRule UpdateImageOperands() { |
4382 | 419k | return [](IRContext*, Instruction* inst, |
4383 | 419k | const std::vector<const analysis::Constant*>& constants) { |
4384 | 270k | const auto opcode = inst->opcode(); |
4385 | 270k | (void)opcode; |
4386 | 270k | assert((opcode == spv::Op::OpImageSampleImplicitLod || |
4387 | 270k | opcode == spv::Op::OpImageSampleExplicitLod || |
4388 | 270k | opcode == spv::Op::OpImageSampleDrefImplicitLod || |
4389 | 270k | opcode == spv::Op::OpImageSampleDrefExplicitLod || |
4390 | 270k | opcode == spv::Op::OpImageSampleProjImplicitLod || |
4391 | 270k | opcode == spv::Op::OpImageSampleProjExplicitLod || |
4392 | 270k | opcode == spv::Op::OpImageSampleProjDrefImplicitLod || |
4393 | 270k | opcode == spv::Op::OpImageSampleProjDrefExplicitLod || |
4394 | 270k | opcode == spv::Op::OpImageFetch || |
4395 | 270k | opcode == spv::Op::OpImageGather || |
4396 | 270k | opcode == spv::Op::OpImageDrefGather || |
4397 | 270k | opcode == spv::Op::OpImageRead || opcode == spv::Op::OpImageWrite || |
4398 | 270k | opcode == spv::Op::OpImageSparseSampleImplicitLod || |
4399 | 270k | opcode == spv::Op::OpImageSparseSampleExplicitLod || |
4400 | 270k | opcode == spv::Op::OpImageSparseSampleDrefImplicitLod || |
4401 | 270k | opcode == spv::Op::OpImageSparseSampleDrefExplicitLod || |
4402 | 270k | opcode == spv::Op::OpImageSparseSampleProjImplicitLod || |
4403 | 270k | opcode == spv::Op::OpImageSparseSampleProjExplicitLod || |
4404 | 270k | opcode == spv::Op::OpImageSparseSampleProjDrefImplicitLod || |
4405 | 270k | opcode == spv::Op::OpImageSparseSampleProjDrefExplicitLod || |
4406 | 270k | opcode == spv::Op::OpImageSparseFetch || |
4407 | 270k | opcode == spv::Op::OpImageSparseGather || |
4408 | 270k | opcode == spv::Op::OpImageSparseDrefGather || |
4409 | 270k | opcode == spv::Op::OpImageSparseRead) && |
4410 | 270k | "Wrong opcode. Should be an image instruction."); |
4411 | | |
4412 | 270k | int32_t operand_index = ImageOperandsMaskInOperandIndex(inst); |
4413 | 270k | if (operand_index >= 0) { |
4414 | 12 | auto image_operands = inst->GetSingleWordInOperand(operand_index); |
4415 | 12 | if (image_operands & uint32_t(spv::ImageOperandsMask::Offset)) { |
4416 | 0 | uint32_t offset_operand_index = operand_index + 1; |
4417 | 0 | if (image_operands & uint32_t(spv::ImageOperandsMask::Bias)) |
4418 | 0 | offset_operand_index++; |
4419 | 0 | if (image_operands & uint32_t(spv::ImageOperandsMask::Lod)) |
4420 | 0 | offset_operand_index++; |
4421 | 0 | if (image_operands & uint32_t(spv::ImageOperandsMask::Grad)) |
4422 | 0 | offset_operand_index += 2; |
4423 | 0 | assert(((image_operands & |
4424 | 0 | uint32_t(spv::ImageOperandsMask::ConstOffset)) == 0) && |
4425 | 0 | "Offset and ConstOffset may not be used together"); |
4426 | 0 | if (offset_operand_index < inst->NumOperands()) { |
4427 | 0 | if (constants[offset_operand_index]) { |
4428 | 0 | if (constants[offset_operand_index]->IsZero()) { |
4429 | 0 | inst->RemoveInOperand(offset_operand_index); |
4430 | 0 | } else { |
4431 | 0 | image_operands = image_operands | |
4432 | 0 | uint32_t(spv::ImageOperandsMask::ConstOffset); |
4433 | 0 | } |
4434 | 0 | image_operands = |
4435 | 0 | image_operands & ~uint32_t(spv::ImageOperandsMask::Offset); |
4436 | 0 | inst->SetInOperand(operand_index, {image_operands}); |
4437 | 0 | return true; |
4438 | 0 | } |
4439 | 0 | } |
4440 | 0 | } |
4441 | 12 | } |
4442 | | |
4443 | 270k | return false; |
4444 | 270k | }; |
4445 | 419k | } |
4446 | | |
4447 | | } // namespace |
4448 | | |
4449 | 16.7k | void FoldingRules::AddFoldingRules() { |
4450 | | // Add all folding rules to the list for the opcodes to which they apply. |
4451 | | // Note that the order in which rules are added to the list matters. If a rule |
4452 | | // applies to the instruction, the rest of the rules will not be attempted. |
4453 | | // Take that into consideration. |
4454 | 16.7k | for (auto op : RedundantBinaryRhs0Ops) |
4455 | 117k | rules_[op].push_back(RedundantBinaryRhs0(op)); |
4456 | 16.7k | for (auto op : RedundantBinaryLhs0Ops) |
4457 | 50.3k | rules_[op].push_back(RedundantBinaryLhs0(op)); |
4458 | 16.7k | for (auto op : RedundantBinaryLhs0To0Ops) |
4459 | 117k | rules_[op].push_back(RedundantBinaryLhs0To0(op)); |
4460 | 16.7k | for (auto op : ReassociateCommutiveBitwiseOps) |
4461 | 50.3k | rules_[op].push_back(ReassociateCommutiveBitwise(op)); |
4462 | 16.7k | for (auto op : ReassociateNestedGenericIntOps) |
4463 | 67.1k | rules_[op].push_back(ReassociateNestedGenericInt(op)); |
4464 | 16.7k | for (auto op : MergeBinaryOpSelectOps) |
4465 | 704k | rules_[op].push_back(MergeBinaryOpSelect(op)); |
4466 | 16.7k | rules_[spv::Op::OpSDiv].push_back(RedundantSUDiv()); |
4467 | 16.7k | rules_[spv::Op::OpUDiv].push_back(RedundantSUDiv()); |
4468 | 16.7k | rules_[spv::Op::OpSMod].push_back(RedundantSUMod()); |
4469 | 16.7k | rules_[spv::Op::OpUMod].push_back(RedundantSUMod()); |
4470 | | |
4471 | 16.7k | rules_[spv::Op::OpBitcast].push_back(BitCastScalarOrVector()); |
4472 | 16.7k | rules_[spv::Op::OpBitcast].push_back(RedundantBitcast()); |
4473 | | |
4474 | 16.7k | rules_[spv::Op::OpBitReverse].push_back(BitReverseScalarOrVector()); |
4475 | | |
4476 | 16.7k | rules_[spv::Op::OpCompositeConstruct].push_back( |
4477 | 16.7k | CompositeExtractFeedingConstruct); |
4478 | | |
4479 | 16.7k | rules_[spv::Op::OpCompositeExtract].push_back(InsertFeedingExtract()); |
4480 | 16.7k | rules_[spv::Op::OpCompositeExtract].push_back( |
4481 | 16.7k | CompositeConstructFeedingExtract); |
4482 | 16.7k | rules_[spv::Op::OpCompositeExtract].push_back(VectorShuffleFeedingExtract()); |
4483 | 16.7k | rules_[spv::Op::OpCompositeExtract].push_back(FMixFeedingExtract()); |
4484 | 16.7k | rules_[spv::Op::OpCompositeExtract].push_back(CopyLogicalFeedingExtract); |
4485 | 16.7k | rules_[spv::Op::OpCompositeExtract].push_back(LoadFeedingExtract); |
4486 | | |
4487 | 16.7k | rules_[spv::Op::OpCompositeInsert].push_back( |
4488 | 16.7k | CompositeInsertToCompositeConstruct); |
4489 | | |
4490 | 16.7k | rules_[spv::Op::OpDot].push_back(DotProductDoingExtract()); |
4491 | | |
4492 | 16.7k | rules_[spv::Op::OpEntryPoint].push_back(RemoveRedundantOperands()); |
4493 | | |
4494 | 16.7k | rules_[spv::Op::OpFAdd].push_back(RedundantFAdd()); |
4495 | 16.7k | rules_[spv::Op::OpFAdd].push_back(MergeAddNegateArithmetic()); |
4496 | 16.7k | rules_[spv::Op::OpFAdd].push_back(MergeAddAddArithmetic()); |
4497 | 16.7k | rules_[spv::Op::OpFAdd].push_back(MergeAddSubArithmetic()); |
4498 | 16.7k | rules_[spv::Op::OpFAdd].push_back(MergeGenericAddSubArithmetic()); |
4499 | 16.7k | rules_[spv::Op::OpFAdd].push_back(ReassociateNestedAddSub()); |
4500 | 16.7k | rules_[spv::Op::OpFAdd].push_back(FactorAddSubMuls()); |
4501 | | |
4502 | 16.7k | rules_[spv::Op::OpFDiv].push_back(RedundantFDiv()); |
4503 | 16.7k | rules_[spv::Op::OpFDiv].push_back(ReciprocalFDiv()); |
4504 | 16.7k | rules_[spv::Op::OpFDiv].push_back(MergeDivDivArithmetic()); |
4505 | 16.7k | rules_[spv::Op::OpFDiv].push_back(MergeDivMulArithmetic()); |
4506 | 16.7k | rules_[spv::Op::OpFDiv].push_back(MergeDivNegateArithmetic()); |
4507 | 16.7k | rules_[spv::Op::OpFDiv].push_back(MergeDivMulDoubleNegative()); |
4508 | 16.7k | rules_[spv::Op::OpFDiv].push_back(ReassociateNestedMulDivFloat()); |
4509 | | |
4510 | 16.7k | rules_[spv::Op::OpFMod].push_back(RedundantFMod()); |
4511 | | |
4512 | 16.7k | rules_[spv::Op::OpFMul].push_back(RedundantFMul()); |
4513 | 16.7k | rules_[spv::Op::OpFMul].push_back(MergeMulMulArithmetic()); |
4514 | 16.7k | rules_[spv::Op::OpFMul].push_back(MergeMulDivArithmetic()); |
4515 | 16.7k | rules_[spv::Op::OpFMul].push_back(MergeMulNegateArithmetic()); |
4516 | 16.7k | rules_[spv::Op::OpFMul].push_back(MergeDivMulDoubleNegative()); |
4517 | 16.7k | rules_[spv::Op::OpFMul].push_back(ReassociateNestedMulDivFloat()); |
4518 | | |
4519 | 16.7k | rules_[spv::Op::OpVectorTimesScalar].push_back(MergeDivMulDoubleNegative()); |
4520 | | |
4521 | 16.7k | rules_[spv::Op::OpFNegate].push_back(MergeNegateArithmetic()); |
4522 | 16.7k | rules_[spv::Op::OpFNegate].push_back(MergeNegateAddSubArithmetic()); |
4523 | 16.7k | rules_[spv::Op::OpFNegate].push_back(MergeNegateMulDivArithmetic()); |
4524 | | |
4525 | 16.7k | rules_[spv::Op::OpFSub].push_back(RedundantFSub()); |
4526 | 16.7k | rules_[spv::Op::OpFSub].push_back(MergeSubNegateArithmetic()); |
4527 | 16.7k | rules_[spv::Op::OpFSub].push_back(MergeSubAddArithmetic()); |
4528 | 16.7k | rules_[spv::Op::OpFSub].push_back(MergeSubSubArithmetic()); |
4529 | 16.7k | rules_[spv::Op::OpFSub].push_back(ReassociateNestedAddSub()); |
4530 | 16.7k | rules_[spv::Op::OpFSub].push_back(FactorAddSubMuls()); |
4531 | | |
4532 | 16.7k | rules_[spv::Op::OpIAdd].push_back(MergeAddNegateArithmetic()); |
4533 | 16.7k | rules_[spv::Op::OpIAdd].push_back(MergeAddAddArithmetic()); |
4534 | 16.7k | rules_[spv::Op::OpIAdd].push_back(MergeAddSubArithmetic()); |
4535 | 16.7k | rules_[spv::Op::OpIAdd].push_back(MergeGenericAddSubArithmetic()); |
4536 | 16.7k | rules_[spv::Op::OpIAdd].push_back(ReassociateNestedAddSub()); |
4537 | 16.7k | rules_[spv::Op::OpIAdd].push_back(FactorAddSubMuls()); |
4538 | | |
4539 | 16.7k | rules_[spv::Op::OpSDiv].push_back(MergeDivMulDoubleNegative()); |
4540 | | |
4541 | 16.7k | rules_[spv::Op::OpIMul].push_back(IntMultipleBy1()); |
4542 | 16.7k | rules_[spv::Op::OpIMul].push_back(MergeMulMulArithmetic()); |
4543 | 16.7k | rules_[spv::Op::OpIMul].push_back(MergeMulNegateArithmetic()); |
4544 | 16.7k | rules_[spv::Op::OpIMul].push_back(MergeDivMulDoubleNegative()); |
4545 | | |
4546 | 16.7k | rules_[spv::Op::OpISub].push_back(MergeSubNegateArithmetic()); |
4547 | 16.7k | rules_[spv::Op::OpISub].push_back(MergeSubAddArithmetic()); |
4548 | 16.7k | rules_[spv::Op::OpISub].push_back(MergeSubSubArithmetic()); |
4549 | 16.7k | rules_[spv::Op::OpISub].push_back(ReassociateNestedAddSub()); |
4550 | 16.7k | rules_[spv::Op::OpISub].push_back(FactorAddSubMuls()); |
4551 | | |
4552 | 16.7k | rules_[spv::Op::OpBitwiseAnd].push_back(RedundantAndOrXor()); |
4553 | 16.7k | rules_[spv::Op::OpBitwiseAnd].push_back(RedundantAndAddSub()); |
4554 | 16.7k | rules_[spv::Op::OpBitwiseAnd].push_back(RedundantAndShift()); |
4555 | | |
4556 | 16.7k | rules_[spv::Op::OpPhi].push_back(RedundantPhi()); |
4557 | | |
4558 | 16.7k | rules_[spv::Op::OpSNegate].push_back(MergeNegateArithmetic()); |
4559 | 16.7k | rules_[spv::Op::OpSNegate].push_back(MergeNegateMulDivArithmetic()); |
4560 | 16.7k | rules_[spv::Op::OpSNegate].push_back(MergeNegateAddSubArithmetic()); |
4561 | | |
4562 | 16.7k | rules_[spv::Op::OpSelect].push_back(RedundantSelect()); |
4563 | 16.7k | rules_[spv::Op::OpSelect].push_back(FoldConstantBooleanSelect()); |
4564 | | |
4565 | 16.7k | rules_[spv::Op::OpLogicalAnd].push_back(RedundantLogicalAnd()); |
4566 | | |
4567 | 16.7k | rules_[spv::Op::OpLogicalOr].push_back(RedundantLogicalOr()); |
4568 | | |
4569 | 16.7k | rules_[spv::Op::OpLogicalNot].push_back(RedundantLogicalNot()); |
4570 | 16.7k | rules_[spv::Op::OpLogicalNot].push_back(FoldLogicalNotComparison()); |
4571 | | |
4572 | 16.7k | rules_[spv::Op::OpLogicalEqual].push_back(RedundantLogicalEqual()); |
4573 | 16.7k | rules_[spv::Op::OpLogicalNotEqual].push_back(RedundantLogicalEqual()); |
4574 | | |
4575 | 16.7k | rules_[spv::Op::OpStore].push_back(StoringUndef()); |
4576 | | |
4577 | 16.7k | rules_[spv::Op::OpVectorShuffle].push_back(VectorShuffleFeedingShuffle()); |
4578 | | |
4579 | 16.7k | rules_[spv::Op::OpImageSampleImplicitLod].push_back(UpdateImageOperands()); |
4580 | 16.7k | rules_[spv::Op::OpImageSampleExplicitLod].push_back(UpdateImageOperands()); |
4581 | 16.7k | rules_[spv::Op::OpImageSampleDrefImplicitLod].push_back( |
4582 | 16.7k | UpdateImageOperands()); |
4583 | 16.7k | rules_[spv::Op::OpImageSampleDrefExplicitLod].push_back( |
4584 | 16.7k | UpdateImageOperands()); |
4585 | 16.7k | rules_[spv::Op::OpImageSampleProjImplicitLod].push_back( |
4586 | 16.7k | UpdateImageOperands()); |
4587 | 16.7k | rules_[spv::Op::OpImageSampleProjExplicitLod].push_back( |
4588 | 16.7k | UpdateImageOperands()); |
4589 | 16.7k | rules_[spv::Op::OpImageSampleProjDrefImplicitLod].push_back( |
4590 | 16.7k | UpdateImageOperands()); |
4591 | 16.7k | rules_[spv::Op::OpImageSampleProjDrefExplicitLod].push_back( |
4592 | 16.7k | UpdateImageOperands()); |
4593 | 16.7k | rules_[spv::Op::OpImageFetch].push_back(UpdateImageOperands()); |
4594 | 16.7k | rules_[spv::Op::OpImageGather].push_back(UpdateImageOperands()); |
4595 | 16.7k | rules_[spv::Op::OpImageDrefGather].push_back(UpdateImageOperands()); |
4596 | 16.7k | rules_[spv::Op::OpImageRead].push_back(UpdateImageOperands()); |
4597 | 16.7k | rules_[spv::Op::OpImageWrite].push_back(UpdateImageOperands()); |
4598 | 16.7k | rules_[spv::Op::OpImageSparseSampleImplicitLod].push_back( |
4599 | 16.7k | UpdateImageOperands()); |
4600 | 16.7k | rules_[spv::Op::OpImageSparseSampleExplicitLod].push_back( |
4601 | 16.7k | UpdateImageOperands()); |
4602 | 16.7k | rules_[spv::Op::OpImageSparseSampleDrefImplicitLod].push_back( |
4603 | 16.7k | UpdateImageOperands()); |
4604 | 16.7k | rules_[spv::Op::OpImageSparseSampleDrefExplicitLod].push_back( |
4605 | 16.7k | UpdateImageOperands()); |
4606 | 16.7k | rules_[spv::Op::OpImageSparseSampleProjImplicitLod].push_back( |
4607 | 16.7k | UpdateImageOperands()); |
4608 | 16.7k | rules_[spv::Op::OpImageSparseSampleProjExplicitLod].push_back( |
4609 | 16.7k | UpdateImageOperands()); |
4610 | 16.7k | rules_[spv::Op::OpImageSparseSampleProjDrefImplicitLod].push_back( |
4611 | 16.7k | UpdateImageOperands()); |
4612 | 16.7k | rules_[spv::Op::OpImageSparseSampleProjDrefExplicitLod].push_back( |
4613 | 16.7k | UpdateImageOperands()); |
4614 | 16.7k | rules_[spv::Op::OpImageSparseFetch].push_back(UpdateImageOperands()); |
4615 | 16.7k | rules_[spv::Op::OpImageSparseGather].push_back(UpdateImageOperands()); |
4616 | 16.7k | rules_[spv::Op::OpImageSparseDrefGather].push_back(UpdateImageOperands()); |
4617 | 16.7k | rules_[spv::Op::OpImageSparseRead].push_back(UpdateImageOperands()); |
4618 | | |
4619 | 16.7k | FeatureManager* feature_manager = context_->get_feature_mgr(); |
4620 | | // Add rules for GLSLstd450 |
4621 | 16.7k | uint32_t ext_inst_glslstd450_id = |
4622 | 16.7k | feature_manager->GetExtInstImportId_GLSLstd450(); |
4623 | 16.7k | if (ext_inst_glslstd450_id != 0) { |
4624 | 10.0k | ext_rules_[{ext_inst_glslstd450_id, GLSLstd450FMix}].push_back( |
4625 | 10.0k | RedundantFMix()); |
4626 | 10.0k | } |
4627 | 16.7k | } |
4628 | | } // namespace opt |
4629 | | } // namespace spvtools |