/src/shaderc/third_party/spirv-tools/source/opt/fold.cpp
Line  | Count  | Source  | 
1  |  | // Copyright (c) 2017 Google Inc.  | 
2  |  | //  | 
3  |  | // Licensed under the Apache License, Version 2.0 (the "License");  | 
4  |  | // you may not use this file except in compliance with the License.  | 
5  |  | // You may obtain a copy of the License at  | 
6  |  | //  | 
7  |  | //     http://www.apache.org/licenses/LICENSE-2.0  | 
8  |  | //  | 
9  |  | // Unless required by applicable law or agreed to in writing, software  | 
10  |  | // distributed under the License is distributed on an "AS IS" BASIS,  | 
11  |  | // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.  | 
12  |  | // See the License for the specific language governing permissions and  | 
13  |  | // limitations under the License.  | 
14  |  |  | 
15  |  | #include "source/opt/fold.h"  | 
16  |  |  | 
17  |  | #include <cassert>  | 
18  |  | #include <cstdint>  | 
19  |  | #include <vector>  | 
20  |  |  | 
21  |  | #include "source/opt/const_folding_rules.h"  | 
22  |  | #include "source/opt/def_use_manager.h"  | 
23  |  | #include "source/opt/folding_rules.h"  | 
24  |  | #include "source/opt/ir_context.h"  | 
25  |  |  | 
26  |  | namespace spvtools { | 
27  |  | namespace opt { | 
28  |  | namespace { | 
29  |  |  | 
30  |  | #ifndef INT32_MIN  | 
31  |  | #define INT32_MIN (-2147483648)  | 
32  |  | #endif  | 
33  |  |  | 
34  |  | #ifndef INT32_MAX  | 
35  |  | #define INT32_MAX 2147483647  | 
36  |  | #endif  | 
37  |  |  | 
38  |  | #ifndef UINT32_MAX  | 
39  |  | #define UINT32_MAX 0xffffffff /* 4294967295U */  | 
40  |  | #endif  | 
41  |  |  | 
42  |  | }  // namespace  | 
43  |  |  | 
44  |  | uint32_t InstructionFolder::UnaryOperate(spv::Op opcode,  | 
45  | 6  |                                          uint32_t operand) const { | 
46  | 6  |   switch (opcode) { | 
47  |  |     // Arthimetics  | 
48  | 0  |     case spv::Op::OpSNegate: { | 
49  | 0  |       int32_t s_operand = static_cast<int32_t>(operand);  | 
50  | 0  |       if (s_operand == std::numeric_limits<int32_t>::min()) { | 
51  | 0  |         return s_operand;  | 
52  | 0  |       }  | 
53  | 0  |       return static_cast<uint32_t>(-s_operand);  | 
54  | 0  |     }  | 
55  | 6  |     case spv::Op::OpNot:  | 
56  | 6  |       return ~operand;  | 
57  | 0  |     case spv::Op::OpLogicalNot:  | 
58  | 0  |       return !static_cast<bool>(operand);  | 
59  | 0  |     case spv::Op::OpUConvert:  | 
60  | 0  |       return operand;  | 
61  | 0  |     case spv::Op::OpSConvert:  | 
62  | 0  |       return operand;  | 
63  | 0  |     default:  | 
64  | 0  |       assert(false &&  | 
65  | 0  |              "Unsupported unary operation for OpSpecConstantOp instruction");  | 
66  | 0  |       return 0u;  | 
67  | 6  |   }  | 
68  | 6  | }  | 
69  |  |  | 
70  |  | uint32_t InstructionFolder::BinaryOperate(spv::Op opcode, uint32_t a,  | 
71  | 1.28k  |                                           uint32_t b) const { | 
72  | 1.28k  |   switch (opcode) { | 
73  |  |     // Shifting  | 
74  | 0  |     case spv::Op::OpShiftRightLogical:  | 
75  | 0  |       if (b >= 32) { | 
76  |  |         // This is undefined behaviour when |b| > 32.  Choose 0 for consistency.  | 
77  |  |         // When |b| == 32, doing the shift in C++ in undefined, but the result  | 
78  |  |         // will be 0, so just return that value.  | 
79  | 0  |         return 0;  | 
80  | 0  |       }  | 
81  | 0  |       return a >> b;  | 
82  | 0  |     case spv::Op::OpShiftRightArithmetic:  | 
83  | 0  |       if (b > 32) { | 
84  |  |         // This is undefined behaviour.  Choose 0 for consistency.  | 
85  | 0  |         return 0;  | 
86  | 0  |       }  | 
87  | 0  |       if (b == 32) { | 
88  |  |         // Doing the shift in C++ is undefined, but the result is defined in the  | 
89  |  |         // spir-v spec.  Find that value another way.  | 
90  | 0  |         if (static_cast<int32_t>(a) >= 0) { | 
91  | 0  |           return 0;  | 
92  | 0  |         } else { | 
93  | 0  |           return static_cast<uint32_t>(-1);  | 
94  | 0  |         }  | 
95  | 0  |       }  | 
96  | 0  |       return (static_cast<int32_t>(a)) >> b;  | 
97  | 0  |     case spv::Op::OpShiftLeftLogical:  | 
98  | 0  |       if (b >= 32) { | 
99  |  |         // This is undefined behaviour when |b| > 32.  Choose 0 for consistency.  | 
100  |  |         // When |b| == 32, doing the shift in C++ in undefined, but the result  | 
101  |  |         // will be 0, so just return that value.  | 
102  | 0  |         return 0;  | 
103  | 0  |       }  | 
104  | 0  |       return a << b;  | 
105  |  |  | 
106  |  |     // Bitwise operations  | 
107  | 0  |     case spv::Op::OpBitwiseOr:  | 
108  | 0  |       return a | b;  | 
109  | 0  |     case spv::Op::OpBitwiseAnd:  | 
110  | 0  |       return a & b;  | 
111  | 0  |     case spv::Op::OpBitwiseXor:  | 
112  | 0  |       return a ^ b;  | 
113  |  |  | 
114  |  |     // Logical  | 
115  | 0  |     case spv::Op::OpLogicalEqual:  | 
116  | 0  |       return (static_cast<bool>(a)) == (static_cast<bool>(b));  | 
117  | 0  |     case spv::Op::OpLogicalNotEqual:  | 
118  | 0  |       return (static_cast<bool>(a)) != (static_cast<bool>(b));  | 
119  | 0  |     case spv::Op::OpLogicalOr:  | 
120  | 0  |       return (static_cast<bool>(a)) || (static_cast<bool>(b));  | 
121  | 0  |     case spv::Op::OpLogicalAnd:  | 
122  | 0  |       return (static_cast<bool>(a)) && (static_cast<bool>(b));  | 
123  |  |  | 
124  |  |     // Comparison  | 
125  | 0  |     case spv::Op::OpIEqual:  | 
126  | 0  |       return a == b;  | 
127  | 0  |     case spv::Op::OpINotEqual:  | 
128  | 0  |       return a != b;  | 
129  | 0  |     case spv::Op::OpULessThan:  | 
130  | 0  |       return a < b;  | 
131  | 638  |     case spv::Op::OpSLessThan:  | 
132  | 638  |       return (static_cast<int32_t>(a)) < (static_cast<int32_t>(b));  | 
133  | 0  |     case spv::Op::OpUGreaterThan:  | 
134  | 0  |       return a > b;  | 
135  | 0  |     case spv::Op::OpSGreaterThan:  | 
136  | 0  |       return (static_cast<int32_t>(a)) > (static_cast<int32_t>(b));  | 
137  | 0  |     case spv::Op::OpULessThanEqual:  | 
138  | 0  |       return a <= b;  | 
139  | 648  |     case spv::Op::OpSLessThanEqual:  | 
140  | 648  |       return (static_cast<int32_t>(a)) <= (static_cast<int32_t>(b));  | 
141  | 0  |     case spv::Op::OpUGreaterThanEqual:  | 
142  | 0  |       return a >= b;  | 
143  | 0  |     case spv::Op::OpSGreaterThanEqual:  | 
144  | 0  |       return (static_cast<int32_t>(a)) >= (static_cast<int32_t>(b));  | 
145  | 0  |     default:  | 
146  | 0  |       assert(false &&  | 
147  | 0  |              "Unsupported binary operation for OpSpecConstantOp instruction");  | 
148  | 0  |       return 0u;  | 
149  | 1.28k  |   }  | 
150  | 1.28k  | }  | 
151  |  |  | 
152  |  | uint32_t InstructionFolder::TernaryOperate(spv::Op opcode, uint32_t a,  | 
153  | 12  |                                            uint32_t b, uint32_t c) const { | 
154  | 12  |   switch (opcode) { | 
155  | 12  |     case spv::Op::OpSelect:  | 
156  | 12  |       return (static_cast<bool>(a)) ? b : c;  | 
157  | 0  |     default:  | 
158  | 0  |       assert(false &&  | 
159  | 0  |              "Unsupported ternary operation for OpSpecConstantOp instruction");  | 
160  | 0  |       return 0u;  | 
161  | 12  |   }  | 
162  | 12  | }  | 
163  |  |  | 
164  |  | uint32_t InstructionFolder::OperateWords(  | 
165  | 1.30k  |     spv::Op opcode, const std::vector<uint32_t>& operand_words) const { | 
166  | 1.30k  |   switch (operand_words.size()) { | 
167  | 6  |     case 1:  | 
168  | 6  |       return UnaryOperate(opcode, operand_words.front());  | 
169  | 1.28k  |     case 2:  | 
170  | 1.28k  |       return BinaryOperate(opcode, operand_words.front(), operand_words.back());  | 
171  | 12  |     case 3:  | 
172  | 12  |       return TernaryOperate(opcode, operand_words[0], operand_words[1],  | 
173  | 12  |                             operand_words[2]);  | 
174  | 0  |     default:  | 
175  | 0  |       assert(false && "Invalid number of operands");  | 
176  | 0  |       return 0;  | 
177  | 1.30k  |   }  | 
178  | 1.30k  | }  | 
179  |  |  | 
180  | 262k  | bool InstructionFolder::FoldInstructionInternal(Instruction* inst) const { | 
181  | 262k  |   auto identity_map = [](uint32_t id) { return id; }; | 
182  | 262k  |   Instruction* folded_inst = FoldInstructionToConstant(inst, identity_map);  | 
183  | 262k  |   if (folded_inst != nullptr) { | 
184  | 344  |     inst->SetOpcode(spv::Op::OpCopyObject);  | 
185  | 344  |     inst->SetInOperands({{SPV_OPERAND_TYPE_ID, {folded_inst->result_id()}}}); | 
186  | 344  |     return true;  | 
187  | 344  |   }  | 
188  |  |  | 
189  | 262k  |   analysis::ConstantManager* const_manager = context_->get_constant_mgr();  | 
190  | 262k  |   std::vector<const analysis::Constant*> constants =  | 
191  | 262k  |       const_manager->GetOperandConstants(inst);  | 
192  |  |  | 
193  | 262k  |   for (const FoldingRule& rule :  | 
194  | 298k  |        GetFoldingRules().GetRulesForInstruction(inst)) { | 
195  | 298k  |     if (rule(context_, inst, constants)) { | 
196  | 7.09k  |       return true;  | 
197  | 7.09k  |     }  | 
198  | 298k  |   }  | 
199  | 254k  |   return false;  | 
200  | 262k  | }  | 
201  |  |  | 
202  |  | // Returns the result of performing an operation on scalar constant operands.  | 
203  |  | // This function extracts the operand values as 32 bit words and returns the  | 
204  |  | // result in 32 bit word. Scalar constants with longer than 32-bit width are  | 
205  |  | // not accepted in this function.  | 
206  |  | uint32_t InstructionFolder::FoldScalars(  | 
207  |  |     spv::Op opcode,  | 
208  | 1.29k  |     const std::vector<const analysis::Constant*>& operands) const { | 
209  | 1.29k  |   assert(IsFoldableOpcode(opcode) &&  | 
210  | 1.29k  |          "Unhandled instruction opcode in FoldScalars");  | 
211  | 1.29k  |   std::vector<uint32_t> operand_values_in_raw_words;  | 
212  | 2.57k  |   for (const auto& operand : operands) { | 
213  | 2.57k  |     if (const analysis::ScalarConstant* scalar = operand->AsScalarConstant()) { | 
214  | 2.57k  |       const auto& scalar_words = scalar->words();  | 
215  | 2.57k  |       assert(scalar_words.size() == 1 &&  | 
216  | 2.57k  |              "Scalar constants with longer than 32-bit width are not allowed "  | 
217  | 2.57k  |              "in FoldScalars()");  | 
218  | 2.57k  |       operand_values_in_raw_words.push_back(scalar_words.front());  | 
219  | 2.57k  |     } else if (operand->AsNullConstant()) { | 
220  | 0  |       operand_values_in_raw_words.push_back(0u);  | 
221  | 0  |     } else { | 
222  | 0  |       assert(false &&  | 
223  | 0  |              "FoldScalars() only accepts ScalarConst or NullConst type of "  | 
224  | 0  |              "constant");  | 
225  | 0  |     }  | 
226  | 2.57k  |   }  | 
227  | 1.29k  |   return OperateWords(opcode, operand_values_in_raw_words);  | 
228  | 1.29k  | }  | 
229  |  |  | 
230  |  | bool InstructionFolder::FoldBinaryIntegerOpToConstant(  | 
231  |  |     Instruction* inst, const std::function<uint32_t(uint32_t)>& id_map,  | 
232  | 16.3k  |     uint32_t* result) const { | 
233  | 16.3k  |   spv::Op opcode = inst->opcode();  | 
234  | 16.3k  |   analysis::ConstantManager* const_manger = context_->get_constant_mgr();  | 
235  |  |  | 
236  | 16.3k  |   uint32_t ids[2];  | 
237  | 16.3k  |   const analysis::IntConstant* constants[2];  | 
238  | 49.1k  |   for (uint32_t i = 0; i < 2; i++) { | 
239  | 32.7k  |     const Operand* operand = &inst->GetInOperand(i);  | 
240  | 32.7k  |     if (operand->type != SPV_OPERAND_TYPE_ID) { | 
241  | 0  |       return false;  | 
242  | 0  |     }  | 
243  | 32.7k  |     ids[i] = id_map(operand->words[0]);  | 
244  | 32.7k  |     const analysis::Constant* constant =  | 
245  | 32.7k  |         const_manger->FindDeclaredConstant(ids[i]);  | 
246  | 32.7k  |     constants[i] = (constant != nullptr ? constant->AsIntConstant() : nullptr);  | 
247  | 32.7k  |   }  | 
248  |  |  | 
249  | 16.3k  |   switch (opcode) { | 
250  |  |     // Arthimetics  | 
251  | 0  |     case spv::Op::OpIMul:  | 
252  | 0  |       for (uint32_t i = 0; i < 2; i++) { | 
253  | 0  |         if (constants[i] != nullptr && constants[i]->IsZero()) { | 
254  | 0  |           *result = 0;  | 
255  | 0  |           return true;  | 
256  | 0  |         }  | 
257  | 0  |       }  | 
258  | 0  |       break;  | 
259  | 0  |     case spv::Op::OpUDiv:  | 
260  | 0  |     case spv::Op::OpSDiv:  | 
261  | 0  |     case spv::Op::OpSRem:  | 
262  | 126  |     case spv::Op::OpSMod:  | 
263  | 126  |     case spv::Op::OpUMod:  | 
264  |  |       // This changes undefined behaviour (ie divide by 0) into a 0.  | 
265  | 378  |       for (uint32_t i = 0; i < 2; i++) { | 
266  | 252  |         if (constants[i] != nullptr && constants[i]->IsZero()) { | 
267  | 0  |           *result = 0;  | 
268  | 0  |           return true;  | 
269  | 0  |         }  | 
270  | 252  |       }  | 
271  | 126  |       break;  | 
272  |  |  | 
273  |  |     // Shifting  | 
274  | 126  |     case spv::Op::OpShiftRightLogical:  | 
275  | 114  |     case spv::Op::OpShiftLeftLogical:  | 
276  | 114  |       if (constants[1] != nullptr) { | 
277  |  |         // When shifting by a value larger than the size of the result, the  | 
278  |  |         // result is undefined.  We are setting the undefined behaviour to a  | 
279  |  |         // result of 0.  If the shift amount is the same as the size of the  | 
280  |  |         // result, then the result is defined, and it 0.  | 
281  | 6  |         uint32_t shift_amount = constants[1]->GetU32BitValue();  | 
282  | 6  |         if (shift_amount >= 32) { | 
283  | 0  |           *result = 0;  | 
284  | 0  |           return true;  | 
285  | 0  |         }  | 
286  | 6  |       }  | 
287  | 114  |       break;  | 
288  |  |  | 
289  |  |     // Bitwise operations  | 
290  | 118  |     case spv::Op::OpBitwiseOr:  | 
291  | 354  |       for (uint32_t i = 0; i < 2; i++) { | 
292  | 236  |         if (constants[i] != nullptr) { | 
293  |  |           // TODO: Change the mask against a value based on the bit width of the  | 
294  |  |           // instruction result type.  This way we can handle say 16-bit values  | 
295  |  |           // as well.  | 
296  | 4  |           uint32_t mask = constants[i]->GetU32BitValue();  | 
297  | 4  |           if (mask == 0xFFFFFFFF) { | 
298  | 0  |             *result = 0xFFFFFFFF;  | 
299  | 0  |             return true;  | 
300  | 0  |           }  | 
301  | 4  |         }  | 
302  | 236  |       }  | 
303  | 118  |       break;  | 
304  | 118  |     case spv::Op::OpBitwiseAnd:  | 
305  | 342  |       for (uint32_t i = 0; i < 2; i++) { | 
306  | 228  |         if (constants[i] != nullptr) { | 
307  | 0  |           if (constants[i]->IsZero()) { | 
308  | 0  |             *result = 0;  | 
309  | 0  |             return true;  | 
310  | 0  |           }  | 
311  | 0  |         }  | 
312  | 228  |       }  | 
313  | 114  |       break;  | 
314  |  |  | 
315  |  |     // Comparison  | 
316  | 114  |     case spv::Op::OpULessThan:  | 
317  | 0  |       if (constants[0] != nullptr &&  | 
318  | 0  |           constants[0]->GetU32BitValue() == UINT32_MAX) { | 
319  | 0  |         *result = false;  | 
320  | 0  |         return true;  | 
321  | 0  |       }  | 
322  | 0  |       if (constants[1] != nullptr && constants[1]->GetU32BitValue() == 0) { | 
323  | 0  |         *result = false;  | 
324  | 0  |         return true;  | 
325  | 0  |       }  | 
326  | 0  |       break;  | 
327  | 2.39k  |     case spv::Op::OpSLessThan:  | 
328  | 2.39k  |       if (constants[0] != nullptr &&  | 
329  | 0  |           constants[0]->GetS32BitValue() == INT32_MAX) { | 
330  | 0  |         *result = false;  | 
331  | 0  |         return true;  | 
332  | 0  |       }  | 
333  | 2.39k  |       if (constants[1] != nullptr &&  | 
334  | 2.39k  |           constants[1]->GetS32BitValue() == INT32_MIN) { | 
335  | 0  |         *result = false;  | 
336  | 0  |         return true;  | 
337  | 0  |       }  | 
338  | 2.39k  |       break;  | 
339  | 2.39k  |     case spv::Op::OpUGreaterThan:  | 
340  | 368  |       if (constants[0] != nullptr && constants[0]->IsZero()) { | 
341  | 0  |         *result = false;  | 
342  | 0  |         return true;  | 
343  | 0  |       }  | 
344  | 368  |       if (constants[1] != nullptr &&  | 
345  | 368  |           constants[1]->GetU32BitValue() == UINT32_MAX) { | 
346  | 0  |         *result = false;  | 
347  | 0  |         return true;  | 
348  | 0  |       }  | 
349  | 368  |       break;  | 
350  | 2.89k  |     case spv::Op::OpSGreaterThan:  | 
351  | 2.89k  |       if (constants[0] != nullptr &&  | 
352  | 0  |           constants[0]->GetS32BitValue() == INT32_MIN) { | 
353  | 0  |         *result = false;  | 
354  | 0  |         return true;  | 
355  | 0  |       }  | 
356  | 2.89k  |       if (constants[1] != nullptr &&  | 
357  | 2.89k  |           constants[1]->GetS32BitValue() == INT32_MAX) { | 
358  | 0  |         *result = false;  | 
359  | 0  |         return true;  | 
360  | 0  |       }  | 
361  | 2.89k  |       break;  | 
362  | 2.89k  |     case spv::Op::OpULessThanEqual:  | 
363  | 0  |       if (constants[0] != nullptr && constants[0]->IsZero()) { | 
364  | 0  |         *result = true;  | 
365  | 0  |         return true;  | 
366  | 0  |       }  | 
367  | 0  |       if (constants[1] != nullptr &&  | 
368  | 0  |           constants[1]->GetU32BitValue() == UINT32_MAX) { | 
369  | 0  |         *result = true;  | 
370  | 0  |         return true;  | 
371  | 0  |       }  | 
372  | 0  |       break;  | 
373  | 2.33k  |     case spv::Op::OpSLessThanEqual:  | 
374  | 2.33k  |       if (constants[0] != nullptr &&  | 
375  | 0  |           constants[0]->GetS32BitValue() == INT32_MIN) { | 
376  | 0  |         *result = true;  | 
377  | 0  |         return true;  | 
378  | 0  |       }  | 
379  | 2.33k  |       if (constants[1] != nullptr &&  | 
380  | 2.21k  |           constants[1]->GetS32BitValue() == INT32_MAX) { | 
381  | 0  |         *result = true;  | 
382  | 0  |         return true;  | 
383  | 0  |       }  | 
384  | 2.33k  |       break;  | 
385  | 2.33k  |     case spv::Op::OpUGreaterThanEqual:  | 
386  | 0  |       if (constants[0] != nullptr &&  | 
387  | 0  |           constants[0]->GetU32BitValue() == UINT32_MAX) { | 
388  | 0  |         *result = true;  | 
389  | 0  |         return true;  | 
390  | 0  |       }  | 
391  | 0  |       if (constants[1] != nullptr && constants[1]->GetU32BitValue() == 0) { | 
392  | 0  |         *result = true;  | 
393  | 0  |         return true;  | 
394  | 0  |       }  | 
395  | 0  |       break;  | 
396  | 0  |     case spv::Op::OpSGreaterThanEqual:  | 
397  | 0  |       if (constants[0] != nullptr &&  | 
398  | 0  |           constants[0]->GetS32BitValue() == INT32_MAX) { | 
399  | 0  |         *result = true;  | 
400  | 0  |         return true;  | 
401  | 0  |       }  | 
402  | 0  |       if (constants[1] != nullptr &&  | 
403  | 0  |           constants[1]->GetS32BitValue() == INT32_MIN) { | 
404  | 0  |         *result = true;  | 
405  | 0  |         return true;  | 
406  | 0  |       }  | 
407  | 0  |       break;  | 
408  | 7.91k  |     default:  | 
409  | 7.91k  |       break;  | 
410  | 16.3k  |   }  | 
411  | 16.3k  |   return false;  | 
412  | 16.3k  | }  | 
413  |  |  | 
414  |  | bool InstructionFolder::FoldBinaryBooleanOpToConstant(  | 
415  |  |     Instruction* inst, const std::function<uint32_t(uint32_t)>& id_map,  | 
416  | 16.3k  |     uint32_t* result) const { | 
417  | 16.3k  |   spv::Op opcode = inst->opcode();  | 
418  | 16.3k  |   analysis::ConstantManager* const_manger = context_->get_constant_mgr();  | 
419  |  |  | 
420  | 16.3k  |   uint32_t ids[2];  | 
421  | 16.3k  |   const analysis::BoolConstant* constants[2];  | 
422  | 49.1k  |   for (uint32_t i = 0; i < 2; i++) { | 
423  | 32.7k  |     const Operand* operand = &inst->GetInOperand(i);  | 
424  | 32.7k  |     if (operand->type != SPV_OPERAND_TYPE_ID) { | 
425  | 0  |       return false;  | 
426  | 0  |     }  | 
427  | 32.7k  |     ids[i] = id_map(operand->words[0]);  | 
428  | 32.7k  |     const analysis::Constant* constant =  | 
429  | 32.7k  |         const_manger->FindDeclaredConstant(ids[i]);  | 
430  | 32.7k  |     constants[i] = (constant != nullptr ? constant->AsBoolConstant() : nullptr);  | 
431  | 32.7k  |   }  | 
432  |  |  | 
433  | 16.3k  |   switch (opcode) { | 
434  |  |     // Logical  | 
435  | 474  |     case spv::Op::OpLogicalOr:  | 
436  | 1.42k  |       for (uint32_t i = 0; i < 2; i++) { | 
437  | 948  |         if (constants[i] != nullptr) { | 
438  | 0  |           if (constants[i]->value()) { | 
439  | 0  |             *result = true;  | 
440  | 0  |             return true;  | 
441  | 0  |           }  | 
442  | 0  |         }  | 
443  | 948  |       }  | 
444  | 474  |       break;  | 
445  | 870  |     case spv::Op::OpLogicalAnd:  | 
446  | 2.61k  |       for (uint32_t i = 0; i < 2; i++) { | 
447  | 1.74k  |         if (constants[i] != nullptr) { | 
448  | 0  |           if (!constants[i]->value()) { | 
449  | 0  |             *result = false;  | 
450  | 0  |             return true;  | 
451  | 0  |           }  | 
452  | 0  |         }  | 
453  | 1.74k  |       }  | 
454  | 870  |       break;  | 
455  |  |  | 
456  | 15.0k  |     default:  | 
457  | 15.0k  |       break;  | 
458  | 16.3k  |   }  | 
459  | 16.3k  |   return false;  | 
460  | 16.3k  | }  | 
461  |  |  | 
462  |  | bool InstructionFolder::FoldIntegerOpToConstant(  | 
463  |  |     Instruction* inst, const std::function<uint32_t(uint32_t)>& id_map,  | 
464  | 16.8k  |     uint32_t* result) const { | 
465  | 16.8k  |   assert(IsFoldableOpcode(inst->opcode()) &&  | 
466  | 16.8k  |          "Unhandled instruction opcode in FoldScalars");  | 
467  | 16.8k  |   switch (inst->NumInOperands()) { | 
468  | 16.3k  |     case 2:  | 
469  | 16.3k  |       return FoldBinaryIntegerOpToConstant(inst, id_map, result) ||  | 
470  | 16.3k  |              FoldBinaryBooleanOpToConstant(inst, id_map, result);  | 
471  | 466  |     default:  | 
472  | 466  |       return false;  | 
473  | 16.8k  |   }  | 
474  | 16.8k  | }  | 
475  |  |  | 
476  |  | std::vector<uint32_t> InstructionFolder::FoldVectors(  | 
477  |  |     spv::Op opcode, uint32_t num_dims,  | 
478  | 6  |     const std::vector<const analysis::Constant*>& operands) const { | 
479  | 6  |   assert(IsFoldableOpcode(opcode) &&  | 
480  | 6  |          "Unhandled instruction opcode in FoldVectors");  | 
481  | 6  |   std::vector<uint32_t> result;  | 
482  | 18  |   for (uint32_t d = 0; d < num_dims; d++) { | 
483  | 12  |     std::vector<uint32_t> operand_values_for_one_dimension;  | 
484  | 36  |     for (const auto& operand : operands) { | 
485  | 36  |       if (const analysis::VectorConstant* vector_operand =  | 
486  | 36  |               operand->AsVectorConstant()) { | 
487  |  |         // Extract the raw value of the scalar component constants  | 
488  |  |         // in 32-bit words here. The reason of not using FoldScalars() here  | 
489  |  |         // is that we do not create temporary null constants as components  | 
490  |  |         // when the vector operand is a NullConstant because Constant creation  | 
491  |  |         // may need extra checks for the validity and that is not managed in  | 
492  |  |         // here.  | 
493  | 36  |         if (const analysis::ScalarConstant* scalar_component =  | 
494  | 36  |                 vector_operand->GetComponents().at(d)->AsScalarConstant()) { | 
495  | 36  |           const auto& scalar_words = scalar_component->words();  | 
496  | 36  |           assert(  | 
497  | 36  |               scalar_words.size() == 1 &&  | 
498  | 36  |               "Vector components with longer than 32-bit width are not allowed "  | 
499  | 36  |               "in FoldVectors()");  | 
500  | 36  |           operand_values_for_one_dimension.push_back(scalar_words.front());  | 
501  | 36  |         } else if (operand->AsNullConstant()) { | 
502  | 0  |           operand_values_for_one_dimension.push_back(0u);  | 
503  | 0  |         } else { | 
504  | 0  |           assert(false &&  | 
505  | 0  |                  "VectorConst should only has ScalarConst or NullConst as "  | 
506  | 0  |                  "components");  | 
507  | 0  |         }  | 
508  | 36  |       } else if (operand->AsNullConstant()) { | 
509  | 0  |         operand_values_for_one_dimension.push_back(0u);  | 
510  | 0  |       } else { | 
511  | 0  |         assert(false &&  | 
512  | 0  |                "FoldVectors() only accepts VectorConst or NullConst type of "  | 
513  | 0  |                "constant");  | 
514  | 0  |       }  | 
515  | 36  |     }  | 
516  | 12  |     result.push_back(OperateWords(opcode, operand_values_for_one_dimension));  | 
517  | 12  |   }  | 
518  | 6  |   return result;  | 
519  | 6  | }  | 
520  |  |  | 
521  | 1.01M  | bool InstructionFolder::IsFoldableOpcode(spv::Op opcode) const { | 
522  |  |   // NOTE: Extend to more opcodes as new cases are handled in the folder  | 
523  |  |   // functions.  | 
524  | 1.01M  |   switch (opcode) { | 
525  | 1.44k  |     case spv::Op::OpBitwiseAnd:  | 
526  | 2.25k  |     case spv::Op::OpBitwiseOr:  | 
527  | 2.81k  |     case spv::Op::OpBitwiseXor:  | 
528  | 22.4k  |     case spv::Op::OpIAdd:  | 
529  | 22.4k  |     case spv::Op::OpIEqual:  | 
530  | 23.0k  |     case spv::Op::OpIMul:  | 
531  | 23.9k  |     case spv::Op::OpINotEqual:  | 
532  | 24.7k  |     case spv::Op::OpISub:  | 
533  | 26.8k  |     case spv::Op::OpLogicalAnd:  | 
534  | 26.8k  |     case spv::Op::OpLogicalEqual:  | 
535  | 27.1k  |     case spv::Op::OpLogicalNot:  | 
536  | 27.4k  |     case spv::Op::OpLogicalNotEqual:  | 
537  | 28.5k  |     case spv::Op::OpLogicalOr:  | 
538  | 30.2k  |     case spv::Op::OpNot:  | 
539  | 30.2k  |     case spv::Op::OpSDiv:  | 
540  | 31.4k  |     case spv::Op::OpSelect:  | 
541  | 37.9k  |     case spv::Op::OpSGreaterThan:  | 
542  | 37.9k  |     case spv::Op::OpSGreaterThanEqual:  | 
543  | 38.7k  |     case spv::Op::OpShiftLeftLogical:  | 
544  | 39.5k  |     case spv::Op::OpShiftRightArithmetic:  | 
545  | 39.5k  |     case spv::Op::OpShiftRightLogical:  | 
546  | 46.8k  |     case spv::Op::OpSLessThan:  | 
547  | 54.0k  |     case spv::Op::OpSLessThanEqual:  | 
548  | 54.3k  |     case spv::Op::OpSMod:  | 
549  | 55.6k  |     case spv::Op::OpSNegate:  | 
550  | 55.6k  |     case spv::Op::OpSRem:  | 
551  | 56.2k  |     case spv::Op::OpSConvert:  | 
552  | 56.4k  |     case spv::Op::OpUConvert:  | 
553  | 56.9k  |     case spv::Op::OpUDiv:  | 
554  | 57.7k  |     case spv::Op::OpUGreaterThan:  | 
555  | 57.7k  |     case spv::Op::OpUGreaterThanEqual:  | 
556  | 57.7k  |     case spv::Op::OpULessThan:  | 
557  | 57.7k  |     case spv::Op::OpULessThanEqual:  | 
558  | 58.3k  |     case spv::Op::OpUMod:  | 
559  | 58.3k  |       return true;  | 
560  | 956k  |     default:  | 
561  | 956k  |       return false;  | 
562  | 1.01M  |   }  | 
563  | 1.01M  | }  | 
564  |  |  | 
565  |  | bool InstructionFolder::IsFoldableConstant(  | 
566  | 0  |     const analysis::Constant* cst) const { | 
567  |  |   // Currently supported constants are 32-bit values or null constants.  | 
568  | 0  |   if (const analysis::ScalarConstant* scalar = cst->AsScalarConstant())  | 
569  | 0  |     return scalar->words().size() == 1;  | 
570  | 0  |   else  | 
571  | 0  |     return cst->AsNullConstant() != nullptr;  | 
572  | 0  | }  | 
573  |  |  | 
574  |  | Instruction* InstructionFolder::FoldInstructionToConstant(  | 
575  | 305k  |     Instruction* inst, std::function<uint32_t(uint32_t)> id_map) const { | 
576  | 305k  |   analysis::ConstantManager* const_mgr = context_->get_constant_mgr();  | 
577  |  |  | 
578  | 305k  |   if (!inst->IsFoldableByFoldScalar() && !inst->IsFoldableByFoldVector() &&  | 
579  | 283k  |       !GetConstantFoldingRules().HasFoldingRule(inst)) { | 
580  | 161k  |     return nullptr;  | 
581  | 161k  |   }  | 
582  |  |   // Collect the values of the constant parameters.  | 
583  | 144k  |   std::vector<const analysis::Constant*> constants;  | 
584  | 144k  |   bool missing_constants = false;  | 
585  | 144k  |   inst->ForEachInId([&constants, &missing_constants, const_mgr,  | 
586  | 248k  |                      &id_map](uint32_t* op_id) { | 
587  | 248k  |     uint32_t id = id_map(*op_id);  | 
588  | 248k  |     const analysis::Constant* const_op = const_mgr->FindDeclaredConstant(id);  | 
589  | 248k  |     if (!const_op) { | 
590  | 203k  |       constants.push_back(nullptr);  | 
591  | 203k  |       missing_constants = true;  | 
592  | 203k  |     } else { | 
593  | 45.7k  |       constants.push_back(const_op);  | 
594  | 45.7k  |     }  | 
595  | 248k  |   });  | 
596  |  |  | 
597  | 144k  |   const analysis::Constant* folded_const = nullptr;  | 
598  | 144k  |   for (auto rule : GetConstantFoldingRules().GetRulesForInstruction(inst)) { | 
599  | 138k  |     folded_const = rule(context_, inst, constants);  | 
600  | 138k  |     if (folded_const != nullptr) { | 
601  | 5.34k  |       Instruction* const_inst =  | 
602  | 5.34k  |           const_mgr->GetDefiningInstruction(folded_const, inst->type_id());  | 
603  | 5.34k  |       if (const_inst == nullptr) { | 
604  | 0  |         return nullptr;  | 
605  | 0  |       }  | 
606  | 5.34k  |       assert(const_inst->type_id() == inst->type_id());  | 
607  |  |       // May be a new instruction that needs to be analysed.  | 
608  | 5.34k  |       context_->UpdateDefUse(const_inst);  | 
609  | 5.34k  |       return const_inst;  | 
610  | 5.34k  |     }  | 
611  | 138k  |   }  | 
612  |  |  | 
613  | 139k  |   bool successful = false;  | 
614  |  |  | 
615  |  |   // If all parameters are constant, fold the instruction to a constant.  | 
616  | 139k  |   if (inst->IsFoldableByFoldScalar()) { | 
617  | 18.1k  |     uint32_t result_val = 0;  | 
618  |  |  | 
619  | 18.1k  |     if (!missing_constants) { | 
620  | 1.29k  |       result_val = FoldScalars(inst->opcode(), constants);  | 
621  | 1.29k  |       successful = true;  | 
622  | 1.29k  |     }  | 
623  |  |  | 
624  | 18.1k  |     if (!successful) { | 
625  | 16.8k  |       successful = FoldIntegerOpToConstant(inst, id_map, &result_val);  | 
626  | 16.8k  |     }  | 
627  |  |  | 
628  | 18.1k  |     if (successful) { | 
629  | 1.29k  |       const analysis::Constant* result_const =  | 
630  | 1.29k  |           const_mgr->GetConstant(const_mgr->GetType(inst), {result_val}); | 
631  | 1.29k  |       Instruction* folded_inst =  | 
632  | 1.29k  |           const_mgr->GetDefiningInstruction(result_const, inst->type_id());  | 
633  | 1.29k  |       return folded_inst;  | 
634  | 1.29k  |     }  | 
635  | 120k  |   } else if (inst->IsFoldableByFoldVector()) { | 
636  | 1.45k  |     std::vector<uint32_t> result_val;  | 
637  |  |  | 
638  | 1.45k  |     if (!missing_constants) { | 
639  | 6  |       if (Instruction* inst_type =  | 
640  | 6  |               context_->get_def_use_mgr()->GetDef(inst->type_id())) { | 
641  | 6  |         result_val = FoldVectors(  | 
642  | 6  |             inst->opcode(), inst_type->GetSingleWordInOperand(1), constants);  | 
643  | 6  |         successful = true;  | 
644  | 6  |       }  | 
645  | 6  |     }  | 
646  |  |  | 
647  | 1.45k  |     if (successful) { | 
648  | 6  |       const analysis::Constant* result_const =  | 
649  | 6  |           const_mgr->GetNumericVectorConstantWithWords(  | 
650  | 6  |               const_mgr->GetType(inst)->AsVector(), result_val);  | 
651  | 6  |       Instruction* folded_inst =  | 
652  | 6  |           const_mgr->GetDefiningInstruction(result_const, inst->type_id());  | 
653  | 6  |       return folded_inst;  | 
654  | 6  |     }  | 
655  | 1.45k  |   }  | 
656  |  |  | 
657  | 137k  |   return nullptr;  | 
658  | 139k  | }  | 
659  |  |  | 
660  | 0  | bool InstructionFolder::IsFoldableType(Instruction* type_inst) const { | 
661  | 0  |   return IsFoldableScalarType(type_inst) || IsFoldableVectorType(type_inst);  | 
662  | 0  | }  | 
663  |  |  | 
664  | 153k  | bool InstructionFolder::IsFoldableScalarType(Instruction* type_inst) const { | 
665  |  |   // Support 32-bit integers.  | 
666  | 153k  |   if (type_inst->opcode() == spv::Op::OpTypeInt) { | 
667  | 115k  |     return type_inst->GetSingleWordInOperand(0) == 32;  | 
668  | 115k  |   }  | 
669  |  |   // Support booleans.  | 
670  | 37.9k  |   if (type_inst->opcode() == spv::Op::OpTypeBool) { | 
671  | 34.0k  |     return true;  | 
672  | 34.0k  |   }  | 
673  |  |   // Nothing else yet.  | 
674  | 3.98k  |   return false;  | 
675  | 37.9k  | }  | 
676  |  |  | 
677  | 12.2k  | bool InstructionFolder::IsFoldableVectorType(Instruction* type_inst) const { | 
678  |  |   // Support vectors with foldable components  | 
679  | 12.2k  |   if (type_inst->opcode() == spv::Op::OpTypeVector) { | 
680  | 9.89k  |     uint32_t component_type_id = type_inst->GetSingleWordInOperand(0);  | 
681  | 9.89k  |     Instruction* def_component_type =  | 
682  | 9.89k  |         context_->get_def_use_mgr()->GetDef(component_type_id);  | 
683  | 9.89k  |     return def_component_type != nullptr &&  | 
684  | 9.89k  |            IsFoldableScalarType(def_component_type);  | 
685  | 9.89k  |   }  | 
686  |  |   // Nothing else yet.  | 
687  | 2.37k  |   return false;  | 
688  | 12.2k  | }  | 
689  |  |  | 
690  | 256k  | bool InstructionFolder::FoldInstruction(Instruction* inst) const { | 
691  | 256k  |   bool modified = false;  | 
692  | 256k  |   Instruction* folded_inst(inst);  | 
693  | 264k  |   while (folded_inst->opcode() != spv::Op::OpCopyObject &&  | 
694  | 262k  |          FoldInstructionInternal(&*folded_inst)) { | 
695  | 7.43k  |     modified = true;  | 
696  | 7.43k  |   }  | 
697  | 256k  |   return modified;  | 
698  | 256k  | }  | 
699  |  |  | 
700  |  | }  // namespace opt  | 
701  |  | }  // namespace spvtools  |