/src/spirv-tools/source/opt/const_folding_rules.cpp
Line | Count | Source |
1 | | // Copyright (c) 2018 Google LLC |
2 | | // |
3 | | // Licensed under the Apache License, Version 2.0 (the "License"); |
4 | | // you may not use this file except in compliance with the License. |
5 | | // You may obtain a copy of the License at |
6 | | // |
7 | | // http://www.apache.org/licenses/LICENSE-2.0 |
8 | | // |
9 | | // Unless required by applicable law or agreed to in writing, software |
10 | | // distributed under the License is distributed on an "AS IS" BASIS, |
11 | | // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. |
12 | | // See the License for the specific language governing permissions and |
13 | | // limitations under the License. |
14 | | |
15 | | #include "source/opt/const_folding_rules.h" |
16 | | |
17 | | #include "source/opt/ir_context.h" |
18 | | |
19 | | namespace spvtools { |
20 | | namespace opt { |
21 | | namespace { |
22 | | constexpr uint32_t kExtractCompositeIdInIdx = 0; |
23 | | |
24 | | // Returns a constants with the value NaN of the given type. Only works for |
25 | | // 32-bit and 64-bit float point types. Returns |nullptr| if an error occurs. |
26 | | const analysis::Constant* GetNan(const analysis::Type* type, |
27 | 1.12k | analysis::ConstantManager* const_mgr) { |
28 | 1.12k | const analysis::Float* float_type = type->AsFloat(); |
29 | 1.12k | if (float_type == nullptr) { |
30 | 0 | return nullptr; |
31 | 0 | } |
32 | | |
33 | 1.12k | switch (float_type->width()) { |
34 | 1.12k | case 32: |
35 | 1.12k | return const_mgr->GetFloatConst(std::numeric_limits<float>::quiet_NaN()); |
36 | 0 | case 64: |
37 | 0 | return const_mgr->GetDoubleConst( |
38 | 0 | std::numeric_limits<double>::quiet_NaN()); |
39 | 0 | default: |
40 | 0 | return nullptr; |
41 | 1.12k | } |
42 | 1.12k | } |
43 | | |
44 | | // Returns a constants with the value INF of the given type. Only works for |
45 | | // 32-bit and 64-bit float point types. Returns |nullptr| if an error occurs. |
46 | | const analysis::Constant* GetInf(const analysis::Type* type, |
47 | 965 | analysis::ConstantManager* const_mgr) { |
48 | 965 | const analysis::Float* float_type = type->AsFloat(); |
49 | 965 | if (float_type == nullptr) { |
50 | 0 | return nullptr; |
51 | 0 | } |
52 | | |
53 | 965 | switch (float_type->width()) { |
54 | 965 | case 32: |
55 | 965 | return const_mgr->GetFloatConst(std::numeric_limits<float>::infinity()); |
56 | 0 | case 64: |
57 | 0 | return const_mgr->GetDoubleConst(std::numeric_limits<double>::infinity()); |
58 | 0 | default: |
59 | 0 | return nullptr; |
60 | 965 | } |
61 | 965 | } |
62 | | |
63 | | // Returns true if |type| is Float or a vector of Float. |
64 | 80 | bool HasFloatingPoint(const analysis::Type* type) { |
65 | 80 | if (type->AsFloat()) { |
66 | 0 | return true; |
67 | 80 | } else if (const analysis::Vector* vec_type = type->AsVector()) { |
68 | 80 | return vec_type->element_type()->AsFloat() != nullptr; |
69 | 80 | } |
70 | | |
71 | 0 | return false; |
72 | 80 | } |
73 | | |
74 | | // Returns a constants with the value |-val| of the given type. Only works for |
75 | | // 32-bit and 64-bit float point types. Returns |nullptr| if an error occurs. |
76 | | const analysis::Constant* NegateFPConst(const analysis::Type* result_type, |
77 | | const analysis::Constant* val, |
78 | 712 | analysis::ConstantManager* const_mgr) { |
79 | 712 | const analysis::Float* float_type = result_type->AsFloat(); |
80 | 712 | assert(float_type != nullptr); |
81 | 712 | if (float_type->width() == 32) { |
82 | 712 | float fa = val->GetFloat(); |
83 | 712 | return const_mgr->GetFloatConst(-fa); |
84 | 712 | } else if (float_type->width() == 64) { |
85 | 0 | double da = val->GetDouble(); |
86 | 0 | return const_mgr->GetDoubleConst(-da); |
87 | 0 | } |
88 | 0 | return nullptr; |
89 | 712 | } |
90 | | |
91 | | // Returns a constants with the value |-val| of the given type. |
92 | | const analysis::Constant* NegateIntConst(const analysis::Type* result_type, |
93 | | const analysis::Constant* val, |
94 | 1.24k | analysis::ConstantManager* const_mgr) { |
95 | 1.24k | const analysis::Integer* int_type = result_type->AsInteger(); |
96 | 1.24k | assert(int_type != nullptr); |
97 | | |
98 | 1.24k | if (val->AsNullConstant()) { |
99 | 2 | return val; |
100 | 2 | } |
101 | | |
102 | 1.24k | uint64_t new_value = static_cast<uint64_t>(-val->GetSignExtendedValue()); |
103 | 1.24k | return const_mgr->GetIntConst(new_value, int_type->width(), |
104 | 1.24k | int_type->IsSigned()); |
105 | 1.24k | } |
106 | | |
107 | | // Folds an OpcompositeExtract where input is a composite constant. |
108 | 11.6k | ConstantFoldingRule FoldExtractWithConstants() { |
109 | 11.6k | return [](IRContext* context, Instruction* inst, |
110 | 11.6k | const std::vector<const analysis::Constant*>& constants) |
111 | 336k | -> const analysis::Constant* { |
112 | 336k | const analysis::Constant* c = constants[kExtractCompositeIdInIdx]; |
113 | 336k | if (c == nullptr) { |
114 | 301k | return nullptr; |
115 | 301k | } |
116 | | |
117 | 70.7k | for (uint32_t i = 1; i < inst->NumInOperands(); ++i) { |
118 | 35.4k | uint32_t element_index = inst->GetSingleWordInOperand(i); |
119 | 35.4k | if (c->AsNullConstant()) { |
120 | | // Return Null for the return type. |
121 | 141 | analysis::ConstantManager* const_mgr = context->get_constant_mgr(); |
122 | 141 | analysis::TypeManager* type_mgr = context->get_type_mgr(); |
123 | 141 | return const_mgr->GetConstant(type_mgr->GetType(inst->type_id()), {}); |
124 | 141 | } |
125 | | |
126 | 35.2k | auto cc = c->AsCompositeConstant(); |
127 | 35.2k | assert(cc != nullptr); |
128 | 35.2k | auto components = cc->GetComponents(); |
129 | | // Protect against invalid IR. Refuse to fold if the index is out |
130 | | // of bounds. |
131 | 35.2k | if (element_index >= components.size()) return nullptr; |
132 | 35.2k | c = components[element_index]; |
133 | 35.2k | } |
134 | 35.2k | return c; |
135 | 35.4k | }; |
136 | 11.6k | } |
137 | | |
138 | | // Folds an OpcompositeInsert where input is a composite constant. |
139 | 11.6k | ConstantFoldingRule FoldInsertWithConstants() { |
140 | 11.6k | return [](IRContext* context, Instruction* inst, |
141 | 11.6k | const std::vector<const analysis::Constant*>& constants) |
142 | 73.8k | -> const analysis::Constant* { |
143 | 73.8k | analysis::ConstantManager* const_mgr = context->get_constant_mgr(); |
144 | 73.8k | const analysis::Constant* object = constants[0]; |
145 | 73.8k | const analysis::Constant* composite = constants[1]; |
146 | 73.8k | if (object == nullptr || composite == nullptr) { |
147 | 69.1k | return nullptr; |
148 | 69.1k | } |
149 | | |
150 | | // If there is more than 1 index, then each additional constant used by the |
151 | | // index will need to be recreated to use the inserted object. |
152 | 4.71k | std::vector<const analysis::Constant*> chain; |
153 | 4.71k | std::vector<const analysis::Constant*> components; |
154 | 4.71k | const analysis::Type* type = nullptr; |
155 | 4.71k | const uint32_t final_index = (inst->NumInOperands() - 1); |
156 | | |
157 | | // Work down hierarchy of all indexes |
158 | 9.42k | for (uint32_t i = 2; i < inst->NumInOperands(); ++i) { |
159 | 4.71k | type = composite->type(); |
160 | | |
161 | 4.71k | if (composite->AsNullConstant()) { |
162 | | // Make new composite so it can be inserted in the index with the |
163 | | // non-null value |
164 | 77 | if (const auto new_composite = |
165 | 77 | const_mgr->GetNullCompositeConstant(type)) { |
166 | | // Keep track of any indexes along the way to last index |
167 | 77 | if (i != final_index) { |
168 | 0 | chain.push_back(new_composite); |
169 | 0 | } |
170 | 77 | components = new_composite->AsCompositeConstant()->GetComponents(); |
171 | 77 | } else { |
172 | | // Unsupported input type (such as structs) |
173 | 0 | return nullptr; |
174 | 0 | } |
175 | 4.63k | } else { |
176 | | // Keep track of any indexes along the way to last index |
177 | 4.63k | if (i != final_index) { |
178 | 0 | chain.push_back(composite); |
179 | 0 | } |
180 | 4.63k | components = composite->AsCompositeConstant()->GetComponents(); |
181 | 4.63k | } |
182 | 4.71k | const uint32_t index = inst->GetSingleWordInOperand(i); |
183 | 4.71k | composite = components[index]; |
184 | 4.71k | } |
185 | | |
186 | | // Final index in hierarchy is inserted with new object. |
187 | 4.71k | const uint32_t final_operand = inst->GetSingleWordInOperand(final_index); |
188 | 4.71k | std::vector<uint32_t> ids; |
189 | 14.2k | for (size_t i = 0; i < components.size(); i++) { |
190 | 9.51k | const analysis::Constant* constant = |
191 | 9.51k | (i == final_operand) ? object : components[i]; |
192 | 9.51k | Instruction* member_inst = const_mgr->GetDefiningInstruction(constant); |
193 | 9.51k | ids.push_back(member_inst->result_id()); |
194 | 9.51k | } |
195 | 4.71k | const analysis::Constant* new_constant = const_mgr->GetConstant(type, ids); |
196 | | |
197 | | // Work backwards up the chain and replace each index with new constant. |
198 | 4.71k | for (size_t i = chain.size(); i > 0; i--) { |
199 | | // Need to insert any previous instruction into the module first. |
200 | | // Can't just insert in types_values_begin() because it will move above |
201 | | // where the types are declared. |
202 | | // Can't compare with location of inst because not all new added |
203 | | // instructions are added to types_values_ |
204 | 0 | auto iter = context->types_values_end(); |
205 | 0 | Module::inst_iterator* pos = &iter; |
206 | 0 | const_mgr->BuildInstructionAndAddToModule(new_constant, pos); |
207 | |
|
208 | 0 | composite = chain[i - 1]; |
209 | 0 | components = composite->AsCompositeConstant()->GetComponents(); |
210 | 0 | type = composite->type(); |
211 | 0 | ids.clear(); |
212 | 0 | for (size_t k = 0; k < components.size(); k++) { |
213 | 0 | const uint32_t index = |
214 | 0 | inst->GetSingleWordInOperand(1 + static_cast<uint32_t>(i)); |
215 | 0 | const analysis::Constant* constant = |
216 | 0 | (k == index) ? new_constant : components[k]; |
217 | 0 | const uint32_t constant_id = |
218 | 0 | const_mgr->FindDeclaredConstant(constant, 0); |
219 | 0 | ids.push_back(constant_id); |
220 | 0 | } |
221 | 0 | new_constant = const_mgr->GetConstant(type, ids); |
222 | 0 | } |
223 | | |
224 | | // If multiple constants were created, only need to return the top index. |
225 | 4.71k | return new_constant; |
226 | 4.71k | }; |
227 | 11.6k | } |
228 | | |
229 | 11.6k | ConstantFoldingRule FoldVectorShuffleWithConstants() { |
230 | 11.6k | return [](IRContext* context, Instruction* inst, |
231 | 11.6k | const std::vector<const analysis::Constant*>& constants) |
232 | 23.9k | -> const analysis::Constant* { |
233 | 23.9k | assert(inst->opcode() == spv::Op::OpVectorShuffle); |
234 | 23.9k | const analysis::Constant* c1 = constants[0]; |
235 | 23.9k | const analysis::Constant* c2 = constants[1]; |
236 | 23.9k | if (c1 == nullptr || c2 == nullptr) { |
237 | 23.6k | return nullptr; |
238 | 23.6k | } |
239 | | |
240 | 295 | analysis::ConstantManager* const_mgr = context->get_constant_mgr(); |
241 | 295 | const analysis::Type* element_type = c1->type()->AsVector()->element_type(); |
242 | | |
243 | 295 | std::vector<const analysis::Constant*> c1_components; |
244 | 295 | if (const analysis::VectorConstant* vec_const = c1->AsVectorConstant()) { |
245 | 243 | c1_components = vec_const->GetComponents(); |
246 | 243 | } else { |
247 | 52 | assert(c1->AsNullConstant()); |
248 | 52 | const analysis::Constant* element = |
249 | 52 | const_mgr->GetConstant(element_type, {}); |
250 | 52 | c1_components.resize(c1->type()->AsVector()->element_count(), element); |
251 | 52 | } |
252 | 295 | std::vector<const analysis::Constant*> c2_components; |
253 | 295 | if (const analysis::VectorConstant* vec_const = c2->AsVectorConstant()) { |
254 | 205 | c2_components = vec_const->GetComponents(); |
255 | 205 | } else { |
256 | 90 | assert(c2->AsNullConstant()); |
257 | 90 | const analysis::Constant* element = |
258 | 90 | const_mgr->GetConstant(element_type, {}); |
259 | 90 | c2_components.resize(c2->type()->AsVector()->element_count(), element); |
260 | 90 | } |
261 | | |
262 | 295 | std::vector<uint32_t> ids; |
263 | 295 | const uint32_t undef_literal_value = 0xffffffff; |
264 | 607 | for (uint32_t i = 2; i < inst->NumInOperands(); ++i) { |
265 | 474 | uint32_t index = inst->GetSingleWordInOperand(i); |
266 | 474 | if (index == undef_literal_value) { |
267 | | // Don't fold shuffle with undef literal value. |
268 | 162 | return nullptr; |
269 | 312 | } else if (index < c1_components.size()) { |
270 | 223 | Instruction* member_inst = |
271 | 223 | const_mgr->GetDefiningInstruction(c1_components[index]); |
272 | 223 | ids.push_back(member_inst->result_id()); |
273 | 223 | } else { |
274 | 89 | Instruction* member_inst = const_mgr->GetDefiningInstruction( |
275 | 89 | c2_components[index - c1_components.size()]); |
276 | 89 | ids.push_back(member_inst->result_id()); |
277 | 89 | } |
278 | 474 | } |
279 | | |
280 | 133 | analysis::TypeManager* type_mgr = context->get_type_mgr(); |
281 | 133 | return const_mgr->GetConstant(type_mgr->GetType(inst->type_id()), ids); |
282 | 295 | }; |
283 | 11.6k | } |
284 | | |
285 | 11.6k | ConstantFoldingRule FoldVectorTimesScalar() { |
286 | 11.6k | return [](IRContext* context, Instruction* inst, |
287 | 11.6k | const std::vector<const analysis::Constant*>& constants) |
288 | 414k | -> const analysis::Constant* { |
289 | 414k | assert(inst->opcode() == spv::Op::OpVectorTimesScalar); |
290 | 414k | analysis::ConstantManager* const_mgr = context->get_constant_mgr(); |
291 | 414k | analysis::TypeManager* type_mgr = context->get_type_mgr(); |
292 | | |
293 | 414k | if (!inst->IsFloatingPointFoldingAllowed()) { |
294 | 80 | if (HasFloatingPoint(type_mgr->GetType(inst->type_id()))) { |
295 | 80 | return nullptr; |
296 | 80 | } |
297 | 80 | } |
298 | | |
299 | 414k | const analysis::Constant* c1 = constants[0]; |
300 | 414k | const analysis::Constant* c2 = constants[1]; |
301 | | |
302 | 414k | if (c1 && c1->IsZero()) { |
303 | 1.40k | return c1; |
304 | 1.40k | } |
305 | | |
306 | 412k | if (c2 && c2->IsZero()) { |
307 | | // Get or create the NullConstant for this type. |
308 | 5.65k | std::vector<uint32_t> ids; |
309 | 5.65k | return const_mgr->GetConstant(type_mgr->GetType(inst->type_id()), ids); |
310 | 5.65k | } |
311 | | |
312 | 407k | if (c1 == nullptr || c2 == nullptr) { |
313 | 377k | return nullptr; |
314 | 377k | } |
315 | | |
316 | | // Check result type. |
317 | 29.7k | const analysis::Type* result_type = type_mgr->GetType(inst->type_id()); |
318 | 29.7k | const analysis::Vector* vector_type = result_type->AsVector(); |
319 | 29.7k | assert(vector_type != nullptr); |
320 | 29.7k | const analysis::Type* element_type = vector_type->element_type(); |
321 | 29.7k | assert(element_type != nullptr); |
322 | 29.7k | const analysis::Float* float_type = element_type->AsFloat(); |
323 | 29.7k | assert(float_type != nullptr); |
324 | | |
325 | | // Check types of c1 and c2. |
326 | 29.7k | assert(c1->type()->AsVector() == vector_type); |
327 | 29.7k | assert(c1->type()->AsVector()->element_type() == element_type && |
328 | 29.7k | c2->type() == element_type); |
329 | | |
330 | | // Get a float vector that is the result of vector-times-scalar. |
331 | 29.7k | std::vector<const analysis::Constant*> c1_components = |
332 | 29.7k | c1->GetVectorComponents(const_mgr); |
333 | 29.7k | std::vector<uint32_t> ids; |
334 | 29.7k | if (float_type->width() == 32) { |
335 | 29.7k | float scalar = c2->GetFloat(); |
336 | 147k | for (uint32_t i = 0; i < c1_components.size(); ++i) { |
337 | 118k | utils::FloatProxy<float> result(c1_components[i]->GetFloat() * scalar); |
338 | 118k | std::vector<uint32_t> words = result.GetWords(); |
339 | 118k | const analysis::Constant* new_elem = |
340 | 118k | const_mgr->GetConstant(float_type, words); |
341 | 118k | ids.push_back(const_mgr->GetDefiningInstruction(new_elem)->result_id()); |
342 | 118k | } |
343 | 29.7k | return const_mgr->GetConstant(vector_type, ids); |
344 | 29.7k | } else if (float_type->width() == 64) { |
345 | 0 | double scalar = c2->GetDouble(); |
346 | 0 | for (uint32_t i = 0; i < c1_components.size(); ++i) { |
347 | 0 | utils::FloatProxy<double> result(c1_components[i]->GetDouble() * |
348 | 0 | scalar); |
349 | 0 | std::vector<uint32_t> words = result.GetWords(); |
350 | 0 | const analysis::Constant* new_elem = |
351 | 0 | const_mgr->GetConstant(float_type, words); |
352 | 0 | ids.push_back(const_mgr->GetDefiningInstruction(new_elem)->result_id()); |
353 | 0 | } |
354 | 0 | return const_mgr->GetConstant(vector_type, ids); |
355 | 0 | } |
356 | 0 | return nullptr; |
357 | 29.7k | }; |
358 | 11.6k | } |
359 | | |
360 | | // Returns to the constant that results from tranposing |matrix|. The result |
361 | | // will have type |result_type|, and |matrix| must exist in |context|. The |
362 | | // result constant will also exist in |context|. |
363 | | const analysis::Constant* TransposeMatrix(const analysis::Constant* matrix, |
364 | | analysis::Matrix* result_type, |
365 | 0 | IRContext* context) { |
366 | 0 | analysis::ConstantManager* const_mgr = context->get_constant_mgr(); |
367 | 0 | if (matrix->AsNullConstant() != nullptr) { |
368 | 0 | return const_mgr->GetNullCompositeConstant(result_type); |
369 | 0 | } |
370 | | |
371 | 0 | const auto& columns = matrix->AsMatrixConstant()->GetComponents(); |
372 | 0 | uint32_t number_of_rows = columns[0]->type()->AsVector()->element_count(); |
373 | | |
374 | | // Collect the ids of the elements in their new positions. |
375 | 0 | std::vector<std::vector<uint32_t>> result_elements(number_of_rows); |
376 | 0 | for (const analysis::Constant* column : columns) { |
377 | 0 | if (column->AsNullConstant()) { |
378 | 0 | column = const_mgr->GetNullCompositeConstant(column->type()); |
379 | 0 | } |
380 | 0 | const auto& column_components = column->AsVectorConstant()->GetComponents(); |
381 | |
|
382 | 0 | for (uint32_t row = 0; row < number_of_rows; ++row) { |
383 | 0 | result_elements[row].push_back( |
384 | 0 | const_mgr->GetDefiningInstruction(column_components[row]) |
385 | 0 | ->result_id()); |
386 | 0 | } |
387 | 0 | } |
388 | | |
389 | | // Create the constant for each row in the result, and collect the ids. |
390 | 0 | std::vector<uint32_t> result_columns(number_of_rows); |
391 | 0 | for (uint32_t col = 0; col < number_of_rows; ++col) { |
392 | 0 | auto* element = const_mgr->GetConstant(result_type->element_type(), |
393 | 0 | result_elements[col]); |
394 | 0 | result_columns[col] = |
395 | 0 | const_mgr->GetDefiningInstruction(element)->result_id(); |
396 | 0 | } |
397 | | |
398 | | // Create the matrix constant from the row ids, and return it. |
399 | 0 | return const_mgr->GetConstant(result_type, result_columns); |
400 | 0 | } |
401 | | |
402 | | const analysis::Constant* FoldTranspose( |
403 | | IRContext* context, Instruction* inst, |
404 | 0 | const std::vector<const analysis::Constant*>& constants) { |
405 | 0 | assert(inst->opcode() == spv::Op::OpTranspose); |
406 | | |
407 | 0 | analysis::TypeManager* type_mgr = context->get_type_mgr(); |
408 | 0 | if (!inst->IsFloatingPointFoldingAllowed()) { |
409 | 0 | if (HasFloatingPoint(type_mgr->GetType(inst->type_id()))) { |
410 | 0 | return nullptr; |
411 | 0 | } |
412 | 0 | } |
413 | | |
414 | 0 | const analysis::Constant* matrix = constants[0]; |
415 | 0 | if (matrix == nullptr) { |
416 | 0 | return nullptr; |
417 | 0 | } |
418 | | |
419 | 0 | auto* result_type = type_mgr->GetType(inst->type_id()); |
420 | 0 | return TransposeMatrix(matrix, result_type->AsMatrix(), context); |
421 | 0 | } |
422 | | |
423 | 11.6k | ConstantFoldingRule FoldVectorTimesMatrix() { |
424 | 11.6k | return [](IRContext* context, Instruction* inst, |
425 | 11.6k | const std::vector<const analysis::Constant*>& constants) |
426 | 11.6k | -> const analysis::Constant* { |
427 | 2.84k | assert(inst->opcode() == spv::Op::OpVectorTimesMatrix); |
428 | 2.84k | analysis::ConstantManager* const_mgr = context->get_constant_mgr(); |
429 | 2.84k | analysis::TypeManager* type_mgr = context->get_type_mgr(); |
430 | | |
431 | 2.84k | if (!inst->IsFloatingPointFoldingAllowed()) { |
432 | 0 | if (HasFloatingPoint(type_mgr->GetType(inst->type_id()))) { |
433 | 0 | return nullptr; |
434 | 0 | } |
435 | 0 | } |
436 | | |
437 | 2.84k | const analysis::Constant* c1 = constants[0]; |
438 | 2.84k | const analysis::Constant* c2 = constants[1]; |
439 | | |
440 | 2.84k | if (c1 == nullptr || c2 == nullptr) { |
441 | 2.81k | return nullptr; |
442 | 2.81k | } |
443 | | |
444 | | // Check result type. |
445 | 24 | const analysis::Type* result_type = type_mgr->GetType(inst->type_id()); |
446 | 24 | const analysis::Vector* vector_type = result_type->AsVector(); |
447 | 24 | assert(vector_type != nullptr); |
448 | 24 | const analysis::Type* element_type = vector_type->element_type(); |
449 | 24 | assert(element_type != nullptr); |
450 | 24 | const analysis::Float* float_type = element_type->AsFloat(); |
451 | 24 | assert(float_type != nullptr); |
452 | | |
453 | | // Check types of c1 and c2. |
454 | 24 | assert(c1->type()->AsVector() == vector_type); |
455 | 24 | assert(c1->type()->AsVector()->element_type() == element_type && |
456 | 24 | c2->type()->AsMatrix()->element_type() == vector_type); |
457 | | |
458 | 24 | uint32_t resultVectorSize = result_type->AsVector()->element_count(); |
459 | 24 | std::vector<uint32_t> ids; |
460 | | |
461 | 24 | if ((c1 && c1->IsZero()) || (c2 && c2->IsZero())) { |
462 | 11 | std::vector<uint32_t> words(float_type->width() / 32, 0); |
463 | 33 | for (uint32_t i = 0; i < resultVectorSize; ++i) { |
464 | 22 | const analysis::Constant* new_elem = |
465 | 22 | const_mgr->GetConstant(float_type, words); |
466 | 22 | ids.push_back(const_mgr->GetDefiningInstruction(new_elem)->result_id()); |
467 | 22 | } |
468 | 11 | return const_mgr->GetConstant(vector_type, ids); |
469 | 11 | } |
470 | | |
471 | | // Get a float vector that is the result of vector-times-matrix. |
472 | 13 | std::vector<const analysis::Constant*> c1_components = |
473 | 13 | c1->GetVectorComponents(const_mgr); |
474 | 13 | std::vector<const analysis::Constant*> c2_components = |
475 | 13 | c2->AsMatrixConstant()->GetComponents(); |
476 | | |
477 | 13 | if (float_type->width() == 32) { |
478 | 39 | for (uint32_t i = 0; i < resultVectorSize; ++i) { |
479 | 26 | float result_scalar = 0.0f; |
480 | 26 | if (!c2_components[i]->AsNullConstant()) { |
481 | 18 | const analysis::VectorConstant* c2_vec = |
482 | 18 | c2_components[i]->AsVectorConstant(); |
483 | 54 | for (uint32_t j = 0; j < c2_vec->GetComponents().size(); ++j) { |
484 | 36 | float c1_scalar = c1_components[j]->GetFloat(); |
485 | 36 | float c2_scalar = c2_vec->GetComponents()[j]->GetFloat(); |
486 | 36 | result_scalar += c1_scalar * c2_scalar; |
487 | 36 | } |
488 | 18 | } |
489 | 26 | utils::FloatProxy<float> result(result_scalar); |
490 | 26 | std::vector<uint32_t> words = result.GetWords(); |
491 | 26 | const analysis::Constant* new_elem = |
492 | 26 | const_mgr->GetConstant(float_type, words); |
493 | 26 | ids.push_back(const_mgr->GetDefiningInstruction(new_elem)->result_id()); |
494 | 26 | } |
495 | 13 | return const_mgr->GetConstant(vector_type, ids); |
496 | 13 | } else if (float_type->width() == 64) { |
497 | 0 | for (uint32_t i = 0; i < c2_components.size(); ++i) { |
498 | 0 | double result_scalar = 0.0; |
499 | 0 | if (!c2_components[i]->AsNullConstant()) { |
500 | 0 | const analysis::VectorConstant* c2_vec = |
501 | 0 | c2_components[i]->AsVectorConstant(); |
502 | 0 | for (uint32_t j = 0; j < c2_vec->GetComponents().size(); ++j) { |
503 | 0 | double c1_scalar = c1_components[j]->GetDouble(); |
504 | 0 | double c2_scalar = c2_vec->GetComponents()[j]->GetDouble(); |
505 | 0 | result_scalar += c1_scalar * c2_scalar; |
506 | 0 | } |
507 | 0 | } |
508 | 0 | utils::FloatProxy<double> result(result_scalar); |
509 | 0 | std::vector<uint32_t> words = result.GetWords(); |
510 | 0 | const analysis::Constant* new_elem = |
511 | 0 | const_mgr->GetConstant(float_type, words); |
512 | 0 | ids.push_back(const_mgr->GetDefiningInstruction(new_elem)->result_id()); |
513 | 0 | } |
514 | 0 | return const_mgr->GetConstant(vector_type, ids); |
515 | 0 | } |
516 | 0 | return nullptr; |
517 | 13 | }; |
518 | 11.6k | } |
519 | | |
520 | 11.6k | ConstantFoldingRule FoldMatrixTimesVector() { |
521 | 11.6k | return [](IRContext* context, Instruction* inst, |
522 | 11.6k | const std::vector<const analysis::Constant*>& constants) |
523 | 11.6k | -> const analysis::Constant* { |
524 | 139 | assert(inst->opcode() == spv::Op::OpMatrixTimesVector); |
525 | 139 | analysis::ConstantManager* const_mgr = context->get_constant_mgr(); |
526 | 139 | analysis::TypeManager* type_mgr = context->get_type_mgr(); |
527 | | |
528 | 139 | if (!inst->IsFloatingPointFoldingAllowed()) { |
529 | 0 | if (HasFloatingPoint(type_mgr->GetType(inst->type_id()))) { |
530 | 0 | return nullptr; |
531 | 0 | } |
532 | 0 | } |
533 | | |
534 | 139 | const analysis::Constant* c1 = constants[0]; |
535 | 139 | const analysis::Constant* c2 = constants[1]; |
536 | | |
537 | 139 | if (c1 == nullptr || c2 == nullptr) { |
538 | 137 | return nullptr; |
539 | 137 | } |
540 | | |
541 | | // Check result type. |
542 | 2 | const analysis::Type* result_type = type_mgr->GetType(inst->type_id()); |
543 | 2 | const analysis::Vector* vector_type = result_type->AsVector(); |
544 | 2 | assert(vector_type != nullptr); |
545 | 2 | const analysis::Type* element_type = vector_type->element_type(); |
546 | 2 | assert(element_type != nullptr); |
547 | 2 | const analysis::Float* float_type = element_type->AsFloat(); |
548 | 2 | assert(float_type != nullptr); |
549 | | |
550 | | // Check types of c1 and c2. |
551 | 2 | assert(c1->type()->AsMatrix()->element_type() == vector_type); |
552 | 2 | assert(c2->type()->AsVector()->element_type() == element_type); |
553 | | |
554 | 2 | uint32_t resultVectorSize = result_type->AsVector()->element_count(); |
555 | 2 | std::vector<uint32_t> ids; |
556 | | |
557 | 2 | if ((c1 && c1->IsZero()) || (c2 && c2->IsZero())) { |
558 | 2 | std::vector<uint32_t> words(float_type->width() / 32, 0); |
559 | 6 | for (uint32_t i = 0; i < resultVectorSize; ++i) { |
560 | 4 | const analysis::Constant* new_elem = |
561 | 4 | const_mgr->GetConstant(float_type, words); |
562 | 4 | ids.push_back(const_mgr->GetDefiningInstruction(new_elem)->result_id()); |
563 | 4 | } |
564 | 2 | return const_mgr->GetConstant(vector_type, ids); |
565 | 2 | } |
566 | | |
567 | | // Get a float vector that is the result of matrix-times-vector. |
568 | 0 | std::vector<const analysis::Constant*> c1_components = |
569 | 0 | c1->AsMatrixConstant()->GetComponents(); |
570 | 0 | std::vector<const analysis::Constant*> c2_components = |
571 | 0 | c2->GetVectorComponents(const_mgr); |
572 | |
|
573 | 0 | if (float_type->width() == 32) { |
574 | 0 | for (uint32_t i = 0; i < resultVectorSize; ++i) { |
575 | 0 | float result_scalar = 0.0f; |
576 | 0 | for (uint32_t j = 0; j < c1_components.size(); ++j) { |
577 | 0 | if (!c1_components[j]->AsNullConstant()) { |
578 | 0 | float c1_scalar = c1_components[j] |
579 | 0 | ->AsVectorConstant() |
580 | 0 | ->GetComponents()[i] |
581 | 0 | ->GetFloat(); |
582 | 0 | float c2_scalar = c2_components[j]->GetFloat(); |
583 | 0 | result_scalar += c1_scalar * c2_scalar; |
584 | 0 | } |
585 | 0 | } |
586 | 0 | utils::FloatProxy<float> result(result_scalar); |
587 | 0 | std::vector<uint32_t> words = result.GetWords(); |
588 | 0 | const analysis::Constant* new_elem = |
589 | 0 | const_mgr->GetConstant(float_type, words); |
590 | 0 | ids.push_back(const_mgr->GetDefiningInstruction(new_elem)->result_id()); |
591 | 0 | } |
592 | 0 | return const_mgr->GetConstant(vector_type, ids); |
593 | 0 | } else if (float_type->width() == 64) { |
594 | 0 | for (uint32_t i = 0; i < resultVectorSize; ++i) { |
595 | 0 | double result_scalar = 0.0; |
596 | 0 | for (uint32_t j = 0; j < c1_components.size(); ++j) { |
597 | 0 | if (!c1_components[j]->AsNullConstant()) { |
598 | 0 | double c1_scalar = c1_components[j] |
599 | 0 | ->AsVectorConstant() |
600 | 0 | ->GetComponents()[i] |
601 | 0 | ->GetDouble(); |
602 | 0 | double c2_scalar = c2_components[j]->GetDouble(); |
603 | 0 | result_scalar += c1_scalar * c2_scalar; |
604 | 0 | } |
605 | 0 | } |
606 | 0 | utils::FloatProxy<double> result(result_scalar); |
607 | 0 | std::vector<uint32_t> words = result.GetWords(); |
608 | 0 | const analysis::Constant* new_elem = |
609 | 0 | const_mgr->GetConstant(float_type, words); |
610 | 0 | ids.push_back(const_mgr->GetDefiningInstruction(new_elem)->result_id()); |
611 | 0 | } |
612 | 0 | return const_mgr->GetConstant(vector_type, ids); |
613 | 0 | } |
614 | 0 | return nullptr; |
615 | 0 | }; |
616 | 11.6k | } |
617 | | |
618 | 11.6k | ConstantFoldingRule FoldCompositeWithConstants() { |
619 | | // Folds an OpCompositeConstruct where all of the inputs are constants to a |
620 | | // constant. A new constant is created if necessary. |
621 | 11.6k | return [](IRContext* context, Instruction* inst, |
622 | 11.6k | const std::vector<const analysis::Constant*>& constants) |
623 | 223k | -> const analysis::Constant* { |
624 | 223k | analysis::ConstantManager* const_mgr = context->get_constant_mgr(); |
625 | 223k | analysis::TypeManager* type_mgr = context->get_type_mgr(); |
626 | 223k | const analysis::Type* new_type = type_mgr->GetType(inst->type_id()); |
627 | 223k | Instruction* type_inst = |
628 | 223k | context->get_def_use_mgr()->GetDef(inst->type_id()); |
629 | | |
630 | 223k | std::vector<uint32_t> ids; |
631 | 536k | for (uint32_t i = 0; i < constants.size(); ++i) { |
632 | 396k | const analysis::Constant* element_const = constants[i]; |
633 | 396k | if (element_const == nullptr) { |
634 | 83.4k | return nullptr; |
635 | 83.4k | } |
636 | | |
637 | 313k | uint32_t component_type_id = 0; |
638 | 313k | if (type_inst->opcode() == spv::Op::OpTypeStruct) { |
639 | 20.7k | component_type_id = type_inst->GetSingleWordInOperand(i); |
640 | 292k | } else if (type_inst->opcode() == spv::Op::OpTypeArray) { |
641 | 67 | component_type_id = type_inst->GetSingleWordInOperand(0); |
642 | 67 | } |
643 | | |
644 | 313k | uint32_t element_id = |
645 | 313k | const_mgr->FindDeclaredConstant(element_const, component_type_id); |
646 | 313k | if (element_id == 0) { |
647 | 118 | return nullptr; |
648 | 118 | } |
649 | 313k | ids.push_back(element_id); |
650 | 313k | } |
651 | 139k | return const_mgr->GetConstant(new_type, ids); |
652 | 223k | }; |
653 | 11.6k | } |
654 | | |
655 | | // The interface for a function that returns the result of applying a scalar |
656 | | // floating-point binary operation on |a| and |b|. The type of the return value |
657 | | // will be |type|. The input constants must also be of type |type|. |
658 | | using UnaryScalarFoldingRule = std::function<const analysis::Constant*( |
659 | | const analysis::Type* result_type, const analysis::Constant* a, |
660 | | analysis::ConstantManager*)>; |
661 | | |
662 | | // The interface for a function that returns the result of applying a scalar |
663 | | // floating-point binary operation on |a| and |b|. The type of the return value |
664 | | // will be |type|. The input constants must also be of type |type|. |
665 | | using BinaryScalarFoldingRule = std::function<const analysis::Constant*( |
666 | | const analysis::Type* result_type, const analysis::Constant* a, |
667 | | const analysis::Constant* b, analysis::ConstantManager*)>; |
668 | | |
669 | | // Returns a |ConstantFoldingRule| that folds unary scalar ops |
670 | | // using |scalar_rule| and unary vectors ops by applying |
671 | | // |scalar_rule| to the elements of the vector. The |ConstantFoldingRule| |
672 | | // that is returned assumes that |constants| contains 1 entry. If they are |
673 | | // not |nullptr|, then their type is either |Float| or |Integer| or a |Vector| |
674 | | // whose element type is |Float| or |Integer|. |
675 | 180k | ConstantFoldingRule FoldUnaryOp(UnaryScalarFoldingRule scalar_rule) { |
676 | 180k | return [scalar_rule](IRContext* context, Instruction* inst, |
677 | 180k | const std::vector<const analysis::Constant*>& constants) |
678 | 421k | -> const analysis::Constant* { |
679 | 421k | analysis::ConstantManager* const_mgr = context->get_constant_mgr(); |
680 | 421k | analysis::TypeManager* type_mgr = context->get_type_mgr(); |
681 | 421k | const analysis::Type* result_type = type_mgr->GetType(inst->type_id()); |
682 | 421k | const analysis::Vector* vector_type = result_type->AsVector(); |
683 | | |
684 | 421k | const analysis::Constant* arg = |
685 | 421k | (inst->opcode() == spv::Op::OpExtInst) ? constants[1] : constants[0]; |
686 | | |
687 | 421k | if (arg == nullptr) { |
688 | 88.6k | return nullptr; |
689 | 88.6k | } |
690 | | |
691 | 333k | if (vector_type != nullptr) { |
692 | 887 | std::vector<const analysis::Constant*> a_components; |
693 | 887 | std::vector<const analysis::Constant*> results_components; |
694 | | |
695 | 887 | a_components = arg->GetVectorComponents(const_mgr); |
696 | | |
697 | | // Fold each component of the vector. |
698 | 3.89k | for (uint32_t i = 0; i < a_components.size(); ++i) { |
699 | 3.01k | results_components.push_back(scalar_rule(vector_type->element_type(), |
700 | 3.01k | a_components[i], const_mgr)); |
701 | 3.01k | if (results_components[i] == nullptr) { |
702 | 0 | return nullptr; |
703 | 0 | } |
704 | 3.01k | } |
705 | | |
706 | | // Build the constant object and return it. |
707 | 887 | std::vector<uint32_t> ids; |
708 | 3.01k | for (const analysis::Constant* member : results_components) { |
709 | 3.01k | ids.push_back(const_mgr->GetDefiningInstruction(member)->result_id()); |
710 | 3.01k | } |
711 | 887 | return const_mgr->GetConstant(vector_type, ids); |
712 | 332k | } else { |
713 | 332k | return scalar_rule(result_type, arg, const_mgr); |
714 | 332k | } |
715 | 333k | }; |
716 | 180k | } |
717 | | |
718 | | // Returns a |ConstantFoldingRule| that folds binary scalar ops |
719 | | // using |scalar_rule| and binary vectors ops by applying |
720 | | // |scalar_rule| to the elements of the vector. The folding rule assumes that op |
721 | | // has two inputs. For regular instruction, those are in operands 0 and 1. For |
722 | | // extended instruction, they are in operands 1 and 2. If an element in |
723 | | // |constants| is not nullprt, then the constant's type is |Float|, |Integer|, |
724 | | // or |Vector| whose element type is |Float| or |Integer|. |
725 | 92.9k | ConstantFoldingRule FoldBinaryOp(BinaryScalarFoldingRule scalar_rule) { |
726 | 92.9k | return [scalar_rule](IRContext* context, Instruction* inst, |
727 | 92.9k | const std::vector<const analysis::Constant*>& constants) |
728 | 1.03M | -> const analysis::Constant* { |
729 | 1.03M | assert(constants.size() == inst->NumInOperands()); |
730 | 1.03M | assert(constants.size() == (inst->opcode() == spv::Op::OpExtInst ? 3 : 2)); |
731 | 1.03M | analysis::ConstantManager* const_mgr = context->get_constant_mgr(); |
732 | 1.03M | analysis::TypeManager* type_mgr = context->get_type_mgr(); |
733 | 1.03M | const analysis::Type* result_type = type_mgr->GetType(inst->type_id()); |
734 | 1.03M | const analysis::Vector* vector_type = result_type->AsVector(); |
735 | | |
736 | 1.03M | const analysis::Constant* arg1 = |
737 | 1.03M | (inst->opcode() == spv::Op::OpExtInst) ? constants[1] : constants[0]; |
738 | 1.03M | const analysis::Constant* arg2 = |
739 | 1.03M | (inst->opcode() == spv::Op::OpExtInst) ? constants[2] : constants[1]; |
740 | | |
741 | 1.03M | if (arg1 == nullptr || arg2 == nullptr) { |
742 | 159k | return nullptr; |
743 | 159k | } |
744 | | |
745 | 877k | if (vector_type == nullptr) { |
746 | 877k | return scalar_rule(result_type, arg1, arg2, const_mgr); |
747 | 877k | } |
748 | | |
749 | 255 | std::vector<const analysis::Constant*> a_components; |
750 | 255 | std::vector<const analysis::Constant*> b_components; |
751 | 255 | std::vector<const analysis::Constant*> results_components; |
752 | | |
753 | 255 | a_components = arg1->GetVectorComponents(const_mgr); |
754 | 255 | b_components = arg2->GetVectorComponents(const_mgr); |
755 | 255 | assert(a_components.size() == b_components.size()); |
756 | | |
757 | | // Fold each component of the vector. |
758 | 1.25k | for (uint32_t i = 0; i < a_components.size(); ++i) { |
759 | 1.00k | results_components.push_back(scalar_rule(vector_type->element_type(), |
760 | 1.00k | a_components[i], b_components[i], |
761 | 1.00k | const_mgr)); |
762 | 1.00k | if (results_components[i] == nullptr) { |
763 | 0 | return nullptr; |
764 | 0 | } |
765 | 1.00k | } |
766 | | |
767 | | // Build the constant object and return it. |
768 | 255 | std::vector<uint32_t> ids; |
769 | 1.00k | for (const analysis::Constant* member : results_components) { |
770 | 1.00k | ids.push_back(const_mgr->GetDefiningInstruction(member)->result_id()); |
771 | 1.00k | } |
772 | 255 | return const_mgr->GetConstant(vector_type, ids); |
773 | 255 | }; |
774 | 92.9k | } |
775 | | |
776 | | // Returns a |ConstantFoldingRule| that folds unary floating point scalar ops |
777 | | // using |scalar_rule| and unary float point vectors ops by applying |
778 | | // |scalar_rule| to the elements of the vector. The |ConstantFoldingRule| |
779 | | // that is returned assumes that |constants| contains 1 entry. If they are |
780 | | // not |nullptr|, then their type is either |Float| or |Integer| or a |Vector| |
781 | | // whose element type is |Float| or |Integer|. |
782 | 145k | ConstantFoldingRule FoldFPUnaryOp(UnaryScalarFoldingRule scalar_rule) { |
783 | 145k | auto folding_rule = FoldUnaryOp(scalar_rule); |
784 | 145k | return [folding_rule](IRContext* context, Instruction* inst, |
785 | 145k | const std::vector<const analysis::Constant*>& constants) |
786 | 416k | -> const analysis::Constant* { |
787 | 416k | if (!inst->IsFloatingPointFoldingAllowed()) { |
788 | 23 | return nullptr; |
789 | 23 | } |
790 | | |
791 | 416k | return folding_rule(context, inst, constants); |
792 | 416k | }; |
793 | 145k | } |
794 | | |
795 | | // Returns the result of folding the constants in |constants| according the |
796 | | // |scalar_rule|. If |result_type| is a vector, then |scalar_rule| is applied |
797 | | // per component. |
798 | | const analysis::Constant* FoldFPBinaryOp( |
799 | | BinaryScalarFoldingRule scalar_rule, uint32_t result_type_id, |
800 | | const std::vector<const analysis::Constant*>& constants, |
801 | 1.25M | IRContext* context) { |
802 | 1.25M | analysis::ConstantManager* const_mgr = context->get_constant_mgr(); |
803 | 1.25M | analysis::TypeManager* type_mgr = context->get_type_mgr(); |
804 | 1.25M | const analysis::Type* result_type = type_mgr->GetType(result_type_id); |
805 | 1.25M | const analysis::Vector* vector_type = result_type->AsVector(); |
806 | | |
807 | 1.25M | if (constants[0] == nullptr || constants[1] == nullptr) { |
808 | 965k | return nullptr; |
809 | 965k | } |
810 | | |
811 | 291k | if (vector_type != nullptr) { |
812 | 36.9k | std::vector<const analysis::Constant*> a_components; |
813 | 36.9k | std::vector<const analysis::Constant*> b_components; |
814 | 36.9k | std::vector<const analysis::Constant*> results_components; |
815 | | |
816 | 36.9k | a_components = constants[0]->GetVectorComponents(const_mgr); |
817 | 36.9k | b_components = constants[1]->GetVectorComponents(const_mgr); |
818 | | |
819 | | // Fold each component of the vector. |
820 | 177k | for (uint32_t i = 0; i < a_components.size(); ++i) { |
821 | 140k | results_components.push_back(scalar_rule(vector_type->element_type(), |
822 | 140k | a_components[i], b_components[i], |
823 | 140k | const_mgr)); |
824 | 140k | if (results_components[i] == nullptr) { |
825 | 0 | return nullptr; |
826 | 0 | } |
827 | 140k | } |
828 | | |
829 | | // Build the constant object and return it. |
830 | 36.9k | std::vector<uint32_t> ids; |
831 | 140k | for (const analysis::Constant* member : results_components) { |
832 | 140k | Instruction* def = const_mgr->GetDefiningInstruction(member); |
833 | 140k | if (!def) return nullptr; |
834 | 140k | ids.push_back(def->result_id()); |
835 | 140k | } |
836 | 36.9k | return const_mgr->GetConstant(vector_type, ids); |
837 | 254k | } else { |
838 | 254k | return scalar_rule(result_type, constants[0], constants[1], const_mgr); |
839 | 254k | } |
840 | 291k | } |
841 | | |
842 | | // Returns a |ConstantFoldingRule| that folds floating point scalars using |
843 | | // |scalar_rule| and vectors of floating point by applying |scalar_rule| to the |
844 | | // elements of the vector. The |ConstantFoldingRule| that is returned assumes |
845 | | // that |constants| contains 2 entries. If they are not |nullptr|, then their |
846 | | // type is either |Float| or a |Vector| whose element type is |Float|. |
847 | 241k | ConstantFoldingRule FoldFPBinaryOp(BinaryScalarFoldingRule scalar_rule) { |
848 | 241k | return [scalar_rule](IRContext* context, Instruction* inst, |
849 | 241k | const std::vector<const analysis::Constant*>& constants) |
850 | 1.25M | -> const analysis::Constant* { |
851 | 1.25M | if (!inst->IsFloatingPointFoldingAllowed()) { |
852 | 3.10k | return nullptr; |
853 | 3.10k | } |
854 | 1.25M | if (inst->opcode() == spv::Op::OpExtInst) { |
855 | 10.7k | return FoldFPBinaryOp(scalar_rule, inst->type_id(), |
856 | 10.7k | {constants[1], constants[2]}, context); |
857 | 10.7k | } |
858 | 1.24M | return FoldFPBinaryOp(scalar_rule, inst->type_id(), constants, context); |
859 | 1.25M | }; |
860 | 241k | } |
861 | | |
862 | | // This macro defines a |UnaryScalarFoldingRule| that performs float to |
863 | | // integer conversion. |
864 | | // TODO(greg-lunarg): Support for 64-bit integer types. |
865 | 23.2k | UnaryScalarFoldingRule FoldFToIOp() { |
866 | 23.2k | return [](const analysis::Type* result_type, const analysis::Constant* a, |
867 | 23.2k | analysis::ConstantManager* const_mgr) -> const analysis::Constant* { |
868 | 981 | assert(result_type != nullptr && a != nullptr); |
869 | 981 | const analysis::Integer* integer_type = result_type->AsInteger(); |
870 | 981 | const analysis::Float* float_type = a->type()->AsFloat(); |
871 | 981 | assert(float_type != nullptr); |
872 | 981 | assert(integer_type != nullptr); |
873 | 981 | if (integer_type->width() != 32) return nullptr; |
874 | 981 | if (float_type->width() == 32) { |
875 | 981 | float fa = a->GetFloat(); |
876 | 981 | uint32_t result = integer_type->IsSigned() |
877 | 981 | ? static_cast<uint32_t>(static_cast<int32_t>(fa)) |
878 | 981 | : static_cast<uint32_t>(fa); |
879 | 981 | std::vector<uint32_t> words = {result}; |
880 | 981 | return const_mgr->GetConstant(result_type, words); |
881 | 981 | } else if (float_type->width() == 64) { |
882 | 0 | double fa = a->GetDouble(); |
883 | 0 | uint32_t result = integer_type->IsSigned() |
884 | 0 | ? static_cast<uint32_t>(static_cast<int32_t>(fa)) |
885 | 0 | : static_cast<uint32_t>(fa); |
886 | 0 | std::vector<uint32_t> words = {result}; |
887 | 0 | return const_mgr->GetConstant(result_type, words); |
888 | 0 | } |
889 | 0 | return nullptr; |
890 | 981 | }; |
891 | 23.2k | } |
892 | | |
893 | | // This function defines a |UnaryScalarFoldingRule| that performs integer to |
894 | | // float conversion. |
895 | | // TODO(greg-lunarg): Support for 64-bit integer types. |
896 | 23.2k | UnaryScalarFoldingRule FoldIToFOp() { |
897 | 23.2k | return [](const analysis::Type* result_type, const analysis::Constant* a, |
898 | 330k | analysis::ConstantManager* const_mgr) -> const analysis::Constant* { |
899 | 330k | assert(result_type != nullptr && a != nullptr); |
900 | 330k | const analysis::Integer* integer_type = a->type()->AsInteger(); |
901 | 330k | const analysis::Float* float_type = result_type->AsFloat(); |
902 | 330k | assert(float_type != nullptr); |
903 | 330k | assert(integer_type != nullptr); |
904 | 330k | if (integer_type->width() != 32) return nullptr; |
905 | 330k | uint32_t ua = a->GetU32(); |
906 | 330k | if (float_type->width() == 32) { |
907 | 330k | float result_val = integer_type->IsSigned() |
908 | 330k | ? static_cast<float>(static_cast<int32_t>(ua)) |
909 | 330k | : static_cast<float>(ua); |
910 | 330k | utils::FloatProxy<float> result(result_val); |
911 | 330k | std::vector<uint32_t> words = {result.data()}; |
912 | 330k | return const_mgr->GetConstant(result_type, words); |
913 | 330k | } else if (float_type->width() == 64) { |
914 | 0 | double result_val = integer_type->IsSigned() |
915 | 0 | ? static_cast<double>(static_cast<int32_t>(ua)) |
916 | 0 | : static_cast<double>(ua); |
917 | 0 | utils::FloatProxy<double> result(result_val); |
918 | 0 | std::vector<uint32_t> words = result.GetWords(); |
919 | 0 | return const_mgr->GetConstant(result_type, words); |
920 | 0 | } |
921 | 0 | return nullptr; |
922 | 330k | }; |
923 | 23.2k | } |
924 | | |
925 | | // This defines a |UnaryScalarFoldingRule| that performs |OpQuantizeToF16|. |
926 | 11.6k | UnaryScalarFoldingRule FoldQuantizeToF16Scalar() { |
927 | 11.6k | return [](const analysis::Type* result_type, const analysis::Constant* a, |
928 | 11.6k | analysis::ConstantManager* const_mgr) -> const analysis::Constant* { |
929 | 1.73k | assert(result_type != nullptr && a != nullptr); |
930 | 1.73k | const analysis::Float* float_type = a->type()->AsFloat(); |
931 | 1.73k | assert(float_type != nullptr); |
932 | 1.73k | if (float_type->width() != 32) { |
933 | 0 | return nullptr; |
934 | 0 | } |
935 | | |
936 | 1.73k | float fa = a->GetFloat(); |
937 | 1.73k | utils::HexFloat<utils::FloatProxy<float>> orignal(fa); |
938 | 1.73k | utils::HexFloat<utils::FloatProxy<utils::Float16>> quantized(0); |
939 | 1.73k | utils::HexFloat<utils::FloatProxy<float>> result(0.0f); |
940 | 1.73k | orignal.castTo(quantized, utils::round_direction::kToZero); |
941 | 1.73k | quantized.castTo(result, utils::round_direction::kToZero); |
942 | 1.73k | std::vector<uint32_t> words = {result.getBits()}; |
943 | 1.73k | return const_mgr->GetConstant(result_type, words); |
944 | 1.73k | }; |
945 | 11.6k | } |
946 | | |
947 | | // This macro defines a |BinaryScalarFoldingRule| that applies |op|. The |
948 | | // operator |op| must work for both float and double, and use syntax "f1 op f2". |
949 | | #define FOLD_FPARITH_OP(op) \ |
950 | 54.2k | [](const analysis::Type* result_type_in_macro, const analysis::Constant* a, \ |
951 | 54.2k | const analysis::Constant* b, \ |
952 | 54.2k | analysis::ConstantManager* const_mgr_in_macro) \ |
953 | 377k | -> const analysis::Constant* { \ |
954 | 377k | assert(result_type_in_macro != nullptr && a != nullptr && b != nullptr); \ |
955 | 377k | assert(result_type_in_macro == a->type() && \ |
956 | 377k | result_type_in_macro == b->type()); \ |
957 | 377k | const analysis::Float* float_type_in_macro = \ |
958 | 377k | result_type_in_macro->AsFloat(); \ |
959 | 377k | assert(float_type_in_macro != nullptr); \ |
960 | 377k | if (float_type_in_macro->width() == 32) { \ |
961 | 377k | float fa = a->GetFloat(); \ |
962 | 377k | float fb = b->GetFloat(); \ |
963 | 377k | utils::FloatProxy<float> result_in_macro(fa op fb); \ |
964 | 377k | std::vector<uint32_t> words_in_macro = result_in_macro.GetWords(); \ |
965 | 377k | return const_mgr_in_macro->GetConstant(result_type_in_macro, \ |
966 | 377k | words_in_macro); \ |
967 | 377k | } else if (float_type_in_macro->width() == 64) { \ |
968 | 0 | double fa = a->GetDouble(); \ |
969 | 0 | double fb = b->GetDouble(); \ |
970 | 0 | utils::FloatProxy<double> result_in_macro(fa op fb); \ |
971 | 0 | std::vector<uint32_t> words_in_macro = result_in_macro.GetWords(); \ |
972 | 0 | return const_mgr_in_macro->GetConstant(result_type_in_macro, \ |
973 | 0 | words_in_macro); \ |
974 | 0 | } \ |
975 | 377k | return nullptr; \ |
976 | 377k | } |
977 | | |
978 | | // Define the folding rule for conversion between floating point and integer |
979 | 23.2k | ConstantFoldingRule FoldFToI() { return FoldFPUnaryOp(FoldFToIOp()); } |
980 | 23.2k | ConstantFoldingRule FoldIToF() { return FoldFPUnaryOp(FoldIToFOp()); } |
981 | 11.6k | ConstantFoldingRule FoldQuantizeToF16() { |
982 | 11.6k | return FoldFPUnaryOp(FoldQuantizeToF16Scalar()); |
983 | 11.6k | } |
984 | | |
985 | | // Define the folding rules for subtraction, addition, multiplication, and |
986 | | // division for floating point values. |
987 | 26.6k | ConstantFoldingRule FoldFSub() { return FoldFPBinaryOp(FOLD_FPARITH_OP(-)); } |
988 | 143k | ConstantFoldingRule FoldFAdd() { return FoldFPBinaryOp(FOLD_FPARITH_OP(+)); } |
989 | 231k | ConstantFoldingRule FoldFMul() { return FoldFPBinaryOp(FOLD_FPARITH_OP(*)); } |
990 | | |
991 | | // Returns the constant that results from evaluating |numerator| / 0.0. Returns |
992 | | // |nullptr| if the result could not be evaluated. |
993 | | const analysis::Constant* FoldFPScalarDivideByZero( |
994 | | const analysis::Type* result_type, const analysis::Constant* numerator, |
995 | 2.08k | analysis::ConstantManager* const_mgr) { |
996 | 2.08k | if (numerator == nullptr) { |
997 | 0 | return nullptr; |
998 | 0 | } |
999 | | |
1000 | 2.08k | if (numerator->IsZero()) { |
1001 | 1.12k | return GetNan(result_type, const_mgr); |
1002 | 1.12k | } |
1003 | | |
1004 | 965 | const analysis::Constant* result = GetInf(result_type, const_mgr); |
1005 | 965 | if (result == nullptr) { |
1006 | 0 | return nullptr; |
1007 | 0 | } |
1008 | | |
1009 | 965 | if (numerator->AsFloatConstant()->GetValueAsDouble() < 0.0) { |
1010 | 227 | result = NegateFPConst(result_type, result, const_mgr); |
1011 | 227 | } |
1012 | 965 | return result; |
1013 | 965 | } |
1014 | | |
1015 | | // Returns the result of folding |numerator| / |denominator|. Returns |nullptr| |
1016 | | // if it cannot be folded. |
1017 | | const analysis::Constant* FoldScalarFPDivide( |
1018 | | const analysis::Type* result_type, const analysis::Constant* numerator, |
1019 | | const analysis::Constant* denominator, |
1020 | 11.5k | analysis::ConstantManager* const_mgr) { |
1021 | 11.5k | if (denominator == nullptr) { |
1022 | 0 | return nullptr; |
1023 | 0 | } |
1024 | | |
1025 | 11.5k | if (denominator->IsZero()) { |
1026 | 1.76k | return FoldFPScalarDivideByZero(result_type, numerator, const_mgr); |
1027 | 1.76k | } |
1028 | | |
1029 | 9.81k | uint32_t width = denominator->type()->AsFloat()->width(); |
1030 | 9.81k | if (width != 32 && width != 64) { |
1031 | 0 | return nullptr; |
1032 | 0 | } |
1033 | | |
1034 | 9.81k | const analysis::FloatConstant* denominator_float = |
1035 | 9.81k | denominator->AsFloatConstant(); |
1036 | 9.81k | if (denominator_float && denominator->GetValueAsDouble() == -0.0) { |
1037 | 323 | const analysis::Constant* result = |
1038 | 323 | FoldFPScalarDivideByZero(result_type, numerator, const_mgr); |
1039 | 323 | if (result != nullptr) |
1040 | 323 | result = NegateFPConst(result_type, result, const_mgr); |
1041 | 323 | return result; |
1042 | 9.49k | } else { |
1043 | 18.9k | return FOLD_FPARITH_OP(/)(result_type, numerator, denominator, const_mgr); |
1044 | 9.49k | } |
1045 | 9.81k | } |
1046 | | |
1047 | | // Returns the constant folding rule to fold |OpFDiv| with two constants. |
1048 | 11.6k | ConstantFoldingRule FoldFDiv() { return FoldFPBinaryOp(FoldScalarFPDivide); } |
1049 | | |
1050 | | bool CompareFloatingPoint(bool op_result, bool op_unordered, |
1051 | 10.9k | bool need_ordered) { |
1052 | 10.9k | if (need_ordered) { |
1053 | | // operands are ordered and Operand 1 is |op| Operand 2 |
1054 | 4.63k | return !op_unordered && op_result; |
1055 | 6.28k | } else { |
1056 | | // operands are unordered or Operand 1 is |op| Operand 2 |
1057 | 6.28k | return op_unordered || op_result; |
1058 | 6.28k | } |
1059 | 10.9k | } |
1060 | | |
1061 | | // This macro defines a |BinaryScalarFoldingRule| that applies |op|. The |
1062 | | // operator |op| must work for both float and double, and use syntax "f1 op f2". |
1063 | | #define FOLD_FPCMP_OP(op, ord) \ |
1064 | 139k | [](const analysis::Type* result_type, const analysis::Constant* a, \ |
1065 | 139k | const analysis::Constant* b, \ |
1066 | 139k | analysis::ConstantManager* const_mgr) -> const analysis::Constant* { \ |
1067 | 10.9k | assert(result_type != nullptr && a != nullptr && b != nullptr); \ |
1068 | 10.9k | assert(result_type->AsBool()); \ |
1069 | 10.9k | assert(a->type() == b->type()); \ |
1070 | 10.9k | const analysis::Float* float_type = a->type()->AsFloat(); \ |
1071 | 10.9k | assert(float_type != nullptr); \ |
1072 | 10.9k | if (float_type->width() == 32) { \ |
1073 | 10.9k | float fa = a->GetFloat(); \ |
1074 | 10.9k | float fb = b->GetFloat(); \ |
1075 | 10.9k | bool result = CompareFloatingPoint( \ |
1076 | 10.9k | fa op fb, std::isnan(fa) || std::isnan(fb), ord); \ |
1077 | 10.9k | std::vector<uint32_t> words = {uint32_t(result)}; \ |
1078 | 10.9k | return const_mgr->GetConstant(result_type, words); \ |
1079 | 10.9k | } else if (float_type->width() == 64) { \ |
1080 | 0 | double fa = a->GetDouble(); \ |
1081 | 0 | double fb = b->GetDouble(); \ |
1082 | 0 | bool result = CompareFloatingPoint( \ |
1083 | 0 | fa op fb, std::isnan(fa) || std::isnan(fb), ord); \ |
1084 | 0 | std::vector<uint32_t> words = {uint32_t(result)}; \ |
1085 | 0 | return const_mgr->GetConstant(result_type, words); \ |
1086 | 0 | } \ |
1087 | 10.9k | return nullptr; \ |
1088 | 10.9k | } |
1089 | | |
1090 | | // Define the folding rules for ordered and unordered comparison for floating |
1091 | | // point values. |
1092 | 11.6k | ConstantFoldingRule FoldFOrdEqual() { |
1093 | 12.1k | return FoldFPBinaryOp(FOLD_FPCMP_OP(==, true)); |
1094 | 11.6k | } |
1095 | 11.6k | ConstantFoldingRule FoldFUnordEqual() { |
1096 | 11.8k | return FoldFPBinaryOp(FOLD_FPCMP_OP(==, false)); |
1097 | 11.6k | } |
1098 | 11.6k | ConstantFoldingRule FoldFOrdNotEqual() { |
1099 | 12.1k | return FoldFPBinaryOp(FOLD_FPCMP_OP(!=, true)); |
1100 | 11.6k | } |
1101 | 11.6k | ConstantFoldingRule FoldFUnordNotEqual() { |
1102 | 12.7k | return FoldFPBinaryOp(FOLD_FPCMP_OP(!=, false)); |
1103 | 11.6k | } |
1104 | 11.6k | ConstantFoldingRule FoldFOrdLessThan() { |
1105 | 12.8k | return FoldFPBinaryOp(FOLD_FPCMP_OP(<, true)); |
1106 | 11.6k | } |
1107 | 11.6k | ConstantFoldingRule FoldFUnordLessThan() { |
1108 | 12.4k | return FoldFPBinaryOp(FOLD_FPCMP_OP(<, false)); |
1109 | 11.6k | } |
1110 | 11.6k | ConstantFoldingRule FoldFOrdGreaterThan() { |
1111 | 12.4k | return FoldFPBinaryOp(FOLD_FPCMP_OP(>, true)); |
1112 | 11.6k | } |
1113 | 11.6k | ConstantFoldingRule FoldFUnordGreaterThan() { |
1114 | 13.3k | return FoldFPBinaryOp(FOLD_FPCMP_OP(>, false)); |
1115 | 11.6k | } |
1116 | 11.6k | ConstantFoldingRule FoldFOrdLessThanEqual() { |
1117 | 12.3k | return FoldFPBinaryOp(FOLD_FPCMP_OP(<=, true)); |
1118 | 11.6k | } |
1119 | 11.6k | ConstantFoldingRule FoldFUnordLessThanEqual() { |
1120 | 12.9k | return FoldFPBinaryOp(FOLD_FPCMP_OP(<=, false)); |
1121 | 11.6k | } |
1122 | 11.6k | ConstantFoldingRule FoldFOrdGreaterThanEqual() { |
1123 | 12.3k | return FoldFPBinaryOp(FOLD_FPCMP_OP(>=, true)); |
1124 | 11.6k | } |
1125 | 11.6k | ConstantFoldingRule FoldFUnordGreaterThanEqual() { |
1126 | 12.6k | return FoldFPBinaryOp(FOLD_FPCMP_OP(>=, false)); |
1127 | 11.6k | } |
1128 | | |
1129 | | // Folds an OpDot where all of the inputs are constants to a |
1130 | | // constant. A new constant is created if necessary. |
1131 | 11.6k | ConstantFoldingRule FoldOpDotWithConstants() { |
1132 | 11.6k | return [](IRContext* context, Instruction* inst, |
1133 | 11.6k | const std::vector<const analysis::Constant*>& constants) |
1134 | 11.6k | -> const analysis::Constant* { |
1135 | 172 | analysis::ConstantManager* const_mgr = context->get_constant_mgr(); |
1136 | 172 | analysis::TypeManager* type_mgr = context->get_type_mgr(); |
1137 | 172 | const analysis::Type* new_type = type_mgr->GetType(inst->type_id()); |
1138 | 172 | assert(new_type->AsFloat() && "OpDot should have a float return type."); |
1139 | 172 | const analysis::Float* float_type = new_type->AsFloat(); |
1140 | | |
1141 | 172 | if (!inst->IsFloatingPointFoldingAllowed()) { |
1142 | 0 | return nullptr; |
1143 | 0 | } |
1144 | | |
1145 | | // If one of the operands is 0, then the result is 0. |
1146 | 172 | bool has_zero_operand = false; |
1147 | | |
1148 | 516 | for (int i = 0; i < 2; ++i) { |
1149 | 344 | if (constants[i]) { |
1150 | 78 | if (constants[i]->AsNullConstant() || |
1151 | 78 | constants[i]->AsVectorConstant()->IsZero()) { |
1152 | 0 | has_zero_operand = true; |
1153 | 0 | break; |
1154 | 0 | } |
1155 | 78 | } |
1156 | 344 | } |
1157 | | |
1158 | 172 | if (has_zero_operand) { |
1159 | 0 | if (float_type->width() == 32) { |
1160 | 0 | utils::FloatProxy<float> result(0.0f); |
1161 | 0 | std::vector<uint32_t> words = result.GetWords(); |
1162 | 0 | return const_mgr->GetConstant(float_type, words); |
1163 | 0 | } |
1164 | 0 | if (float_type->width() == 64) { |
1165 | 0 | utils::FloatProxy<double> result(0.0); |
1166 | 0 | std::vector<uint32_t> words = result.GetWords(); |
1167 | 0 | return const_mgr->GetConstant(float_type, words); |
1168 | 0 | } |
1169 | 0 | return nullptr; |
1170 | 0 | } |
1171 | | |
1172 | 172 | if (constants[0] == nullptr || constants[1] == nullptr) { |
1173 | 172 | return nullptr; |
1174 | 172 | } |
1175 | | |
1176 | 0 | std::vector<const analysis::Constant*> a_components; |
1177 | 0 | std::vector<const analysis::Constant*> b_components; |
1178 | |
|
1179 | 0 | a_components = constants[0]->GetVectorComponents(const_mgr); |
1180 | 0 | b_components = constants[1]->GetVectorComponents(const_mgr); |
1181 | |
|
1182 | 0 | utils::FloatProxy<double> result(0.0); |
1183 | 0 | std::vector<uint32_t> words = result.GetWords(); |
1184 | 0 | const analysis::Constant* result_const = |
1185 | 0 | const_mgr->GetConstant(float_type, words); |
1186 | 0 | for (uint32_t i = 0; i < a_components.size() && result_const != nullptr; |
1187 | 0 | ++i) { |
1188 | 0 | if (a_components[i] == nullptr || b_components[i] == nullptr) { |
1189 | 0 | return nullptr; |
1190 | 0 | } |
1191 | | |
1192 | 0 | const analysis::Constant* component = FOLD_FPARITH_OP(*)( |
1193 | 0 | new_type, a_components[i], b_components[i], const_mgr); |
1194 | 0 | if (component == nullptr) { |
1195 | 0 | return nullptr; |
1196 | 0 | } |
1197 | 0 | result_const = |
1198 | 0 | FOLD_FPARITH_OP(+)(new_type, result_const, component, const_mgr); |
1199 | 0 | } |
1200 | 0 | return result_const; |
1201 | 0 | }; |
1202 | 11.6k | } |
1203 | | |
1204 | 11.6k | ConstantFoldingRule FoldFNegate() { return FoldFPUnaryOp(NegateFPConst); } |
1205 | 11.6k | ConstantFoldingRule FoldSNegate() { return FoldUnaryOp(NegateIntConst); } |
1206 | | |
1207 | 92.9k | ConstantFoldingRule FoldFClampFeedingCompare(spv::Op cmp_opcode) { |
1208 | 92.9k | return [cmp_opcode](IRContext* context, Instruction* inst, |
1209 | 92.9k | const std::vector<const analysis::Constant*>& constants) |
1210 | 92.9k | -> const analysis::Constant* { |
1211 | 66.6k | analysis::ConstantManager* const_mgr = context->get_constant_mgr(); |
1212 | 66.6k | analysis::DefUseManager* def_use_mgr = context->get_def_use_mgr(); |
1213 | | |
1214 | 66.6k | if (!inst->IsFloatingPointFoldingAllowed()) { |
1215 | 19 | return nullptr; |
1216 | 19 | } |
1217 | | |
1218 | 66.6k | uint32_t non_const_idx = (constants[0] ? 1 : 0); |
1219 | 66.6k | uint32_t operand_id = inst->GetSingleWordInOperand(non_const_idx); |
1220 | 66.6k | Instruction* operand_inst = def_use_mgr->GetDef(operand_id); |
1221 | | |
1222 | 66.6k | analysis::TypeManager* type_mgr = context->get_type_mgr(); |
1223 | 66.6k | const analysis::Type* operand_type = |
1224 | 66.6k | type_mgr->GetType(operand_inst->type_id()); |
1225 | | |
1226 | 66.6k | if (!operand_type->AsFloat()) { |
1227 | 0 | return nullptr; |
1228 | 0 | } |
1229 | | |
1230 | 66.6k | if (operand_type->AsFloat()->width() != 32 && |
1231 | 0 | operand_type->AsFloat()->width() != 64) { |
1232 | 0 | return nullptr; |
1233 | 0 | } |
1234 | | |
1235 | 66.6k | if (operand_inst->opcode() != spv::Op::OpExtInst) { |
1236 | 60.2k | return nullptr; |
1237 | 60.2k | } |
1238 | | |
1239 | 6.39k | if (operand_inst->GetSingleWordInOperand(1) != GLSLstd450FClamp) { |
1240 | 6.39k | return nullptr; |
1241 | 6.39k | } |
1242 | | |
1243 | 0 | if (constants[1] == nullptr && constants[0] == nullptr) { |
1244 | 0 | return nullptr; |
1245 | 0 | } |
1246 | | |
1247 | 0 | uint32_t max_id = operand_inst->GetSingleWordInOperand(4); |
1248 | 0 | const analysis::Constant* max_const = |
1249 | 0 | const_mgr->FindDeclaredConstant(max_id); |
1250 | |
|
1251 | 0 | uint32_t min_id = operand_inst->GetSingleWordInOperand(3); |
1252 | 0 | const analysis::Constant* min_const = |
1253 | 0 | const_mgr->FindDeclaredConstant(min_id); |
1254 | |
|
1255 | 0 | bool found_result = false; |
1256 | 0 | bool result = false; |
1257 | |
|
1258 | 0 | switch (cmp_opcode) { |
1259 | 0 | case spv::Op::OpFOrdLessThan: |
1260 | 0 | case spv::Op::OpFUnordLessThan: |
1261 | 0 | case spv::Op::OpFOrdGreaterThanEqual: |
1262 | 0 | case spv::Op::OpFUnordGreaterThanEqual: |
1263 | 0 | if (constants[0]) { |
1264 | 0 | if (min_const) { |
1265 | 0 | if (constants[0]->GetValueAsDouble() < |
1266 | 0 | min_const->GetValueAsDouble()) { |
1267 | 0 | found_result = true; |
1268 | 0 | result = (cmp_opcode == spv::Op::OpFOrdLessThan || |
1269 | 0 | cmp_opcode == spv::Op::OpFUnordLessThan); |
1270 | 0 | } |
1271 | 0 | } |
1272 | 0 | if (max_const) { |
1273 | 0 | if (constants[0]->GetValueAsDouble() >= |
1274 | 0 | max_const->GetValueAsDouble()) { |
1275 | 0 | found_result = true; |
1276 | 0 | result = !(cmp_opcode == spv::Op::OpFOrdLessThan || |
1277 | 0 | cmp_opcode == spv::Op::OpFUnordLessThan); |
1278 | 0 | } |
1279 | 0 | } |
1280 | 0 | } |
1281 | |
|
1282 | 0 | if (constants[1]) { |
1283 | 0 | if (max_const) { |
1284 | 0 | if (max_const->GetValueAsDouble() < |
1285 | 0 | constants[1]->GetValueAsDouble()) { |
1286 | 0 | found_result = true; |
1287 | 0 | result = (cmp_opcode == spv::Op::OpFOrdLessThan || |
1288 | 0 | cmp_opcode == spv::Op::OpFUnordLessThan); |
1289 | 0 | } |
1290 | 0 | } |
1291 | |
|
1292 | 0 | if (min_const) { |
1293 | 0 | if (min_const->GetValueAsDouble() >= |
1294 | 0 | constants[1]->GetValueAsDouble()) { |
1295 | 0 | found_result = true; |
1296 | 0 | result = !(cmp_opcode == spv::Op::OpFOrdLessThan || |
1297 | 0 | cmp_opcode == spv::Op::OpFUnordLessThan); |
1298 | 0 | } |
1299 | 0 | } |
1300 | 0 | } |
1301 | 0 | break; |
1302 | 0 | case spv::Op::OpFOrdGreaterThan: |
1303 | 0 | case spv::Op::OpFUnordGreaterThan: |
1304 | 0 | case spv::Op::OpFOrdLessThanEqual: |
1305 | 0 | case spv::Op::OpFUnordLessThanEqual: |
1306 | 0 | if (constants[0]) { |
1307 | 0 | if (min_const) { |
1308 | 0 | if (constants[0]->GetValueAsDouble() <= |
1309 | 0 | min_const->GetValueAsDouble()) { |
1310 | 0 | found_result = true; |
1311 | 0 | result = (cmp_opcode == spv::Op::OpFOrdLessThanEqual || |
1312 | 0 | cmp_opcode == spv::Op::OpFUnordLessThanEqual); |
1313 | 0 | } |
1314 | 0 | } |
1315 | 0 | if (max_const) { |
1316 | 0 | if (constants[0]->GetValueAsDouble() > |
1317 | 0 | max_const->GetValueAsDouble()) { |
1318 | 0 | found_result = true; |
1319 | 0 | result = !(cmp_opcode == spv::Op::OpFOrdLessThanEqual || |
1320 | 0 | cmp_opcode == spv::Op::OpFUnordLessThanEqual); |
1321 | 0 | } |
1322 | 0 | } |
1323 | 0 | } |
1324 | |
|
1325 | 0 | if (constants[1]) { |
1326 | 0 | if (max_const) { |
1327 | 0 | if (max_const->GetValueAsDouble() <= |
1328 | 0 | constants[1]->GetValueAsDouble()) { |
1329 | 0 | found_result = true; |
1330 | 0 | result = (cmp_opcode == spv::Op::OpFOrdLessThanEqual || |
1331 | 0 | cmp_opcode == spv::Op::OpFUnordLessThanEqual); |
1332 | 0 | } |
1333 | 0 | } |
1334 | |
|
1335 | 0 | if (min_const) { |
1336 | 0 | if (min_const->GetValueAsDouble() > |
1337 | 0 | constants[1]->GetValueAsDouble()) { |
1338 | 0 | found_result = true; |
1339 | 0 | result = !(cmp_opcode == spv::Op::OpFOrdLessThanEqual || |
1340 | 0 | cmp_opcode == spv::Op::OpFUnordLessThanEqual); |
1341 | 0 | } |
1342 | 0 | } |
1343 | 0 | } |
1344 | 0 | break; |
1345 | 0 | default: |
1346 | 0 | return nullptr; |
1347 | 0 | } |
1348 | | |
1349 | 0 | if (!found_result) { |
1350 | 0 | return nullptr; |
1351 | 0 | } |
1352 | | |
1353 | 0 | const analysis::Type* bool_type = |
1354 | 0 | context->get_type_mgr()->GetType(inst->type_id()); |
1355 | 0 | const analysis::Constant* result_const = |
1356 | 0 | const_mgr->GetConstant(bool_type, {static_cast<uint32_t>(result)}); |
1357 | 0 | assert(result_const); |
1358 | 0 | return result_const; |
1359 | 0 | }; |
1360 | 92.9k | } |
1361 | | |
1362 | 6.92k | ConstantFoldingRule FoldFMix() { |
1363 | 6.92k | return [](IRContext* context, Instruction* inst, |
1364 | 6.92k | const std::vector<const analysis::Constant*>& constants) |
1365 | 9.04k | -> const analysis::Constant* { |
1366 | 9.04k | analysis::ConstantManager* const_mgr = context->get_constant_mgr(); |
1367 | 9.04k | assert(inst->opcode() == spv::Op::OpExtInst && |
1368 | 9.04k | "Expecting an extended instruction."); |
1369 | 9.04k | assert(inst->GetSingleWordInOperand(0) == |
1370 | 9.04k | context->get_feature_mgr()->GetExtInstImportId_GLSLstd450() && |
1371 | 9.04k | "Expecting a GLSLstd450 extended instruction."); |
1372 | 9.04k | assert(inst->GetSingleWordInOperand(1) == GLSLstd450FMix && |
1373 | 9.04k | "Expecting and FMix instruction."); |
1374 | | |
1375 | 9.04k | if (!inst->IsFloatingPointFoldingAllowed()) { |
1376 | 0 | return nullptr; |
1377 | 0 | } |
1378 | | |
1379 | | // Make sure all FMix operands are constants. |
1380 | 9.74k | for (uint32_t i = 1; i < 4; i++) { |
1381 | 9.64k | if (constants[i] == nullptr) { |
1382 | 8.94k | return nullptr; |
1383 | 8.94k | } |
1384 | 9.64k | } |
1385 | | |
1386 | 103 | const analysis::Constant* one; |
1387 | 103 | bool is_vector = false; |
1388 | 103 | const analysis::Type* result_type = constants[1]->type(); |
1389 | 103 | const analysis::Type* base_type = result_type; |
1390 | 103 | if (base_type->AsVector()) { |
1391 | 85 | is_vector = true; |
1392 | 85 | base_type = base_type->AsVector()->element_type(); |
1393 | 85 | } |
1394 | 103 | assert(base_type->AsFloat() != nullptr && |
1395 | 103 | "FMix is suppose to act on floats or vectors of floats."); |
1396 | | |
1397 | 103 | if (base_type->AsFloat()->width() == 32) { |
1398 | 103 | one = const_mgr->GetConstant(base_type, |
1399 | 103 | utils::FloatProxy<float>(1.0f).GetWords()); |
1400 | 103 | } else if (base_type->AsFloat()->width() == 64) { |
1401 | 0 | one = const_mgr->GetConstant(base_type, |
1402 | 0 | utils::FloatProxy<double>(1.0).GetWords()); |
1403 | 0 | } else { |
1404 | | // We won't support folding half types. |
1405 | 0 | return nullptr; |
1406 | 0 | } |
1407 | | |
1408 | 103 | if (is_vector) { |
1409 | 85 | Instruction* one_inst = const_mgr->GetDefiningInstruction(one); |
1410 | 85 | if (one_inst == nullptr) return nullptr; |
1411 | 85 | uint32_t one_id = one_inst->result_id(); |
1412 | 85 | one = |
1413 | 85 | const_mgr->GetConstant(result_type, std::vector<uint32_t>(4, one_id)); |
1414 | 85 | } |
1415 | | |
1416 | 103 | const analysis::Constant* temp1 = FoldFPBinaryOp( |
1417 | 374 | FOLD_FPARITH_OP(-), inst->type_id(), {one, constants[3]}, context); |
1418 | 103 | if (temp1 == nullptr) { |
1419 | 0 | return nullptr; |
1420 | 0 | } |
1421 | | |
1422 | 103 | const analysis::Constant* temp2 = FoldFPBinaryOp( |
1423 | 374 | FOLD_FPARITH_OP(*), inst->type_id(), {constants[1], temp1}, context); |
1424 | 103 | if (temp2 == nullptr) { |
1425 | 0 | return nullptr; |
1426 | 0 | } |
1427 | 103 | const analysis::Constant* temp3 = |
1428 | 374 | FoldFPBinaryOp(FOLD_FPARITH_OP(*), inst->type_id(), |
1429 | 103 | {constants[2], constants[3]}, context); |
1430 | 103 | if (temp3 == nullptr) { |
1431 | 0 | return nullptr; |
1432 | 0 | } |
1433 | 374 | return FoldFPBinaryOp(FOLD_FPARITH_OP(+), inst->type_id(), {temp2, temp3}, |
1434 | 103 | context); |
1435 | 103 | }; |
1436 | 6.92k | } |
1437 | | |
1438 | | const analysis::Constant* FoldMin(const analysis::Type* result_type, |
1439 | | const analysis::Constant* a, |
1440 | | const analysis::Constant* b, |
1441 | 1.97k | analysis::ConstantManager*) { |
1442 | 1.97k | if (const analysis::Integer* int_type = result_type->AsInteger()) { |
1443 | 102 | if (int_type->width() <= 32) { |
1444 | 102 | assert( |
1445 | 102 | (a->AsIntConstant() != nullptr || a->AsNullConstant() != nullptr) && |
1446 | 102 | "Must be an integer or null constant."); |
1447 | 102 | assert( |
1448 | 102 | (b->AsIntConstant() != nullptr || b->AsNullConstant() != nullptr) && |
1449 | 102 | "Must be an integer or null constant."); |
1450 | | |
1451 | 102 | if (int_type->IsSigned()) { |
1452 | 93 | int32_t va = (a->AsIntConstant() != nullptr) |
1453 | 93 | ? a->AsIntConstant()->GetS32BitValue() |
1454 | 93 | : 0; |
1455 | 93 | int32_t vb = (b->AsIntConstant() != nullptr) |
1456 | 93 | ? b->AsIntConstant()->GetS32BitValue() |
1457 | 93 | : 0; |
1458 | 93 | return (va < vb ? a : b); |
1459 | 93 | } else { |
1460 | 9 | uint32_t va = (a->AsIntConstant() != nullptr) |
1461 | 9 | ? a->AsIntConstant()->GetU32BitValue() |
1462 | 9 | : 0; |
1463 | 9 | uint32_t vb = (b->AsIntConstant() != nullptr) |
1464 | 9 | ? b->AsIntConstant()->GetU32BitValue() |
1465 | 9 | : 0; |
1466 | 9 | return (va < vb ? a : b); |
1467 | 9 | } |
1468 | 102 | } else if (int_type->width() == 64) { |
1469 | 0 | if (int_type->IsSigned()) { |
1470 | 0 | int64_t va = a->GetS64(); |
1471 | 0 | int64_t vb = b->GetS64(); |
1472 | 0 | return (va < vb ? a : b); |
1473 | 0 | } else { |
1474 | 0 | uint64_t va = a->GetU64(); |
1475 | 0 | uint64_t vb = b->GetU64(); |
1476 | 0 | return (va < vb ? a : b); |
1477 | 0 | } |
1478 | 0 | } |
1479 | 1.86k | } else if (const analysis::Float* float_type = result_type->AsFloat()) { |
1480 | 1.86k | if (float_type->width() == 32) { |
1481 | 1.86k | float va = a->GetFloat(); |
1482 | 1.86k | float vb = b->GetFloat(); |
1483 | 1.86k | return (va < vb ? a : b); |
1484 | 1.86k | } else if (float_type->width() == 64) { |
1485 | 0 | double va = a->GetDouble(); |
1486 | 0 | double vb = b->GetDouble(); |
1487 | 0 | return (va < vb ? a : b); |
1488 | 0 | } |
1489 | 1.86k | } |
1490 | 0 | return nullptr; |
1491 | 1.97k | } |
1492 | | |
1493 | | const analysis::Constant* FoldMax(const analysis::Type* result_type, |
1494 | | const analysis::Constant* a, |
1495 | | const analysis::Constant* b, |
1496 | 2.86k | analysis::ConstantManager*) { |
1497 | 2.86k | if (const analysis::Integer* int_type = result_type->AsInteger()) { |
1498 | 11 | if (int_type->width() <= 32) { |
1499 | 11 | assert( |
1500 | 11 | (a->AsIntConstant() != nullptr || a->AsNullConstant() != nullptr) && |
1501 | 11 | "Must be an integer or null constant."); |
1502 | 11 | assert( |
1503 | 11 | (b->AsIntConstant() != nullptr || b->AsNullConstant() != nullptr) && |
1504 | 11 | "Must be an integer or null constant."); |
1505 | | |
1506 | 11 | if (int_type->IsSigned()) { |
1507 | 11 | int32_t va = (a->AsIntConstant() != nullptr) |
1508 | 11 | ? a->AsIntConstant()->GetS32BitValue() |
1509 | 11 | : 0; |
1510 | 11 | int32_t vb = (b->AsIntConstant() != nullptr) |
1511 | 11 | ? b->AsIntConstant()->GetS32BitValue() |
1512 | 11 | : 0; |
1513 | 11 | return (va > vb ? a : b); |
1514 | 11 | } else { |
1515 | 0 | uint32_t va = (a->AsIntConstant() != nullptr) |
1516 | 0 | ? a->AsIntConstant()->GetU32BitValue() |
1517 | 0 | : 0; |
1518 | 0 | uint32_t vb = (b->AsIntConstant() != nullptr) |
1519 | 0 | ? b->AsIntConstant()->GetU32BitValue() |
1520 | 0 | : 0; |
1521 | 0 | return (va > vb ? a : b); |
1522 | 0 | } |
1523 | 11 | } else if (int_type->width() == 64) { |
1524 | 0 | if (int_type->IsSigned()) { |
1525 | 0 | int64_t va = a->GetS64(); |
1526 | 0 | int64_t vb = b->GetS64(); |
1527 | 0 | return (va > vb ? a : b); |
1528 | 0 | } else { |
1529 | 0 | uint64_t va = a->GetU64(); |
1530 | 0 | uint64_t vb = b->GetU64(); |
1531 | 0 | return (va > vb ? a : b); |
1532 | 0 | } |
1533 | 0 | } |
1534 | 2.85k | } else if (const analysis::Float* float_type = result_type->AsFloat()) { |
1535 | 2.85k | if (float_type->width() == 32) { |
1536 | 2.85k | float va = a->GetFloat(); |
1537 | 2.85k | float vb = b->GetFloat(); |
1538 | 2.85k | return (va > vb ? a : b); |
1539 | 2.85k | } else if (float_type->width() == 64) { |
1540 | 0 | double va = a->GetDouble(); |
1541 | 0 | double vb = b->GetDouble(); |
1542 | 0 | return (va > vb ? a : b); |
1543 | 0 | } |
1544 | 2.85k | } |
1545 | 0 | return nullptr; |
1546 | 2.86k | } |
1547 | | |
1548 | | // Fold an clamp instruction when all three operands are constant. |
1549 | | const analysis::Constant* FoldClamp1( |
1550 | | IRContext* context, Instruction* inst, |
1551 | 8.03k | const std::vector<const analysis::Constant*>& constants) { |
1552 | 8.03k | assert(inst->opcode() == spv::Op::OpExtInst && |
1553 | 8.03k | "Expecting an extended instruction."); |
1554 | 8.03k | assert(inst->GetSingleWordInOperand(0) == |
1555 | 8.03k | context->get_feature_mgr()->GetExtInstImportId_GLSLstd450() && |
1556 | 8.03k | "Expecting a GLSLstd450 extended instruction."); |
1557 | | |
1558 | | // Make sure all Clamp operands are constants. |
1559 | 12.7k | for (uint32_t i = 1; i < 4; i++) { |
1560 | 12.1k | if (constants[i] == nullptr) { |
1561 | 7.49k | return nullptr; |
1562 | 7.49k | } |
1563 | 12.1k | } |
1564 | | |
1565 | 534 | const analysis::Constant* temp = FoldFPBinaryOp( |
1566 | 534 | FoldMax, inst->type_id(), {constants[1], constants[2]}, context); |
1567 | 534 | if (temp == nullptr) { |
1568 | 0 | return nullptr; |
1569 | 0 | } |
1570 | 534 | return FoldFPBinaryOp(FoldMin, inst->type_id(), {temp, constants[3]}, |
1571 | 534 | context); |
1572 | 534 | } |
1573 | | |
1574 | | // Fold a clamp instruction when |x <= min_val|. |
1575 | | const analysis::Constant* FoldClamp2( |
1576 | | IRContext* context, Instruction* inst, |
1577 | 7.49k | const std::vector<const analysis::Constant*>& constants) { |
1578 | 7.49k | assert(inst->opcode() == spv::Op::OpExtInst && |
1579 | 7.49k | "Expecting an extended instruction."); |
1580 | 7.49k | assert(inst->GetSingleWordInOperand(0) == |
1581 | 7.49k | context->get_feature_mgr()->GetExtInstImportId_GLSLstd450() && |
1582 | 7.49k | "Expecting a GLSLstd450 extended instruction."); |
1583 | | |
1584 | 7.49k | const analysis::Constant* x = constants[1]; |
1585 | 7.49k | const analysis::Constant* min_val = constants[2]; |
1586 | | |
1587 | 7.49k | if (x == nullptr || min_val == nullptr) { |
1588 | 6.07k | return nullptr; |
1589 | 6.07k | } |
1590 | | |
1591 | 1.42k | const analysis::Constant* temp = |
1592 | 1.42k | FoldFPBinaryOp(FoldMax, inst->type_id(), {x, min_val}, context); |
1593 | 1.42k | if (temp == min_val) { |
1594 | | // We can assume that |min_val| is less than |max_val|. Therefore, if the |
1595 | | // result of the max operation is |min_val|, we know the result of the min |
1596 | | // operation, even if |max_val| is not a constant. |
1597 | 1.18k | return min_val; |
1598 | 1.18k | } |
1599 | 237 | return nullptr; |
1600 | 1.42k | } |
1601 | | |
1602 | | // Fold a clamp instruction when |x >= max_val|. |
1603 | | const analysis::Constant* FoldClamp3( |
1604 | | IRContext* context, Instruction* inst, |
1605 | 6.31k | const std::vector<const analysis::Constant*>& constants) { |
1606 | 6.31k | assert(inst->opcode() == spv::Op::OpExtInst && |
1607 | 6.31k | "Expecting an extended instruction."); |
1608 | 6.31k | assert(inst->GetSingleWordInOperand(0) == |
1609 | 6.31k | context->get_feature_mgr()->GetExtInstImportId_GLSLstd450() && |
1610 | 6.31k | "Expecting a GLSLstd450 extended instruction."); |
1611 | | |
1612 | 6.31k | const analysis::Constant* x = constants[1]; |
1613 | 6.31k | const analysis::Constant* max_val = constants[3]; |
1614 | | |
1615 | 6.31k | if (x == nullptr || max_val == nullptr) { |
1616 | 6.14k | return nullptr; |
1617 | 6.14k | } |
1618 | | |
1619 | 166 | const analysis::Constant* temp = |
1620 | 166 | FoldFPBinaryOp(FoldMin, inst->type_id(), {x, max_val}, context); |
1621 | 166 | if (temp == max_val) { |
1622 | | // We can assume that |min_val| is less than |max_val|. Therefore, if the |
1623 | | // result of the max operation is |min_val|, we know the result of the min |
1624 | | // operation, even if |max_val| is not a constant. |
1625 | 63 | return max_val; |
1626 | 63 | } |
1627 | 103 | return nullptr; |
1628 | 166 | } |
1629 | | |
1630 | 76.1k | UnaryScalarFoldingRule FoldFTranscendentalUnary(double (*fp)(double)) { |
1631 | 76.1k | return |
1632 | 76.1k | [fp](const analysis::Type* result_type, const analysis::Constant* a, |
1633 | 76.1k | analysis::ConstantManager* const_mgr) -> const analysis::Constant* { |
1634 | 1.15k | assert(result_type != nullptr && a != nullptr); |
1635 | 1.15k | const analysis::Float* float_type = a->type()->AsFloat(); |
1636 | 1.15k | assert(float_type != nullptr); |
1637 | 1.15k | assert(float_type == result_type->AsFloat()); |
1638 | 1.15k | if (float_type->width() == 32) { |
1639 | 1.15k | float fa = a->GetFloat(); |
1640 | 1.15k | float res = static_cast<float>(fp(fa)); |
1641 | 1.15k | utils::FloatProxy<float> result(res); |
1642 | 1.15k | std::vector<uint32_t> words = result.GetWords(); |
1643 | 1.15k | return const_mgr->GetConstant(result_type, words); |
1644 | 1.15k | } else if (float_type->width() == 64) { |
1645 | 0 | double fa = a->GetDouble(); |
1646 | 0 | double res = fp(fa); |
1647 | 0 | utils::FloatProxy<double> result(res); |
1648 | 0 | std::vector<uint32_t> words = result.GetWords(); |
1649 | 0 | return const_mgr->GetConstant(result_type, words); |
1650 | 0 | } |
1651 | 0 | return nullptr; |
1652 | 1.15k | }; |
1653 | 76.1k | } |
1654 | | |
1655 | | BinaryScalarFoldingRule FoldFTranscendentalBinary(double (*fp)(double, |
1656 | 13.8k | double)) { |
1657 | 13.8k | return |
1658 | 13.8k | [fp](const analysis::Type* result_type, const analysis::Constant* a, |
1659 | 13.8k | const analysis::Constant* b, |
1660 | 13.8k | analysis::ConstantManager* const_mgr) -> const analysis::Constant* { |
1661 | 135 | assert(result_type != nullptr && a != nullptr); |
1662 | 135 | const analysis::Float* float_type = a->type()->AsFloat(); |
1663 | 135 | assert(float_type != nullptr); |
1664 | 135 | assert(float_type == result_type->AsFloat()); |
1665 | 135 | assert(float_type == b->type()->AsFloat()); |
1666 | 135 | if (float_type->width() == 32) { |
1667 | 135 | float fa = a->GetFloat(); |
1668 | 135 | float fb = b->GetFloat(); |
1669 | 135 | float res = static_cast<float>(fp(fa, fb)); |
1670 | 135 | utils::FloatProxy<float> result(res); |
1671 | 135 | std::vector<uint32_t> words = result.GetWords(); |
1672 | 135 | return const_mgr->GetConstant(result_type, words); |
1673 | 135 | } else if (float_type->width() == 64) { |
1674 | 0 | double fa = a->GetDouble(); |
1675 | 0 | double fb = b->GetDouble(); |
1676 | 0 | double res = fp(fa, fb); |
1677 | 0 | utils::FloatProxy<double> result(res); |
1678 | 0 | std::vector<uint32_t> words = result.GetWords(); |
1679 | 0 | return const_mgr->GetConstant(result_type, words); |
1680 | 0 | } |
1681 | 0 | return nullptr; |
1682 | 135 | }; |
1683 | 13.8k | } |
1684 | | |
1685 | | enum Sign { Signed, Unsigned }; |
1686 | | |
1687 | | // Returns a BinaryScalarFoldingRule that applies `op` to the scalars. |
1688 | | // The `signedness` is used to determine if the operands should be interpreted |
1689 | | // as signed or unsigned. If the operands are signed, the value will be sign |
1690 | | // extended before the value is passed to `op`. Otherwise the values will be |
1691 | | // zero extended. |
1692 | | template <Sign signedness> |
1693 | | BinaryScalarFoldingRule FoldBinaryIntegerOperation(uint64_t (*op)(uint64_t, |
1694 | 92.9k | uint64_t)) { |
1695 | 92.9k | return |
1696 | 92.9k | [op](const analysis::Type* result_type, const analysis::Constant* a, |
1697 | 92.9k | const analysis::Constant* b, |
1698 | 878k | analysis::ConstantManager* const_mgr) -> const analysis::Constant* { |
1699 | 878k | assert(result_type != nullptr && a != nullptr && b != nullptr); |
1700 | 878k | const analysis::Integer* integer_type = result_type->AsInteger(); |
1701 | 878k | assert(integer_type != nullptr); |
1702 | 878k | assert(a->type()->kind() == analysis::Type::kInteger); |
1703 | 878k | assert(b->type()->kind() == analysis::Type::kInteger); |
1704 | 878k | assert(integer_type->width() == a->type()->AsInteger()->width()); |
1705 | 878k | assert(integer_type->width() == b->type()->AsInteger()->width()); |
1706 | | |
1707 | | // In SPIR-V, all operations support unsigned types, but the way they |
1708 | | // are interpreted depends on the opcode. This is why we use the |
1709 | | // template argument to determine how to interpret the operands. |
1710 | 878k | uint64_t ia = (signedness == Signed ? a->GetSignExtendedValue() |
1711 | 878k | : a->GetZeroExtendedValue()); |
1712 | 878k | uint64_t ib = (signedness == Signed ? b->GetSignExtendedValue() |
1713 | 878k | : b->GetZeroExtendedValue()); |
1714 | 878k | uint64_t result = op(ia, ib); |
1715 | | |
1716 | 878k | const analysis::Constant* result_constant = |
1717 | 878k | const_mgr->GenerateIntegerConstant(integer_type, result); |
1718 | 878k | return result_constant; |
1719 | 878k | }; const_folding_rules.cpp:spvtools::opt::(anonymous namespace)::FoldBinaryIntegerOperation<(spvtools::opt::(anonymous namespace)::Sign)1>(unsigned long (*)(unsigned long, unsigned long))::{lambda(spvtools::opt::analysis::Type const*, spvtools::opt::analysis::Constant const*, spvtools::opt::analysis::Constant const*, spvtools::opt::analysis::ConstantManager*)#1}::operator()(spvtools::opt::analysis::Type const*, spvtools::opt::analysis::Constant const*, spvtools::opt::analysis::Constant const*, spvtools::opt::analysis::ConstantManager*) constLine | Count | Source | 1698 | 777k | analysis::ConstantManager* const_mgr) -> const analysis::Constant* { | 1699 | 777k | assert(result_type != nullptr && a != nullptr && b != nullptr); | 1700 | 777k | const analysis::Integer* integer_type = result_type->AsInteger(); | 1701 | 777k | assert(integer_type != nullptr); | 1702 | 777k | assert(a->type()->kind() == analysis::Type::kInteger); | 1703 | 777k | assert(b->type()->kind() == analysis::Type::kInteger); | 1704 | 777k | assert(integer_type->width() == a->type()->AsInteger()->width()); | 1705 | 777k | assert(integer_type->width() == b->type()->AsInteger()->width()); | 1706 | | | 1707 | | // In SPIR-V, all operations support unsigned types, but the way they | 1708 | | // are interpreted depends on the opcode. This is why we use the | 1709 | | // template argument to determine how to interpret the operands. | 1710 | 777k | uint64_t ia = (signedness == Signed ? a->GetSignExtendedValue() | 1711 | 777k | : a->GetZeroExtendedValue()); | 1712 | 777k | uint64_t ib = (signedness == Signed ? b->GetSignExtendedValue() | 1713 | 777k | : b->GetZeroExtendedValue()); | 1714 | 777k | uint64_t result = op(ia, ib); | 1715 | | | 1716 | 777k | const analysis::Constant* result_constant = | 1717 | 777k | const_mgr->GenerateIntegerConstant(integer_type, result); | 1718 | 777k | return result_constant; | 1719 | 777k | }; |
const_folding_rules.cpp:spvtools::opt::(anonymous namespace)::FoldBinaryIntegerOperation<(spvtools::opt::(anonymous namespace)::Sign)0>(unsigned long (*)(unsigned long, unsigned long))::{lambda(spvtools::opt::analysis::Type const*, spvtools::opt::analysis::Constant const*, spvtools::opt::analysis::Constant const*, spvtools::opt::analysis::ConstantManager*)#1}::operator()(spvtools::opt::analysis::Type const*, spvtools::opt::analysis::Constant const*, spvtools::opt::analysis::Constant const*, spvtools::opt::analysis::ConstantManager*) constLine | Count | Source | 1698 | 100k | analysis::ConstantManager* const_mgr) -> const analysis::Constant* { | 1699 | 100k | assert(result_type != nullptr && a != nullptr && b != nullptr); | 1700 | 100k | const analysis::Integer* integer_type = result_type->AsInteger(); | 1701 | 100k | assert(integer_type != nullptr); | 1702 | 100k | assert(a->type()->kind() == analysis::Type::kInteger); | 1703 | 100k | assert(b->type()->kind() == analysis::Type::kInteger); | 1704 | 100k | assert(integer_type->width() == a->type()->AsInteger()->width()); | 1705 | 100k | assert(integer_type->width() == b->type()->AsInteger()->width()); | 1706 | | | 1707 | | // In SPIR-V, all operations support unsigned types, but the way they | 1708 | | // are interpreted depends on the opcode. This is why we use the | 1709 | | // template argument to determine how to interpret the operands. | 1710 | 100k | uint64_t ia = (signedness == Signed ? a->GetSignExtendedValue() | 1711 | 100k | : a->GetZeroExtendedValue()); | 1712 | 100k | uint64_t ib = (signedness == Signed ? b->GetSignExtendedValue() | 1713 | 100k | : b->GetZeroExtendedValue()); | 1714 | 100k | uint64_t result = op(ia, ib); | 1715 | | | 1716 | 100k | const analysis::Constant* result_constant = | 1717 | 100k | const_mgr->GenerateIntegerConstant(integer_type, result); | 1718 | 100k | return result_constant; | 1719 | 100k | }; |
|
1720 | 92.9k | } const_folding_rules.cpp:std::__1::function<spvtools::opt::analysis::Constant const* (spvtools::opt::analysis::Type const*, spvtools::opt::analysis::Constant const*, spvtools::opt::analysis::Constant const*, spvtools::opt::analysis::ConstantManager*)> spvtools::opt::(anonymous namespace)::FoldBinaryIntegerOperation<(spvtools::opt::(anonymous namespace)::Sign)1>(unsigned long (*)(unsigned long, unsigned long)) Line | Count | Source | 1694 | 58.0k | uint64_t)) { | 1695 | 58.0k | return | 1696 | 58.0k | [op](const analysis::Type* result_type, const analysis::Constant* a, | 1697 | 58.0k | const analysis::Constant* b, | 1698 | 58.0k | analysis::ConstantManager* const_mgr) -> const analysis::Constant* { | 1699 | 58.0k | assert(result_type != nullptr && a != nullptr && b != nullptr); | 1700 | 58.0k | const analysis::Integer* integer_type = result_type->AsInteger(); | 1701 | 58.0k | assert(integer_type != nullptr); | 1702 | 58.0k | assert(a->type()->kind() == analysis::Type::kInteger); | 1703 | 58.0k | assert(b->type()->kind() == analysis::Type::kInteger); | 1704 | 58.0k | assert(integer_type->width() == a->type()->AsInteger()->width()); | 1705 | 58.0k | assert(integer_type->width() == b->type()->AsInteger()->width()); | 1706 | | | 1707 | | // In SPIR-V, all operations support unsigned types, but the way they | 1708 | | // are interpreted depends on the opcode. This is why we use the | 1709 | | // template argument to determine how to interpret the operands. | 1710 | 58.0k | uint64_t ia = (signedness == Signed ? a->GetSignExtendedValue() | 1711 | 58.0k | : a->GetZeroExtendedValue()); | 1712 | 58.0k | uint64_t ib = (signedness == Signed ? b->GetSignExtendedValue() | 1713 | 58.0k | : b->GetZeroExtendedValue()); | 1714 | 58.0k | uint64_t result = op(ia, ib); | 1715 | | | 1716 | 58.0k | const analysis::Constant* result_constant = | 1717 | 58.0k | const_mgr->GenerateIntegerConstant(integer_type, result); | 1718 | 58.0k | return result_constant; | 1719 | 58.0k | }; | 1720 | 58.0k | } |
const_folding_rules.cpp:std::__1::function<spvtools::opt::analysis::Constant const* (spvtools::opt::analysis::Type const*, spvtools::opt::analysis::Constant const*, spvtools::opt::analysis::Constant const*, spvtools::opt::analysis::ConstantManager*)> spvtools::opt::(anonymous namespace)::FoldBinaryIntegerOperation<(spvtools::opt::(anonymous namespace)::Sign)0>(unsigned long (*)(unsigned long, unsigned long)) Line | Count | Source | 1694 | 34.8k | uint64_t)) { | 1695 | 34.8k | return | 1696 | 34.8k | [op](const analysis::Type* result_type, const analysis::Constant* a, | 1697 | 34.8k | const analysis::Constant* b, | 1698 | 34.8k | analysis::ConstantManager* const_mgr) -> const analysis::Constant* { | 1699 | 34.8k | assert(result_type != nullptr && a != nullptr && b != nullptr); | 1700 | 34.8k | const analysis::Integer* integer_type = result_type->AsInteger(); | 1701 | 34.8k | assert(integer_type != nullptr); | 1702 | 34.8k | assert(a->type()->kind() == analysis::Type::kInteger); | 1703 | 34.8k | assert(b->type()->kind() == analysis::Type::kInteger); | 1704 | 34.8k | assert(integer_type->width() == a->type()->AsInteger()->width()); | 1705 | 34.8k | assert(integer_type->width() == b->type()->AsInteger()->width()); | 1706 | | | 1707 | | // In SPIR-V, all operations support unsigned types, but the way they | 1708 | | // are interpreted depends on the opcode. This is why we use the | 1709 | | // template argument to determine how to interpret the operands. | 1710 | 34.8k | uint64_t ia = (signedness == Signed ? a->GetSignExtendedValue() | 1711 | 34.8k | : a->GetZeroExtendedValue()); | 1712 | 34.8k | uint64_t ib = (signedness == Signed ? b->GetSignExtendedValue() | 1713 | 34.8k | : b->GetZeroExtendedValue()); | 1714 | 34.8k | uint64_t result = op(ia, ib); | 1715 | | | 1716 | 34.8k | const analysis::Constant* result_constant = | 1717 | 34.8k | const_mgr->GenerateIntegerConstant(integer_type, result); | 1718 | 34.8k | return result_constant; | 1719 | 34.8k | }; | 1720 | 34.8k | } |
|
1721 | | |
1722 | | // A scalar folding rule that folds OpSConvert. |
1723 | | const analysis::Constant* FoldScalarSConvert( |
1724 | | const analysis::Type* result_type, const analysis::Constant* a, |
1725 | 0 | analysis::ConstantManager* const_mgr) { |
1726 | 0 | assert(result_type != nullptr); |
1727 | 0 | assert(a != nullptr); |
1728 | 0 | assert(const_mgr != nullptr); |
1729 | 0 | const analysis::Integer* integer_type = result_type->AsInteger(); |
1730 | 0 | assert(integer_type && "The result type of an SConvert"); |
1731 | 0 | int64_t value = a->GetSignExtendedValue(); |
1732 | 0 | return const_mgr->GenerateIntegerConstant(integer_type, value); |
1733 | 0 | } |
1734 | | |
1735 | | // A scalar folding rule that folds OpUConvert. |
1736 | | const analysis::Constant* FoldScalarUConvert( |
1737 | | const analysis::Type* result_type, const analysis::Constant* a, |
1738 | 0 | analysis::ConstantManager* const_mgr) { |
1739 | 0 | assert(result_type != nullptr); |
1740 | 0 | assert(a != nullptr); |
1741 | 0 | assert(const_mgr != nullptr); |
1742 | 0 | const analysis::Integer* integer_type = result_type->AsInteger(); |
1743 | 0 | assert(integer_type && "The result type of an UConvert"); |
1744 | 0 | uint64_t value = a->GetZeroExtendedValue(); |
1745 | | |
1746 | | // If the operand was an unsigned value with less than 32-bit, it would have |
1747 | | // been sign extended earlier, and we need to clear those bits. |
1748 | 0 | auto* operand_type = a->type()->AsInteger(); |
1749 | 0 | value = utils::ClearHighBits(value, 64 - operand_type->width()); |
1750 | 0 | return const_mgr->GenerateIntegerConstant(integer_type, value); |
1751 | 0 | } |
1752 | | } // namespace |
1753 | | |
1754 | 11.6k | void ConstantFoldingRules::AddFoldingRules() { |
1755 | | // Add all folding rules to the list for the opcodes to which they apply. |
1756 | | // Note that the order in which rules are added to the list matters. If a rule |
1757 | | // applies to the instruction, the rest of the rules will not be attempted. |
1758 | | // Take that into consideration. |
1759 | | |
1760 | 11.6k | rules_[spv::Op::OpCompositeConstruct].push_back(FoldCompositeWithConstants()); |
1761 | | |
1762 | 11.6k | rules_[spv::Op::OpCompositeExtract].push_back(FoldExtractWithConstants()); |
1763 | 11.6k | rules_[spv::Op::OpCompositeInsert].push_back(FoldInsertWithConstants()); |
1764 | | |
1765 | 11.6k | rules_[spv::Op::OpConvertFToS].push_back(FoldFToI()); |
1766 | 11.6k | rules_[spv::Op::OpConvertFToU].push_back(FoldFToI()); |
1767 | 11.6k | rules_[spv::Op::OpConvertSToF].push_back(FoldIToF()); |
1768 | 11.6k | rules_[spv::Op::OpConvertUToF].push_back(FoldIToF()); |
1769 | 11.6k | rules_[spv::Op::OpSConvert].push_back(FoldUnaryOp(FoldScalarSConvert)); |
1770 | 11.6k | rules_[spv::Op::OpUConvert].push_back(FoldUnaryOp(FoldScalarUConvert)); |
1771 | | |
1772 | 11.6k | rules_[spv::Op::OpDot].push_back(FoldOpDotWithConstants()); |
1773 | 11.6k | rules_[spv::Op::OpFAdd].push_back(FoldFAdd()); |
1774 | 11.6k | rules_[spv::Op::OpFDiv].push_back(FoldFDiv()); |
1775 | 11.6k | rules_[spv::Op::OpFMul].push_back(FoldFMul()); |
1776 | 11.6k | rules_[spv::Op::OpFSub].push_back(FoldFSub()); |
1777 | | |
1778 | 11.6k | rules_[spv::Op::OpFOrdEqual].push_back(FoldFOrdEqual()); |
1779 | | |
1780 | 11.6k | rules_[spv::Op::OpFUnordEqual].push_back(FoldFUnordEqual()); |
1781 | | |
1782 | 11.6k | rules_[spv::Op::OpFOrdNotEqual].push_back(FoldFOrdNotEqual()); |
1783 | | |
1784 | 11.6k | rules_[spv::Op::OpFUnordNotEqual].push_back(FoldFUnordNotEqual()); |
1785 | | |
1786 | 11.6k | rules_[spv::Op::OpFOrdLessThan].push_back(FoldFOrdLessThan()); |
1787 | 11.6k | rules_[spv::Op::OpFOrdLessThan].push_back( |
1788 | 11.6k | FoldFClampFeedingCompare(spv::Op::OpFOrdLessThan)); |
1789 | | |
1790 | 11.6k | rules_[spv::Op::OpFUnordLessThan].push_back(FoldFUnordLessThan()); |
1791 | 11.6k | rules_[spv::Op::OpFUnordLessThan].push_back( |
1792 | 11.6k | FoldFClampFeedingCompare(spv::Op::OpFUnordLessThan)); |
1793 | | |
1794 | 11.6k | rules_[spv::Op::OpFOrdGreaterThan].push_back(FoldFOrdGreaterThan()); |
1795 | 11.6k | rules_[spv::Op::OpFOrdGreaterThan].push_back( |
1796 | 11.6k | FoldFClampFeedingCompare(spv::Op::OpFOrdGreaterThan)); |
1797 | | |
1798 | 11.6k | rules_[spv::Op::OpFUnordGreaterThan].push_back(FoldFUnordGreaterThan()); |
1799 | 11.6k | rules_[spv::Op::OpFUnordGreaterThan].push_back( |
1800 | 11.6k | FoldFClampFeedingCompare(spv::Op::OpFUnordGreaterThan)); |
1801 | | |
1802 | 11.6k | rules_[spv::Op::OpFOrdLessThanEqual].push_back(FoldFOrdLessThanEqual()); |
1803 | 11.6k | rules_[spv::Op::OpFOrdLessThanEqual].push_back( |
1804 | 11.6k | FoldFClampFeedingCompare(spv::Op::OpFOrdLessThanEqual)); |
1805 | | |
1806 | 11.6k | rules_[spv::Op::OpFUnordLessThanEqual].push_back(FoldFUnordLessThanEqual()); |
1807 | 11.6k | rules_[spv::Op::OpFUnordLessThanEqual].push_back( |
1808 | 11.6k | FoldFClampFeedingCompare(spv::Op::OpFUnordLessThanEqual)); |
1809 | | |
1810 | 11.6k | rules_[spv::Op::OpFOrdGreaterThanEqual].push_back(FoldFOrdGreaterThanEqual()); |
1811 | 11.6k | rules_[spv::Op::OpFOrdGreaterThanEqual].push_back( |
1812 | 11.6k | FoldFClampFeedingCompare(spv::Op::OpFOrdGreaterThanEqual)); |
1813 | | |
1814 | 11.6k | rules_[spv::Op::OpFUnordGreaterThanEqual].push_back( |
1815 | 11.6k | FoldFUnordGreaterThanEqual()); |
1816 | 11.6k | rules_[spv::Op::OpFUnordGreaterThanEqual].push_back( |
1817 | 11.6k | FoldFClampFeedingCompare(spv::Op::OpFUnordGreaterThanEqual)); |
1818 | | |
1819 | 11.6k | rules_[spv::Op::OpVectorShuffle].push_back(FoldVectorShuffleWithConstants()); |
1820 | 11.6k | rules_[spv::Op::OpVectorTimesScalar].push_back(FoldVectorTimesScalar()); |
1821 | 11.6k | rules_[spv::Op::OpVectorTimesMatrix].push_back(FoldVectorTimesMatrix()); |
1822 | 11.6k | rules_[spv::Op::OpMatrixTimesVector].push_back(FoldMatrixTimesVector()); |
1823 | 11.6k | rules_[spv::Op::OpTranspose].push_back(FoldTranspose); |
1824 | | |
1825 | 11.6k | rules_[spv::Op::OpFNegate].push_back(FoldFNegate()); |
1826 | 11.6k | rules_[spv::Op::OpSNegate].push_back(FoldSNegate()); |
1827 | 11.6k | rules_[spv::Op::OpQuantizeToF16].push_back(FoldQuantizeToF16()); |
1828 | | |
1829 | 11.6k | rules_[spv::Op::OpIAdd].push_back( |
1830 | 11.6k | FoldBinaryOp(FoldBinaryIntegerOperation<Unsigned>( |
1831 | 458k | [](uint64_t a, uint64_t b) { return a + b; }))); |
1832 | 11.6k | rules_[spv::Op::OpISub].push_back( |
1833 | 11.6k | FoldBinaryOp(FoldBinaryIntegerOperation<Unsigned>( |
1834 | 226k | [](uint64_t a, uint64_t b) { return a - b; }))); |
1835 | 11.6k | rules_[spv::Op::OpIMul].push_back( |
1836 | 11.6k | FoldBinaryOp(FoldBinaryIntegerOperation<Unsigned>( |
1837 | 92.7k | [](uint64_t a, uint64_t b) { return a * b; }))); |
1838 | 11.6k | rules_[spv::Op::OpUDiv].push_back( |
1839 | 11.6k | FoldBinaryOp(FoldBinaryIntegerOperation<Unsigned>( |
1840 | 11.6k | [](uint64_t a, uint64_t b) { return (b != 0 ? a / b : 0); }))); |
1841 | 11.6k | rules_[spv::Op::OpSDiv].push_back(FoldBinaryOp( |
1842 | 37.4k | FoldBinaryIntegerOperation<Signed>([](uint64_t a, uint64_t b) { |
1843 | 37.4k | return (b != 0 ? static_cast<uint64_t>(static_cast<int64_t>(a) / |
1844 | 36.5k | static_cast<int64_t>(b)) |
1845 | 37.4k | : 0); |
1846 | 37.4k | }))); |
1847 | 11.6k | rules_[spv::Op::OpUMod].push_back( |
1848 | 11.6k | FoldBinaryOp(FoldBinaryIntegerOperation<Unsigned>( |
1849 | 11.6k | [](uint64_t a, uint64_t b) { return (b != 0 ? a % b : 0); }))); |
1850 | | |
1851 | 11.6k | rules_[spv::Op::OpSRem].push_back(FoldBinaryOp( |
1852 | 61.3k | FoldBinaryIntegerOperation<Signed>([](uint64_t a, uint64_t b) { |
1853 | 61.3k | return (b != 0 ? static_cast<uint64_t>(static_cast<int64_t>(a) % |
1854 | 60.7k | static_cast<int64_t>(b)) |
1855 | 61.3k | : 0); |
1856 | 61.3k | }))); |
1857 | | |
1858 | 11.6k | rules_[spv::Op::OpSMod].push_back(FoldBinaryOp( |
1859 | 11.6k | FoldBinaryIntegerOperation<Signed>([](uint64_t a, uint64_t b) { |
1860 | 1.82k | if (b == 0) return static_cast<uint64_t>(0ull); |
1861 | | |
1862 | 1.73k | int64_t signed_a = static_cast<int64_t>(a); |
1863 | 1.73k | int64_t signed_b = static_cast<int64_t>(b); |
1864 | 1.73k | int64_t result = signed_a % signed_b; |
1865 | 1.73k | if ((signed_b < 0) != (result < 0)) result += signed_b; |
1866 | 1.73k | return static_cast<uint64_t>(result); |
1867 | 1.82k | }))); |
1868 | | |
1869 | | // Add rules for GLSLstd450 |
1870 | 11.6k | FeatureManager* feature_manager = context_->get_feature_mgr(); |
1871 | 11.6k | uint32_t ext_inst_glslstd450_id = |
1872 | 11.6k | feature_manager->GetExtInstImportId_GLSLstd450(); |
1873 | 11.6k | if (ext_inst_glslstd450_id != 0) { |
1874 | 6.92k | ext_rules_[{ext_inst_glslstd450_id, GLSLstd450FMix}].push_back(FoldFMix()); |
1875 | 6.92k | ext_rules_[{ext_inst_glslstd450_id, GLSLstd450SMin}].push_back( |
1876 | 6.92k | FoldFPBinaryOp(FoldMin)); |
1877 | 6.92k | ext_rules_[{ext_inst_glslstd450_id, GLSLstd450UMin}].push_back( |
1878 | 6.92k | FoldFPBinaryOp(FoldMin)); |
1879 | 6.92k | ext_rules_[{ext_inst_glslstd450_id, GLSLstd450FMin}].push_back( |
1880 | 6.92k | FoldFPBinaryOp(FoldMin)); |
1881 | 6.92k | ext_rules_[{ext_inst_glslstd450_id, GLSLstd450SMax}].push_back( |
1882 | 6.92k | FoldFPBinaryOp(FoldMax)); |
1883 | 6.92k | ext_rules_[{ext_inst_glslstd450_id, GLSLstd450UMax}].push_back( |
1884 | 6.92k | FoldFPBinaryOp(FoldMax)); |
1885 | 6.92k | ext_rules_[{ext_inst_glslstd450_id, GLSLstd450FMax}].push_back( |
1886 | 6.92k | FoldFPBinaryOp(FoldMax)); |
1887 | 6.92k | ext_rules_[{ext_inst_glslstd450_id, GLSLstd450UClamp}].push_back( |
1888 | 6.92k | FoldClamp1); |
1889 | 6.92k | ext_rules_[{ext_inst_glslstd450_id, GLSLstd450UClamp}].push_back( |
1890 | 6.92k | FoldClamp2); |
1891 | 6.92k | ext_rules_[{ext_inst_glslstd450_id, GLSLstd450UClamp}].push_back( |
1892 | 6.92k | FoldClamp3); |
1893 | 6.92k | ext_rules_[{ext_inst_glslstd450_id, GLSLstd450SClamp}].push_back( |
1894 | 6.92k | FoldClamp1); |
1895 | 6.92k | ext_rules_[{ext_inst_glslstd450_id, GLSLstd450SClamp}].push_back( |
1896 | 6.92k | FoldClamp2); |
1897 | 6.92k | ext_rules_[{ext_inst_glslstd450_id, GLSLstd450SClamp}].push_back( |
1898 | 6.92k | FoldClamp3); |
1899 | 6.92k | ext_rules_[{ext_inst_glslstd450_id, GLSLstd450FClamp}].push_back( |
1900 | 6.92k | FoldClamp1); |
1901 | 6.92k | ext_rules_[{ext_inst_glslstd450_id, GLSLstd450FClamp}].push_back( |
1902 | 6.92k | FoldClamp2); |
1903 | 6.92k | ext_rules_[{ext_inst_glslstd450_id, GLSLstd450FClamp}].push_back( |
1904 | 6.92k | FoldClamp3); |
1905 | 6.92k | ext_rules_[{ext_inst_glslstd450_id, GLSLstd450Sin}].push_back( |
1906 | 6.92k | FoldFPUnaryOp(FoldFTranscendentalUnary(std::sin))); |
1907 | 6.92k | ext_rules_[{ext_inst_glslstd450_id, GLSLstd450Cos}].push_back( |
1908 | 6.92k | FoldFPUnaryOp(FoldFTranscendentalUnary(std::cos))); |
1909 | 6.92k | ext_rules_[{ext_inst_glslstd450_id, GLSLstd450Tan}].push_back( |
1910 | 6.92k | FoldFPUnaryOp(FoldFTranscendentalUnary(std::tan))); |
1911 | 6.92k | ext_rules_[{ext_inst_glslstd450_id, GLSLstd450Asin}].push_back( |
1912 | 6.92k | FoldFPUnaryOp(FoldFTranscendentalUnary(std::asin))); |
1913 | 6.92k | ext_rules_[{ext_inst_glslstd450_id, GLSLstd450Acos}].push_back( |
1914 | 6.92k | FoldFPUnaryOp(FoldFTranscendentalUnary(std::acos))); |
1915 | 6.92k | ext_rules_[{ext_inst_glslstd450_id, GLSLstd450Atan}].push_back( |
1916 | 6.92k | FoldFPUnaryOp(FoldFTranscendentalUnary(std::atan))); |
1917 | 6.92k | ext_rules_[{ext_inst_glslstd450_id, GLSLstd450Exp}].push_back( |
1918 | 6.92k | FoldFPUnaryOp(FoldFTranscendentalUnary(std::exp))); |
1919 | 6.92k | ext_rules_[{ext_inst_glslstd450_id, GLSLstd450Log}].push_back( |
1920 | 6.92k | FoldFPUnaryOp(FoldFTranscendentalUnary(std::log))); |
1921 | | |
1922 | | #ifdef __ANDROID__ |
1923 | | // Android NDK r15c targeting ABI 15 doesn't have full support for C++11 |
1924 | | // (no std::exp2/log2). ::exp2 is available from C99 but ::log2 isn't |
1925 | | // available up until ABI 18 so we use a shim |
1926 | | auto log2_shim = [](double v) -> double { return log(v) / log(2.0); }; |
1927 | | ext_rules_[{ext_inst_glslstd450_id, GLSLstd450Exp2}].push_back( |
1928 | | FoldFPUnaryOp(FoldFTranscendentalUnary(::exp2))); |
1929 | | ext_rules_[{ext_inst_glslstd450_id, GLSLstd450Log2}].push_back( |
1930 | | FoldFPUnaryOp(FoldFTranscendentalUnary(log2_shim))); |
1931 | | #else |
1932 | 6.92k | ext_rules_[{ext_inst_glslstd450_id, GLSLstd450Exp2}].push_back( |
1933 | 6.92k | FoldFPUnaryOp(FoldFTranscendentalUnary(std::exp2))); |
1934 | 6.92k | ext_rules_[{ext_inst_glslstd450_id, GLSLstd450Log2}].push_back( |
1935 | 6.92k | FoldFPUnaryOp(FoldFTranscendentalUnary(std::log2))); |
1936 | 6.92k | #endif |
1937 | | |
1938 | 6.92k | ext_rules_[{ext_inst_glslstd450_id, GLSLstd450Sqrt}].push_back( |
1939 | 6.92k | FoldFPUnaryOp(FoldFTranscendentalUnary(std::sqrt))); |
1940 | 6.92k | ext_rules_[{ext_inst_glslstd450_id, GLSLstd450Atan2}].push_back( |
1941 | 6.92k | FoldFPBinaryOp(FoldFTranscendentalBinary(std::atan2))); |
1942 | 6.92k | ext_rules_[{ext_inst_glslstd450_id, GLSLstd450Pow}].push_back( |
1943 | 6.92k | FoldFPBinaryOp(FoldFTranscendentalBinary(std::pow))); |
1944 | 6.92k | } |
1945 | 11.6k | } |
1946 | | } // namespace opt |
1947 | | } // namespace spvtools |