/src/spirv-tools/source/opt/const_folding_rules.cpp
Line | Count | Source |
1 | | // Copyright (c) 2018 Google LLC |
2 | | // |
3 | | // Licensed under the Apache License, Version 2.0 (the "License"); |
4 | | // you may not use this file except in compliance with the License. |
5 | | // You may obtain a copy of the License at |
6 | | // |
7 | | // http://www.apache.org/licenses/LICENSE-2.0 |
8 | | // |
9 | | // Unless required by applicable law or agreed to in writing, software |
10 | | // distributed under the License is distributed on an "AS IS" BASIS, |
11 | | // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. |
12 | | // See the License for the specific language governing permissions and |
13 | | // limitations under the License. |
14 | | |
15 | | #include "source/opt/const_folding_rules.h" |
16 | | |
17 | | #include <optional> |
18 | | |
19 | | #include "source/opt/ir_context.h" |
20 | | |
21 | | namespace spvtools { |
22 | | namespace opt { |
23 | | namespace { |
24 | | constexpr uint32_t kExtractCompositeIdInIdx = 0; |
25 | | |
26 | | // Returns a constants with the value NaN of the given type. Only works for |
27 | | // 32-bit and 64-bit float point types. Returns |nullptr| if an error occurs. |
28 | | const analysis::Constant* GetNan(const analysis::Type* type, |
29 | 1.41k | analysis::ConstantManager* const_mgr) { |
30 | 1.41k | const analysis::Float* float_type = type->AsFloat(); |
31 | 1.41k | if (float_type == nullptr) { |
32 | 0 | return nullptr; |
33 | 0 | } |
34 | | |
35 | 1.41k | switch (float_type->width()) { |
36 | 1.41k | case 32: |
37 | 1.41k | return const_mgr->GetFloatConst(std::numeric_limits<float>::quiet_NaN()); |
38 | 0 | case 64: |
39 | 0 | return const_mgr->GetDoubleConst( |
40 | 0 | std::numeric_limits<double>::quiet_NaN()); |
41 | 0 | default: |
42 | 0 | return nullptr; |
43 | 1.41k | } |
44 | 1.41k | } |
45 | | |
46 | | // Returns a constants with the value INF of the given type. Only works for |
47 | | // 32-bit and 64-bit float point types. Returns |nullptr| if an error occurs. |
48 | | const analysis::Constant* GetInf(const analysis::Type* type, |
49 | 1.06k | analysis::ConstantManager* const_mgr) { |
50 | 1.06k | const analysis::Float* float_type = type->AsFloat(); |
51 | 1.06k | if (float_type == nullptr) { |
52 | 0 | return nullptr; |
53 | 0 | } |
54 | | |
55 | 1.06k | switch (float_type->width()) { |
56 | 1.06k | case 32: |
57 | 1.06k | return const_mgr->GetFloatConst(std::numeric_limits<float>::infinity()); |
58 | 0 | case 64: |
59 | 0 | return const_mgr->GetDoubleConst(std::numeric_limits<double>::infinity()); |
60 | 0 | default: |
61 | 0 | return nullptr; |
62 | 1.06k | } |
63 | 1.06k | } |
64 | | |
65 | | // Returns true if |type| is Float or a vector of Float. |
66 | 522 | bool HasFloatingPoint(const analysis::Type* type) { |
67 | 522 | if (type->AsFloat()) { |
68 | 0 | return true; |
69 | 522 | } else if (const analysis::Vector* vec_type = type->AsVector()) { |
70 | 522 | return vec_type->element_type()->AsFloat() != nullptr; |
71 | 522 | } |
72 | | |
73 | 0 | return false; |
74 | 522 | } |
75 | | |
76 | | // Returns a constants with the value |-val| of the given type. Only works for |
77 | | // 32-bit and 64-bit float point types. Returns |nullptr| if an error occurs. |
78 | | const analysis::Constant* NegateFPConst(const analysis::Type* result_type, |
79 | | const analysis::Constant* val, |
80 | 911 | analysis::ConstantManager* const_mgr) { |
81 | 911 | const analysis::Float* float_type = result_type->AsFloat(); |
82 | 911 | assert(float_type != nullptr); |
83 | 911 | if (float_type->width() == 32) { |
84 | 911 | float fa = val->GetFloat(); |
85 | 911 | return const_mgr->GetFloatConst(-fa); |
86 | 911 | } else if (float_type->width() == 64) { |
87 | 0 | double da = val->GetDouble(); |
88 | 0 | return const_mgr->GetDoubleConst(-da); |
89 | 0 | } |
90 | 0 | return nullptr; |
91 | 911 | } |
92 | | |
93 | | // Returns a constants with the value |-val| of the given type. |
94 | | const analysis::Constant* NegateIntConst(const analysis::Type* result_type, |
95 | | const analysis::Constant* val, |
96 | 2.81k | analysis::ConstantManager* const_mgr) { |
97 | 2.81k | const analysis::Integer* int_type = result_type->AsInteger(); |
98 | 2.81k | assert(int_type != nullptr); |
99 | | |
100 | 2.81k | if (val->AsNullConstant()) { |
101 | 7 | return val; |
102 | 7 | } |
103 | | |
104 | 2.80k | uint64_t new_value = static_cast<uint64_t>(-val->GetSignExtendedValue()); |
105 | 2.80k | return const_mgr->GetIntConst(new_value, int_type->width(), |
106 | 2.80k | int_type->IsSigned()); |
107 | 2.81k | } |
108 | | |
109 | | // Folds an OpcompositeExtract where input is a composite constant. |
110 | 16.0k | ConstantFoldingRule FoldExtractWithConstants() { |
111 | 16.0k | return [](IRContext* context, Instruction* inst, |
112 | 16.0k | const std::vector<const analysis::Constant*>& constants) |
113 | 472k | -> const analysis::Constant* { |
114 | 472k | const analysis::Constant* c = constants[kExtractCompositeIdInIdx]; |
115 | 472k | if (c == nullptr) { |
116 | 427k | return nullptr; |
117 | 427k | } |
118 | | |
119 | 88.1k | for (uint32_t i = 1; i < inst->NumInOperands(); ++i) { |
120 | 44.1k | uint32_t element_index = inst->GetSingleWordInOperand(i); |
121 | 44.1k | if (c->AsNullConstant()) { |
122 | | // Return Null for the return type. |
123 | 181 | analysis::ConstantManager* const_mgr = context->get_constant_mgr(); |
124 | 181 | analysis::TypeManager* type_mgr = context->get_type_mgr(); |
125 | 181 | return const_mgr->GetConstant(type_mgr->GetType(inst->type_id()), {}); |
126 | 181 | } |
127 | | |
128 | 43.9k | auto cc = c->AsCompositeConstant(); |
129 | 43.9k | assert(cc != nullptr); |
130 | 43.9k | auto components = cc->GetComponents(); |
131 | | // Protect against invalid IR. Refuse to fold if the index is out |
132 | | // of bounds. |
133 | 43.9k | if (element_index >= components.size()) return nullptr; |
134 | 43.9k | c = components[element_index]; |
135 | 43.9k | } |
136 | 43.9k | return c; |
137 | 44.1k | }; |
138 | 16.0k | } |
139 | | |
140 | | // Folds an OpcompositeInsert where input is a composite constant. |
141 | 16.0k | ConstantFoldingRule FoldInsertWithConstants() { |
142 | 16.0k | return [](IRContext* context, Instruction* inst, |
143 | 16.0k | const std::vector<const analysis::Constant*>& constants) |
144 | 102k | -> const analysis::Constant* { |
145 | 102k | analysis::ConstantManager* const_mgr = context->get_constant_mgr(); |
146 | 102k | const analysis::Constant* object = constants[0]; |
147 | 102k | const analysis::Constant* composite = constants[1]; |
148 | 102k | if (object == nullptr || composite == nullptr) { |
149 | 97.5k | return nullptr; |
150 | 97.5k | } |
151 | | |
152 | | // If there is more than 1 index, then each additional constant used by the |
153 | | // index will need to be recreated to use the inserted object. |
154 | 4.86k | std::vector<const analysis::Constant*> chain; |
155 | 4.86k | std::vector<const analysis::Constant*> components; |
156 | 4.86k | const analysis::Type* type = nullptr; |
157 | 4.86k | const uint32_t final_index = (inst->NumInOperands() - 1); |
158 | | |
159 | | // Work down hierarchy of all indexes |
160 | 9.74k | for (uint32_t i = 2; i < inst->NumInOperands(); ++i) { |
161 | 4.87k | type = composite->type(); |
162 | | |
163 | 4.87k | if (composite->AsNullConstant()) { |
164 | | // Make new composite so it can be inserted in the index with the |
165 | | // non-null value |
166 | 22 | if (const auto new_composite = |
167 | 22 | const_mgr->GetNullCompositeConstant(type)) { |
168 | | // Keep track of any indexes along the way to last index |
169 | 22 | if (i != final_index) { |
170 | 0 | chain.push_back(new_composite); |
171 | 0 | } |
172 | 22 | components = new_composite->AsCompositeConstant()->GetComponents(); |
173 | 22 | } else { |
174 | | // Unsupported input type (such as structs) |
175 | 0 | return nullptr; |
176 | 0 | } |
177 | 4.85k | } else { |
178 | | // Keep track of any indexes along the way to last index |
179 | 4.85k | if (i != final_index) { |
180 | 4 | chain.push_back(composite); |
181 | 4 | } |
182 | 4.85k | components = composite->AsCompositeConstant()->GetComponents(); |
183 | 4.85k | } |
184 | 4.87k | const uint32_t index = inst->GetSingleWordInOperand(i); |
185 | 4.87k | composite = components[index]; |
186 | 4.87k | } |
187 | | |
188 | | // Final index in hierarchy is inserted with new object. |
189 | 4.86k | const uint32_t final_operand = inst->GetSingleWordInOperand(final_index); |
190 | 4.86k | std::vector<uint32_t> ids; |
191 | 14.7k | for (size_t i = 0; i < components.size(); i++) { |
192 | 9.90k | const analysis::Constant* constant = |
193 | 9.90k | (i == final_operand) ? object : components[i]; |
194 | 9.90k | Instruction* member_inst = const_mgr->GetDefiningInstruction(constant); |
195 | 9.90k | ids.push_back(member_inst->result_id()); |
196 | 9.90k | } |
197 | 4.86k | const analysis::Constant* new_constant = const_mgr->GetConstant(type, ids); |
198 | | |
199 | | // Work backwards up the chain and replace each index with new constant. |
200 | 4.87k | for (size_t i = chain.size(); i > 0; i--) { |
201 | | // Need to insert any previous instruction into the module first. |
202 | | // Can't just insert in types_values_begin() because it will move above |
203 | | // where the types are declared. |
204 | | // Can't compare with location of inst because not all new added |
205 | | // instructions are added to types_values_ |
206 | 4 | auto iter = context->types_values_end(); |
207 | 4 | Module::inst_iterator* pos = &iter; |
208 | 4 | const_mgr->BuildInstructionAndAddToModule(new_constant, pos); |
209 | | |
210 | 4 | composite = chain[i - 1]; |
211 | 4 | components = composite->AsCompositeConstant()->GetComponents(); |
212 | 4 | type = composite->type(); |
213 | 4 | ids.clear(); |
214 | 8 | for (size_t k = 0; k < components.size(); k++) { |
215 | 4 | const uint32_t index = |
216 | 4 | inst->GetSingleWordInOperand(1 + static_cast<uint32_t>(i)); |
217 | 4 | const analysis::Constant* constant = |
218 | 4 | (k == index) ? new_constant : components[k]; |
219 | 4 | const uint32_t constant_id = |
220 | 4 | const_mgr->FindDeclaredConstant(constant, 0); |
221 | 4 | ids.push_back(constant_id); |
222 | 4 | } |
223 | 4 | new_constant = const_mgr->GetConstant(type, ids); |
224 | 4 | } |
225 | | |
226 | | // If multiple constants were created, only need to return the top index. |
227 | 4.86k | return new_constant; |
228 | 4.86k | }; |
229 | 16.0k | } |
230 | | |
231 | 16.0k | ConstantFoldingRule FoldVectorShuffleWithConstants() { |
232 | 16.0k | return [](IRContext* context, Instruction* inst, |
233 | 16.0k | const std::vector<const analysis::Constant*>& constants) |
234 | 34.9k | -> const analysis::Constant* { |
235 | 34.9k | assert(inst->opcode() == spv::Op::OpVectorShuffle); |
236 | 34.9k | const analysis::Constant* c1 = constants[0]; |
237 | 34.9k | const analysis::Constant* c2 = constants[1]; |
238 | 34.9k | if (c1 == nullptr || c2 == nullptr) { |
239 | 34.6k | return nullptr; |
240 | 34.6k | } |
241 | | |
242 | 293 | analysis::ConstantManager* const_mgr = context->get_constant_mgr(); |
243 | 293 | const analysis::Type* element_type = c1->type()->AsVector()->element_type(); |
244 | | |
245 | 293 | std::vector<const analysis::Constant*> c1_components; |
246 | 293 | if (const analysis::VectorConstant* vec_const = c1->AsVectorConstant()) { |
247 | 253 | c1_components = vec_const->GetComponents(); |
248 | 253 | } else { |
249 | 40 | assert(c1->AsNullConstant()); |
250 | 40 | const analysis::Constant* element = |
251 | 40 | const_mgr->GetConstant(element_type, {}); |
252 | 40 | c1_components.resize(c1->type()->AsVector()->element_count(), element); |
253 | 40 | } |
254 | 293 | std::vector<const analysis::Constant*> c2_components; |
255 | 293 | if (const analysis::VectorConstant* vec_const = c2->AsVectorConstant()) { |
256 | 217 | c2_components = vec_const->GetComponents(); |
257 | 217 | } else { |
258 | 76 | assert(c2->AsNullConstant()); |
259 | 76 | const analysis::Constant* element = |
260 | 76 | const_mgr->GetConstant(element_type, {}); |
261 | 76 | c2_components.resize(c2->type()->AsVector()->element_count(), element); |
262 | 76 | } |
263 | | |
264 | 293 | std::vector<uint32_t> ids; |
265 | 293 | const uint32_t undef_literal_value = 0xffffffff; |
266 | 656 | for (uint32_t i = 2; i < inst->NumInOperands(); ++i) { |
267 | 508 | uint32_t index = inst->GetSingleWordInOperand(i); |
268 | 508 | if (index == undef_literal_value) { |
269 | | // Don't fold shuffle with undef literal value. |
270 | 145 | return nullptr; |
271 | 363 | } else if (index < c1_components.size()) { |
272 | 274 | Instruction* member_inst = |
273 | 274 | const_mgr->GetDefiningInstruction(c1_components[index]); |
274 | 274 | ids.push_back(member_inst->result_id()); |
275 | 274 | } else { |
276 | 89 | Instruction* member_inst = const_mgr->GetDefiningInstruction( |
277 | 89 | c2_components[index - c1_components.size()]); |
278 | 89 | ids.push_back(member_inst->result_id()); |
279 | 89 | } |
280 | 508 | } |
281 | | |
282 | 148 | analysis::TypeManager* type_mgr = context->get_type_mgr(); |
283 | 148 | return const_mgr->GetConstant(type_mgr->GetType(inst->type_id()), ids); |
284 | 293 | }; |
285 | 16.0k | } |
286 | | |
287 | 16.0k | ConstantFoldingRule FoldVectorTimesScalar() { |
288 | 16.0k | return [](IRContext* context, Instruction* inst, |
289 | 16.0k | const std::vector<const analysis::Constant*>& constants) |
290 | 398k | -> const analysis::Constant* { |
291 | 398k | assert(inst->opcode() == spv::Op::OpVectorTimesScalar); |
292 | 398k | analysis::ConstantManager* const_mgr = context->get_constant_mgr(); |
293 | 398k | analysis::TypeManager* type_mgr = context->get_type_mgr(); |
294 | | |
295 | 398k | if (!inst->IsFloatingPointFoldingAllowed()) { |
296 | 522 | if (HasFloatingPoint(type_mgr->GetType(inst->type_id()))) { |
297 | 522 | return nullptr; |
298 | 522 | } |
299 | 522 | } |
300 | | |
301 | 397k | const analysis::Constant* c1 = constants[0]; |
302 | 397k | const analysis::Constant* c2 = constants[1]; |
303 | | |
304 | 397k | if (c1 && c1->IsZero()) { |
305 | 1.76k | return c1; |
306 | 1.76k | } |
307 | | |
308 | 395k | if (c2 && c2->IsZero()) { |
309 | | // Get or create the NullConstant for this type. |
310 | 4.64k | std::vector<uint32_t> ids; |
311 | 4.64k | return const_mgr->GetConstant(type_mgr->GetType(inst->type_id()), ids); |
312 | 4.64k | } |
313 | | |
314 | 391k | if (c1 == nullptr || c2 == nullptr) { |
315 | 345k | return nullptr; |
316 | 345k | } |
317 | | |
318 | | // Check result type. |
319 | 45.9k | const analysis::Type* result_type = type_mgr->GetType(inst->type_id()); |
320 | 45.9k | const analysis::Vector* vector_type = result_type->AsVector(); |
321 | 45.9k | assert(vector_type != nullptr); |
322 | 45.9k | const analysis::Type* element_type = vector_type->element_type(); |
323 | 45.9k | assert(element_type != nullptr); |
324 | 45.9k | const analysis::Float* float_type = element_type->AsFloat(); |
325 | 45.9k | assert(float_type != nullptr); |
326 | | |
327 | | // Check types of c1 and c2. |
328 | 45.9k | assert(c1->type()->AsVector() == vector_type); |
329 | 45.9k | assert(c1->type()->AsVector()->element_type() == element_type && |
330 | 45.9k | c2->type() == element_type); |
331 | | |
332 | | // Get a float vector that is the result of vector-times-scalar. |
333 | 45.9k | std::vector<const analysis::Constant*> c1_components = |
334 | 45.9k | c1->GetVectorComponents(const_mgr); |
335 | 45.9k | std::vector<uint32_t> ids; |
336 | 45.9k | if (float_type->width() == 32) { |
337 | 45.9k | float scalar = c2->GetFloat(); |
338 | 227k | for (uint32_t i = 0; i < c1_components.size(); ++i) { |
339 | 181k | utils::FloatProxy<float> result(c1_components[i]->GetFloat() * scalar); |
340 | 181k | std::vector<uint32_t> words = result.GetWords(); |
341 | 181k | const analysis::Constant* new_elem = |
342 | 181k | const_mgr->GetConstant(float_type, words); |
343 | 181k | ids.push_back(const_mgr->GetDefiningInstruction(new_elem)->result_id()); |
344 | 181k | } |
345 | 45.9k | return const_mgr->GetConstant(vector_type, ids); |
346 | 45.9k | } else if (float_type->width() == 64) { |
347 | 0 | double scalar = c2->GetDouble(); |
348 | 0 | for (uint32_t i = 0; i < c1_components.size(); ++i) { |
349 | 0 | utils::FloatProxy<double> result(c1_components[i]->GetDouble() * |
350 | 0 | scalar); |
351 | 0 | std::vector<uint32_t> words = result.GetWords(); |
352 | 0 | const analysis::Constant* new_elem = |
353 | 0 | const_mgr->GetConstant(float_type, words); |
354 | 0 | ids.push_back(const_mgr->GetDefiningInstruction(new_elem)->result_id()); |
355 | 0 | } |
356 | 0 | return const_mgr->GetConstant(vector_type, ids); |
357 | 0 | } |
358 | 0 | return nullptr; |
359 | 45.9k | }; |
360 | 16.0k | } |
361 | | |
362 | | // Returns to the constant that results from tranposing |matrix|. The result |
363 | | // will have type |result_type|, and |matrix| must exist in |context|. The |
364 | | // result constant will also exist in |context|. |
365 | | const analysis::Constant* TransposeMatrix(const analysis::Constant* matrix, |
366 | | analysis::Matrix* result_type, |
367 | 0 | IRContext* context) { |
368 | 0 | analysis::ConstantManager* const_mgr = context->get_constant_mgr(); |
369 | 0 | if (matrix->AsNullConstant() != nullptr) { |
370 | 0 | return const_mgr->GetNullCompositeConstant(result_type); |
371 | 0 | } |
372 | | |
373 | 0 | const auto& columns = matrix->AsMatrixConstant()->GetComponents(); |
374 | 0 | uint32_t number_of_rows = columns[0]->type()->AsVector()->element_count(); |
375 | | |
376 | | // Collect the ids of the elements in their new positions. |
377 | 0 | std::vector<std::vector<uint32_t>> result_elements(number_of_rows); |
378 | 0 | for (const analysis::Constant* column : columns) { |
379 | 0 | if (column->AsNullConstant()) { |
380 | 0 | column = const_mgr->GetNullCompositeConstant(column->type()); |
381 | 0 | } |
382 | 0 | const auto& column_components = column->AsVectorConstant()->GetComponents(); |
383 | |
|
384 | 0 | for (uint32_t row = 0; row < number_of_rows; ++row) { |
385 | 0 | result_elements[row].push_back( |
386 | 0 | const_mgr->GetDefiningInstruction(column_components[row]) |
387 | 0 | ->result_id()); |
388 | 0 | } |
389 | 0 | } |
390 | | |
391 | | // Create the constant for each row in the result, and collect the ids. |
392 | 0 | std::vector<uint32_t> result_columns(number_of_rows); |
393 | 0 | for (uint32_t col = 0; col < number_of_rows; ++col) { |
394 | 0 | auto* element = const_mgr->GetConstant(result_type->element_type(), |
395 | 0 | result_elements[col]); |
396 | 0 | result_columns[col] = |
397 | 0 | const_mgr->GetDefiningInstruction(element)->result_id(); |
398 | 0 | } |
399 | | |
400 | | // Create the matrix constant from the row ids, and return it. |
401 | 0 | return const_mgr->GetConstant(result_type, result_columns); |
402 | 0 | } |
403 | | |
404 | | const analysis::Constant* FoldTranspose( |
405 | | IRContext* context, Instruction* inst, |
406 | 0 | const std::vector<const analysis::Constant*>& constants) { |
407 | 0 | assert(inst->opcode() == spv::Op::OpTranspose); |
408 | | |
409 | 0 | analysis::TypeManager* type_mgr = context->get_type_mgr(); |
410 | 0 | if (!inst->IsFloatingPointFoldingAllowed()) { |
411 | 0 | if (HasFloatingPoint(type_mgr->GetType(inst->type_id()))) { |
412 | 0 | return nullptr; |
413 | 0 | } |
414 | 0 | } |
415 | | |
416 | 0 | const analysis::Constant* matrix = constants[0]; |
417 | 0 | if (matrix == nullptr) { |
418 | 0 | return nullptr; |
419 | 0 | } |
420 | | |
421 | 0 | auto* result_type = type_mgr->GetType(inst->type_id()); |
422 | 0 | return TransposeMatrix(matrix, result_type->AsMatrix(), context); |
423 | 0 | } |
424 | | |
425 | 16.0k | ConstantFoldingRule FoldVectorTimesMatrix() { |
426 | 16.0k | return [](IRContext* context, Instruction* inst, |
427 | 16.0k | const std::vector<const analysis::Constant*>& constants) |
428 | 16.0k | -> const analysis::Constant* { |
429 | 6.24k | assert(inst->opcode() == spv::Op::OpVectorTimesMatrix); |
430 | 6.24k | analysis::ConstantManager* const_mgr = context->get_constant_mgr(); |
431 | 6.24k | analysis::TypeManager* type_mgr = context->get_type_mgr(); |
432 | | |
433 | 6.24k | if (!inst->IsFloatingPointFoldingAllowed()) { |
434 | 0 | if (HasFloatingPoint(type_mgr->GetType(inst->type_id()))) { |
435 | 0 | return nullptr; |
436 | 0 | } |
437 | 0 | } |
438 | | |
439 | 6.24k | const analysis::Constant* c1 = constants[0]; |
440 | 6.24k | const analysis::Constant* c2 = constants[1]; |
441 | | |
442 | 6.24k | if (c1 == nullptr || c2 == nullptr) { |
443 | 6.18k | return nullptr; |
444 | 6.18k | } |
445 | | |
446 | | // Check result type. |
447 | 61 | const analysis::Type* result_type = type_mgr->GetType(inst->type_id()); |
448 | 61 | const analysis::Vector* vector_type = result_type->AsVector(); |
449 | 61 | assert(vector_type != nullptr); |
450 | 61 | const analysis::Type* element_type = vector_type->element_type(); |
451 | 61 | assert(element_type != nullptr); |
452 | 61 | const analysis::Float* float_type = element_type->AsFloat(); |
453 | 61 | assert(float_type != nullptr); |
454 | | |
455 | | // Check types of c1 and c2. |
456 | 61 | assert(c1->type()->AsVector() == vector_type); |
457 | 61 | assert(c1->type()->AsVector()->element_type() == element_type && |
458 | 61 | c2->type()->AsMatrix()->element_type() == vector_type); |
459 | | |
460 | 61 | uint32_t resultVectorSize = result_type->AsVector()->element_count(); |
461 | 61 | std::vector<uint32_t> ids; |
462 | | |
463 | 61 | if ((c1 && c1->IsZero()) || (c2 && c2->IsZero())) { |
464 | 43 | std::vector<uint32_t> words(float_type->width() / 32, 0); |
465 | 129 | for (uint32_t i = 0; i < resultVectorSize; ++i) { |
466 | 86 | const analysis::Constant* new_elem = |
467 | 86 | const_mgr->GetConstant(float_type, words); |
468 | 86 | ids.push_back(const_mgr->GetDefiningInstruction(new_elem)->result_id()); |
469 | 86 | } |
470 | 43 | return const_mgr->GetConstant(vector_type, ids); |
471 | 43 | } |
472 | | |
473 | | // Get a float vector that is the result of vector-times-matrix. |
474 | 18 | std::vector<const analysis::Constant*> c1_components = |
475 | 18 | c1->GetVectorComponents(const_mgr); |
476 | 18 | std::vector<const analysis::Constant*> c2_components = |
477 | 18 | c2->AsMatrixConstant()->GetComponents(); |
478 | | |
479 | 18 | if (float_type->width() == 32) { |
480 | 54 | for (uint32_t i = 0; i < resultVectorSize; ++i) { |
481 | 36 | float result_scalar = 0.0f; |
482 | 36 | if (!c2_components[i]->AsNullConstant()) { |
483 | 28 | const analysis::VectorConstant* c2_vec = |
484 | 28 | c2_components[i]->AsVectorConstant(); |
485 | 84 | for (uint32_t j = 0; j < c2_vec->GetComponents().size(); ++j) { |
486 | 56 | float c1_scalar = c1_components[j]->GetFloat(); |
487 | 56 | float c2_scalar = c2_vec->GetComponents()[j]->GetFloat(); |
488 | 56 | result_scalar += c1_scalar * c2_scalar; |
489 | 56 | } |
490 | 28 | } |
491 | 36 | utils::FloatProxy<float> result(result_scalar); |
492 | 36 | std::vector<uint32_t> words = result.GetWords(); |
493 | 36 | const analysis::Constant* new_elem = |
494 | 36 | const_mgr->GetConstant(float_type, words); |
495 | 36 | ids.push_back(const_mgr->GetDefiningInstruction(new_elem)->result_id()); |
496 | 36 | } |
497 | 18 | return const_mgr->GetConstant(vector_type, ids); |
498 | 18 | } else if (float_type->width() == 64) { |
499 | 0 | for (uint32_t i = 0; i < c2_components.size(); ++i) { |
500 | 0 | double result_scalar = 0.0; |
501 | 0 | if (!c2_components[i]->AsNullConstant()) { |
502 | 0 | const analysis::VectorConstant* c2_vec = |
503 | 0 | c2_components[i]->AsVectorConstant(); |
504 | 0 | for (uint32_t j = 0; j < c2_vec->GetComponents().size(); ++j) { |
505 | 0 | double c1_scalar = c1_components[j]->GetDouble(); |
506 | 0 | double c2_scalar = c2_vec->GetComponents()[j]->GetDouble(); |
507 | 0 | result_scalar += c1_scalar * c2_scalar; |
508 | 0 | } |
509 | 0 | } |
510 | 0 | utils::FloatProxy<double> result(result_scalar); |
511 | 0 | std::vector<uint32_t> words = result.GetWords(); |
512 | 0 | const analysis::Constant* new_elem = |
513 | 0 | const_mgr->GetConstant(float_type, words); |
514 | 0 | ids.push_back(const_mgr->GetDefiningInstruction(new_elem)->result_id()); |
515 | 0 | } |
516 | 0 | return const_mgr->GetConstant(vector_type, ids); |
517 | 0 | } |
518 | 0 | return nullptr; |
519 | 18 | }; |
520 | 16.0k | } |
521 | | |
522 | 16.0k | ConstantFoldingRule FoldMatrixTimesVector() { |
523 | 16.0k | return [](IRContext* context, Instruction* inst, |
524 | 16.0k | const std::vector<const analysis::Constant*>& constants) |
525 | 16.0k | -> const analysis::Constant* { |
526 | 185 | assert(inst->opcode() == spv::Op::OpMatrixTimesVector); |
527 | 185 | analysis::ConstantManager* const_mgr = context->get_constant_mgr(); |
528 | 185 | analysis::TypeManager* type_mgr = context->get_type_mgr(); |
529 | | |
530 | 185 | if (!inst->IsFloatingPointFoldingAllowed()) { |
531 | 0 | if (HasFloatingPoint(type_mgr->GetType(inst->type_id()))) { |
532 | 0 | return nullptr; |
533 | 0 | } |
534 | 0 | } |
535 | | |
536 | 185 | const analysis::Constant* c1 = constants[0]; |
537 | 185 | const analysis::Constant* c2 = constants[1]; |
538 | | |
539 | 185 | if (c1 == nullptr || c2 == nullptr) { |
540 | 183 | return nullptr; |
541 | 183 | } |
542 | | |
543 | | // Check result type. |
544 | 2 | const analysis::Type* result_type = type_mgr->GetType(inst->type_id()); |
545 | 2 | const analysis::Vector* vector_type = result_type->AsVector(); |
546 | 2 | assert(vector_type != nullptr); |
547 | 2 | const analysis::Type* element_type = vector_type->element_type(); |
548 | 2 | assert(element_type != nullptr); |
549 | 2 | const analysis::Float* float_type = element_type->AsFloat(); |
550 | 2 | assert(float_type != nullptr); |
551 | | |
552 | | // Check types of c1 and c2. |
553 | 2 | assert(c1->type()->AsMatrix()->element_type() == vector_type); |
554 | 2 | assert(c2->type()->AsVector()->element_type() == element_type); |
555 | | |
556 | 2 | uint32_t resultVectorSize = result_type->AsVector()->element_count(); |
557 | 2 | std::vector<uint32_t> ids; |
558 | | |
559 | 2 | if ((c1 && c1->IsZero()) || (c2 && c2->IsZero())) { |
560 | 2 | std::vector<uint32_t> words(float_type->width() / 32, 0); |
561 | 6 | for (uint32_t i = 0; i < resultVectorSize; ++i) { |
562 | 4 | const analysis::Constant* new_elem = |
563 | 4 | const_mgr->GetConstant(float_type, words); |
564 | 4 | ids.push_back(const_mgr->GetDefiningInstruction(new_elem)->result_id()); |
565 | 4 | } |
566 | 2 | return const_mgr->GetConstant(vector_type, ids); |
567 | 2 | } |
568 | | |
569 | | // Get a float vector that is the result of matrix-times-vector. |
570 | 0 | std::vector<const analysis::Constant*> c1_components = |
571 | 0 | c1->AsMatrixConstant()->GetComponents(); |
572 | 0 | std::vector<const analysis::Constant*> c2_components = |
573 | 0 | c2->GetVectorComponents(const_mgr); |
574 | |
|
575 | 0 | if (float_type->width() == 32) { |
576 | 0 | for (uint32_t i = 0; i < resultVectorSize; ++i) { |
577 | 0 | float result_scalar = 0.0f; |
578 | 0 | for (uint32_t j = 0; j < c1_components.size(); ++j) { |
579 | 0 | if (!c1_components[j]->AsNullConstant()) { |
580 | 0 | float c1_scalar = c1_components[j] |
581 | 0 | ->AsVectorConstant() |
582 | 0 | ->GetComponents()[i] |
583 | 0 | ->GetFloat(); |
584 | 0 | float c2_scalar = c2_components[j]->GetFloat(); |
585 | 0 | result_scalar += c1_scalar * c2_scalar; |
586 | 0 | } |
587 | 0 | } |
588 | 0 | utils::FloatProxy<float> result(result_scalar); |
589 | 0 | std::vector<uint32_t> words = result.GetWords(); |
590 | 0 | const analysis::Constant* new_elem = |
591 | 0 | const_mgr->GetConstant(float_type, words); |
592 | 0 | ids.push_back(const_mgr->GetDefiningInstruction(new_elem)->result_id()); |
593 | 0 | } |
594 | 0 | return const_mgr->GetConstant(vector_type, ids); |
595 | 0 | } else if (float_type->width() == 64) { |
596 | 0 | for (uint32_t i = 0; i < resultVectorSize; ++i) { |
597 | 0 | double result_scalar = 0.0; |
598 | 0 | for (uint32_t j = 0; j < c1_components.size(); ++j) { |
599 | 0 | if (!c1_components[j]->AsNullConstant()) { |
600 | 0 | double c1_scalar = c1_components[j] |
601 | 0 | ->AsVectorConstant() |
602 | 0 | ->GetComponents()[i] |
603 | 0 | ->GetDouble(); |
604 | 0 | double c2_scalar = c2_components[j]->GetDouble(); |
605 | 0 | result_scalar += c1_scalar * c2_scalar; |
606 | 0 | } |
607 | 0 | } |
608 | 0 | utils::FloatProxy<double> result(result_scalar); |
609 | 0 | std::vector<uint32_t> words = result.GetWords(); |
610 | 0 | const analysis::Constant* new_elem = |
611 | 0 | const_mgr->GetConstant(float_type, words); |
612 | 0 | ids.push_back(const_mgr->GetDefiningInstruction(new_elem)->result_id()); |
613 | 0 | } |
614 | 0 | return const_mgr->GetConstant(vector_type, ids); |
615 | 0 | } |
616 | 0 | return nullptr; |
617 | 0 | }; |
618 | 16.0k | } |
619 | | |
620 | 16.0k | ConstantFoldingRule FoldCompositeWithConstants() { |
621 | | // Folds an OpCompositeConstruct where all of the inputs are constants to a |
622 | | // constant. A new constant is created if necessary. |
623 | 16.0k | return [](IRContext* context, Instruction* inst, |
624 | 16.0k | const std::vector<const analysis::Constant*>& constants) |
625 | 282k | -> const analysis::Constant* { |
626 | 282k | analysis::ConstantManager* const_mgr = context->get_constant_mgr(); |
627 | 282k | analysis::TypeManager* type_mgr = context->get_type_mgr(); |
628 | 282k | const analysis::Type* new_type = type_mgr->GetType(inst->type_id()); |
629 | 282k | Instruction* type_inst = |
630 | 282k | context->get_def_use_mgr()->GetDef(inst->type_id()); |
631 | | |
632 | 282k | std::vector<uint32_t> ids; |
633 | 655k | for (uint32_t i = 0; i < constants.size(); ++i) { |
634 | 497k | const analysis::Constant* element_const = constants[i]; |
635 | 497k | if (element_const == nullptr) { |
636 | 124k | return nullptr; |
637 | 124k | } |
638 | | |
639 | 373k | uint32_t component_type_id = 0; |
640 | 373k | if (type_inst->opcode() == spv::Op::OpTypeStruct) { |
641 | 23.6k | component_type_id = type_inst->GetSingleWordInOperand(i); |
642 | 349k | } else if (type_inst->opcode() == spv::Op::OpTypeArray) { |
643 | 552 | component_type_id = type_inst->GetSingleWordInOperand(0); |
644 | 552 | } |
645 | | |
646 | 373k | uint32_t element_id = |
647 | 373k | const_mgr->FindDeclaredConstant(element_const, component_type_id); |
648 | 373k | if (element_id == 0) { |
649 | 178 | return nullptr; |
650 | 178 | } |
651 | 373k | ids.push_back(element_id); |
652 | 373k | } |
653 | 157k | return const_mgr->GetConstant(new_type, ids); |
654 | 282k | }; |
655 | 16.0k | } |
656 | | |
657 | | // The interface for a function that returns the result of applying a scalar |
658 | | // floating-point binary operation on |a| and |b|. The type of the return value |
659 | | // will be |type|. The input constants must also be of type |type|. |
660 | | using UnaryScalarFoldingRule = std::function<const analysis::Constant*( |
661 | | const analysis::Type* result_type, const analysis::Constant* a, |
662 | | analysis::ConstantManager*)>; |
663 | | |
664 | | // The interface for a function that returns the result of applying a scalar |
665 | | // floating-point binary operation on |a| and |b|. The type of the return value |
666 | | // will be |type|. The input constants must also be of type |type|. |
667 | | using BinaryScalarFoldingRule = std::function<const analysis::Constant*( |
668 | | const analysis::Type* result_type, const analysis::Constant* a, |
669 | | const analysis::Constant* b, analysis::ConstantManager*)>; |
670 | | |
671 | | // Returns a |ConstantFoldingRule| that folds unary scalar ops |
672 | | // using |scalar_rule| and unary vectors ops by applying |
673 | | // |scalar_rule| to the elements of the vector. The |ConstantFoldingRule| |
674 | | // that is returned assumes that |constants| contains 1 entry. If they are |
675 | | // not |nullptr|, then their type is either |Float| or |Integer| or a |Vector| |
676 | | // whose element type is |Float| or |Integer|. |
677 | 250k | ConstantFoldingRule FoldUnaryOp(UnaryScalarFoldingRule scalar_rule) { |
678 | 250k | return [scalar_rule](IRContext* context, Instruction* inst, |
679 | 250k | const std::vector<const analysis::Constant*>& constants) |
680 | 455k | -> const analysis::Constant* { |
681 | 455k | analysis::ConstantManager* const_mgr = context->get_constant_mgr(); |
682 | 455k | analysis::TypeManager* type_mgr = context->get_type_mgr(); |
683 | 455k | const analysis::Type* result_type = type_mgr->GetType(inst->type_id()); |
684 | 455k | const analysis::Vector* vector_type = result_type->AsVector(); |
685 | | |
686 | 455k | const analysis::Constant* arg = |
687 | 455k | (inst->opcode() == spv::Op::OpExtInst) ? constants[1] : constants[0]; |
688 | | |
689 | 455k | if (arg == nullptr) { |
690 | 119k | return nullptr; |
691 | 119k | } |
692 | | |
693 | 335k | if (vector_type != nullptr) { |
694 | 1.50k | std::vector<const analysis::Constant*> a_components; |
695 | 1.50k | std::vector<const analysis::Constant*> results_components; |
696 | | |
697 | 1.50k | a_components = arg->GetVectorComponents(const_mgr); |
698 | | |
699 | | // Fold each component of the vector. |
700 | 6.78k | for (uint32_t i = 0; i < a_components.size(); ++i) { |
701 | 5.28k | results_components.push_back(scalar_rule(vector_type->element_type(), |
702 | 5.28k | a_components[i], const_mgr)); |
703 | 5.28k | if (results_components[i] == nullptr) { |
704 | 0 | return nullptr; |
705 | 0 | } |
706 | 5.28k | } |
707 | | |
708 | | // Build the constant object and return it. |
709 | 1.50k | std::vector<uint32_t> ids; |
710 | 5.28k | for (const analysis::Constant* member : results_components) { |
711 | 5.28k | ids.push_back(const_mgr->GetDefiningInstruction(member)->result_id()); |
712 | 5.28k | } |
713 | 1.50k | return const_mgr->GetConstant(vector_type, ids); |
714 | 334k | } else { |
715 | 334k | return scalar_rule(result_type, arg, const_mgr); |
716 | 334k | } |
717 | 335k | }; |
718 | 250k | } |
719 | | |
720 | | // Returns a |ConstantFoldingRule| that folds binary scalar ops |
721 | | // using |scalar_rule| and binary vectors ops by applying |
722 | | // |scalar_rule| to the elements of the vector. The folding rule assumes that op |
723 | | // has two inputs. For regular instruction, those are in operands 0 and 1. For |
724 | | // extended instruction, they are in operands 1 and 2. If an element in |
725 | | // |constants| is not nullprt, then the constant's type is |Float|, |Integer|, |
726 | | // or |Vector| whose element type is |Float| or |Integer|. |
727 | 128k | ConstantFoldingRule FoldBinaryOp(BinaryScalarFoldingRule scalar_rule) { |
728 | 128k | return [scalar_rule](IRContext* context, Instruction* inst, |
729 | 128k | const std::vector<const analysis::Constant*>& constants) |
730 | 1.08M | -> const analysis::Constant* { |
731 | 1.08M | assert(constants.size() == inst->NumInOperands()); |
732 | 1.08M | assert(constants.size() == (inst->opcode() == spv::Op::OpExtInst ? 3 : 2)); |
733 | 1.08M | analysis::ConstantManager* const_mgr = context->get_constant_mgr(); |
734 | 1.08M | analysis::TypeManager* type_mgr = context->get_type_mgr(); |
735 | 1.08M | const analysis::Type* result_type = type_mgr->GetType(inst->type_id()); |
736 | 1.08M | const analysis::Vector* vector_type = result_type->AsVector(); |
737 | | |
738 | 1.08M | const analysis::Constant* arg1 = |
739 | 1.08M | (inst->opcode() == spv::Op::OpExtInst) ? constants[1] : constants[0]; |
740 | 1.08M | const analysis::Constant* arg2 = |
741 | 1.08M | (inst->opcode() == spv::Op::OpExtInst) ? constants[2] : constants[1]; |
742 | | |
743 | 1.08M | if (arg1 == nullptr || arg2 == nullptr) { |
744 | 222k | return nullptr; |
745 | 222k | } |
746 | | |
747 | 857k | if (vector_type == nullptr) { |
748 | 857k | return scalar_rule(result_type, arg1, arg2, const_mgr); |
749 | 857k | } |
750 | | |
751 | 811 | std::vector<const analysis::Constant*> a_components; |
752 | 811 | std::vector<const analysis::Constant*> b_components; |
753 | 811 | std::vector<const analysis::Constant*> results_components; |
754 | | |
755 | 811 | a_components = arg1->GetVectorComponents(const_mgr); |
756 | 811 | b_components = arg2->GetVectorComponents(const_mgr); |
757 | 811 | assert(a_components.size() == b_components.size()); |
758 | | |
759 | | // Fold each component of the vector. |
760 | 4.03k | for (uint32_t i = 0; i < a_components.size(); ++i) { |
761 | 3.22k | results_components.push_back(scalar_rule(vector_type->element_type(), |
762 | 3.22k | a_components[i], b_components[i], |
763 | 3.22k | const_mgr)); |
764 | 3.22k | if (results_components[i] == nullptr) { |
765 | 0 | return nullptr; |
766 | 0 | } |
767 | 3.22k | } |
768 | | |
769 | | // Build the constant object and return it. |
770 | 811 | std::vector<uint32_t> ids; |
771 | 3.22k | for (const analysis::Constant* member : results_components) { |
772 | 3.22k | ids.push_back(const_mgr->GetDefiningInstruction(member)->result_id()); |
773 | 3.22k | } |
774 | 811 | return const_mgr->GetConstant(vector_type, ids); |
775 | 811 | }; |
776 | 128k | } |
777 | | |
778 | | // Returns a |ConstantFoldingRule| that folds unary floating point scalar ops |
779 | | // using |scalar_rule| and unary float point vectors ops by applying |
780 | | // |scalar_rule| to the elements of the vector. The |ConstantFoldingRule| |
781 | | // that is returned assumes that |constants| contains 1 entry. If they are |
782 | | // not |nullptr|, then their type is either |Float| or |Integer| or a |Vector| |
783 | | // whose element type is |Float| or |Integer|. |
784 | 202k | ConstantFoldingRule FoldFPUnaryOp(UnaryScalarFoldingRule scalar_rule) { |
785 | 202k | auto folding_rule = FoldUnaryOp(scalar_rule); |
786 | 202k | return [folding_rule](IRContext* context, Instruction* inst, |
787 | 202k | const std::vector<const analysis::Constant*>& constants) |
788 | 448k | -> const analysis::Constant* { |
789 | 448k | if (!inst->IsFloatingPointFoldingAllowed()) { |
790 | 72 | return nullptr; |
791 | 72 | } |
792 | | |
793 | 447k | return folding_rule(context, inst, constants); |
794 | 448k | }; |
795 | 202k | } |
796 | | |
797 | | // Returns the result of folding the constants in |constants| according the |
798 | | // |scalar_rule|. If |result_type| is a vector, then |scalar_rule| is applied |
799 | | // per component. |
800 | | const analysis::Constant* FoldFPBinaryOp( |
801 | | BinaryScalarFoldingRule scalar_rule, uint32_t result_type_id, |
802 | | const std::vector<const analysis::Constant*>& constants, |
803 | 1.47M | IRContext* context) { |
804 | 1.47M | analysis::ConstantManager* const_mgr = context->get_constant_mgr(); |
805 | 1.47M | analysis::TypeManager* type_mgr = context->get_type_mgr(); |
806 | 1.47M | const analysis::Type* result_type = type_mgr->GetType(result_type_id); |
807 | 1.47M | const analysis::Vector* vector_type = result_type->AsVector(); |
808 | | |
809 | 1.47M | if (constants[0] == nullptr || constants[1] == nullptr) { |
810 | 1.14M | return nullptr; |
811 | 1.14M | } |
812 | | |
813 | 338k | if (vector_type != nullptr) { |
814 | 60.4k | std::vector<const analysis::Constant*> a_components; |
815 | 60.4k | std::vector<const analysis::Constant*> b_components; |
816 | 60.4k | std::vector<const analysis::Constant*> results_components; |
817 | | |
818 | 60.4k | a_components = constants[0]->GetVectorComponents(const_mgr); |
819 | 60.4k | b_components = constants[1]->GetVectorComponents(const_mgr); |
820 | | |
821 | | // Fold each component of the vector. |
822 | 280k | for (uint32_t i = 0; i < a_components.size(); ++i) { |
823 | 220k | results_components.push_back(scalar_rule(vector_type->element_type(), |
824 | 220k | a_components[i], b_components[i], |
825 | 220k | const_mgr)); |
826 | 220k | if (results_components[i] == nullptr) { |
827 | 0 | return nullptr; |
828 | 0 | } |
829 | 220k | } |
830 | | |
831 | | // Build the constant object and return it. |
832 | 60.4k | std::vector<uint32_t> ids; |
833 | 220k | for (const analysis::Constant* member : results_components) { |
834 | 220k | Instruction* def = const_mgr->GetDefiningInstruction(member); |
835 | 220k | if (!def) return nullptr; |
836 | 220k | ids.push_back(def->result_id()); |
837 | 220k | } |
838 | 60.4k | return const_mgr->GetConstant(vector_type, ids); |
839 | 277k | } else { |
840 | 277k | return scalar_rule(result_type, constants[0], constants[1], const_mgr); |
841 | 277k | } |
842 | 338k | } |
843 | | |
844 | | // Returns a |ConstantFoldingRule| that folds floating point scalars using |
845 | | // |scalar_rule| and vectors of floating point by applying |scalar_rule| to the |
846 | | // elements of the vector. The |ConstantFoldingRule| that is returned assumes |
847 | | // that |constants| contains 2 entries. If they are not |nullptr|, then their |
848 | | // type is either |Float| or a |Vector| whose element type is |Float|. |
849 | 353k | ConstantFoldingRule FoldFPBinaryOp(BinaryScalarFoldingRule scalar_rule) { |
850 | 353k | return [scalar_rule](IRContext* context, Instruction* inst, |
851 | 353k | const std::vector<const analysis::Constant*>& constants) |
852 | 1.47M | -> const analysis::Constant* { |
853 | 1.47M | if (!inst->IsFloatingPointFoldingAllowed()) { |
854 | 4.58k | return nullptr; |
855 | 4.58k | } |
856 | 1.47M | if (inst->opcode() == spv::Op::OpExtInst) { |
857 | 16.0k | return FoldFPBinaryOp(scalar_rule, inst->type_id(), |
858 | 16.0k | {constants[1], constants[2]}, context); |
859 | 16.0k | } |
860 | 1.45M | return FoldFPBinaryOp(scalar_rule, inst->type_id(), constants, context); |
861 | 1.47M | }; |
862 | 353k | } |
863 | | |
864 | | // This macro defines a |UnaryScalarFoldingRule| that performs float to |
865 | | // integer conversion. |
866 | | // TODO(greg-lunarg): Support for 64-bit integer types. |
867 | 32.1k | UnaryScalarFoldingRule FoldFToIOp() { |
868 | 32.1k | return [](const analysis::Type* result_type, const analysis::Constant* a, |
869 | 32.1k | analysis::ConstantManager* const_mgr) -> const analysis::Constant* { |
870 | 1.70k | assert(result_type != nullptr && a != nullptr); |
871 | 1.70k | const analysis::Integer* integer_type = result_type->AsInteger(); |
872 | 1.70k | const analysis::Float* float_type = a->type()->AsFloat(); |
873 | 1.70k | assert(float_type != nullptr); |
874 | 1.70k | assert(integer_type != nullptr); |
875 | 1.70k | if (integer_type->width() != 32) return nullptr; |
876 | 1.70k | if (float_type->width() == 32) { |
877 | 1.70k | float fa = a->GetFloat(); |
878 | 1.70k | uint32_t result = integer_type->IsSigned() |
879 | 1.70k | ? static_cast<uint32_t>(static_cast<int32_t>(fa)) |
880 | 1.70k | : static_cast<uint32_t>(fa); |
881 | 1.70k | std::vector<uint32_t> words = {result}; |
882 | 1.70k | return const_mgr->GetConstant(result_type, words); |
883 | 1.70k | } else if (float_type->width() == 64) { |
884 | 0 | double fa = a->GetDouble(); |
885 | 0 | uint32_t result = integer_type->IsSigned() |
886 | 0 | ? static_cast<uint32_t>(static_cast<int32_t>(fa)) |
887 | 0 | : static_cast<uint32_t>(fa); |
888 | 0 | std::vector<uint32_t> words = {result}; |
889 | 0 | return const_mgr->GetConstant(result_type, words); |
890 | 0 | } |
891 | 0 | return nullptr; |
892 | 1.70k | }; |
893 | 32.1k | } |
894 | | |
895 | | // This function defines a |UnaryScalarFoldingRule| that performs integer to |
896 | | // float conversion. |
897 | | // TODO(greg-lunarg): Support for 64-bit integer types. |
898 | 32.1k | UnaryScalarFoldingRule FoldIToFOp() { |
899 | 32.1k | return [](const analysis::Type* result_type, const analysis::Constant* a, |
900 | 329k | analysis::ConstantManager* const_mgr) -> const analysis::Constant* { |
901 | 329k | assert(result_type != nullptr && a != nullptr); |
902 | 329k | const analysis::Integer* integer_type = a->type()->AsInteger(); |
903 | 329k | const analysis::Float* float_type = result_type->AsFloat(); |
904 | 329k | assert(float_type != nullptr); |
905 | 329k | assert(integer_type != nullptr); |
906 | 329k | if (integer_type->width() != 32) return nullptr; |
907 | 329k | uint32_t ua = a->GetU32(); |
908 | 329k | if (float_type->width() == 32) { |
909 | 329k | float result_val = integer_type->IsSigned() |
910 | 329k | ? static_cast<float>(static_cast<int32_t>(ua)) |
911 | 329k | : static_cast<float>(ua); |
912 | 329k | utils::FloatProxy<float> result(result_val); |
913 | 329k | std::vector<uint32_t> words = {result.data()}; |
914 | 329k | return const_mgr->GetConstant(result_type, words); |
915 | 329k | } else if (float_type->width() == 64) { |
916 | 0 | double result_val = integer_type->IsSigned() |
917 | 0 | ? static_cast<double>(static_cast<int32_t>(ua)) |
918 | 0 | : static_cast<double>(ua); |
919 | 0 | utils::FloatProxy<double> result(result_val); |
920 | 0 | std::vector<uint32_t> words = result.GetWords(); |
921 | 0 | return const_mgr->GetConstant(result_type, words); |
922 | 0 | } |
923 | 0 | return nullptr; |
924 | 329k | }; |
925 | 32.1k | } |
926 | | |
927 | | // This defines a |UnaryScalarFoldingRule| that performs |OpQuantizeToF16|. |
928 | 16.0k | UnaryScalarFoldingRule FoldQuantizeToF16Scalar() { |
929 | 16.0k | return [](const analysis::Type* result_type, const analysis::Constant* a, |
930 | 16.0k | analysis::ConstantManager* const_mgr) -> const analysis::Constant* { |
931 | 1.85k | assert(result_type != nullptr && a != nullptr); |
932 | 1.85k | const analysis::Float* float_type = a->type()->AsFloat(); |
933 | 1.85k | assert(float_type != nullptr); |
934 | 1.85k | if (float_type->width() != 32) { |
935 | 0 | return nullptr; |
936 | 0 | } |
937 | | |
938 | 1.85k | float fa = a->GetFloat(); |
939 | 1.85k | utils::HexFloat<utils::FloatProxy<float>> orignal(fa); |
940 | 1.85k | utils::HexFloat<utils::FloatProxy<utils::Float16>> quantized(0); |
941 | 1.85k | utils::HexFloat<utils::FloatProxy<float>> result(0.0f); |
942 | 1.85k | orignal.castTo(quantized, utils::round_direction::kToZero); |
943 | 1.85k | quantized.castTo(result, utils::round_direction::kToZero); |
944 | 1.85k | std::vector<uint32_t> words = {result.getBits()}; |
945 | 1.85k | return const_mgr->GetConstant(result_type, words); |
946 | 1.85k | }; |
947 | 16.0k | } |
948 | | |
949 | | // This macro defines a |BinaryScalarFoldingRule| that applies |op|. The |
950 | | // operator |op| must work for both float and double, and use syntax "f1 op f2". |
951 | | #define FOLD_FPARITH_OP(op) \ |
952 | 68.7k | [](const analysis::Type* result_type_in_macro, const analysis::Constant* a, \ |
953 | 68.7k | const analysis::Constant* b, \ |
954 | 68.7k | analysis::ConstantManager* const_mgr_in_macro) \ |
955 | 465k | -> const analysis::Constant* { \ |
956 | 465k | assert(result_type_in_macro != nullptr && a != nullptr && b != nullptr); \ |
957 | 465k | assert(result_type_in_macro == a->type() && \ |
958 | 465k | result_type_in_macro == b->type()); \ |
959 | 465k | const analysis::Float* float_type_in_macro = \ |
960 | 465k | result_type_in_macro->AsFloat(); \ |
961 | 465k | assert(float_type_in_macro != nullptr); \ |
962 | 465k | if (float_type_in_macro->width() == 32) { \ |
963 | 465k | float fa = a->GetFloat(); \ |
964 | 465k | float fb = b->GetFloat(); \ |
965 | 465k | utils::FloatProxy<float> result_in_macro(fa op fb); \ |
966 | 465k | std::vector<uint32_t> words_in_macro = result_in_macro.GetWords(); \ |
967 | 465k | return const_mgr_in_macro->GetConstant(result_type_in_macro, \ |
968 | 465k | words_in_macro); \ |
969 | 465k | } else if (float_type_in_macro->width() == 64) { \ |
970 | 0 | double fa = a->GetDouble(); \ |
971 | 0 | double fb = b->GetDouble(); \ |
972 | 0 | utils::FloatProxy<double> result_in_macro(fa op fb); \ |
973 | 0 | std::vector<uint32_t> words_in_macro = result_in_macro.GetWords(); \ |
974 | 0 | return const_mgr_in_macro->GetConstant(result_type_in_macro, \ |
975 | 0 | words_in_macro); \ |
976 | 0 | } \ |
977 | 465k | return nullptr; \ |
978 | 465k | } |
979 | | |
980 | | // Define the folding rule for conversion between floating point and integer |
981 | 32.1k | ConstantFoldingRule FoldFToI() { return FoldFPUnaryOp(FoldFToIOp()); } |
982 | 32.1k | ConstantFoldingRule FoldIToF() { return FoldFPUnaryOp(FoldIToFOp()); } |
983 | 16.0k | ConstantFoldingRule FoldQuantizeToF16() { |
984 | 16.0k | return FoldFPUnaryOp(FoldQuantizeToF16Scalar()); |
985 | 16.0k | } |
986 | | |
987 | | // Define the folding rules for subtraction, addition, multiplication, and |
988 | | // division for floating point values. |
989 | 42.0k | ConstantFoldingRule FoldFSub() { return FoldFPBinaryOp(FOLD_FPARITH_OP(-)); } |
990 | 207k | ConstantFoldingRule FoldFAdd() { return FoldFPBinaryOp(FOLD_FPARITH_OP(+)); } |
991 | 251k | ConstantFoldingRule FoldFMul() { return FoldFPBinaryOp(FOLD_FPARITH_OP(*)); } |
992 | | |
993 | | // x - x = 0 |
994 | 32.1k | ConstantFoldingRule FoldRedundantSub() { |
995 | 32.1k | return [](IRContext* context, Instruction* inst, |
996 | 32.1k | const std::vector<const analysis::Constant*>&) |
997 | 187k | -> const analysis::Constant* { |
998 | 187k | assert(inst->opcode() == spv::Op::OpFSub || |
999 | 187k | inst->opcode() == spv::Op::OpISub); |
1000 | | |
1001 | 187k | if (inst->GetSingleWordInOperand(0) == inst->GetSingleWordInOperand(1)) { |
1002 | 2.43k | bool use_float = inst->opcode() == spv::Op::OpFSub; |
1003 | 2.43k | if (use_float && !inst->IsFloatingPointFoldingAllowed()) { |
1004 | 2 | return nullptr; |
1005 | 2 | } |
1006 | 2.43k | analysis::TypeManager* type_mgr = context->get_type_mgr(); |
1007 | 2.43k | const analysis::Type* type = type_mgr->GetType(inst->type_id()); |
1008 | 2.43k | if (type->IsCooperativeMatrix()) { |
1009 | 0 | return nullptr; |
1010 | 0 | } |
1011 | 2.43k | analysis::ConstantManager* const_mgr = context->get_constant_mgr(); |
1012 | 2.43k | uint32_t null_id = const_mgr->GetNullConstId(type); |
1013 | 2.43k | return const_mgr->FindDeclaredConstant(null_id); |
1014 | 2.43k | } |
1015 | 185k | return nullptr; |
1016 | 187k | }; |
1017 | 32.1k | } |
1018 | | |
1019 | | // Returns the constant that results from evaluating |numerator| / 0.0. Returns |
1020 | | // |nullptr| if the result could not be evaluated. |
1021 | | const analysis::Constant* FoldFPScalarDivideByZero( |
1022 | | const analysis::Type* result_type, const analysis::Constant* numerator, |
1023 | 2.48k | analysis::ConstantManager* const_mgr) { |
1024 | 2.48k | if (numerator == nullptr) { |
1025 | 0 | return nullptr; |
1026 | 0 | } |
1027 | | |
1028 | 2.48k | if (numerator->IsZero()) { |
1029 | 1.41k | return GetNan(result_type, const_mgr); |
1030 | 1.41k | } |
1031 | | |
1032 | 1.06k | const analysis::Constant* result = GetInf(result_type, const_mgr); |
1033 | 1.06k | if (result == nullptr) { |
1034 | 0 | return nullptr; |
1035 | 0 | } |
1036 | | |
1037 | 1.06k | if (numerator->AsFloatConstant()->GetValueAsDouble() < 0.0) { |
1038 | 233 | result = NegateFPConst(result_type, result, const_mgr); |
1039 | 233 | } |
1040 | 1.06k | return result; |
1041 | 1.06k | } |
1042 | | |
1043 | | // Returns the result of folding |numerator| / |denominator|. Returns |nullptr| |
1044 | | // if it cannot be folded. |
1045 | | const analysis::Constant* FoldScalarFPDivide( |
1046 | | const analysis::Type* result_type, const analysis::Constant* numerator, |
1047 | | const analysis::Constant* denominator, |
1048 | 12.1k | analysis::ConstantManager* const_mgr) { |
1049 | 12.1k | if (denominator == nullptr) { |
1050 | 0 | return nullptr; |
1051 | 0 | } |
1052 | | |
1053 | 12.1k | if (denominator->IsZero()) { |
1054 | 2.13k | return FoldFPScalarDivideByZero(result_type, numerator, const_mgr); |
1055 | 2.13k | } |
1056 | | |
1057 | 10.0k | uint32_t width = denominator->type()->AsFloat()->width(); |
1058 | 10.0k | if (width != 32 && width != 64) { |
1059 | 0 | return nullptr; |
1060 | 0 | } |
1061 | | |
1062 | 10.0k | const analysis::FloatConstant* denominator_float = |
1063 | 10.0k | denominator->AsFloatConstant(); |
1064 | 10.0k | if (denominator_float && denominator->GetValueAsDouble() == -0.0) { |
1065 | 342 | const analysis::Constant* result = |
1066 | 342 | FoldFPScalarDivideByZero(result_type, numerator, const_mgr); |
1067 | 342 | if (result != nullptr) |
1068 | 342 | result = NegateFPConst(result_type, result, const_mgr); |
1069 | 342 | return result; |
1070 | 9.70k | } else { |
1071 | 19.4k | return FOLD_FPARITH_OP(/)(result_type, numerator, denominator, const_mgr); |
1072 | 9.70k | } |
1073 | 10.0k | } |
1074 | | |
1075 | | // Returns the constant folding rule to fold |OpFDiv| with two constants. |
1076 | 16.0k | ConstantFoldingRule FoldFDiv() { return FoldFPBinaryOp(FoldScalarFPDivide); } |
1077 | | |
1078 | | // Get a singular uniform value, which is repeated when the |type| is a vector. |
1079 | | const analysis::Constant* GetConstantUniformValue( |
1080 | | analysis::ConstantManager* const_mgr, const analysis::Type* type, |
1081 | 1.75k | std::optional<double> f = {}, std::optional<uint64_t> i = {}) { |
1082 | 1.75k | const analysis::Constant* uniform = nullptr; |
1083 | 1.75k | bool is_vector = false; |
1084 | 1.75k | const analysis::Type* base_type = type; |
1085 | | |
1086 | 1.75k | if (base_type->AsVector()) { |
1087 | 1.57k | is_vector = true; |
1088 | 1.57k | base_type = base_type->AsVector()->element_type(); |
1089 | 1.57k | } |
1090 | | |
1091 | 1.75k | if (f && base_type->AsFloat()) { |
1092 | 1.68k | if (base_type->AsFloat()->width() == 32) { |
1093 | 1.68k | uniform = const_mgr->GetConstant( |
1094 | 1.68k | base_type, utils::FloatProxy<float>((float)f.value()).GetWords()); |
1095 | 1.68k | } else if (base_type->AsFloat()->width() == 64) { |
1096 | 0 | uniform = const_mgr->GetConstant( |
1097 | 0 | base_type, utils::FloatProxy<double>(f.value()).GetWords()); |
1098 | 0 | } |
1099 | 1.68k | } else if (i && base_type->AsInteger()) { |
1100 | 73 | uniform = |
1101 | 73 | const_mgr->GenerateIntegerConstant(base_type->AsInteger(), i.value()); |
1102 | 73 | } |
1103 | | |
1104 | 1.75k | if (!uniform) { |
1105 | 0 | return nullptr; |
1106 | 0 | } |
1107 | | |
1108 | 1.75k | if (is_vector) { |
1109 | 1.57k | Instruction* uniform_inst = const_mgr->GetDefiningInstruction(uniform); |
1110 | 1.57k | if (!uniform_inst) return nullptr; |
1111 | | |
1112 | 1.57k | uint32_t uniform_id = uniform_inst->result_id(); |
1113 | 1.57k | uniform = |
1114 | 1.57k | const_mgr->GetConstant(type, std::vector<uint32_t>(4, uniform_id)); |
1115 | 1.57k | } |
1116 | | |
1117 | 1.75k | return uniform; |
1118 | 1.75k | } |
1119 | | |
1120 | | // x / x = 1 |
1121 | | // -x / x = -1 |
1122 | | // x / -x = -1 |
1123 | 48.2k | ConstantFoldingRule FoldRedundantDiv() { |
1124 | 48.2k | return [](IRContext* context, Instruction* inst, |
1125 | 48.2k | const std::vector<const analysis::Constant*>& constants) |
1126 | 177k | -> const analysis::Constant* { |
1127 | 177k | assert(inst->opcode() == spv::Op::OpFDiv || |
1128 | 177k | inst->opcode() == spv::Op::OpSDiv || |
1129 | 177k | inst->opcode() == spv::Op::OpUDiv); |
1130 | | |
1131 | 177k | if (constants[0] || constants[1]) { |
1132 | 138k | return nullptr; |
1133 | 138k | } |
1134 | | |
1135 | 38.6k | analysis::TypeManager* type_mgr = context->get_type_mgr(); |
1136 | 38.6k | const analysis::Type* type = type_mgr->GetType(inst->type_id()); |
1137 | | |
1138 | 38.6k | if (type->IsCooperativeMatrix()) { |
1139 | 0 | return nullptr; |
1140 | 0 | } |
1141 | | |
1142 | 38.6k | bool use_float = inst->opcode() == spv::Op::OpFDiv; |
1143 | 38.6k | if (use_float && !inst->IsFloatingPointFoldingAllowed()) { |
1144 | 30 | return nullptr; |
1145 | 30 | } |
1146 | | |
1147 | 38.6k | analysis::ConstantManager* const_mgr = context->get_constant_mgr(); |
1148 | | |
1149 | 38.6k | if (inst->GetSingleWordInOperand(0) == inst->GetSingleWordInOperand(1)) { |
1150 | 1.75k | return GetConstantUniformValue(const_mgr, type, 1.0, 1); |
1151 | 1.75k | } |
1152 | | |
1153 | 36.8k | if (inst->opcode() == spv::Op::OpUDiv) { |
1154 | 40 | return nullptr; |
1155 | 40 | } |
1156 | | |
1157 | 36.8k | analysis::DefUseManager* def_use_mgr = context->get_def_use_mgr(); |
1158 | | |
1159 | 36.8k | Instruction* lhs = def_use_mgr->GetDef(inst->GetSingleWordInOperand(0)); |
1160 | 36.8k | if ((lhs->opcode() == spv::Op::OpSNegate || |
1161 | 36.8k | lhs->opcode() == spv::Op::OpFNegate) && |
1162 | 37 | lhs->GetSingleWordInOperand(0) == inst->GetSingleWordInOperand(1) && |
1163 | 1 | (!use_float || lhs->IsFloatingPointFoldingAllowed())) { |
1164 | 1 | return GetConstantUniformValue(const_mgr, type, -1.0, UINT64_MAX); |
1165 | 1 | } |
1166 | | |
1167 | 36.8k | Instruction* rhs = def_use_mgr->GetDef(inst->GetSingleWordInOperand(1)); |
1168 | 36.8k | if ((rhs->opcode() == spv::Op::OpSNegate || |
1169 | 36.8k | rhs->opcode() == spv::Op::OpFNegate) && |
1170 | 32 | rhs->GetSingleWordInOperand(0) == inst->GetSingleWordInOperand(0) && |
1171 | 2 | (!use_float || rhs->IsFloatingPointFoldingAllowed())) { |
1172 | 2 | return GetConstantUniformValue(const_mgr, type, -1.0, UINT64_MAX); |
1173 | 2 | } |
1174 | | |
1175 | 36.8k | return nullptr; |
1176 | 36.8k | }; |
1177 | 48.2k | } |
1178 | | |
1179 | | bool CompareFloatingPoint(bool op_result, bool op_unordered, |
1180 | 14.8k | bool need_ordered) { |
1181 | 14.8k | if (need_ordered) { |
1182 | | // operands are ordered and Operand 1 is |op| Operand 2 |
1183 | 6.02k | return !op_unordered && op_result; |
1184 | 8.83k | } else { |
1185 | | // operands are unordered or Operand 1 is |op| Operand 2 |
1186 | 8.83k | return op_unordered || op_result; |
1187 | 8.83k | } |
1188 | 14.8k | } |
1189 | | |
1190 | | // This macro defines a |BinaryScalarFoldingRule| that applies |op|. The |
1191 | | // operator |op| must work for both float and double, and use syntax "f1 op f2". |
1192 | | #define FOLD_FPCMP_OP(op, ord) \ |
1193 | 193k | [](const analysis::Type* result_type, const analysis::Constant* a, \ |
1194 | 193k | const analysis::Constant* b, \ |
1195 | 193k | analysis::ConstantManager* const_mgr) -> const analysis::Constant* { \ |
1196 | 14.8k | assert(result_type != nullptr && a != nullptr && b != nullptr); \ |
1197 | 14.8k | assert(result_type->AsBool()); \ |
1198 | 14.8k | assert(a->type() == b->type()); \ |
1199 | 14.8k | const analysis::Float* float_type = a->type()->AsFloat(); \ |
1200 | 14.8k | assert(float_type != nullptr); \ |
1201 | 14.8k | if (float_type->width() == 32) { \ |
1202 | 14.8k | float fa = a->GetFloat(); \ |
1203 | 14.8k | float fb = b->GetFloat(); \ |
1204 | 14.8k | bool result = CompareFloatingPoint( \ |
1205 | 14.8k | fa op fb, std::isnan(fa) || std::isnan(fb), ord); \ |
1206 | 14.8k | std::vector<uint32_t> words = {uint32_t(result)}; \ |
1207 | 14.8k | return const_mgr->GetConstant(result_type, words); \ |
1208 | 14.8k | } else if (float_type->width() == 64) { \ |
1209 | 0 | double fa = a->GetDouble(); \ |
1210 | 0 | double fb = b->GetDouble(); \ |
1211 | 0 | bool result = CompareFloatingPoint( \ |
1212 | 0 | fa op fb, std::isnan(fa) || std::isnan(fb), ord); \ |
1213 | 0 | std::vector<uint32_t> words = {uint32_t(result)}; \ |
1214 | 0 | return const_mgr->GetConstant(result_type, words); \ |
1215 | 0 | } \ |
1216 | 14.8k | return nullptr; \ |
1217 | 14.8k | } |
1218 | | |
1219 | | // Define the folding rules for ordered and unordered comparison for floating |
1220 | | // point values. |
1221 | 16.0k | ConstantFoldingRule FoldFOrdEqual() { |
1222 | 16.8k | return FoldFPBinaryOp(FOLD_FPCMP_OP(==, true)); |
1223 | 16.0k | } |
1224 | 16.0k | ConstantFoldingRule FoldFUnordEqual() { |
1225 | 16.5k | return FoldFPBinaryOp(FOLD_FPCMP_OP(==, false)); |
1226 | 16.0k | } |
1227 | 16.0k | ConstantFoldingRule FoldFOrdNotEqual() { |
1228 | 16.8k | return FoldFPBinaryOp(FOLD_FPCMP_OP(!=, true)); |
1229 | 16.0k | } |
1230 | 16.0k | ConstantFoldingRule FoldFUnordNotEqual() { |
1231 | 17.6k | return FoldFPBinaryOp(FOLD_FPCMP_OP(!=, false)); |
1232 | 16.0k | } |
1233 | 16.0k | ConstantFoldingRule FoldFOrdLessThan() { |
1234 | 18.2k | return FoldFPBinaryOp(FOLD_FPCMP_OP(<, true)); |
1235 | 16.0k | } |
1236 | 16.0k | ConstantFoldingRule FoldFUnordLessThan() { |
1237 | 17.7k | return FoldFPBinaryOp(FOLD_FPCMP_OP(<, false)); |
1238 | 16.0k | } |
1239 | 16.0k | ConstantFoldingRule FoldFOrdGreaterThan() { |
1240 | 16.6k | return FoldFPBinaryOp(FOLD_FPCMP_OP(>, true)); |
1241 | 16.0k | } |
1242 | 16.0k | ConstantFoldingRule FoldFUnordGreaterThan() { |
1243 | 17.0k | return FoldFPBinaryOp(FOLD_FPCMP_OP(>, false)); |
1244 | 16.0k | } |
1245 | 16.0k | ConstantFoldingRule FoldFOrdLessThanEqual() { |
1246 | 16.9k | return FoldFPBinaryOp(FOLD_FPCMP_OP(<=, true)); |
1247 | 16.0k | } |
1248 | 16.0k | ConstantFoldingRule FoldFUnordLessThanEqual() { |
1249 | 17.3k | return FoldFPBinaryOp(FOLD_FPCMP_OP(<=, false)); |
1250 | 16.0k | } |
1251 | 16.0k | ConstantFoldingRule FoldFOrdGreaterThanEqual() { |
1252 | 17.0k | return FoldFPBinaryOp(FOLD_FPCMP_OP(>=, true)); |
1253 | 16.0k | } |
1254 | 16.0k | ConstantFoldingRule FoldFUnordGreaterThanEqual() { |
1255 | 18.9k | return FoldFPBinaryOp(FOLD_FPCMP_OP(>=, false)); |
1256 | 16.0k | } |
1257 | | |
1258 | 16.0k | ConstantFoldingRule FoldInvariantSelect() { |
1259 | 16.0k | return [](IRContext*, Instruction* inst, |
1260 | 16.0k | const std::vector<const analysis::Constant*>& constants) |
1261 | 23.2k | -> const analysis::Constant* { |
1262 | 23.2k | assert(inst->opcode() == spv::Op::OpSelect); |
1263 | 23.2k | (void)inst; |
1264 | | |
1265 | 23.2k | if (!constants[1] || !constants[2]) { |
1266 | 11.8k | return nullptr; |
1267 | 11.8k | } |
1268 | 11.4k | if (constants[1] == constants[2]) { |
1269 | 1.45k | return constants[1]; |
1270 | 1.45k | } |
1271 | 9.96k | if (constants[1]->IsZero() && constants[2]->IsZero()) { |
1272 | 4 | return constants[1]; |
1273 | 4 | } |
1274 | 9.96k | return nullptr; |
1275 | 9.96k | }; |
1276 | 16.0k | } |
1277 | | |
1278 | | // Folds an OpDot where all of the inputs are constants to a |
1279 | | // constant. A new constant is created if necessary. |
1280 | 16.0k | ConstantFoldingRule FoldOpDotWithConstants() { |
1281 | 16.0k | return [](IRContext* context, Instruction* inst, |
1282 | 16.0k | const std::vector<const analysis::Constant*>& constants) |
1283 | 16.0k | -> const analysis::Constant* { |
1284 | 77 | analysis::ConstantManager* const_mgr = context->get_constant_mgr(); |
1285 | 77 | analysis::TypeManager* type_mgr = context->get_type_mgr(); |
1286 | 77 | const analysis::Type* new_type = type_mgr->GetType(inst->type_id()); |
1287 | 77 | assert(new_type->AsFloat() && "OpDot should have a float return type."); |
1288 | 77 | const analysis::Float* float_type = new_type->AsFloat(); |
1289 | | |
1290 | 77 | if (!inst->IsFloatingPointFoldingAllowed()) { |
1291 | 0 | return nullptr; |
1292 | 0 | } |
1293 | | |
1294 | | // If one of the operands is 0, then the result is 0. |
1295 | 77 | bool has_zero_operand = false; |
1296 | | |
1297 | 227 | for (int i = 0; i < 2; ++i) { |
1298 | 152 | if (constants[i]) { |
1299 | 67 | if (constants[i]->AsNullConstant() || |
1300 | 67 | constants[i]->AsVectorConstant()->IsZero()) { |
1301 | 2 | has_zero_operand = true; |
1302 | 2 | break; |
1303 | 2 | } |
1304 | 67 | } |
1305 | 152 | } |
1306 | | |
1307 | 77 | if (has_zero_operand) { |
1308 | 2 | if (float_type->width() == 32) { |
1309 | 2 | utils::FloatProxy<float> result(0.0f); |
1310 | 2 | std::vector<uint32_t> words = result.GetWords(); |
1311 | 2 | return const_mgr->GetConstant(float_type, words); |
1312 | 2 | } |
1313 | 0 | if (float_type->width() == 64) { |
1314 | 0 | utils::FloatProxy<double> result(0.0); |
1315 | 0 | std::vector<uint32_t> words = result.GetWords(); |
1316 | 0 | return const_mgr->GetConstant(float_type, words); |
1317 | 0 | } |
1318 | 0 | return nullptr; |
1319 | 0 | } |
1320 | | |
1321 | 75 | if (constants[0] == nullptr || constants[1] == nullptr) { |
1322 | 75 | return nullptr; |
1323 | 75 | } |
1324 | | |
1325 | 0 | std::vector<const analysis::Constant*> a_components; |
1326 | 0 | std::vector<const analysis::Constant*> b_components; |
1327 | |
|
1328 | 0 | a_components = constants[0]->GetVectorComponents(const_mgr); |
1329 | 0 | b_components = constants[1]->GetVectorComponents(const_mgr); |
1330 | |
|
1331 | 0 | utils::FloatProxy<double> result(0.0); |
1332 | 0 | std::vector<uint32_t> words = result.GetWords(); |
1333 | 0 | const analysis::Constant* result_const = |
1334 | 0 | const_mgr->GetConstant(float_type, words); |
1335 | 0 | for (uint32_t i = 0; i < a_components.size() && result_const != nullptr; |
1336 | 0 | ++i) { |
1337 | 0 | if (a_components[i] == nullptr || b_components[i] == nullptr) { |
1338 | 0 | return nullptr; |
1339 | 0 | } |
1340 | | |
1341 | 0 | const analysis::Constant* component = FOLD_FPARITH_OP(*)( |
1342 | 0 | new_type, a_components[i], b_components[i], const_mgr); |
1343 | 0 | if (component == nullptr) { |
1344 | 0 | return nullptr; |
1345 | 0 | } |
1346 | 0 | result_const = |
1347 | 0 | FOLD_FPARITH_OP(+)(new_type, result_const, component, const_mgr); |
1348 | 0 | } |
1349 | 0 | return result_const; |
1350 | 0 | }; |
1351 | 16.0k | } |
1352 | | |
1353 | 16.0k | ConstantFoldingRule FoldFNegate() { return FoldFPUnaryOp(NegateFPConst); } |
1354 | 16.0k | ConstantFoldingRule FoldSNegate() { return FoldUnaryOp(NegateIntConst); } |
1355 | | |
1356 | 128k | ConstantFoldingRule FoldFClampFeedingCompare(spv::Op cmp_opcode) { |
1357 | 128k | return [cmp_opcode](IRContext* context, Instruction* inst, |
1358 | 128k | const std::vector<const analysis::Constant*>& constants) |
1359 | 128k | -> const analysis::Constant* { |
1360 | 102k | analysis::ConstantManager* const_mgr = context->get_constant_mgr(); |
1361 | 102k | analysis::DefUseManager* def_use_mgr = context->get_def_use_mgr(); |
1362 | | |
1363 | 102k | if (!inst->IsFloatingPointFoldingAllowed()) { |
1364 | 36 | return nullptr; |
1365 | 36 | } |
1366 | | |
1367 | 102k | uint32_t non_const_idx = (constants[0] ? 1 : 0); |
1368 | 102k | uint32_t operand_id = inst->GetSingleWordInOperand(non_const_idx); |
1369 | 102k | Instruction* operand_inst = def_use_mgr->GetDef(operand_id); |
1370 | | |
1371 | 102k | analysis::TypeManager* type_mgr = context->get_type_mgr(); |
1372 | 102k | const analysis::Type* operand_type = |
1373 | 102k | type_mgr->GetType(operand_inst->type_id()); |
1374 | | |
1375 | 102k | if (!operand_type->AsFloat()) { |
1376 | 0 | return nullptr; |
1377 | 0 | } |
1378 | | |
1379 | 102k | if (operand_type->AsFloat()->width() != 32 && |
1380 | 0 | operand_type->AsFloat()->width() != 64) { |
1381 | 0 | return nullptr; |
1382 | 0 | } |
1383 | | |
1384 | 102k | if (operand_inst->opcode() != spv::Op::OpExtInst) { |
1385 | 92.3k | return nullptr; |
1386 | 92.3k | } |
1387 | | |
1388 | 10.0k | if (operand_inst->GetSingleWordInOperand(1) != GLSLstd450FClamp) { |
1389 | 10.0k | return nullptr; |
1390 | 10.0k | } |
1391 | | |
1392 | 0 | if (constants[1] == nullptr && constants[0] == nullptr) { |
1393 | 0 | return nullptr; |
1394 | 0 | } |
1395 | | |
1396 | 0 | uint32_t max_id = operand_inst->GetSingleWordInOperand(4); |
1397 | 0 | const analysis::Constant* max_const = |
1398 | 0 | const_mgr->FindDeclaredConstant(max_id); |
1399 | |
|
1400 | 0 | uint32_t min_id = operand_inst->GetSingleWordInOperand(3); |
1401 | 0 | const analysis::Constant* min_const = |
1402 | 0 | const_mgr->FindDeclaredConstant(min_id); |
1403 | |
|
1404 | 0 | bool found_result = false; |
1405 | 0 | bool result = false; |
1406 | |
|
1407 | 0 | switch (cmp_opcode) { |
1408 | 0 | case spv::Op::OpFOrdLessThan: |
1409 | 0 | case spv::Op::OpFUnordLessThan: |
1410 | 0 | case spv::Op::OpFOrdGreaterThanEqual: |
1411 | 0 | case spv::Op::OpFUnordGreaterThanEqual: |
1412 | 0 | if (constants[0]) { |
1413 | 0 | if (min_const) { |
1414 | 0 | if (constants[0]->GetValueAsDouble() < |
1415 | 0 | min_const->GetValueAsDouble()) { |
1416 | 0 | found_result = true; |
1417 | 0 | result = (cmp_opcode == spv::Op::OpFOrdLessThan || |
1418 | 0 | cmp_opcode == spv::Op::OpFUnordLessThan); |
1419 | 0 | } |
1420 | 0 | } |
1421 | 0 | if (max_const) { |
1422 | 0 | if (constants[0]->GetValueAsDouble() >= |
1423 | 0 | max_const->GetValueAsDouble()) { |
1424 | 0 | found_result = true; |
1425 | 0 | result = !(cmp_opcode == spv::Op::OpFOrdLessThan || |
1426 | 0 | cmp_opcode == spv::Op::OpFUnordLessThan); |
1427 | 0 | } |
1428 | 0 | } |
1429 | 0 | } |
1430 | |
|
1431 | 0 | if (constants[1]) { |
1432 | 0 | if (max_const) { |
1433 | 0 | if (max_const->GetValueAsDouble() < |
1434 | 0 | constants[1]->GetValueAsDouble()) { |
1435 | 0 | found_result = true; |
1436 | 0 | result = (cmp_opcode == spv::Op::OpFOrdLessThan || |
1437 | 0 | cmp_opcode == spv::Op::OpFUnordLessThan); |
1438 | 0 | } |
1439 | 0 | } |
1440 | |
|
1441 | 0 | if (min_const) { |
1442 | 0 | if (min_const->GetValueAsDouble() >= |
1443 | 0 | constants[1]->GetValueAsDouble()) { |
1444 | 0 | found_result = true; |
1445 | 0 | result = !(cmp_opcode == spv::Op::OpFOrdLessThan || |
1446 | 0 | cmp_opcode == spv::Op::OpFUnordLessThan); |
1447 | 0 | } |
1448 | 0 | } |
1449 | 0 | } |
1450 | 0 | break; |
1451 | 0 | case spv::Op::OpFOrdGreaterThan: |
1452 | 0 | case spv::Op::OpFUnordGreaterThan: |
1453 | 0 | case spv::Op::OpFOrdLessThanEqual: |
1454 | 0 | case spv::Op::OpFUnordLessThanEqual: |
1455 | 0 | if (constants[0]) { |
1456 | 0 | if (min_const) { |
1457 | 0 | if (constants[0]->GetValueAsDouble() <= |
1458 | 0 | min_const->GetValueAsDouble()) { |
1459 | 0 | found_result = true; |
1460 | 0 | result = (cmp_opcode == spv::Op::OpFOrdLessThanEqual || |
1461 | 0 | cmp_opcode == spv::Op::OpFUnordLessThanEqual); |
1462 | 0 | } |
1463 | 0 | } |
1464 | 0 | if (max_const) { |
1465 | 0 | if (constants[0]->GetValueAsDouble() > |
1466 | 0 | max_const->GetValueAsDouble()) { |
1467 | 0 | found_result = true; |
1468 | 0 | result = !(cmp_opcode == spv::Op::OpFOrdLessThanEqual || |
1469 | 0 | cmp_opcode == spv::Op::OpFUnordLessThanEqual); |
1470 | 0 | } |
1471 | 0 | } |
1472 | 0 | } |
1473 | |
|
1474 | 0 | if (constants[1]) { |
1475 | 0 | if (max_const) { |
1476 | 0 | if (max_const->GetValueAsDouble() <= |
1477 | 0 | constants[1]->GetValueAsDouble()) { |
1478 | 0 | found_result = true; |
1479 | 0 | result = (cmp_opcode == spv::Op::OpFOrdLessThanEqual || |
1480 | 0 | cmp_opcode == spv::Op::OpFUnordLessThanEqual); |
1481 | 0 | } |
1482 | 0 | } |
1483 | |
|
1484 | 0 | if (min_const) { |
1485 | 0 | if (min_const->GetValueAsDouble() > |
1486 | 0 | constants[1]->GetValueAsDouble()) { |
1487 | 0 | found_result = true; |
1488 | 0 | result = !(cmp_opcode == spv::Op::OpFOrdLessThanEqual || |
1489 | 0 | cmp_opcode == spv::Op::OpFUnordLessThanEqual); |
1490 | 0 | } |
1491 | 0 | } |
1492 | 0 | } |
1493 | 0 | break; |
1494 | 0 | default: |
1495 | 0 | return nullptr; |
1496 | 0 | } |
1497 | | |
1498 | 0 | if (!found_result) { |
1499 | 0 | return nullptr; |
1500 | 0 | } |
1501 | | |
1502 | 0 | const analysis::Type* bool_type = |
1503 | 0 | context->get_type_mgr()->GetType(inst->type_id()); |
1504 | 0 | const analysis::Constant* result_const = |
1505 | 0 | const_mgr->GetConstant(bool_type, {static_cast<uint32_t>(result)}); |
1506 | 0 | assert(result_const); |
1507 | 0 | return result_const; |
1508 | 0 | }; |
1509 | 128k | } |
1510 | | |
1511 | 9.59k | ConstantFoldingRule FoldFMix() { |
1512 | 9.59k | return [](IRContext* context, Instruction* inst, |
1513 | 9.59k | const std::vector<const analysis::Constant*>& constants) |
1514 | 12.8k | -> const analysis::Constant* { |
1515 | 12.8k | analysis::ConstantManager* const_mgr = context->get_constant_mgr(); |
1516 | 12.8k | assert(inst->opcode() == spv::Op::OpExtInst && |
1517 | 12.8k | "Expecting an extended instruction."); |
1518 | 12.8k | assert(inst->GetSingleWordInOperand(0) == |
1519 | 12.8k | context->get_feature_mgr()->GetExtInstImportId_GLSLstd450() && |
1520 | 12.8k | "Expecting a GLSLstd450 extended instruction."); |
1521 | 12.8k | assert(inst->GetSingleWordInOperand(1) == GLSLstd450FMix && |
1522 | 12.8k | "Expecting and FMix instruction."); |
1523 | | |
1524 | 12.8k | if (!inst->IsFloatingPointFoldingAllowed()) { |
1525 | 0 | return nullptr; |
1526 | 0 | } |
1527 | | |
1528 | | // Make sure all FMix operands are constants. |
1529 | 14.5k | for (uint32_t i = 1; i < 4; i++) { |
1530 | 14.2k | if (constants[i] == nullptr) { |
1531 | 12.5k | return nullptr; |
1532 | 12.5k | } |
1533 | 14.2k | } |
1534 | | |
1535 | 272 | const analysis::Constant* one; |
1536 | 272 | bool is_vector = false; |
1537 | 272 | const analysis::Type* result_type = constants[1]->type(); |
1538 | 272 | const analysis::Type* base_type = result_type; |
1539 | 272 | if (base_type->AsVector()) { |
1540 | 228 | is_vector = true; |
1541 | 228 | base_type = base_type->AsVector()->element_type(); |
1542 | 228 | } |
1543 | 272 | assert(base_type->AsFloat() != nullptr && |
1544 | 272 | "FMix is suppose to act on floats or vectors of floats."); |
1545 | | |
1546 | 272 | if (base_type->AsFloat()->width() == 32) { |
1547 | 272 | one = const_mgr->GetConstant(base_type, |
1548 | 272 | utils::FloatProxy<float>(1.0f).GetWords()); |
1549 | 272 | } else if (base_type->AsFloat()->width() == 64) { |
1550 | 0 | one = const_mgr->GetConstant(base_type, |
1551 | 0 | utils::FloatProxy<double>(1.0).GetWords()); |
1552 | 0 | } else { |
1553 | | // We won't support folding half types. |
1554 | 0 | return nullptr; |
1555 | 0 | } |
1556 | | |
1557 | 272 | if (is_vector) { |
1558 | 228 | Instruction* one_inst = const_mgr->GetDefiningInstruction(one); |
1559 | 228 | if (one_inst == nullptr) return nullptr; |
1560 | 228 | uint32_t one_id = one_inst->result_id(); |
1561 | 228 | one = |
1562 | 228 | const_mgr->GetConstant(result_type, std::vector<uint32_t>(4, one_id)); |
1563 | 228 | } |
1564 | | |
1565 | 272 | const analysis::Constant* temp1 = FoldFPBinaryOp( |
1566 | 1.00k | FOLD_FPARITH_OP(-), inst->type_id(), {one, constants[3]}, context); |
1567 | 272 | if (temp1 == nullptr) { |
1568 | 0 | return nullptr; |
1569 | 0 | } |
1570 | | |
1571 | 272 | const analysis::Constant* temp2 = FoldFPBinaryOp( |
1572 | 1.00k | FOLD_FPARITH_OP(*), inst->type_id(), {constants[1], temp1}, context); |
1573 | 272 | if (temp2 == nullptr) { |
1574 | 0 | return nullptr; |
1575 | 0 | } |
1576 | 272 | const analysis::Constant* temp3 = |
1577 | 1.00k | FoldFPBinaryOp(FOLD_FPARITH_OP(*), inst->type_id(), |
1578 | 272 | {constants[2], constants[3]}, context); |
1579 | 272 | if (temp3 == nullptr) { |
1580 | 0 | return nullptr; |
1581 | 0 | } |
1582 | 1.00k | return FoldFPBinaryOp(FOLD_FPARITH_OP(+), inst->type_id(), {temp2, temp3}, |
1583 | 272 | context); |
1584 | 272 | }; |
1585 | 9.59k | } |
1586 | | |
1587 | | template <typename FloatType> |
1588 | 15.3k | static bool NegZeroAwareLessThan(FloatType a, FloatType b) { |
1589 | 15.3k | if (a == 0.0 && b == 0.0) { |
1590 | 2.56k | bool sba = std::signbit(a); |
1591 | 2.56k | bool sbb = std::signbit(b); |
1592 | 2.56k | if (sba && !sbb) { |
1593 | 130 | return true; |
1594 | 130 | } |
1595 | 2.56k | } |
1596 | 15.2k | return a < b; |
1597 | 15.3k | } const_folding_rules.cpp:bool spvtools::opt::(anonymous namespace)::NegZeroAwareLessThan<float>(float, float) Line | Count | Source | 1588 | 15.3k | static bool NegZeroAwareLessThan(FloatType a, FloatType b) { | 1589 | 15.3k | if (a == 0.0 && b == 0.0) { | 1590 | 2.56k | bool sba = std::signbit(a); | 1591 | 2.56k | bool sbb = std::signbit(b); | 1592 | 2.56k | if (sba && !sbb) { | 1593 | 130 | return true; | 1594 | 130 | } | 1595 | 2.56k | } | 1596 | 15.2k | return a < b; | 1597 | 15.3k | } |
Unexecuted instantiation: const_folding_rules.cpp:bool spvtools::opt::(anonymous namespace)::NegZeroAwareLessThan<double>(double, double) |
1598 | | |
1599 | | const analysis::Constant* FoldMin(const analysis::Type* result_type, |
1600 | | const analysis::Constant* a, |
1601 | | const analysis::Constant* b, |
1602 | 6.30k | analysis::ConstantManager*) { |
1603 | 6.30k | if (const analysis::Integer* int_type = result_type->AsInteger()) { |
1604 | 87 | if (int_type->width() <= 32) { |
1605 | 87 | assert( |
1606 | 87 | (a->AsIntConstant() != nullptr || a->AsNullConstant() != nullptr) && |
1607 | 87 | "Must be an integer or null constant."); |
1608 | 87 | assert( |
1609 | 87 | (b->AsIntConstant() != nullptr || b->AsNullConstant() != nullptr) && |
1610 | 87 | "Must be an integer or null constant."); |
1611 | | |
1612 | 87 | if (int_type->IsSigned()) { |
1613 | 64 | int32_t va = (a->AsIntConstant() != nullptr) |
1614 | 64 | ? a->AsIntConstant()->GetS32BitValue() |
1615 | 64 | : 0; |
1616 | 64 | int32_t vb = (b->AsIntConstant() != nullptr) |
1617 | 64 | ? b->AsIntConstant()->GetS32BitValue() |
1618 | 64 | : 0; |
1619 | 64 | return (va < vb ? a : b); |
1620 | 64 | } else { |
1621 | 23 | uint32_t va = (a->AsIntConstant() != nullptr) |
1622 | 23 | ? a->AsIntConstant()->GetU32BitValue() |
1623 | 23 | : 0; |
1624 | 23 | uint32_t vb = (b->AsIntConstant() != nullptr) |
1625 | 23 | ? b->AsIntConstant()->GetU32BitValue() |
1626 | 23 | : 0; |
1627 | 23 | return (va < vb ? a : b); |
1628 | 23 | } |
1629 | 87 | } else if (int_type->width() == 64) { |
1630 | 0 | if (int_type->IsSigned()) { |
1631 | 0 | int64_t va = a->GetS64(); |
1632 | 0 | int64_t vb = b->GetS64(); |
1633 | 0 | return (va < vb ? a : b); |
1634 | 0 | } else { |
1635 | 0 | uint64_t va = a->GetU64(); |
1636 | 0 | uint64_t vb = b->GetU64(); |
1637 | 0 | return (va < vb ? a : b); |
1638 | 0 | } |
1639 | 0 | } |
1640 | 6.21k | } else if (const analysis::Float* float_type = result_type->AsFloat()) { |
1641 | 6.21k | if (float_type->width() == 32) { |
1642 | 6.21k | float va = a->GetFloat(); |
1643 | 6.21k | float vb = b->GetFloat(); |
1644 | 6.21k | return NegZeroAwareLessThan(va, vb) ? a : b; |
1645 | 6.21k | } else if (float_type->width() == 64) { |
1646 | 0 | double va = a->GetDouble(); |
1647 | 0 | double vb = b->GetDouble(); |
1648 | 0 | return NegZeroAwareLessThan(va, vb) ? a : b; |
1649 | 0 | } |
1650 | 6.21k | } |
1651 | 0 | return nullptr; |
1652 | 6.30k | } |
1653 | | |
1654 | | const analysis::Constant* FoldMax(const analysis::Type* result_type, |
1655 | | const analysis::Constant* a, |
1656 | | const analysis::Constant* b, |
1657 | 8.03k | analysis::ConstantManager*) { |
1658 | 8.03k | if (const analysis::Integer* int_type = result_type->AsInteger()) { |
1659 | 36 | if (int_type->width() <= 32) { |
1660 | 36 | assert( |
1661 | 36 | (a->AsIntConstant() != nullptr || a->AsNullConstant() != nullptr) && |
1662 | 36 | "Must be an integer or null constant."); |
1663 | 36 | assert( |
1664 | 36 | (b->AsIntConstant() != nullptr || b->AsNullConstant() != nullptr) && |
1665 | 36 | "Must be an integer or null constant."); |
1666 | | |
1667 | 36 | if (int_type->IsSigned()) { |
1668 | 36 | int32_t va = (a->AsIntConstant() != nullptr) |
1669 | 36 | ? a->AsIntConstant()->GetS32BitValue() |
1670 | 36 | : 0; |
1671 | 36 | int32_t vb = (b->AsIntConstant() != nullptr) |
1672 | 36 | ? b->AsIntConstant()->GetS32BitValue() |
1673 | 36 | : 0; |
1674 | 36 | return (va > vb ? a : b); |
1675 | 36 | } else { |
1676 | 0 | uint32_t va = (a->AsIntConstant() != nullptr) |
1677 | 0 | ? a->AsIntConstant()->GetU32BitValue() |
1678 | 0 | : 0; |
1679 | 0 | uint32_t vb = (b->AsIntConstant() != nullptr) |
1680 | 0 | ? b->AsIntConstant()->GetU32BitValue() |
1681 | 0 | : 0; |
1682 | 0 | return (va > vb ? a : b); |
1683 | 0 | } |
1684 | 36 | } else if (int_type->width() == 64) { |
1685 | 0 | if (int_type->IsSigned()) { |
1686 | 0 | int64_t va = a->GetS64(); |
1687 | 0 | int64_t vb = b->GetS64(); |
1688 | 0 | return (va > vb ? a : b); |
1689 | 0 | } else { |
1690 | 0 | uint64_t va = a->GetU64(); |
1691 | 0 | uint64_t vb = b->GetU64(); |
1692 | 0 | return (va > vb ? a : b); |
1693 | 0 | } |
1694 | 0 | } |
1695 | 8.00k | } else if (const analysis::Float* float_type = result_type->AsFloat()) { |
1696 | 8.00k | if (float_type->width() == 32) { |
1697 | 8.00k | float va = a->GetFloat(); |
1698 | 8.00k | float vb = b->GetFloat(); |
1699 | 8.00k | return NegZeroAwareLessThan(vb, va) ? a : b; |
1700 | 8.00k | } else if (float_type->width() == 64) { |
1701 | 0 | double va = a->GetDouble(); |
1702 | 0 | double vb = b->GetDouble(); |
1703 | 0 | return NegZeroAwareLessThan(vb, va) ? a : b; |
1704 | 0 | } |
1705 | 8.00k | } |
1706 | 0 | return nullptr; |
1707 | 8.03k | } |
1708 | | |
1709 | | const analysis::Constant* FoldNMin(const analysis::Type* result_type, |
1710 | | const analysis::Constant* a, |
1711 | | const analysis::Constant* b, |
1712 | 661 | analysis::ConstantManager*) { |
1713 | 661 | if (const analysis::Float* float_type = result_type->AsFloat()) { |
1714 | 661 | if (float_type->width() == 32) { |
1715 | 661 | float va = a->GetFloat(); |
1716 | 661 | float vb = b->GetFloat(); |
1717 | 661 | if (std::isnan(va)) { |
1718 | 39 | return b; |
1719 | 39 | } |
1720 | 622 | if (std::isnan(vb)) { |
1721 | 34 | return a; |
1722 | 34 | } |
1723 | 588 | return NegZeroAwareLessThan(va, vb) ? a : b; |
1724 | 622 | } else if (float_type->width() == 64) { |
1725 | 0 | double va = a->GetDouble(); |
1726 | 0 | double vb = b->GetDouble(); |
1727 | 0 | if (std::isnan(va)) { |
1728 | 0 | return b; |
1729 | 0 | } |
1730 | 0 | if (std::isnan(vb)) { |
1731 | 0 | return a; |
1732 | 0 | } |
1733 | 0 | return NegZeroAwareLessThan(va, vb) ? a : b; |
1734 | 0 | } |
1735 | 661 | } |
1736 | 0 | return nullptr; |
1737 | 661 | } |
1738 | | |
1739 | | const analysis::Constant* FoldNMax(const analysis::Type* result_type, |
1740 | | const analysis::Constant* a, |
1741 | | const analysis::Constant* b, |
1742 | 663 | analysis::ConstantManager*) { |
1743 | 663 | if (const analysis::Float* float_type = result_type->AsFloat()) { |
1744 | 663 | if (float_type->width() == 32) { |
1745 | 663 | float va = a->GetFloat(); |
1746 | 663 | float vb = b->GetFloat(); |
1747 | 663 | if (std::isnan(va)) { |
1748 | 68 | return b; |
1749 | 68 | } |
1750 | 595 | if (std::isnan(vb)) { |
1751 | 51 | return a; |
1752 | 51 | } |
1753 | 544 | return NegZeroAwareLessThan(vb, va) ? a : b; |
1754 | 595 | } else if (float_type->width() == 64) { |
1755 | 0 | double va = a->GetDouble(); |
1756 | 0 | double vb = b->GetDouble(); |
1757 | 0 | if (std::isnan(va)) { |
1758 | 0 | return b; |
1759 | 0 | } |
1760 | 0 | if (std::isnan(vb)) { |
1761 | 0 | return a; |
1762 | 0 | } |
1763 | 0 | return NegZeroAwareLessThan(vb, va) ? a : b; |
1764 | 0 | } |
1765 | 663 | } |
1766 | 0 | return nullptr; |
1767 | 663 | } |
1768 | | |
1769 | | // Fold an clamp instruction when all three operands are constant. |
1770 | | const analysis::Constant* FoldClamp1( |
1771 | | IRContext* context, Instruction* inst, |
1772 | 12.7k | const std::vector<const analysis::Constant*>& constants) { |
1773 | 12.7k | assert(inst->opcode() == spv::Op::OpExtInst && |
1774 | 12.7k | "Expecting an extended instruction."); |
1775 | 12.7k | assert(inst->GetSingleWordInOperand(0) == |
1776 | 12.7k | context->get_feature_mgr()->GetExtInstImportId_GLSLstd450() && |
1777 | 12.7k | "Expecting a GLSLstd450 extended instruction."); |
1778 | | |
1779 | | // Make sure all Clamp operands are constants. |
1780 | 24.3k | for (uint32_t i = 1; i < 4; i++) { |
1781 | 22.5k | if (constants[i] == nullptr) { |
1782 | 10.9k | return nullptr; |
1783 | 10.9k | } |
1784 | 22.5k | } |
1785 | | |
1786 | 1.83k | const analysis::Constant* temp = FoldFPBinaryOp( |
1787 | 1.83k | FoldMax, inst->type_id(), {constants[1], constants[2]}, context); |
1788 | 1.83k | if (temp == nullptr) { |
1789 | 0 | return nullptr; |
1790 | 0 | } |
1791 | 1.83k | return FoldFPBinaryOp(FoldMin, inst->type_id(), {temp, constants[3]}, |
1792 | 1.83k | context); |
1793 | 1.83k | } |
1794 | | |
1795 | | // Fold a clamp instruction when |x <= min_val|. |
1796 | | const analysis::Constant* FoldClamp2( |
1797 | | IRContext* context, Instruction* inst, |
1798 | 10.9k | const std::vector<const analysis::Constant*>& constants) { |
1799 | 10.9k | assert(inst->opcode() == spv::Op::OpExtInst && |
1800 | 10.9k | "Expecting an extended instruction."); |
1801 | 10.9k | assert(inst->GetSingleWordInOperand(0) == |
1802 | 10.9k | context->get_feature_mgr()->GetExtInstImportId_GLSLstd450() && |
1803 | 10.9k | "Expecting a GLSLstd450 extended instruction."); |
1804 | | |
1805 | 10.9k | const analysis::Constant* x = constants[1]; |
1806 | 10.9k | const analysis::Constant* min_val = constants[2]; |
1807 | | |
1808 | 10.9k | if (x == nullptr || min_val == nullptr) { |
1809 | 8.15k | return nullptr; |
1810 | 8.15k | } |
1811 | | |
1812 | 2.77k | const analysis::Constant* temp = |
1813 | 2.77k | FoldFPBinaryOp(FoldMax, inst->type_id(), {x, min_val}, context); |
1814 | 2.77k | if (temp == min_val) { |
1815 | | // We can assume that |min_val| is less than |max_val|. Therefore, if the |
1816 | | // result of the max operation is |min_val|, we know the result of the min |
1817 | | // operation, even if |max_val| is not a constant. |
1818 | 2.05k | return min_val; |
1819 | 2.05k | } |
1820 | 714 | return nullptr; |
1821 | 2.77k | } |
1822 | | |
1823 | | // Fold a clamp instruction when |x >= max_val|. |
1824 | | const analysis::Constant* FoldClamp3( |
1825 | | IRContext* context, Instruction* inst, |
1826 | 8.86k | const std::vector<const analysis::Constant*>& constants) { |
1827 | 8.86k | assert(inst->opcode() == spv::Op::OpExtInst && |
1828 | 8.86k | "Expecting an extended instruction."); |
1829 | 8.86k | assert(inst->GetSingleWordInOperand(0) == |
1830 | 8.86k | context->get_feature_mgr()->GetExtInstImportId_GLSLstd450() && |
1831 | 8.86k | "Expecting a GLSLstd450 extended instruction."); |
1832 | | |
1833 | 8.86k | const analysis::Constant* x = constants[1]; |
1834 | 8.86k | const analysis::Constant* max_val = constants[3]; |
1835 | | |
1836 | 8.86k | if (x == nullptr || max_val == nullptr) { |
1837 | 8.46k | return nullptr; |
1838 | 8.46k | } |
1839 | | |
1840 | 404 | const analysis::Constant* temp = |
1841 | 404 | FoldFPBinaryOp(FoldMin, inst->type_id(), {x, max_val}, context); |
1842 | 404 | if (temp == max_val) { |
1843 | | // We can assume that |min_val| is less than |max_val|. Therefore, if the |
1844 | | // result of the max operation is |min_val|, we know the result of the min |
1845 | | // operation, even if |max_val| is not a constant. |
1846 | 137 | return max_val; |
1847 | 137 | } |
1848 | 267 | return nullptr; |
1849 | 404 | } |
1850 | | |
1851 | | // Fold an clamp instruction when all three operands are constant. |
1852 | | const analysis::Constant* FoldNClamp1( |
1853 | | IRContext* context, Instruction* inst, |
1854 | 607 | const std::vector<const analysis::Constant*>& constants) { |
1855 | 607 | assert(inst->opcode() == spv::Op::OpExtInst && |
1856 | 607 | "Expecting an extended instruction."); |
1857 | 607 | assert(inst->GetSingleWordInOperand(0) == |
1858 | 607 | context->get_feature_mgr()->GetExtInstImportId_GLSLstd450() && |
1859 | 607 | "Expecting a GLSLstd450 extended instruction."); |
1860 | | |
1861 | | // Make sure all Clamp operands are constants. |
1862 | 1.41k | for (uint32_t i = 1; i < 4; i++) { |
1863 | 1.27k | if (constants[i] == nullptr) { |
1864 | 468 | return nullptr; |
1865 | 468 | } |
1866 | 1.27k | } |
1867 | | |
1868 | 139 | const analysis::Constant* temp = FoldFPBinaryOp( |
1869 | 139 | FoldNMax, inst->type_id(), {constants[1], constants[2]}, context); |
1870 | 139 | if (temp == nullptr) { |
1871 | 0 | return nullptr; |
1872 | 0 | } |
1873 | 139 | return FoldFPBinaryOp(FoldNMin, inst->type_id(), {temp, constants[3]}, |
1874 | 139 | context); |
1875 | 139 | } |
1876 | | |
1877 | | // Fold a clamp instruction when |x <= min_val|. |
1878 | | const analysis::Constant* FoldNClamp2( |
1879 | | IRContext* context, Instruction* inst, |
1880 | 468 | const std::vector<const analysis::Constant*>& constants) { |
1881 | 468 | assert(inst->opcode() == spv::Op::OpExtInst && |
1882 | 468 | "Expecting an extended instruction."); |
1883 | 468 | assert(inst->GetSingleWordInOperand(0) == |
1884 | 468 | context->get_feature_mgr()->GetExtInstImportId_GLSLstd450() && |
1885 | 468 | "Expecting a GLSLstd450 extended instruction."); |
1886 | | |
1887 | 468 | const analysis::Constant* x = constants[1]; |
1888 | 468 | const analysis::Constant* min_val = constants[2]; |
1889 | | |
1890 | 468 | if (x == nullptr || min_val == nullptr) { |
1891 | 327 | return nullptr; |
1892 | 327 | } |
1893 | | |
1894 | 141 | const analysis::Constant* temp = |
1895 | 141 | FoldFPBinaryOp(FoldNMax, inst->type_id(), {x, min_val}, context); |
1896 | 141 | if (temp == min_val) { |
1897 | | // We can assume that |min_val| is less than |max_val|. Therefore, if the |
1898 | | // result of the max operation is |min_val|, we know the result of the min |
1899 | | // operation, even if |max_val| is not a constant. |
1900 | 65 | return min_val; |
1901 | 65 | } |
1902 | 76 | return nullptr; |
1903 | 141 | } |
1904 | | |
1905 | | // Fold a clamp instruction when |x >= max_val|. |
1906 | | const analysis::Constant* FoldNClamp3( |
1907 | | IRContext* context, Instruction* inst, |
1908 | 403 | const std::vector<const analysis::Constant*>& constants) { |
1909 | 403 | assert(inst->opcode() == spv::Op::OpExtInst && |
1910 | 403 | "Expecting an extended instruction."); |
1911 | 403 | assert(inst->GetSingleWordInOperand(0) == |
1912 | 403 | context->get_feature_mgr()->GetExtInstImportId_GLSLstd450() && |
1913 | 403 | "Expecting a GLSLstd450 extended instruction."); |
1914 | | |
1915 | 403 | const analysis::Constant* x = constants[1]; |
1916 | 403 | const analysis::Constant* max_val = constants[3]; |
1917 | | |
1918 | 403 | if (x == nullptr || max_val == nullptr) { |
1919 | 334 | return nullptr; |
1920 | 334 | } |
1921 | | |
1922 | 69 | const analysis::Constant* temp = |
1923 | 69 | FoldFPBinaryOp(FoldNMin, inst->type_id(), {x, max_val}, context); |
1924 | 69 | if (temp == max_val) { |
1925 | | // We can assume that |min_val| is less than |max_val|. Therefore, if the |
1926 | | // result of the max operation is |min_val|, we know the result of the min |
1927 | | // operation, even if |max_val| is not a constant. |
1928 | 20 | return max_val; |
1929 | 20 | } |
1930 | 49 | return nullptr; |
1931 | 69 | } |
1932 | | |
1933 | 105k | UnaryScalarFoldingRule FoldFTranscendentalUnary(double (*fp)(double)) { |
1934 | 105k | return |
1935 | 105k | [fp](const analysis::Type* result_type, const analysis::Constant* a, |
1936 | 105k | analysis::ConstantManager* const_mgr) -> const analysis::Constant* { |
1937 | 3.05k | assert(result_type != nullptr && a != nullptr); |
1938 | 3.05k | const analysis::Float* float_type = a->type()->AsFloat(); |
1939 | 3.05k | assert(float_type != nullptr); |
1940 | 3.05k | assert(float_type == result_type->AsFloat()); |
1941 | 3.05k | if (float_type->width() == 32) { |
1942 | 3.05k | float fa = a->GetFloat(); |
1943 | 3.05k | float res = static_cast<float>(fp(fa)); |
1944 | 3.05k | utils::FloatProxy<float> result(res); |
1945 | 3.05k | std::vector<uint32_t> words = result.GetWords(); |
1946 | 3.05k | return const_mgr->GetConstant(result_type, words); |
1947 | 3.05k | } else if (float_type->width() == 64) { |
1948 | 0 | double fa = a->GetDouble(); |
1949 | 0 | double res = fp(fa); |
1950 | 0 | utils::FloatProxy<double> result(res); |
1951 | 0 | std::vector<uint32_t> words = result.GetWords(); |
1952 | 0 | return const_mgr->GetConstant(result_type, words); |
1953 | 0 | } |
1954 | 0 | return nullptr; |
1955 | 3.05k | }; |
1956 | 105k | } |
1957 | | |
1958 | | BinaryScalarFoldingRule FoldFTranscendentalBinary(double (*fp)(double, |
1959 | 19.1k | double)) { |
1960 | 19.1k | return |
1961 | 19.1k | [fp](const analysis::Type* result_type, const analysis::Constant* a, |
1962 | 19.1k | const analysis::Constant* b, |
1963 | 19.1k | analysis::ConstantManager* const_mgr) -> const analysis::Constant* { |
1964 | 217 | assert(result_type != nullptr && a != nullptr); |
1965 | 217 | const analysis::Float* float_type = a->type()->AsFloat(); |
1966 | 217 | assert(float_type != nullptr); |
1967 | 217 | assert(float_type == result_type->AsFloat()); |
1968 | 217 | assert(float_type == b->type()->AsFloat()); |
1969 | 217 | if (float_type->width() == 32) { |
1970 | 217 | float fa = a->GetFloat(); |
1971 | 217 | float fb = b->GetFloat(); |
1972 | 217 | float res = static_cast<float>(fp(fa, fb)); |
1973 | 217 | utils::FloatProxy<float> result(res); |
1974 | 217 | std::vector<uint32_t> words = result.GetWords(); |
1975 | 217 | return const_mgr->GetConstant(result_type, words); |
1976 | 217 | } else if (float_type->width() == 64) { |
1977 | 0 | double fa = a->GetDouble(); |
1978 | 0 | double fb = b->GetDouble(); |
1979 | 0 | double res = fp(fa, fb); |
1980 | 0 | utils::FloatProxy<double> result(res); |
1981 | 0 | std::vector<uint32_t> words = result.GetWords(); |
1982 | 0 | return const_mgr->GetConstant(result_type, words); |
1983 | 0 | } |
1984 | 0 | return nullptr; |
1985 | 217 | }; |
1986 | 19.1k | } |
1987 | | |
1988 | | enum Sign { Signed, Unsigned }; |
1989 | | |
1990 | | // Returns a BinaryScalarFoldingRule that applies `op` to the scalars. |
1991 | | // The `signedness` is used to determine if the operands should be interpreted |
1992 | | // as signed or unsigned. If the operands are signed, the value will be sign |
1993 | | // extended before the value is passed to `op`. Otherwise the values will be |
1994 | | // zero extended. |
1995 | | template <Sign signedness> |
1996 | | BinaryScalarFoldingRule FoldBinaryIntegerOperation(uint64_t (*op)(uint64_t, |
1997 | 128k | uint64_t)) { |
1998 | 128k | return |
1999 | 128k | [op](const analysis::Type* result_type, const analysis::Constant* a, |
2000 | 128k | const analysis::Constant* b, |
2001 | 860k | analysis::ConstantManager* const_mgr) -> const analysis::Constant* { |
2002 | 860k | assert(result_type != nullptr && a != nullptr && b != nullptr); |
2003 | 860k | const analysis::Integer* integer_type = result_type->AsInteger(); |
2004 | 860k | assert(integer_type != nullptr); |
2005 | 860k | assert(a->type()->kind() == analysis::Type::kInteger); |
2006 | 860k | assert(b->type()->kind() == analysis::Type::kInteger); |
2007 | 860k | assert(integer_type->width() == a->type()->AsInteger()->width()); |
2008 | 860k | assert(integer_type->width() == b->type()->AsInteger()->width()); |
2009 | | |
2010 | | // In SPIR-V, all operations support unsigned types, but the way they |
2011 | | // are interpreted depends on the opcode. This is why we use the |
2012 | | // template argument to determine how to interpret the operands. |
2013 | 860k | uint64_t ia = (signedness == Signed ? a->GetSignExtendedValue() |
2014 | 860k | : a->GetZeroExtendedValue()); |
2015 | 860k | uint64_t ib = (signedness == Signed ? b->GetSignExtendedValue() |
2016 | 860k | : b->GetZeroExtendedValue()); |
2017 | 860k | uint64_t result = op(ia, ib); |
2018 | | |
2019 | 860k | const analysis::Constant* result_constant = |
2020 | 860k | const_mgr->GenerateIntegerConstant(integer_type, result); |
2021 | 860k | return result_constant; |
2022 | 860k | }; const_folding_rules.cpp:spvtools::opt::(anonymous namespace)::FoldBinaryIntegerOperation<(spvtools::opt::(anonymous namespace)::Sign)1>(unsigned long (*)(unsigned long, unsigned long))::{lambda(spvtools::opt::analysis::Type const*, spvtools::opt::analysis::Constant const*, spvtools::opt::analysis::Constant const*, spvtools::opt::analysis::ConstantManager*)#1}::operator()(spvtools::opt::analysis::Type const*, spvtools::opt::analysis::Constant const*, spvtools::opt::analysis::Constant const*, spvtools::opt::analysis::ConstantManager*) constLine | Count | Source | 2001 | 750k | analysis::ConstantManager* const_mgr) -> const analysis::Constant* { | 2002 | 750k | assert(result_type != nullptr && a != nullptr && b != nullptr); | 2003 | 750k | const analysis::Integer* integer_type = result_type->AsInteger(); | 2004 | 750k | assert(integer_type != nullptr); | 2005 | 750k | assert(a->type()->kind() == analysis::Type::kInteger); | 2006 | 750k | assert(b->type()->kind() == analysis::Type::kInteger); | 2007 | 750k | assert(integer_type->width() == a->type()->AsInteger()->width()); | 2008 | 750k | assert(integer_type->width() == b->type()->AsInteger()->width()); | 2009 | | | 2010 | | // In SPIR-V, all operations support unsigned types, but the way they | 2011 | | // are interpreted depends on the opcode. This is why we use the | 2012 | | // template argument to determine how to interpret the operands. | 2013 | 750k | uint64_t ia = (signedness == Signed ? a->GetSignExtendedValue() | 2014 | 750k | : a->GetZeroExtendedValue()); | 2015 | 750k | uint64_t ib = (signedness == Signed ? b->GetSignExtendedValue() | 2016 | 750k | : b->GetZeroExtendedValue()); | 2017 | 750k | uint64_t result = op(ia, ib); | 2018 | | | 2019 | 750k | const analysis::Constant* result_constant = | 2020 | 750k | const_mgr->GenerateIntegerConstant(integer_type, result); | 2021 | 750k | return result_constant; | 2022 | 750k | }; |
const_folding_rules.cpp:spvtools::opt::(anonymous namespace)::FoldBinaryIntegerOperation<(spvtools::opt::(anonymous namespace)::Sign)0>(unsigned long (*)(unsigned long, unsigned long))::{lambda(spvtools::opt::analysis::Type const*, spvtools::opt::analysis::Constant const*, spvtools::opt::analysis::Constant const*, spvtools::opt::analysis::ConstantManager*)#1}::operator()(spvtools::opt::analysis::Type const*, spvtools::opt::analysis::Constant const*, spvtools::opt::analysis::Constant const*, spvtools::opt::analysis::ConstantManager*) constLine | Count | Source | 2001 | 110k | analysis::ConstantManager* const_mgr) -> const analysis::Constant* { | 2002 | 110k | assert(result_type != nullptr && a != nullptr && b != nullptr); | 2003 | 110k | const analysis::Integer* integer_type = result_type->AsInteger(); | 2004 | 110k | assert(integer_type != nullptr); | 2005 | 110k | assert(a->type()->kind() == analysis::Type::kInteger); | 2006 | 110k | assert(b->type()->kind() == analysis::Type::kInteger); | 2007 | 110k | assert(integer_type->width() == a->type()->AsInteger()->width()); | 2008 | 110k | assert(integer_type->width() == b->type()->AsInteger()->width()); | 2009 | | | 2010 | | // In SPIR-V, all operations support unsigned types, but the way they | 2011 | | // are interpreted depends on the opcode. This is why we use the | 2012 | | // template argument to determine how to interpret the operands. | 2013 | 110k | uint64_t ia = (signedness == Signed ? a->GetSignExtendedValue() | 2014 | 110k | : a->GetZeroExtendedValue()); | 2015 | 110k | uint64_t ib = (signedness == Signed ? b->GetSignExtendedValue() | 2016 | 110k | : b->GetZeroExtendedValue()); | 2017 | 110k | uint64_t result = op(ia, ib); | 2018 | | | 2019 | 110k | const analysis::Constant* result_constant = | 2020 | 110k | const_mgr->GenerateIntegerConstant(integer_type, result); | 2021 | 110k | return result_constant; | 2022 | 110k | }; |
|
2023 | 128k | } const_folding_rules.cpp:std::__1::function<spvtools::opt::analysis::Constant const* (spvtools::opt::analysis::Type const*, spvtools::opt::analysis::Constant const*, spvtools::opt::analysis::Constant const*, spvtools::opt::analysis::ConstantManager*)> spvtools::opt::(anonymous namespace)::FoldBinaryIntegerOperation<(spvtools::opt::(anonymous namespace)::Sign)1>(unsigned long (*)(unsigned long, unsigned long)) Line | Count | Source | 1997 | 80.4k | uint64_t)) { | 1998 | 80.4k | return | 1999 | 80.4k | [op](const analysis::Type* result_type, const analysis::Constant* a, | 2000 | 80.4k | const analysis::Constant* b, | 2001 | 80.4k | analysis::ConstantManager* const_mgr) -> const analysis::Constant* { | 2002 | 80.4k | assert(result_type != nullptr && a != nullptr && b != nullptr); | 2003 | 80.4k | const analysis::Integer* integer_type = result_type->AsInteger(); | 2004 | 80.4k | assert(integer_type != nullptr); | 2005 | 80.4k | assert(a->type()->kind() == analysis::Type::kInteger); | 2006 | 80.4k | assert(b->type()->kind() == analysis::Type::kInteger); | 2007 | 80.4k | assert(integer_type->width() == a->type()->AsInteger()->width()); | 2008 | 80.4k | assert(integer_type->width() == b->type()->AsInteger()->width()); | 2009 | | | 2010 | | // In SPIR-V, all operations support unsigned types, but the way they | 2011 | | // are interpreted depends on the opcode. This is why we use the | 2012 | | // template argument to determine how to interpret the operands. | 2013 | 80.4k | uint64_t ia = (signedness == Signed ? a->GetSignExtendedValue() | 2014 | 80.4k | : a->GetZeroExtendedValue()); | 2015 | 80.4k | uint64_t ib = (signedness == Signed ? b->GetSignExtendedValue() | 2016 | 80.4k | : b->GetZeroExtendedValue()); | 2017 | 80.4k | uint64_t result = op(ia, ib); | 2018 | | | 2019 | 80.4k | const analysis::Constant* result_constant = | 2020 | 80.4k | const_mgr->GenerateIntegerConstant(integer_type, result); | 2021 | 80.4k | return result_constant; | 2022 | 80.4k | }; | 2023 | 80.4k | } |
const_folding_rules.cpp:std::__1::function<spvtools::opt::analysis::Constant const* (spvtools::opt::analysis::Type const*, spvtools::opt::analysis::Constant const*, spvtools::opt::analysis::Constant const*, spvtools::opt::analysis::ConstantManager*)> spvtools::opt::(anonymous namespace)::FoldBinaryIntegerOperation<(spvtools::opt::(anonymous namespace)::Sign)0>(unsigned long (*)(unsigned long, unsigned long)) Line | Count | Source | 1997 | 48.2k | uint64_t)) { | 1998 | 48.2k | return | 1999 | 48.2k | [op](const analysis::Type* result_type, const analysis::Constant* a, | 2000 | 48.2k | const analysis::Constant* b, | 2001 | 48.2k | analysis::ConstantManager* const_mgr) -> const analysis::Constant* { | 2002 | 48.2k | assert(result_type != nullptr && a != nullptr && b != nullptr); | 2003 | 48.2k | const analysis::Integer* integer_type = result_type->AsInteger(); | 2004 | 48.2k | assert(integer_type != nullptr); | 2005 | 48.2k | assert(a->type()->kind() == analysis::Type::kInteger); | 2006 | 48.2k | assert(b->type()->kind() == analysis::Type::kInteger); | 2007 | 48.2k | assert(integer_type->width() == a->type()->AsInteger()->width()); | 2008 | 48.2k | assert(integer_type->width() == b->type()->AsInteger()->width()); | 2009 | | | 2010 | | // In SPIR-V, all operations support unsigned types, but the way they | 2011 | | // are interpreted depends on the opcode. This is why we use the | 2012 | | // template argument to determine how to interpret the operands. | 2013 | 48.2k | uint64_t ia = (signedness == Signed ? a->GetSignExtendedValue() | 2014 | 48.2k | : a->GetZeroExtendedValue()); | 2015 | 48.2k | uint64_t ib = (signedness == Signed ? b->GetSignExtendedValue() | 2016 | 48.2k | : b->GetZeroExtendedValue()); | 2017 | 48.2k | uint64_t result = op(ia, ib); | 2018 | | | 2019 | 48.2k | const analysis::Constant* result_constant = | 2020 | 48.2k | const_mgr->GenerateIntegerConstant(integer_type, result); | 2021 | 48.2k | return result_constant; | 2022 | 48.2k | }; | 2023 | 48.2k | } |
|
2024 | | |
2025 | | // A scalar folding rule that folds OpSConvert. |
2026 | | const analysis::Constant* FoldScalarSConvert( |
2027 | | const analysis::Type* result_type, const analysis::Constant* a, |
2028 | 0 | analysis::ConstantManager* const_mgr) { |
2029 | 0 | assert(result_type != nullptr); |
2030 | 0 | assert(a != nullptr); |
2031 | 0 | assert(const_mgr != nullptr); |
2032 | 0 | const analysis::Integer* integer_type = result_type->AsInteger(); |
2033 | 0 | assert(integer_type && "The result type of an SConvert"); |
2034 | 0 | int64_t value = a->GetSignExtendedValue(); |
2035 | 0 | return const_mgr->GenerateIntegerConstant(integer_type, value); |
2036 | 0 | } |
2037 | | |
2038 | | // A scalar folding rule that folds OpUConvert. |
2039 | | const analysis::Constant* FoldScalarUConvert( |
2040 | | const analysis::Type* result_type, const analysis::Constant* a, |
2041 | 0 | analysis::ConstantManager* const_mgr) { |
2042 | 0 | assert(result_type != nullptr); |
2043 | 0 | assert(a != nullptr); |
2044 | 0 | assert(const_mgr != nullptr); |
2045 | 0 | const analysis::Integer* integer_type = result_type->AsInteger(); |
2046 | 0 | assert(integer_type && "The result type of an UConvert"); |
2047 | 0 | uint64_t value = a->GetZeroExtendedValue(); |
2048 | | |
2049 | | // If the operand was an unsigned value with less than 32-bit, it would have |
2050 | | // been sign extended earlier, and we need to clear those bits. |
2051 | 0 | auto* operand_type = a->type()->AsInteger(); |
2052 | 0 | value = utils::ClearHighBits(value, 64 - operand_type->width()); |
2053 | 0 | return const_mgr->GenerateIntegerConstant(integer_type, value); |
2054 | 0 | } |
2055 | | } // namespace |
2056 | | |
2057 | 16.0k | void ConstantFoldingRules::AddFoldingRules() { |
2058 | | // Add all folding rules to the list for the opcodes to which they apply. |
2059 | | // Note that the order in which rules are added to the list matters. If a rule |
2060 | | // applies to the instruction, the rest of the rules will not be attempted. |
2061 | | // Take that into consideration. |
2062 | | |
2063 | 16.0k | rules_[spv::Op::OpCompositeConstruct].push_back(FoldCompositeWithConstants()); |
2064 | | |
2065 | 16.0k | rules_[spv::Op::OpCompositeExtract].push_back(FoldExtractWithConstants()); |
2066 | 16.0k | rules_[spv::Op::OpCompositeInsert].push_back(FoldInsertWithConstants()); |
2067 | | |
2068 | 16.0k | rules_[spv::Op::OpConvertFToS].push_back(FoldFToI()); |
2069 | 16.0k | rules_[spv::Op::OpConvertFToU].push_back(FoldFToI()); |
2070 | 16.0k | rules_[spv::Op::OpConvertSToF].push_back(FoldIToF()); |
2071 | 16.0k | rules_[spv::Op::OpConvertUToF].push_back(FoldIToF()); |
2072 | 16.0k | rules_[spv::Op::OpSConvert].push_back(FoldUnaryOp(FoldScalarSConvert)); |
2073 | 16.0k | rules_[spv::Op::OpUConvert].push_back(FoldUnaryOp(FoldScalarUConvert)); |
2074 | | |
2075 | 16.0k | rules_[spv::Op::OpDot].push_back(FoldOpDotWithConstants()); |
2076 | 16.0k | rules_[spv::Op::OpFAdd].push_back(FoldFAdd()); |
2077 | | |
2078 | 16.0k | rules_[spv::Op::OpFDiv].push_back(FoldFDiv()); |
2079 | 16.0k | rules_[spv::Op::OpFDiv].push_back(FoldRedundantDiv()); |
2080 | | |
2081 | 16.0k | rules_[spv::Op::OpFMul].push_back(FoldFMul()); |
2082 | | |
2083 | 16.0k | rules_[spv::Op::OpFSub].push_back(FoldFSub()); |
2084 | 16.0k | rules_[spv::Op::OpFSub].push_back(FoldRedundantSub()); |
2085 | | |
2086 | 16.0k | rules_[spv::Op::OpSelect].push_back(FoldInvariantSelect()); |
2087 | | |
2088 | 16.0k | rules_[spv::Op::OpFOrdEqual].push_back(FoldFOrdEqual()); |
2089 | | |
2090 | 16.0k | rules_[spv::Op::OpFUnordEqual].push_back(FoldFUnordEqual()); |
2091 | | |
2092 | 16.0k | rules_[spv::Op::OpFOrdNotEqual].push_back(FoldFOrdNotEqual()); |
2093 | | |
2094 | 16.0k | rules_[spv::Op::OpFUnordNotEqual].push_back(FoldFUnordNotEqual()); |
2095 | | |
2096 | 16.0k | rules_[spv::Op::OpFOrdLessThan].push_back(FoldFOrdLessThan()); |
2097 | 16.0k | rules_[spv::Op::OpFOrdLessThan].push_back( |
2098 | 16.0k | FoldFClampFeedingCompare(spv::Op::OpFOrdLessThan)); |
2099 | | |
2100 | 16.0k | rules_[spv::Op::OpFUnordLessThan].push_back(FoldFUnordLessThan()); |
2101 | 16.0k | rules_[spv::Op::OpFUnordLessThan].push_back( |
2102 | 16.0k | FoldFClampFeedingCompare(spv::Op::OpFUnordLessThan)); |
2103 | | |
2104 | 16.0k | rules_[spv::Op::OpFOrdGreaterThan].push_back(FoldFOrdGreaterThan()); |
2105 | 16.0k | rules_[spv::Op::OpFOrdGreaterThan].push_back( |
2106 | 16.0k | FoldFClampFeedingCompare(spv::Op::OpFOrdGreaterThan)); |
2107 | | |
2108 | 16.0k | rules_[spv::Op::OpFUnordGreaterThan].push_back(FoldFUnordGreaterThan()); |
2109 | 16.0k | rules_[spv::Op::OpFUnordGreaterThan].push_back( |
2110 | 16.0k | FoldFClampFeedingCompare(spv::Op::OpFUnordGreaterThan)); |
2111 | | |
2112 | 16.0k | rules_[spv::Op::OpFOrdLessThanEqual].push_back(FoldFOrdLessThanEqual()); |
2113 | 16.0k | rules_[spv::Op::OpFOrdLessThanEqual].push_back( |
2114 | 16.0k | FoldFClampFeedingCompare(spv::Op::OpFOrdLessThanEqual)); |
2115 | | |
2116 | 16.0k | rules_[spv::Op::OpFUnordLessThanEqual].push_back(FoldFUnordLessThanEqual()); |
2117 | 16.0k | rules_[spv::Op::OpFUnordLessThanEqual].push_back( |
2118 | 16.0k | FoldFClampFeedingCompare(spv::Op::OpFUnordLessThanEqual)); |
2119 | | |
2120 | 16.0k | rules_[spv::Op::OpFOrdGreaterThanEqual].push_back(FoldFOrdGreaterThanEqual()); |
2121 | 16.0k | rules_[spv::Op::OpFOrdGreaterThanEqual].push_back( |
2122 | 16.0k | FoldFClampFeedingCompare(spv::Op::OpFOrdGreaterThanEqual)); |
2123 | | |
2124 | 16.0k | rules_[spv::Op::OpFUnordGreaterThanEqual].push_back( |
2125 | 16.0k | FoldFUnordGreaterThanEqual()); |
2126 | 16.0k | rules_[spv::Op::OpFUnordGreaterThanEqual].push_back( |
2127 | 16.0k | FoldFClampFeedingCompare(spv::Op::OpFUnordGreaterThanEqual)); |
2128 | | |
2129 | 16.0k | rules_[spv::Op::OpVectorShuffle].push_back(FoldVectorShuffleWithConstants()); |
2130 | 16.0k | rules_[spv::Op::OpVectorTimesScalar].push_back(FoldVectorTimesScalar()); |
2131 | 16.0k | rules_[spv::Op::OpVectorTimesMatrix].push_back(FoldVectorTimesMatrix()); |
2132 | 16.0k | rules_[spv::Op::OpMatrixTimesVector].push_back(FoldMatrixTimesVector()); |
2133 | 16.0k | rules_[spv::Op::OpTranspose].push_back(FoldTranspose); |
2134 | | |
2135 | 16.0k | rules_[spv::Op::OpFNegate].push_back(FoldFNegate()); |
2136 | 16.0k | rules_[spv::Op::OpSNegate].push_back(FoldSNegate()); |
2137 | 16.0k | rules_[spv::Op::OpQuantizeToF16].push_back(FoldQuantizeToF16()); |
2138 | | |
2139 | 16.0k | rules_[spv::Op::OpIAdd].push_back( |
2140 | 16.0k | FoldBinaryOp(FoldBinaryIntegerOperation<Unsigned>( |
2141 | 466k | [](uint64_t a, uint64_t b) { return a + b; }))); |
2142 | | |
2143 | 16.0k | rules_[spv::Op::OpISub].push_back( |
2144 | 16.0k | FoldBinaryOp(FoldBinaryIntegerOperation<Unsigned>( |
2145 | 218k | [](uint64_t a, uint64_t b) { return a - b; }))); |
2146 | 16.0k | rules_[spv::Op::OpISub].push_back(FoldRedundantSub()); |
2147 | | |
2148 | 16.0k | rules_[spv::Op::OpIMul].push_back( |
2149 | 16.0k | FoldBinaryOp(FoldBinaryIntegerOperation<Unsigned>( |
2150 | 64.5k | [](uint64_t a, uint64_t b) { return a * b; }))); |
2151 | | |
2152 | 16.0k | rules_[spv::Op::OpUDiv].push_back( |
2153 | 16.0k | FoldBinaryOp(FoldBinaryIntegerOperation<Unsigned>( |
2154 | 16.0k | [](uint64_t a, uint64_t b) { return (b != 0 ? a / b : 0); }))); |
2155 | 16.0k | rules_[spv::Op::OpUDiv].push_back(FoldRedundantDiv()); |
2156 | | |
2157 | 16.0k | rules_[spv::Op::OpSDiv].push_back(FoldBinaryOp( |
2158 | 38.8k | FoldBinaryIntegerOperation<Signed>([](uint64_t a, uint64_t b) { |
2159 | 38.8k | return (b != 0 ? static_cast<uint64_t>(static_cast<int64_t>(a) / |
2160 | 37.9k | static_cast<int64_t>(b)) |
2161 | 38.8k | : 0); |
2162 | 38.8k | }))); |
2163 | 16.0k | rules_[spv::Op::OpSDiv].push_back(FoldRedundantDiv()); |
2164 | | |
2165 | 16.0k | rules_[spv::Op::OpUMod].push_back( |
2166 | 16.0k | FoldBinaryOp(FoldBinaryIntegerOperation<Unsigned>( |
2167 | 16.0k | [](uint64_t a, uint64_t b) { return (b != 0 ? a % b : 0); }))); |
2168 | | |
2169 | 16.0k | rules_[spv::Op::OpSRem].push_back(FoldBinaryOp( |
2170 | 65.9k | FoldBinaryIntegerOperation<Signed>([](uint64_t a, uint64_t b) { |
2171 | 65.9k | return (b != 0 ? static_cast<uint64_t>(static_cast<int64_t>(a) % |
2172 | 65.1k | static_cast<int64_t>(b)) |
2173 | 65.9k | : 0); |
2174 | 65.9k | }))); |
2175 | | |
2176 | 16.0k | rules_[spv::Op::OpSMod].push_back(FoldBinaryOp( |
2177 | 16.0k | FoldBinaryIntegerOperation<Signed>([](uint64_t a, uint64_t b) { |
2178 | 5.37k | if (b == 0) return static_cast<uint64_t>(0ull); |
2179 | | |
2180 | 5.14k | int64_t signed_a = static_cast<int64_t>(a); |
2181 | 5.14k | int64_t signed_b = static_cast<int64_t>(b); |
2182 | 5.14k | int64_t result = signed_a % signed_b; |
2183 | 5.14k | if ((signed_b < 0) != (result < 0)) result += signed_b; |
2184 | 5.14k | return static_cast<uint64_t>(result); |
2185 | 5.37k | }))); |
2186 | | |
2187 | | // Add rules for GLSLstd450 |
2188 | 16.0k | FeatureManager* feature_manager = context_->get_feature_mgr(); |
2189 | 16.0k | uint32_t ext_inst_glslstd450_id = |
2190 | 16.0k | feature_manager->GetExtInstImportId_GLSLstd450(); |
2191 | 16.0k | if (ext_inst_glslstd450_id != 0) { |
2192 | 9.59k | ext_rules_[{ext_inst_glslstd450_id, GLSLstd450FMix}].push_back(FoldFMix()); |
2193 | 9.59k | ext_rules_[{ext_inst_glslstd450_id, GLSLstd450SMin}].push_back( |
2194 | 9.59k | FoldFPBinaryOp(FoldMin)); |
2195 | 9.59k | ext_rules_[{ext_inst_glslstd450_id, GLSLstd450UMin}].push_back( |
2196 | 9.59k | FoldFPBinaryOp(FoldMin)); |
2197 | 9.59k | ext_rules_[{ext_inst_glslstd450_id, GLSLstd450FMin}].push_back( |
2198 | 9.59k | FoldFPBinaryOp(FoldMin)); |
2199 | 9.59k | ext_rules_[{ext_inst_glslstd450_id, GLSLstd450NMin}].push_back( |
2200 | 9.59k | FoldFPBinaryOp(FoldNMin)); |
2201 | 9.59k | ext_rules_[{ext_inst_glslstd450_id, GLSLstd450SMax}].push_back( |
2202 | 9.59k | FoldFPBinaryOp(FoldMax)); |
2203 | 9.59k | ext_rules_[{ext_inst_glslstd450_id, GLSLstd450UMax}].push_back( |
2204 | 9.59k | FoldFPBinaryOp(FoldMax)); |
2205 | 9.59k | ext_rules_[{ext_inst_glslstd450_id, GLSLstd450FMax}].push_back( |
2206 | 9.59k | FoldFPBinaryOp(FoldMax)); |
2207 | 9.59k | ext_rules_[{ext_inst_glslstd450_id, GLSLstd450NMax}].push_back( |
2208 | 9.59k | FoldFPBinaryOp(FoldNMax)); |
2209 | 9.59k | ext_rules_[{ext_inst_glslstd450_id, GLSLstd450UClamp}].push_back( |
2210 | 9.59k | FoldClamp1); |
2211 | 9.59k | ext_rules_[{ext_inst_glslstd450_id, GLSLstd450UClamp}].push_back( |
2212 | 9.59k | FoldClamp2); |
2213 | 9.59k | ext_rules_[{ext_inst_glslstd450_id, GLSLstd450UClamp}].push_back( |
2214 | 9.59k | FoldClamp3); |
2215 | 9.59k | ext_rules_[{ext_inst_glslstd450_id, GLSLstd450SClamp}].push_back( |
2216 | 9.59k | FoldClamp1); |
2217 | 9.59k | ext_rules_[{ext_inst_glslstd450_id, GLSLstd450SClamp}].push_back( |
2218 | 9.59k | FoldClamp2); |
2219 | 9.59k | ext_rules_[{ext_inst_glslstd450_id, GLSLstd450SClamp}].push_back( |
2220 | 9.59k | FoldClamp3); |
2221 | 9.59k | ext_rules_[{ext_inst_glslstd450_id, GLSLstd450FClamp}].push_back( |
2222 | 9.59k | FoldClamp1); |
2223 | 9.59k | ext_rules_[{ext_inst_glslstd450_id, GLSLstd450FClamp}].push_back( |
2224 | 9.59k | FoldClamp2); |
2225 | 9.59k | ext_rules_[{ext_inst_glslstd450_id, GLSLstd450FClamp}].push_back( |
2226 | 9.59k | FoldClamp3); |
2227 | 9.59k | ext_rules_[{ext_inst_glslstd450_id, GLSLstd450NClamp}].push_back( |
2228 | 9.59k | FoldNClamp1); |
2229 | 9.59k | ext_rules_[{ext_inst_glslstd450_id, GLSLstd450NClamp}].push_back( |
2230 | 9.59k | FoldNClamp2); |
2231 | 9.59k | ext_rules_[{ext_inst_glslstd450_id, GLSLstd450NClamp}].push_back( |
2232 | 9.59k | FoldNClamp3); |
2233 | 9.59k | ext_rules_[{ext_inst_glslstd450_id, GLSLstd450Sin}].push_back( |
2234 | 9.59k | FoldFPUnaryOp(FoldFTranscendentalUnary(std::sin))); |
2235 | 9.59k | ext_rules_[{ext_inst_glslstd450_id, GLSLstd450Cos}].push_back( |
2236 | 9.59k | FoldFPUnaryOp(FoldFTranscendentalUnary(std::cos))); |
2237 | 9.59k | ext_rules_[{ext_inst_glslstd450_id, GLSLstd450Tan}].push_back( |
2238 | 9.59k | FoldFPUnaryOp(FoldFTranscendentalUnary(std::tan))); |
2239 | 9.59k | ext_rules_[{ext_inst_glslstd450_id, GLSLstd450Asin}].push_back( |
2240 | 9.59k | FoldFPUnaryOp(FoldFTranscendentalUnary(std::asin))); |
2241 | 9.59k | ext_rules_[{ext_inst_glslstd450_id, GLSLstd450Acos}].push_back( |
2242 | 9.59k | FoldFPUnaryOp(FoldFTranscendentalUnary(std::acos))); |
2243 | 9.59k | ext_rules_[{ext_inst_glslstd450_id, GLSLstd450Atan}].push_back( |
2244 | 9.59k | FoldFPUnaryOp(FoldFTranscendentalUnary(std::atan))); |
2245 | 9.59k | ext_rules_[{ext_inst_glslstd450_id, GLSLstd450Exp}].push_back( |
2246 | 9.59k | FoldFPUnaryOp(FoldFTranscendentalUnary(std::exp))); |
2247 | 9.59k | ext_rules_[{ext_inst_glslstd450_id, GLSLstd450Log}].push_back( |
2248 | 9.59k | FoldFPUnaryOp(FoldFTranscendentalUnary(std::log))); |
2249 | | |
2250 | | #ifdef __ANDROID__ |
2251 | | // Android NDK r15c targeting ABI 15 doesn't have full support for C++11 |
2252 | | // (no std::exp2/log2). ::exp2 is available from C99 but ::log2 isn't |
2253 | | // available up until ABI 18 so we use a shim |
2254 | | auto log2_shim = [](double v) -> double { return log(v) / log(2.0); }; |
2255 | | ext_rules_[{ext_inst_glslstd450_id, GLSLstd450Exp2}].push_back( |
2256 | | FoldFPUnaryOp(FoldFTranscendentalUnary(::exp2))); |
2257 | | ext_rules_[{ext_inst_glslstd450_id, GLSLstd450Log2}].push_back( |
2258 | | FoldFPUnaryOp(FoldFTranscendentalUnary(log2_shim))); |
2259 | | #else |
2260 | 9.59k | ext_rules_[{ext_inst_glslstd450_id, GLSLstd450Exp2}].push_back( |
2261 | 9.59k | FoldFPUnaryOp(FoldFTranscendentalUnary(std::exp2))); |
2262 | 9.59k | ext_rules_[{ext_inst_glslstd450_id, GLSLstd450Log2}].push_back( |
2263 | 9.59k | FoldFPUnaryOp(FoldFTranscendentalUnary(std::log2))); |
2264 | 9.59k | #endif |
2265 | | |
2266 | 9.59k | ext_rules_[{ext_inst_glslstd450_id, GLSLstd450Sqrt}].push_back( |
2267 | 9.59k | FoldFPUnaryOp(FoldFTranscendentalUnary(std::sqrt))); |
2268 | 9.59k | ext_rules_[{ext_inst_glslstd450_id, GLSLstd450Atan2}].push_back( |
2269 | 9.59k | FoldFPBinaryOp(FoldFTranscendentalBinary(std::atan2))); |
2270 | 9.59k | ext_rules_[{ext_inst_glslstd450_id, GLSLstd450Pow}].push_back( |
2271 | 9.59k | FoldFPBinaryOp(FoldFTranscendentalBinary(std::pow))); |
2272 | 9.59k | } |
2273 | 16.0k | } |
2274 | | } // namespace opt |
2275 | | } // namespace spvtools |