/src/spirv-tools/source/opt/fold.cpp
Line | Count | Source |
1 | | // Copyright (c) 2017 Google Inc. |
2 | | // |
3 | | // Licensed under the Apache License, Version 2.0 (the "License"); |
4 | | // you may not use this file except in compliance with the License. |
5 | | // You may obtain a copy of the License at |
6 | | // |
7 | | // http://www.apache.org/licenses/LICENSE-2.0 |
8 | | // |
9 | | // Unless required by applicable law or agreed to in writing, software |
10 | | // distributed under the License is distributed on an "AS IS" BASIS, |
11 | | // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. |
12 | | // See the License for the specific language governing permissions and |
13 | | // limitations under the License. |
14 | | |
15 | | #include "source/opt/fold.h" |
16 | | |
17 | | #include <cassert> |
18 | | #include <cstdint> |
19 | | #include <vector> |
20 | | |
21 | | #include "source/opt/const_folding_rules.h" |
22 | | #include "source/opt/def_use_manager.h" |
23 | | #include "source/opt/folding_rules.h" |
24 | | #include "source/opt/ir_context.h" |
25 | | |
26 | | namespace spvtools { |
27 | | namespace opt { |
28 | | namespace { |
29 | | |
30 | | #ifndef INT32_MIN |
31 | | #define INT32_MIN (-2147483648) |
32 | | #endif |
33 | | |
34 | | #ifndef INT32_MAX |
35 | | #define INT32_MAX 2147483647 |
36 | | #endif |
37 | | |
38 | | #ifndef UINT32_MAX |
39 | | #define UINT32_MAX 0xffffffff /* 4294967295U */ |
40 | | #endif |
41 | | |
42 | | } // namespace |
43 | | |
44 | | uint32_t InstructionFolder::UnaryOperate(spv::Op opcode, |
45 | 1.30k | uint32_t operand) const { |
46 | 1.30k | switch (opcode) { |
47 | | // Arthimetics |
48 | 0 | case spv::Op::OpSNegate: { |
49 | 0 | int32_t s_operand = static_cast<int32_t>(operand); |
50 | 0 | if (s_operand == std::numeric_limits<int32_t>::min()) { |
51 | 0 | return s_operand; |
52 | 0 | } |
53 | 0 | return static_cast<uint32_t>(-s_operand); |
54 | 0 | } |
55 | 831 | case spv::Op::OpNot: |
56 | 831 | return ~operand; |
57 | 475 | case spv::Op::OpLogicalNot: |
58 | 475 | return !static_cast<bool>(operand); |
59 | 0 | case spv::Op::OpUConvert: |
60 | 0 | return operand; |
61 | 0 | case spv::Op::OpSConvert: |
62 | 0 | return operand; |
63 | 0 | default: |
64 | 0 | assert(false && |
65 | 0 | "Unsupported unary operation for OpSpecConstantOp instruction"); |
66 | 0 | return 0u; |
67 | 1.30k | } |
68 | 1.30k | } |
69 | | |
70 | | uint32_t InstructionFolder::BinaryOperate(spv::Op opcode, uint32_t a, |
71 | 524k | uint32_t b) const { |
72 | 524k | switch (opcode) { |
73 | | // Shifting |
74 | 19.8k | case spv::Op::OpShiftRightLogical: |
75 | 19.8k | if (b >= 32) { |
76 | | // This is undefined behaviour when |b| > 32. Choose 0 for consistency. |
77 | | // When |b| == 32, doing the shift in C++ in undefined, but the result |
78 | | // will be 0, so just return that value. |
79 | 16.0k | return 0; |
80 | 16.0k | } |
81 | 3.80k | return a >> b; |
82 | 10.9k | case spv::Op::OpShiftRightArithmetic: |
83 | 10.9k | if (b > 32) { |
84 | | // This is undefined behaviour. Choose 0 for consistency. |
85 | 3.89k | return 0; |
86 | 3.89k | } |
87 | 7.06k | if (b == 32) { |
88 | | // Doing the shift in C++ is undefined, but the result is defined in the |
89 | | // spir-v spec. Find that value another way. |
90 | 505 | if (static_cast<int32_t>(a) >= 0) { |
91 | 458 | return 0; |
92 | 458 | } else { |
93 | 47 | return static_cast<uint32_t>(-1); |
94 | 47 | } |
95 | 505 | } |
96 | 6.56k | return (static_cast<int32_t>(a)) >> b; |
97 | 4.18k | case spv::Op::OpShiftLeftLogical: |
98 | 4.18k | if (b >= 32) { |
99 | | // This is undefined behaviour when |b| > 32. Choose 0 for consistency. |
100 | | // When |b| == 32, doing the shift in C++ in undefined, but the result |
101 | | // will be 0, so just return that value. |
102 | 1.70k | return 0; |
103 | 1.70k | } |
104 | 2.47k | return a << b; |
105 | | |
106 | | // Bitwise operations |
107 | 8.75k | case spv::Op::OpBitwiseOr: |
108 | 8.75k | return a | b; |
109 | 23.2k | case spv::Op::OpBitwiseAnd: |
110 | 23.2k | return a & b; |
111 | 52.6k | case spv::Op::OpBitwiseXor: |
112 | 52.6k | return a ^ b; |
113 | | |
114 | | // Logical |
115 | 210 | case spv::Op::OpLogicalEqual: |
116 | 210 | return (static_cast<bool>(a)) == (static_cast<bool>(b)); |
117 | 183 | case spv::Op::OpLogicalNotEqual: |
118 | 183 | return (static_cast<bool>(a)) != (static_cast<bool>(b)); |
119 | 482 | case spv::Op::OpLogicalOr: |
120 | 482 | return (static_cast<bool>(a)) || (static_cast<bool>(b)); |
121 | 607 | case spv::Op::OpLogicalAnd: |
122 | 607 | return (static_cast<bool>(a)) && (static_cast<bool>(b)); |
123 | | |
124 | | // Comparison |
125 | 15.4k | case spv::Op::OpIEqual: |
126 | 15.4k | return a == b; |
127 | 6.76k | case spv::Op::OpINotEqual: |
128 | 6.76k | return a != b; |
129 | 20.0k | case spv::Op::OpULessThan: |
130 | 20.0k | return a < b; |
131 | 152k | case spv::Op::OpSLessThan: |
132 | 152k | return (static_cast<int32_t>(a)) < (static_cast<int32_t>(b)); |
133 | 13.0k | case spv::Op::OpUGreaterThan: |
134 | 13.0k | return a > b; |
135 | 31.3k | case spv::Op::OpSGreaterThan: |
136 | 31.3k | return (static_cast<int32_t>(a)) > (static_cast<int32_t>(b)); |
137 | 7.48k | case spv::Op::OpULessThanEqual: |
138 | 7.48k | return a <= b; |
139 | 138k | case spv::Op::OpSLessThanEqual: |
140 | 138k | return (static_cast<int32_t>(a)) <= (static_cast<int32_t>(b)); |
141 | 5.97k | case spv::Op::OpUGreaterThanEqual: |
142 | 5.97k | return a >= b; |
143 | 12.5k | case spv::Op::OpSGreaterThanEqual: |
144 | 12.5k | return (static_cast<int32_t>(a)) >= (static_cast<int32_t>(b)); |
145 | 0 | default: |
146 | 0 | assert(false && |
147 | 0 | "Unsupported binary operation for OpSpecConstantOp instruction"); |
148 | 0 | return 0u; |
149 | 524k | } |
150 | 524k | } |
151 | | |
152 | | uint32_t InstructionFolder::TernaryOperate(spv::Op opcode, uint32_t a, |
153 | 3.40k | uint32_t b, uint32_t c) const { |
154 | 3.40k | switch (opcode) { |
155 | 3.40k | case spv::Op::OpSelect: |
156 | 3.40k | return (static_cast<bool>(a)) ? b : c; |
157 | 0 | default: |
158 | 0 | assert(false && |
159 | 0 | "Unsupported ternary operation for OpSpecConstantOp instruction"); |
160 | 0 | return 0u; |
161 | 3.40k | } |
162 | 3.40k | } |
163 | | |
164 | | uint32_t InstructionFolder::OperateWords( |
165 | 529k | spv::Op opcode, const std::vector<uint32_t>& operand_words) const { |
166 | 529k | switch (operand_words.size()) { |
167 | 1.30k | case 1: |
168 | 1.30k | return UnaryOperate(opcode, operand_words.front()); |
169 | 524k | case 2: |
170 | 524k | return BinaryOperate(opcode, operand_words.front(), operand_words.back()); |
171 | 3.40k | case 3: |
172 | 3.40k | return TernaryOperate(opcode, operand_words[0], operand_words[1], |
173 | 3.40k | operand_words[2]); |
174 | 0 | default: |
175 | 0 | assert(false && "Invalid number of operands"); |
176 | 0 | return 0; |
177 | 529k | } |
178 | 529k | } |
179 | | |
180 | 12.3M | bool InstructionFolder::FoldInstructionInternal(Instruction* inst) const { |
181 | 12.3M | auto identity_map = [](uint32_t id) { return id; }; |
182 | 12.3M | Instruction* folded_inst = FoldInstructionToConstant(inst, identity_map); |
183 | 12.3M | if (folded_inst != nullptr) { |
184 | 2.24M | inst->SetOpcode(spv::Op::OpCopyObject); |
185 | 2.24M | inst->SetInOperands({{SPV_OPERAND_TYPE_ID, {folded_inst->result_id()}}}); |
186 | 2.24M | return true; |
187 | 2.24M | } |
188 | | |
189 | 10.0M | analysis::ConstantManager* const_manager = context_->get_constant_mgr(); |
190 | 10.0M | std::vector<const analysis::Constant*> constants = |
191 | 10.0M | const_manager->GetOperandConstants(inst); |
192 | | |
193 | 10.0M | for (const FoldingRule& rule : |
194 | 10.0M | GetFoldingRules().GetRulesForInstruction(inst)) { |
195 | 8.55M | if (rule(context_, inst, constants)) { |
196 | 281k | return true; |
197 | 281k | } |
198 | 8.55M | } |
199 | 9.80M | return false; |
200 | 10.0M | } |
201 | | |
202 | | // Returns the result of performing an operation on scalar constant operands. |
203 | | // This function extracts the operand values as 32 bit words and returns the |
204 | | // result in 32 bit word. Scalar constants with longer than 32-bit width are |
205 | | // not accepted in this function. |
206 | | uint32_t InstructionFolder::FoldScalars( |
207 | | spv::Op opcode, |
208 | 528k | const std::vector<const analysis::Constant*>& operands) const { |
209 | 528k | assert(IsFoldableOpcode(opcode) && |
210 | 528k | "Unhandled instruction opcode in FoldScalars"); |
211 | 528k | std::vector<uint32_t> operand_values_in_raw_words; |
212 | 1.05M | for (const auto& operand : operands) { |
213 | 1.05M | if (const analysis::ScalarConstant* scalar = operand->AsScalarConstant()) { |
214 | 1.05M | const auto& scalar_words = scalar->words(); |
215 | 1.05M | assert(scalar_words.size() == 1 && |
216 | 1.05M | "Scalar constants with longer than 32-bit width are not allowed " |
217 | 1.05M | "in FoldScalars()"); |
218 | 1.05M | operand_values_in_raw_words.push_back(scalar_words.front()); |
219 | 1.05M | } else if (operand->AsNullConstant()) { |
220 | 79 | operand_values_in_raw_words.push_back(0u); |
221 | 79 | } else { |
222 | 0 | assert(false && |
223 | 0 | "FoldScalars() only accepts ScalarConst or NullConst type of " |
224 | 0 | "constant"); |
225 | 0 | } |
226 | 1.05M | } |
227 | 528k | return OperateWords(opcode, operand_values_in_raw_words); |
228 | 528k | } |
229 | | |
230 | | bool InstructionFolder::FoldBinaryIntegerOpToConstant( |
231 | | Instruction* inst, const std::function<uint32_t(uint32_t)>& id_map, |
232 | 463k | uint32_t* result) const { |
233 | 463k | spv::Op opcode = inst->opcode(); |
234 | 463k | analysis::ConstantManager* const_manger = context_->get_constant_mgr(); |
235 | | |
236 | 463k | uint32_t ids[2]; |
237 | 463k | const analysis::IntConstant* constants[2]; |
238 | 1.39M | for (uint32_t i = 0; i < 2; i++) { |
239 | 927k | const Operand* operand = &inst->GetInOperand(i); |
240 | 927k | if (operand->type != SPV_OPERAND_TYPE_ID) { |
241 | 0 | return false; |
242 | 0 | } |
243 | 927k | ids[i] = id_map(operand->words[0]); |
244 | 927k | const analysis::Constant* constant = |
245 | 927k | const_manger->FindDeclaredConstant(ids[i]); |
246 | 927k | constants[i] = (constant != nullptr ? constant->AsIntConstant() : nullptr); |
247 | 927k | } |
248 | | |
249 | 463k | switch (opcode) { |
250 | | // Arthimetics |
251 | 16.1k | case spv::Op::OpIMul: |
252 | 47.6k | for (uint32_t i = 0; i < 2; i++) { |
253 | 32.1k | if (constants[i] != nullptr && constants[i]->IsZero()) { |
254 | 594 | *result = 0; |
255 | 594 | return true; |
256 | 594 | } |
257 | 32.1k | } |
258 | 15.5k | break; |
259 | 15.5k | case spv::Op::OpUDiv: |
260 | 12.1k | case spv::Op::OpSDiv: |
261 | 25.6k | case spv::Op::OpSRem: |
262 | 32.7k | case spv::Op::OpSMod: |
263 | 32.8k | case spv::Op::OpUMod: |
264 | | // This changes undefined behaviour (ie divide by 0) into a 0. |
265 | 96.3k | for (uint32_t i = 0; i < 2; i++) { |
266 | 65.5k | if (constants[i] != nullptr && constants[i]->IsZero()) { |
267 | 2.04k | *result = 0; |
268 | 2.04k | return true; |
269 | 2.04k | } |
270 | 65.5k | } |
271 | 30.8k | break; |
272 | | |
273 | | // Shifting |
274 | 30.8k | case spv::Op::OpShiftRightLogical: |
275 | 5.65k | case spv::Op::OpShiftLeftLogical: |
276 | 5.65k | if (constants[1] != nullptr) { |
277 | | // When shifting by a value larger than the size of the result, the |
278 | | // result is undefined. We are setting the undefined behaviour to a |
279 | | // result of 0. If the shift amount is the same as the size of the |
280 | | // result, then the result is defined, and it 0. |
281 | 3.49k | uint32_t shift_amount = constants[1]->GetU32BitValue(); |
282 | 3.49k | if (shift_amount >= 32) { |
283 | 1.54k | *result = 0; |
284 | 1.54k | return true; |
285 | 1.54k | } |
286 | 3.49k | } |
287 | 4.10k | break; |
288 | | |
289 | | // Bitwise operations |
290 | 48.9k | case spv::Op::OpBitwiseOr: |
291 | 132k | for (uint32_t i = 0; i < 2; i++) { |
292 | 91.8k | if (constants[i] != nullptr) { |
293 | | // TODO: Change the mask against a value based on the bit width of the |
294 | | // instruction result type. This way we can handle say 16-bit values |
295 | | // as well. |
296 | 16.7k | uint32_t mask = constants[i]->GetU32BitValue(); |
297 | 16.7k | if (mask == 0xFFFFFFFF) { |
298 | 7.91k | *result = 0xFFFFFFFF; |
299 | 7.91k | return true; |
300 | 7.91k | } |
301 | 16.7k | } |
302 | 91.8k | } |
303 | 41.0k | break; |
304 | 41.0k | case spv::Op::OpBitwiseAnd: |
305 | 39.2k | for (uint32_t i = 0; i < 2; i++) { |
306 | 26.5k | if (constants[i] != nullptr) { |
307 | 8.91k | if (constants[i]->IsZero()) { |
308 | 690 | *result = 0; |
309 | 690 | return true; |
310 | 690 | } |
311 | 8.91k | } |
312 | 26.5k | } |
313 | 12.7k | break; |
314 | | |
315 | | // Comparison |
316 | 12.7k | case spv::Op::OpULessThan: |
317 | 2.81k | if (constants[0] != nullptr && |
318 | 898 | constants[0]->GetU32BitValue() == UINT32_MAX) { |
319 | 18 | *result = false; |
320 | 18 | return true; |
321 | 18 | } |
322 | 2.80k | if (constants[1] != nullptr && constants[1]->GetU32BitValue() == 0) { |
323 | 295 | *result = false; |
324 | 295 | return true; |
325 | 295 | } |
326 | 2.50k | break; |
327 | 36.5k | case spv::Op::OpSLessThan: |
328 | 36.5k | if (constants[0] != nullptr && |
329 | 1.03k | constants[0]->GetS32BitValue() == INT32_MAX) { |
330 | 0 | *result = false; |
331 | 0 | return true; |
332 | 0 | } |
333 | 36.5k | if (constants[1] != nullptr && |
334 | 28.1k | constants[1]->GetS32BitValue() == INT32_MIN) { |
335 | 28 | *result = false; |
336 | 28 | return true; |
337 | 28 | } |
338 | 36.4k | break; |
339 | 36.4k | case spv::Op::OpUGreaterThan: |
340 | 9.13k | if (constants[0] != nullptr && constants[0]->IsZero()) { |
341 | 241 | *result = false; |
342 | 241 | return true; |
343 | 241 | } |
344 | 8.89k | if (constants[1] != nullptr && |
345 | 7.56k | constants[1]->GetU32BitValue() == UINT32_MAX) { |
346 | 33 | *result = false; |
347 | 33 | return true; |
348 | 33 | } |
349 | 8.85k | break; |
350 | 27.9k | case spv::Op::OpSGreaterThan: |
351 | 27.9k | if (constants[0] != nullptr && |
352 | 6.83k | constants[0]->GetS32BitValue() == INT32_MIN) { |
353 | 10 | *result = false; |
354 | 10 | return true; |
355 | 10 | } |
356 | 27.9k | if (constants[1] != nullptr && |
357 | 13.7k | constants[1]->GetS32BitValue() == INT32_MAX) { |
358 | 3 | *result = false; |
359 | 3 | return true; |
360 | 3 | } |
361 | 27.9k | break; |
362 | 27.9k | case spv::Op::OpULessThanEqual: |
363 | 6.47k | if (constants[0] != nullptr && constants[0]->IsZero()) { |
364 | 233 | *result = true; |
365 | 233 | return true; |
366 | 233 | } |
367 | 6.23k | if (constants[1] != nullptr && |
368 | 3.19k | constants[1]->GetU32BitValue() == UINT32_MAX) { |
369 | 39 | *result = true; |
370 | 39 | return true; |
371 | 39 | } |
372 | 6.19k | break; |
373 | 38.9k | case spv::Op::OpSLessThanEqual: |
374 | 38.9k | if (constants[0] != nullptr && |
375 | 10.9k | constants[0]->GetS32BitValue() == INT32_MIN) { |
376 | 1 | *result = true; |
377 | 1 | return true; |
378 | 1 | } |
379 | 38.9k | if (constants[1] != nullptr && |
380 | 13.0k | constants[1]->GetS32BitValue() == INT32_MAX) { |
381 | 24 | *result = true; |
382 | 24 | return true; |
383 | 24 | } |
384 | 38.9k | break; |
385 | 38.9k | case spv::Op::OpUGreaterThanEqual: |
386 | 4.57k | if (constants[0] != nullptr && |
387 | 1.00k | constants[0]->GetU32BitValue() == UINT32_MAX) { |
388 | 17 | *result = true; |
389 | 17 | return true; |
390 | 17 | } |
391 | 4.55k | if (constants[1] != nullptr && constants[1]->GetU32BitValue() == 0) { |
392 | 525 | *result = true; |
393 | 525 | return true; |
394 | 525 | } |
395 | 4.02k | break; |
396 | 7.99k | case spv::Op::OpSGreaterThanEqual: |
397 | 7.99k | if (constants[0] != nullptr && |
398 | 1.10k | constants[0]->GetS32BitValue() == INT32_MAX) { |
399 | 4 | *result = true; |
400 | 4 | return true; |
401 | 4 | } |
402 | 7.98k | if (constants[1] != nullptr && |
403 | 4.33k | constants[1]->GetS32BitValue() == INT32_MIN) { |
404 | 6 | *result = true; |
405 | 6 | return true; |
406 | 6 | } |
407 | 7.98k | break; |
408 | 212k | default: |
409 | 212k | break; |
410 | 463k | } |
411 | 449k | return false; |
412 | 463k | } |
413 | | |
414 | | bool InstructionFolder::FoldBinaryBooleanOpToConstant( |
415 | | Instruction* inst, const std::function<uint32_t(uint32_t)>& id_map, |
416 | 449k | uint32_t* result) const { |
417 | 449k | spv::Op opcode = inst->opcode(); |
418 | 449k | analysis::ConstantManager* const_manger = context_->get_constant_mgr(); |
419 | | |
420 | 449k | uint32_t ids[2]; |
421 | 449k | const analysis::BoolConstant* constants[2]; |
422 | 1.34M | for (uint32_t i = 0; i < 2; i++) { |
423 | 899k | const Operand* operand = &inst->GetInOperand(i); |
424 | 899k | if (operand->type != SPV_OPERAND_TYPE_ID) { |
425 | 0 | return false; |
426 | 0 | } |
427 | 899k | ids[i] = id_map(operand->words[0]); |
428 | 899k | const analysis::Constant* constant = |
429 | 899k | const_manger->FindDeclaredConstant(ids[i]); |
430 | 899k | constants[i] = (constant != nullptr ? constant->AsBoolConstant() : nullptr); |
431 | 899k | } |
432 | | |
433 | 449k | switch (opcode) { |
434 | | // Logical |
435 | 948 | case spv::Op::OpLogicalOr: |
436 | 2.33k | for (uint32_t i = 0; i < 2; i++) { |
437 | 1.86k | if (constants[i] != nullptr) { |
438 | 596 | if (constants[i]->value()) { |
439 | 474 | *result = true; |
440 | 474 | return true; |
441 | 474 | } |
442 | 596 | } |
443 | 1.86k | } |
444 | 474 | break; |
445 | 7.84k | case spv::Op::OpLogicalAnd: |
446 | 23.0k | for (uint32_t i = 0; i < 2; i++) { |
447 | 15.5k | if (constants[i] != nullptr) { |
448 | 4.56k | if (!constants[i]->value()) { |
449 | 274 | *result = false; |
450 | 274 | return true; |
451 | 274 | } |
452 | 4.56k | } |
453 | 15.5k | } |
454 | 7.56k | break; |
455 | | |
456 | 440k | default: |
457 | 440k | break; |
458 | 449k | } |
459 | 448k | return false; |
460 | 449k | } |
461 | | |
462 | | bool InstructionFolder::FoldIntegerOpToConstant( |
463 | | Instruction* inst, const std::function<uint32_t(uint32_t)>& id_map, |
464 | 499k | uint32_t* result) const { |
465 | 499k | assert(IsFoldableOpcode(inst->opcode()) && |
466 | 499k | "Unhandled instruction opcode in FoldScalars"); |
467 | 499k | switch (inst->NumInOperands()) { |
468 | 463k | case 2: |
469 | 463k | return FoldBinaryIntegerOpToConstant(inst, id_map, result) || |
470 | 449k | FoldBinaryBooleanOpToConstant(inst, id_map, result); |
471 | 36.0k | default: |
472 | 36.0k | return false; |
473 | 499k | } |
474 | 499k | } |
475 | | |
476 | | std::vector<uint32_t> InstructionFolder::FoldVectors( |
477 | | spv::Op opcode, uint32_t num_dims, |
478 | 541 | const std::vector<const analysis::Constant*>& operands) const { |
479 | 541 | assert(IsFoldableOpcode(opcode) && |
480 | 541 | "Unhandled instruction opcode in FoldVectors"); |
481 | 541 | std::vector<uint32_t> result; |
482 | 1.69k | for (uint32_t d = 0; d < num_dims; d++) { |
483 | 1.15k | std::vector<uint32_t> operand_values_for_one_dimension; |
484 | 3.29k | for (const auto& operand : operands) { |
485 | 3.29k | if (const analysis::VectorConstant* vector_operand = |
486 | 3.29k | operand->AsVectorConstant()) { |
487 | | // Extract the raw value of the scalar component constants |
488 | | // in 32-bit words here. The reason of not using FoldScalars() here |
489 | | // is that we do not create temporary null constants as components |
490 | | // when the vector operand is a NullConstant because Constant creation |
491 | | // may need extra checks for the validity and that is not managed in |
492 | | // here. |
493 | 3.29k | if (const analysis::ScalarConstant* scalar_component = |
494 | 3.29k | vector_operand->GetComponents().at(d)->AsScalarConstant()) { |
495 | 3.29k | const auto& scalar_words = scalar_component->words(); |
496 | 3.29k | assert( |
497 | 3.29k | scalar_words.size() == 1 && |
498 | 3.29k | "Vector components with longer than 32-bit width are not allowed " |
499 | 3.29k | "in FoldVectors()"); |
500 | 3.29k | operand_values_for_one_dimension.push_back(scalar_words.front()); |
501 | 3.29k | } else if (operand->AsNullConstant()) { |
502 | 0 | operand_values_for_one_dimension.push_back(0u); |
503 | 0 | } else { |
504 | 0 | assert(false && |
505 | 0 | "VectorConst should only has ScalarConst or NullConst as " |
506 | 0 | "components"); |
507 | 0 | } |
508 | 3.29k | } else if (operand->AsNullConstant()) { |
509 | 0 | operand_values_for_one_dimension.push_back(0u); |
510 | 0 | } else { |
511 | 0 | assert(false && |
512 | 0 | "FoldVectors() only accepts VectorConst or NullConst type of " |
513 | 0 | "constant"); |
514 | 0 | } |
515 | 3.29k | } |
516 | 1.15k | result.push_back(OperateWords(opcode, operand_values_for_one_dimension)); |
517 | 1.15k | } |
518 | 541 | return result; |
519 | 541 | } |
520 | | |
521 | 32.0M | bool InstructionFolder::IsFoldableOpcode(spv::Op opcode) const { |
522 | | // NOTE: Extend to more opcodes as new cases are handled in the folder |
523 | | // functions. |
524 | 32.0M | switch (opcode) { |
525 | 115k | case spv::Op::OpBitwiseAnd: |
526 | 308k | case spv::Op::OpBitwiseOr: |
527 | 470k | case spv::Op::OpBitwiseXor: |
528 | 1.25M | case spv::Op::OpIAdd: |
529 | 1.47M | case spv::Op::OpIEqual: |
530 | 1.60M | case spv::Op::OpIMul: |
531 | 1.67M | case spv::Op::OpINotEqual: |
532 | 2.08M | case spv::Op::OpISub: |
533 | 2.11M | case spv::Op::OpLogicalAnd: |
534 | 2.11M | case spv::Op::OpLogicalEqual: |
535 | 2.17M | case spv::Op::OpLogicalNot: |
536 | 2.17M | case spv::Op::OpLogicalNotEqual: |
537 | 2.17M | case spv::Op::OpLogicalOr: |
538 | 2.19M | case spv::Op::OpNot: |
539 | 2.28M | case spv::Op::OpSDiv: |
540 | 2.34M | case spv::Op::OpSelect: |
541 | 2.52M | case spv::Op::OpSGreaterThan: |
542 | 2.59M | case spv::Op::OpSGreaterThanEqual: |
543 | 2.61M | case spv::Op::OpShiftLeftLogical: |
544 | 2.65M | case spv::Op::OpShiftRightArithmetic: |
545 | 2.72M | case spv::Op::OpShiftRightLogical: |
546 | 3.30M | case spv::Op::OpSLessThan: |
547 | 3.85M | case spv::Op::OpSLessThanEqual: |
548 | 3.87M | case spv::Op::OpSMod: |
549 | 3.89M | case spv::Op::OpSNegate: |
550 | 4.00M | case spv::Op::OpSRem: |
551 | 4.00M | case spv::Op::OpSConvert: |
552 | 4.00M | case spv::Op::OpUConvert: |
553 | 4.00M | case spv::Op::OpUDiv: |
554 | 4.07M | case spv::Op::OpUGreaterThan: |
555 | 4.10M | case spv::Op::OpUGreaterThanEqual: |
556 | 4.17M | case spv::Op::OpULessThan: |
557 | 4.21M | case spv::Op::OpULessThanEqual: |
558 | 4.21M | case spv::Op::OpUMod: |
559 | 4.21M | return true; |
560 | 27.8M | default: |
561 | 27.8M | return false; |
562 | 32.0M | } |
563 | 32.0M | } |
564 | | |
565 | | bool InstructionFolder::IsFoldableConstant( |
566 | 0 | const analysis::Constant* cst) const { |
567 | | // Currently supported constants are 32-bit values or null constants. |
568 | 0 | if (const analysis::ScalarConstant* scalar = cst->AsScalarConstant()) |
569 | 0 | return scalar->words().size() == 1; |
570 | 0 | else |
571 | 0 | return cst->AsNullConstant() != nullptr; |
572 | 0 | } |
573 | | |
574 | | Instruction* InstructionFolder::FoldInstructionToConstant( |
575 | 12.8M | Instruction* inst, std::function<uint32_t(uint32_t)> id_map) const { |
576 | 12.8M | analysis::ConstantManager* const_mgr = context_->get_constant_mgr(); |
577 | | |
578 | 12.8M | if (!inst->IsFoldableByFoldScalar() && !inst->IsFoldableByFoldVector() && |
579 | 10.8M | !GetConstantFoldingRules().HasFoldingRule(inst)) { |
580 | 7.85M | return nullptr; |
581 | 7.85M | } |
582 | | // Collect the values of the constant parameters. |
583 | 4.95M | std::vector<const analysis::Constant*> constants; |
584 | 4.95M | bool missing_constants = false; |
585 | 4.95M | inst->ForEachInId([&constants, &missing_constants, const_mgr, |
586 | 9.37M | &id_map](uint32_t* op_id) { |
587 | 9.37M | uint32_t id = id_map(*op_id); |
588 | 9.37M | const analysis::Constant* const_op = const_mgr->FindDeclaredConstant(id); |
589 | 9.37M | if (!const_op) { |
590 | 3.98M | constants.push_back(nullptr); |
591 | 3.98M | missing_constants = true; |
592 | 5.39M | } else { |
593 | 5.39M | constants.push_back(const_op); |
594 | 5.39M | } |
595 | 9.37M | }); |
596 | | |
597 | 4.95M | const analysis::Constant* folded_const = nullptr; |
598 | 4.95M | for (auto rule : GetConstantFoldingRules().GetRulesForInstruction(inst)) { |
599 | 4.51M | folded_const = rule(context_, inst, constants); |
600 | 4.51M | if (folded_const == nullptr && inst->context()->id_overflow()) { |
601 | 0 | return nullptr; |
602 | 0 | } |
603 | 4.51M | if (folded_const != nullptr) { |
604 | 1.81M | Instruction* const_inst = |
605 | 1.81M | const_mgr->GetDefiningInstruction(folded_const, inst->type_id()); |
606 | 1.81M | if (const_inst == nullptr) { |
607 | 2 | return nullptr; |
608 | 2 | } |
609 | 1.81M | assert(const_inst->type_id() == inst->type_id()); |
610 | | // May be a new instruction that needs to be analysed. |
611 | 1.81M | context_->UpdateDefUse(const_inst); |
612 | 1.81M | return const_inst; |
613 | 1.81M | } |
614 | 4.51M | } |
615 | | |
616 | 3.14M | bool successful = false; |
617 | | |
618 | | // If all parameters are constant, fold the instruction to a constant. |
619 | 3.14M | if (inst->IsFoldableByFoldScalar()) { |
620 | 1.02M | uint32_t result_val = 0; |
621 | | |
622 | 1.02M | if (!missing_constants) { |
623 | 528k | result_val = FoldScalars(inst->opcode(), constants); |
624 | 528k | successful = true; |
625 | 528k | } |
626 | | |
627 | 1.02M | if (!successful) { |
628 | 499k | successful = FoldIntegerOpToConstant(inst, id_map, &result_val); |
629 | 499k | } |
630 | | |
631 | 1.02M | if (successful) { |
632 | 543k | const analysis::Constant* result_const = |
633 | 543k | const_mgr->GetConstant(const_mgr->GetType(inst), {result_val}); |
634 | 543k | Instruction* folded_inst = |
635 | 543k | const_mgr->GetDefiningInstruction(result_const, inst->type_id()); |
636 | 543k | return folded_inst; |
637 | 543k | } |
638 | 2.11M | } else if (inst->IsFoldableByFoldVector()) { |
639 | 1.57k | std::vector<uint32_t> result_val; |
640 | | |
641 | 1.57k | if (!missing_constants) { |
642 | 541 | if (Instruction* inst_type = |
643 | 541 | context_->get_def_use_mgr()->GetDef(inst->type_id())) { |
644 | 541 | result_val = FoldVectors( |
645 | 541 | inst->opcode(), inst_type->GetSingleWordInOperand(1), constants); |
646 | 541 | successful = true; |
647 | 541 | } |
648 | 541 | } |
649 | | |
650 | 1.57k | if (successful) { |
651 | 541 | const analysis::Constant* result_const = |
652 | 541 | const_mgr->GetNumericVectorConstantWithWords( |
653 | 541 | const_mgr->GetType(inst)->AsVector(), result_val); |
654 | 541 | Instruction* folded_inst = |
655 | 541 | const_mgr->GetDefiningInstruction(result_const, inst->type_id()); |
656 | 541 | return folded_inst; |
657 | 541 | } |
658 | 1.57k | } |
659 | | |
660 | 2.60M | return nullptr; |
661 | 3.14M | } |
662 | | |
663 | 0 | bool InstructionFolder::IsFoldableType(Instruction* type_inst) const { |
664 | 0 | return IsFoldableScalarType(type_inst) || IsFoldableVectorType(type_inst); |
665 | 0 | } |
666 | | |
667 | 9.48M | bool InstructionFolder::IsFoldableScalarType(Instruction* type_inst) const { |
668 | | // Support 32-bit integers. |
669 | 9.48M | if (type_inst->opcode() == spv::Op::OpTypeInt) { |
670 | 7.98M | return type_inst->GetSingleWordInOperand(0) == 32; |
671 | 7.98M | } |
672 | | // Support booleans. |
673 | 1.50M | if (type_inst->opcode() == spv::Op::OpTypeBool) { |
674 | 1.49M | return true; |
675 | 1.49M | } |
676 | | // Nothing else yet. |
677 | 11.3k | return false; |
678 | 1.50M | } |
679 | | |
680 | 21.0k | bool InstructionFolder::IsFoldableVectorType(Instruction* type_inst) const { |
681 | | // Support vectors with foldable components |
682 | 21.0k | if (type_inst->opcode() == spv::Op::OpTypeVector) { |
683 | 13.5k | uint32_t component_type_id = type_inst->GetSingleWordInOperand(0); |
684 | 13.5k | Instruction* def_component_type = |
685 | 13.5k | context_->get_def_use_mgr()->GetDef(component_type_id); |
686 | 13.5k | return def_component_type != nullptr && |
687 | 13.5k | IsFoldableScalarType(def_component_type); |
688 | 13.5k | } |
689 | | // Nothing else yet. |
690 | 7.56k | return false; |
691 | 21.0k | } |
692 | | |
693 | 12.2M | bool InstructionFolder::FoldInstruction(Instruction* inst) const { |
694 | 12.2M | bool modified = false; |
695 | 12.2M | Instruction* folded_inst(inst); |
696 | 14.7M | while (folded_inst->opcode() != spv::Op::OpCopyObject && |
697 | 12.3M | FoldInstructionInternal(&*folded_inst)) { |
698 | 2.52M | modified = true; |
699 | 2.52M | } |
700 | 12.2M | return modified; |
701 | 12.2M | } |
702 | | |
703 | | } // namespace opt |
704 | | } // namespace spvtools |