/src/spirv-tools/source/opt/scalar_replacement_pass.cpp
Line | Count | Source (jump to first uncovered line) |
1 | | // Copyright (c) 2017 Google Inc. |
2 | | // Modifications Copyright (C) 2024 Advanced Micro Devices, Inc. All rights |
3 | | // reserved. |
4 | | // |
5 | | // Licensed under the Apache License, Version 2.0 (the "License"); |
6 | | // you may not use this file except in compliance with the License. |
7 | | // You may obtain a copy of the License at |
8 | | // |
9 | | // http://www.apache.org/licenses/LICENSE-2.0 |
10 | | // |
11 | | // Unless required by applicable law or agreed to in writing, software |
12 | | // distributed under the License is distributed on an "AS IS" BASIS, |
13 | | // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. |
14 | | // See the License for the specific language governing permissions and |
15 | | // limitations under the License. |
16 | | |
17 | | #include "source/opt/scalar_replacement_pass.h" |
18 | | |
19 | | #include <algorithm> |
20 | | #include <queue> |
21 | | #include <tuple> |
22 | | #include <utility> |
23 | | |
24 | | #include "source/extensions.h" |
25 | | #include "source/opt/reflect.h" |
26 | | #include "source/opt/types.h" |
27 | | #include "source/util/make_unique.h" |
28 | | |
29 | | namespace spvtools { |
30 | | namespace opt { |
31 | | namespace { |
32 | | constexpr uint32_t kDebugValueOperandValueIndex = 5; |
33 | | constexpr uint32_t kDebugValueOperandExpressionIndex = 6; |
34 | | constexpr uint32_t kDebugDeclareOperandVariableIndex = 5; |
35 | | } // namespace |
36 | | |
37 | 22.7k | Pass::Status ScalarReplacementPass::Process() { |
38 | 22.7k | Status status = Status::SuccessWithoutChange; |
39 | 22.7k | for (auto& f : *get_module()) { |
40 | 22.0k | if (f.IsDeclaration()) { |
41 | 0 | continue; |
42 | 0 | } |
43 | | |
44 | 22.0k | Status functionStatus = ProcessFunction(&f); |
45 | 22.0k | if (functionStatus == Status::Failure) |
46 | 0 | return functionStatus; |
47 | 22.0k | else if (functionStatus == Status::SuccessWithChange) |
48 | 2.63k | status = functionStatus; |
49 | 22.0k | } |
50 | | |
51 | 22.7k | return status; |
52 | 22.7k | } |
53 | | |
54 | 22.0k | Pass::Status ScalarReplacementPass::ProcessFunction(Function* function) { |
55 | 22.0k | std::queue<Instruction*> worklist; |
56 | 22.0k | BasicBlock& entry = *function->begin(); |
57 | 198k | for (auto iter = entry.begin(); iter != entry.end(); ++iter) { |
58 | | // Function storage class OpVariables must appear as the first instructions |
59 | | // of the entry block. |
60 | 198k | if (iter->opcode() != spv::Op::OpVariable) break; |
61 | | |
62 | 176k | Instruction* varInst = &*iter; |
63 | 176k | if (CanReplaceVariable(varInst)) { |
64 | 15.5k | worklist.push(varInst); |
65 | 15.5k | } |
66 | 176k | } |
67 | | |
68 | 22.0k | Status status = Status::SuccessWithoutChange; |
69 | 37.6k | while (!worklist.empty()) { |
70 | 15.6k | Instruction* varInst = worklist.front(); |
71 | 15.6k | worklist.pop(); |
72 | | |
73 | 15.6k | Status var_status = ReplaceVariable(varInst, &worklist); |
74 | 15.6k | if (var_status == Status::Failure) |
75 | 0 | return var_status; |
76 | 15.6k | else if (var_status == Status::SuccessWithChange) |
77 | 15.6k | status = var_status; |
78 | 15.6k | } |
79 | | |
80 | 22.0k | return status; |
81 | 22.0k | } |
82 | | |
83 | | Pass::Status ScalarReplacementPass::ReplaceVariable( |
84 | 15.6k | Instruction* inst, std::queue<Instruction*>* worklist) { |
85 | 15.6k | std::vector<Instruction*> replacements; |
86 | 15.6k | if (!CreateReplacementVariables(inst, &replacements)) { |
87 | 0 | return Status::Failure; |
88 | 0 | } |
89 | | |
90 | 15.6k | std::vector<Instruction*> dead; |
91 | 15.6k | bool replaced_all_uses = get_def_use_mgr()->WhileEachUser( |
92 | 91.9k | inst, [this, &replacements, &dead](Instruction* user) { |
93 | 91.9k | if (user->GetCommonDebugOpcode() == CommonDebugInfoDebugDeclare) { |
94 | 0 | if (ReplaceWholeDebugDeclare(user, replacements)) { |
95 | 0 | dead.push_back(user); |
96 | 0 | return true; |
97 | 0 | } |
98 | 0 | return false; |
99 | 0 | } |
100 | 91.9k | if (user->GetCommonDebugOpcode() == CommonDebugInfoDebugValue) { |
101 | 0 | if (ReplaceWholeDebugValue(user, replacements)) { |
102 | 0 | dead.push_back(user); |
103 | 0 | return true; |
104 | 0 | } |
105 | 0 | return false; |
106 | 0 | } |
107 | 91.9k | if (!IsAnnotationInst(user->opcode())) { |
108 | 91.9k | switch (user->opcode()) { |
109 | 9.83k | case spv::Op::OpLoad: |
110 | 9.83k | if (ReplaceWholeLoad(user, replacements)) { |
111 | 9.83k | dead.push_back(user); |
112 | 9.83k | } else { |
113 | 0 | return false; |
114 | 0 | } |
115 | 9.83k | break; |
116 | 26.3k | case spv::Op::OpStore: |
117 | 26.3k | if (ReplaceWholeStore(user, replacements)) { |
118 | 26.3k | dead.push_back(user); |
119 | 26.3k | } else { |
120 | 0 | return false; |
121 | 0 | } |
122 | 26.3k | break; |
123 | 52.8k | case spv::Op::OpAccessChain: |
124 | 54.6k | case spv::Op::OpInBoundsAccessChain: |
125 | 54.6k | if (ReplaceAccessChain(user, replacements)) |
126 | 54.6k | dead.push_back(user); |
127 | 0 | else |
128 | 0 | return false; |
129 | 54.6k | break; |
130 | 54.6k | case spv::Op::OpName: |
131 | 1.09k | case spv::Op::OpMemberName: |
132 | 1.09k | break; |
133 | 0 | default: |
134 | 0 | assert(false && "Unexpected opcode"); |
135 | 0 | break; |
136 | 91.9k | } |
137 | 91.9k | } |
138 | 91.9k | return true; |
139 | 91.9k | }); |
140 | | |
141 | 15.6k | if (replaced_all_uses) { |
142 | 15.6k | dead.push_back(inst); |
143 | 15.6k | } else { |
144 | 0 | return Status::Failure; |
145 | 0 | } |
146 | | |
147 | | // If there are no dead instructions to clean up, return with no changes. |
148 | 15.6k | if (dead.empty()) return Status::SuccessWithoutChange; |
149 | | |
150 | | // Clean up some dead code. |
151 | 122k | while (!dead.empty()) { |
152 | 106k | Instruction* toKill = dead.back(); |
153 | 106k | dead.pop_back(); |
154 | 106k | context()->KillInst(toKill); |
155 | 106k | } |
156 | | |
157 | | // Attempt to further scalarize. |
158 | 129k | for (auto var : replacements) { |
159 | 129k | if (var->opcode() == spv::Op::OpVariable) { |
160 | 79.0k | if (get_def_use_mgr()->NumUsers(var) == 0) { |
161 | 11 | context()->KillInst(var); |
162 | 79.0k | } else if (CanReplaceVariable(var)) { |
163 | 57 | worklist->push(var); |
164 | 57 | } |
165 | 79.0k | } |
166 | 129k | } |
167 | | |
168 | 15.6k | return Status::SuccessWithChange; |
169 | 15.6k | } |
170 | | |
171 | | bool ScalarReplacementPass::ReplaceWholeDebugDeclare( |
172 | 0 | Instruction* dbg_decl, const std::vector<Instruction*>& replacements) { |
173 | | // Insert Deref operation to the front of the operation list of |dbg_decl|. |
174 | 0 | Instruction* dbg_expr = context()->get_def_use_mgr()->GetDef( |
175 | 0 | dbg_decl->GetSingleWordOperand(kDebugValueOperandExpressionIndex)); |
176 | 0 | auto* deref_expr = |
177 | 0 | context()->get_debug_info_mgr()->DerefDebugExpression(dbg_expr); |
178 | | |
179 | | // Add DebugValue instruction with Indexes operand and Deref operation. |
180 | 0 | int32_t idx = 0; |
181 | 0 | for (const auto* var : replacements) { |
182 | 0 | Instruction* insert_before = var->NextNode(); |
183 | 0 | while (insert_before->opcode() == spv::Op::OpVariable) |
184 | 0 | insert_before = insert_before->NextNode(); |
185 | 0 | assert(insert_before != nullptr && "unexpected end of list"); |
186 | 0 | Instruction* added_dbg_value = |
187 | 0 | context()->get_debug_info_mgr()->AddDebugValueForDecl( |
188 | 0 | dbg_decl, /*value_id=*/var->result_id(), |
189 | 0 | /*insert_before=*/insert_before, /*line=*/dbg_decl); |
190 | |
|
191 | 0 | if (added_dbg_value == nullptr) return false; |
192 | 0 | added_dbg_value->AddOperand( |
193 | 0 | {SPV_OPERAND_TYPE_ID, |
194 | 0 | {context()->get_constant_mgr()->GetSIntConstId(idx)}}); |
195 | 0 | added_dbg_value->SetOperand(kDebugValueOperandExpressionIndex, |
196 | 0 | {deref_expr->result_id()}); |
197 | 0 | if (context()->AreAnalysesValid(IRContext::Analysis::kAnalysisDefUse)) { |
198 | 0 | context()->get_def_use_mgr()->AnalyzeInstUse(added_dbg_value); |
199 | 0 | } |
200 | 0 | ++idx; |
201 | 0 | } |
202 | 0 | return true; |
203 | 0 | } |
204 | | |
205 | | bool ScalarReplacementPass::ReplaceWholeDebugValue( |
206 | 0 | Instruction* dbg_value, const std::vector<Instruction*>& replacements) { |
207 | 0 | int32_t idx = 0; |
208 | 0 | BasicBlock* block = context()->get_instr_block(dbg_value); |
209 | 0 | for (auto var : replacements) { |
210 | | // Clone the DebugValue. |
211 | 0 | std::unique_ptr<Instruction> new_dbg_value(dbg_value->Clone(context())); |
212 | 0 | uint32_t new_id = TakeNextId(); |
213 | 0 | if (new_id == 0) return false; |
214 | 0 | new_dbg_value->SetResultId(new_id); |
215 | | // Update 'Value' operand to the |replacements|. |
216 | 0 | new_dbg_value->SetOperand(kDebugValueOperandValueIndex, {var->result_id()}); |
217 | | // Append 'Indexes' operand. |
218 | 0 | new_dbg_value->AddOperand( |
219 | 0 | {SPV_OPERAND_TYPE_ID, |
220 | 0 | {context()->get_constant_mgr()->GetSIntConstId(idx)}}); |
221 | | // Insert the new DebugValue to the basic block. |
222 | 0 | auto* added_instr = dbg_value->InsertBefore(std::move(new_dbg_value)); |
223 | 0 | get_def_use_mgr()->AnalyzeInstDefUse(added_instr); |
224 | 0 | context()->set_instr_block(added_instr, block); |
225 | 0 | ++idx; |
226 | 0 | } |
227 | 0 | return true; |
228 | 0 | } |
229 | | |
230 | | bool ScalarReplacementPass::ReplaceWholeLoad( |
231 | 9.83k | Instruction* load, const std::vector<Instruction*>& replacements) { |
232 | | // Replaces the load of the entire composite with a load from each replacement |
233 | | // variable followed by a composite construction. |
234 | 9.83k | BasicBlock* block = context()->get_instr_block(load); |
235 | 9.83k | std::vector<Instruction*> loads; |
236 | 9.83k | loads.reserve(replacements.size()); |
237 | 9.83k | BasicBlock::iterator where(load); |
238 | 96.1k | for (auto var : replacements) { |
239 | | // Create a load of each replacement variable. |
240 | 96.1k | if (var->opcode() != spv::Op::OpVariable) { |
241 | 8.12k | loads.push_back(var); |
242 | 8.12k | continue; |
243 | 8.12k | } |
244 | | |
245 | 87.9k | Instruction* type = GetStorageType(var); |
246 | 87.9k | uint32_t loadId = TakeNextId(); |
247 | 87.9k | if (loadId == 0) { |
248 | 0 | return false; |
249 | 0 | } |
250 | 87.9k | std::unique_ptr<Instruction> newLoad( |
251 | 87.9k | new Instruction(context(), spv::Op::OpLoad, type->result_id(), loadId, |
252 | 87.9k | std::initializer_list<Operand>{ |
253 | 87.9k | {SPV_OPERAND_TYPE_ID, {var->result_id()}}})); |
254 | | // Copy memory access attributes which start at index 1. Index 0 is the |
255 | | // pointer to load. |
256 | 87.9k | for (uint32_t i = 1; i < load->NumInOperands(); ++i) { |
257 | 0 | Operand copy(load->GetInOperand(i)); |
258 | 0 | newLoad->AddOperand(std::move(copy)); |
259 | 0 | } |
260 | 87.9k | where = where.InsertBefore(std::move(newLoad)); |
261 | 87.9k | get_def_use_mgr()->AnalyzeInstDefUse(&*where); |
262 | 87.9k | context()->set_instr_block(&*where, block); |
263 | 87.9k | where->UpdateDebugInfoFrom(load); |
264 | 87.9k | loads.push_back(&*where); |
265 | 87.9k | } |
266 | | |
267 | | // Construct a new composite. |
268 | 9.83k | uint32_t compositeId = TakeNextId(); |
269 | 9.83k | if (compositeId == 0) { |
270 | 0 | return false; |
271 | 0 | } |
272 | 9.83k | where = load; |
273 | 9.83k | std::unique_ptr<Instruction> compositeConstruct( |
274 | 9.83k | new Instruction(context(), spv::Op::OpCompositeConstruct, load->type_id(), |
275 | 9.83k | compositeId, {})); |
276 | 96.1k | for (auto l : loads) { |
277 | 96.1k | Operand op(SPV_OPERAND_TYPE_ID, |
278 | 96.1k | std::initializer_list<uint32_t>{l->result_id()}); |
279 | 96.1k | compositeConstruct->AddOperand(std::move(op)); |
280 | 96.1k | } |
281 | 9.83k | where = where.InsertBefore(std::move(compositeConstruct)); |
282 | 9.83k | get_def_use_mgr()->AnalyzeInstDefUse(&*where); |
283 | 9.83k | where->UpdateDebugInfoFrom(load); |
284 | 9.83k | context()->set_instr_block(&*where, block); |
285 | 9.83k | context()->ReplaceAllUsesWith(load->result_id(), compositeId); |
286 | 9.83k | return true; |
287 | 9.83k | } |
288 | | |
289 | | bool ScalarReplacementPass::ReplaceWholeStore( |
290 | 26.3k | Instruction* store, const std::vector<Instruction*>& replacements) { |
291 | | // Replaces a store to the whole composite with a series of extract and stores |
292 | | // to each element. |
293 | 26.3k | uint32_t storeInput = store->GetSingleWordInOperand(1u); |
294 | 26.3k | BasicBlock* block = context()->get_instr_block(store); |
295 | 26.3k | BasicBlock::iterator where(store); |
296 | 26.3k | uint32_t elementIndex = 0; |
297 | 225k | for (auto var : replacements) { |
298 | | // Create the extract. |
299 | 225k | if (var->opcode() != spv::Op::OpVariable) { |
300 | 135k | elementIndex++; |
301 | 135k | continue; |
302 | 135k | } |
303 | | |
304 | 89.4k | Instruction* type = GetStorageType(var); |
305 | 89.4k | uint32_t extractId = TakeNextId(); |
306 | 89.4k | if (extractId == 0) { |
307 | 0 | return false; |
308 | 0 | } |
309 | 89.4k | std::unique_ptr<Instruction> extract(new Instruction( |
310 | 89.4k | context(), spv::Op::OpCompositeExtract, type->result_id(), extractId, |
311 | 89.4k | std::initializer_list<Operand>{ |
312 | 89.4k | {SPV_OPERAND_TYPE_ID, {storeInput}}, |
313 | 89.4k | {SPV_OPERAND_TYPE_LITERAL_INTEGER, {elementIndex++}}})); |
314 | 89.4k | auto iter = where.InsertBefore(std::move(extract)); |
315 | 89.4k | iter->UpdateDebugInfoFrom(store); |
316 | 89.4k | get_def_use_mgr()->AnalyzeInstDefUse(&*iter); |
317 | 89.4k | context()->set_instr_block(&*iter, block); |
318 | | |
319 | | // Create the store. |
320 | 89.4k | std::unique_ptr<Instruction> newStore( |
321 | 89.4k | new Instruction(context(), spv::Op::OpStore, 0, 0, |
322 | 89.4k | std::initializer_list<Operand>{ |
323 | 89.4k | {SPV_OPERAND_TYPE_ID, {var->result_id()}}, |
324 | 89.4k | {SPV_OPERAND_TYPE_ID, {extractId}}})); |
325 | | // Copy memory access attributes which start at index 2. Index 0 is the |
326 | | // pointer and index 1 is the data. |
327 | 89.4k | for (uint32_t i = 2; i < store->NumInOperands(); ++i) { |
328 | 14 | Operand copy(store->GetInOperand(i)); |
329 | 14 | newStore->AddOperand(std::move(copy)); |
330 | 14 | } |
331 | 89.4k | iter = where.InsertBefore(std::move(newStore)); |
332 | 89.4k | iter->UpdateDebugInfoFrom(store); |
333 | 89.4k | get_def_use_mgr()->AnalyzeInstDefUse(&*iter); |
334 | 89.4k | context()->set_instr_block(&*iter, block); |
335 | 89.4k | } |
336 | 26.3k | return true; |
337 | 26.3k | } |
338 | | |
339 | | bool ScalarReplacementPass::ReplaceAccessChain( |
340 | 54.6k | Instruction* chain, const std::vector<Instruction*>& replacements) { |
341 | | // Replaces the access chain with either another access chain (with one fewer |
342 | | // indexes) or a direct use of the replacement variable. |
343 | 54.6k | uint32_t indexId = chain->GetSingleWordInOperand(1u); |
344 | 54.6k | const Instruction* index = get_def_use_mgr()->GetDef(indexId); |
345 | 54.6k | int64_t indexValue = context() |
346 | 54.6k | ->get_constant_mgr() |
347 | 54.6k | ->GetConstantFromInst(index) |
348 | 54.6k | ->GetSignExtendedValue(); |
349 | 54.6k | if (indexValue < 0 || |
350 | 54.6k | indexValue >= static_cast<int64_t>(replacements.size())) { |
351 | | // Out of bounds access, this is illegal IR. Notice that OpAccessChain |
352 | | // indexing is 0-based, so we should also reject index == size-of-array. |
353 | 0 | return false; |
354 | 54.6k | } else { |
355 | 54.6k | const Instruction* var = replacements[static_cast<size_t>(indexValue)]; |
356 | 54.6k | if (chain->NumInOperands() > 2) { |
357 | | // Replace input access chain with another access chain. |
358 | 2.18k | BasicBlock::iterator chainIter(chain); |
359 | 2.18k | uint32_t replacementId = TakeNextId(); |
360 | 2.18k | if (replacementId == 0) { |
361 | 0 | return false; |
362 | 0 | } |
363 | 2.18k | std::unique_ptr<Instruction> replacementChain(new Instruction( |
364 | 2.18k | context(), chain->opcode(), chain->type_id(), replacementId, |
365 | 2.18k | std::initializer_list<Operand>{ |
366 | 2.18k | {SPV_OPERAND_TYPE_ID, {var->result_id()}}})); |
367 | | // Add the remaining indexes. |
368 | 4.36k | for (uint32_t i = 2; i < chain->NumInOperands(); ++i) { |
369 | 2.18k | Operand copy(chain->GetInOperand(i)); |
370 | 2.18k | replacementChain->AddOperand(std::move(copy)); |
371 | 2.18k | } |
372 | 2.18k | replacementChain->UpdateDebugInfoFrom(chain); |
373 | 2.18k | auto iter = chainIter.InsertBefore(std::move(replacementChain)); |
374 | 2.18k | get_def_use_mgr()->AnalyzeInstDefUse(&*iter); |
375 | 2.18k | context()->set_instr_block(&*iter, context()->get_instr_block(chain)); |
376 | 2.18k | context()->ReplaceAllUsesWith(chain->result_id(), replacementId); |
377 | 52.4k | } else { |
378 | | // Replace with a use of the variable. |
379 | 52.4k | context()->ReplaceAllUsesWith(chain->result_id(), var->result_id()); |
380 | 52.4k | } |
381 | 54.6k | } |
382 | | |
383 | 54.6k | return true; |
384 | 54.6k | } |
385 | | |
386 | | bool ScalarReplacementPass::CreateReplacementVariables( |
387 | 15.6k | Instruction* inst, std::vector<Instruction*>* replacements) { |
388 | 15.6k | Instruction* type = GetStorageType(inst); |
389 | | |
390 | 15.6k | std::unique_ptr<std::unordered_set<int64_t>> components_used = |
391 | 15.6k | GetUsedComponents(inst); |
392 | | |
393 | 15.6k | uint32_t elem = 0; |
394 | 15.6k | switch (type->opcode()) { |
395 | 13.7k | case spv::Op::OpTypeStruct: |
396 | 13.7k | type->ForEachInOperand( |
397 | 40.1k | [this, inst, &elem, replacements, &components_used](uint32_t* id) { |
398 | 40.1k | if (!components_used || components_used->count(elem)) { |
399 | 32.5k | CreateVariable(*id, inst, elem, replacements); |
400 | 32.5k | } else { |
401 | 7.59k | replacements->push_back(GetUndef(*id)); |
402 | 7.59k | } |
403 | 40.1k | elem++; |
404 | 40.1k | }); |
405 | 13.7k | break; |
406 | 1.84k | case spv::Op::OpTypeArray: |
407 | 91.2k | for (uint32_t i = 0; i != GetArrayLength(type); ++i) { |
408 | 89.3k | if (!components_used || components_used->count(i)) { |
409 | 46.4k | CreateVariable(type->GetSingleWordInOperand(0u), inst, i, |
410 | 46.4k | replacements); |
411 | 46.4k | } else { |
412 | 42.9k | uint32_t element_type_id = type->GetSingleWordInOperand(0); |
413 | 42.9k | replacements->push_back(GetUndef(element_type_id)); |
414 | 42.9k | } |
415 | 89.3k | } |
416 | 1.84k | break; |
417 | | |
418 | 0 | case spv::Op::OpTypeMatrix: |
419 | 0 | case spv::Op::OpTypeVector: |
420 | 0 | for (uint32_t i = 0; i != GetNumElements(type); ++i) { |
421 | 0 | CreateVariable(type->GetSingleWordInOperand(0u), inst, i, replacements); |
422 | 0 | } |
423 | 0 | break; |
424 | | |
425 | 0 | default: |
426 | 0 | assert(false && "Unexpected type."); |
427 | 0 | break; |
428 | 15.6k | } |
429 | | |
430 | 15.6k | TransferAnnotations(inst, replacements); |
431 | 15.6k | return std::find(replacements->begin(), replacements->end(), nullptr) == |
432 | 15.6k | replacements->end(); |
433 | 15.6k | } |
434 | | |
435 | 50.4k | Instruction* ScalarReplacementPass::GetUndef(uint32_t type_id) { |
436 | 50.4k | return get_def_use_mgr()->GetDef(Type2Undef(type_id)); |
437 | 50.4k | } |
438 | | |
439 | | void ScalarReplacementPass::TransferAnnotations( |
440 | 15.6k | const Instruction* source, std::vector<Instruction*>* replacements) { |
441 | | // Only transfer invariant and restrict decorations on the variable. There are |
442 | | // no type or member decorations that are necessary to transfer. |
443 | 15.6k | for (auto inst : |
444 | 15.6k | get_decoration_mgr()->GetDecorationsFor(source->result_id(), false)) { |
445 | 7 | assert(inst->opcode() == spv::Op::OpDecorate); |
446 | 7 | auto decoration = spv::Decoration(inst->GetSingleWordInOperand(1u)); |
447 | 7 | if (decoration == spv::Decoration::Invariant || |
448 | 7 | decoration == spv::Decoration::Restrict) { |
449 | 195 | for (auto var : *replacements) { |
450 | 195 | if (var == nullptr) { |
451 | 0 | continue; |
452 | 0 | } |
453 | | |
454 | 195 | std::unique_ptr<Instruction> annotation(new Instruction( |
455 | 195 | context(), spv::Op::OpDecorate, 0, 0, |
456 | 195 | std::initializer_list<Operand>{ |
457 | 195 | {SPV_OPERAND_TYPE_ID, {var->result_id()}}, |
458 | 195 | {SPV_OPERAND_TYPE_DECORATION, {uint32_t(decoration)}}})); |
459 | 195 | for (uint32_t i = 2; i < inst->NumInOperands(); ++i) { |
460 | 0 | Operand copy(inst->GetInOperand(i)); |
461 | 0 | annotation->AddOperand(std::move(copy)); |
462 | 0 | } |
463 | 195 | context()->AddAnnotationInst(std::move(annotation)); |
464 | 195 | get_def_use_mgr()->AnalyzeInstUse(&*--context()->annotation_end()); |
465 | 195 | } |
466 | 7 | } |
467 | 7 | } |
468 | 15.6k | } |
469 | | |
470 | | void ScalarReplacementPass::CreateVariable( |
471 | | uint32_t type_id, Instruction* var_inst, uint32_t index, |
472 | 79.0k | std::vector<Instruction*>* replacements) { |
473 | 79.0k | uint32_t ptr_id = GetOrCreatePointerType(type_id); |
474 | 79.0k | uint32_t id = TakeNextId(); |
475 | | |
476 | 79.0k | if (id == 0) { |
477 | 0 | replacements->push_back(nullptr); |
478 | 0 | } |
479 | | |
480 | 79.0k | std::unique_ptr<Instruction> variable( |
481 | 79.0k | new Instruction(context(), spv::Op::OpVariable, ptr_id, id, |
482 | 79.0k | std::initializer_list<Operand>{ |
483 | 79.0k | {SPV_OPERAND_TYPE_STORAGE_CLASS, |
484 | 79.0k | {uint32_t(spv::StorageClass::Function)}}})); |
485 | | |
486 | 79.0k | BasicBlock* block = context()->get_instr_block(var_inst); |
487 | 79.0k | block->begin().InsertBefore(std::move(variable)); |
488 | 79.0k | Instruction* inst = &*block->begin(); |
489 | | |
490 | | // If varInst was initialized, make sure to initialize its replacement. |
491 | 79.0k | GetOrCreateInitialValue(var_inst, index, inst); |
492 | 79.0k | get_def_use_mgr()->AnalyzeInstDefUse(inst); |
493 | 79.0k | context()->set_instr_block(inst, block); |
494 | | |
495 | 79.0k | CopyDecorationsToVariable(var_inst, inst, index); |
496 | 79.0k | inst->UpdateDebugInfoFrom(var_inst); |
497 | | |
498 | 79.0k | replacements->push_back(inst); |
499 | 79.0k | } |
500 | | |
501 | 79.0k | uint32_t ScalarReplacementPass::GetOrCreatePointerType(uint32_t id) { |
502 | 79.0k | auto iter = pointee_to_pointer_.find(id); |
503 | 79.0k | if (iter != pointee_to_pointer_.end()) return iter->second; |
504 | | |
505 | 1.95k | analysis::TypeManager* type_mgr = context()->get_type_mgr(); |
506 | 1.95k | uint32_t ptr_type_id = |
507 | 1.95k | type_mgr->FindPointerToType(id, spv::StorageClass::Function); |
508 | 1.95k | pointee_to_pointer_[id] = ptr_type_id; |
509 | 1.95k | return ptr_type_id; |
510 | 79.0k | } |
511 | | |
512 | | void ScalarReplacementPass::GetOrCreateInitialValue(Instruction* source, |
513 | | uint32_t index, |
514 | 79.0k | Instruction* newVar) { |
515 | 79.0k | assert(source->opcode() == spv::Op::OpVariable); |
516 | 79.0k | if (source->NumInOperands() < 2) return; |
517 | | |
518 | 0 | uint32_t initId = source->GetSingleWordInOperand(1u); |
519 | 0 | uint32_t storageId = GetStorageType(newVar)->result_id(); |
520 | 0 | Instruction* init = get_def_use_mgr()->GetDef(initId); |
521 | 0 | uint32_t newInitId = 0; |
522 | | // TODO(dnovillo): Refactor this with constant propagation. |
523 | 0 | if (init->opcode() == spv::Op::OpConstantNull) { |
524 | | // Initialize to appropriate NULL. |
525 | 0 | auto iter = type_to_null_.find(storageId); |
526 | 0 | if (iter == type_to_null_.end()) { |
527 | 0 | newInitId = TakeNextId(); |
528 | 0 | type_to_null_[storageId] = newInitId; |
529 | 0 | context()->AddGlobalValue( |
530 | 0 | MakeUnique<Instruction>(context(), spv::Op::OpConstantNull, storageId, |
531 | 0 | newInitId, std::initializer_list<Operand>{})); |
532 | 0 | Instruction* newNull = &*--context()->types_values_end(); |
533 | 0 | get_def_use_mgr()->AnalyzeInstDefUse(newNull); |
534 | 0 | } else { |
535 | 0 | newInitId = iter->second; |
536 | 0 | } |
537 | 0 | } else if (IsSpecConstantInst(init->opcode())) { |
538 | | // Create a new constant extract. |
539 | 0 | newInitId = TakeNextId(); |
540 | 0 | context()->AddGlobalValue(MakeUnique<Instruction>( |
541 | 0 | context(), spv::Op::OpSpecConstantOp, storageId, newInitId, |
542 | 0 | std::initializer_list<Operand>{ |
543 | 0 | {SPV_OPERAND_TYPE_SPEC_CONSTANT_OP_NUMBER, |
544 | 0 | {uint32_t(spv::Op::OpCompositeExtract)}}, |
545 | 0 | {SPV_OPERAND_TYPE_ID, {init->result_id()}}, |
546 | 0 | {SPV_OPERAND_TYPE_LITERAL_INTEGER, {index}}})); |
547 | 0 | Instruction* newSpecConst = &*--context()->types_values_end(); |
548 | 0 | get_def_use_mgr()->AnalyzeInstDefUse(newSpecConst); |
549 | 0 | } else if (init->opcode() == spv::Op::OpConstantComposite) { |
550 | | // Get the appropriate index constant. |
551 | 0 | newInitId = init->GetSingleWordInOperand(index); |
552 | 0 | Instruction* element = get_def_use_mgr()->GetDef(newInitId); |
553 | 0 | if (element->opcode() == spv::Op::OpUndef) { |
554 | | // Undef is not a valid initializer for a variable. |
555 | 0 | newInitId = 0; |
556 | 0 | } |
557 | 0 | } else { |
558 | 0 | assert(false); |
559 | 0 | } |
560 | | |
561 | 0 | if (newInitId != 0) { |
562 | 0 | newVar->AddOperand({SPV_OPERAND_TYPE_ID, {newInitId}}); |
563 | 0 | } |
564 | 0 | } |
565 | | |
566 | | uint64_t ScalarReplacementPass::GetArrayLength( |
567 | 108k | const Instruction* arrayType) const { |
568 | 108k | assert(arrayType->opcode() == spv::Op::OpTypeArray); |
569 | 108k | const Instruction* length = |
570 | 108k | get_def_use_mgr()->GetDef(arrayType->GetSingleWordInOperand(1u)); |
571 | 108k | return context() |
572 | 108k | ->get_constant_mgr() |
573 | 108k | ->GetConstantFromInst(length) |
574 | 108k | ->GetZeroExtendedValue(); |
575 | 108k | } |
576 | | |
577 | 0 | uint64_t ScalarReplacementPass::GetNumElements(const Instruction* type) const { |
578 | 0 | assert(type->opcode() == spv::Op::OpTypeVector || |
579 | 0 | type->opcode() == spv::Op::OpTypeMatrix); |
580 | 0 | const Operand& op = type->GetInOperand(1u); |
581 | 0 | assert(op.words.size() <= 2); |
582 | 0 | uint64_t len = 0; |
583 | 0 | for (size_t i = 0; i != op.words.size(); ++i) { |
584 | 0 | len |= (static_cast<uint64_t>(op.words[i]) << (32ull * i)); |
585 | 0 | } |
586 | 0 | return len; |
587 | 0 | } |
588 | | |
589 | 10.0k | bool ScalarReplacementPass::IsSpecConstant(uint32_t id) const { |
590 | 10.0k | const Instruction* inst = get_def_use_mgr()->GetDef(id); |
591 | 10.0k | assert(inst); |
592 | 10.0k | return spvOpcodeIsSpecConstant(inst->opcode()); |
593 | 10.0k | } |
594 | | |
595 | | Instruction* ScalarReplacementPass::GetStorageType( |
596 | 552k | const Instruction* inst) const { |
597 | 552k | assert(inst->opcode() == spv::Op::OpVariable); |
598 | | |
599 | 552k | uint32_t ptrTypeId = inst->type_id(); |
600 | 552k | uint32_t typeId = |
601 | 552k | get_def_use_mgr()->GetDef(ptrTypeId)->GetSingleWordInOperand(1u); |
602 | 552k | return get_def_use_mgr()->GetDef(typeId); |
603 | 552k | } |
604 | | |
605 | | bool ScalarReplacementPass::CanReplaceVariable( |
606 | 255k | const Instruction* varInst) const { |
607 | 255k | assert(varInst->opcode() == spv::Op::OpVariable); |
608 | | |
609 | | // Can only replace function scope variables. |
610 | 255k | if (spv::StorageClass(varInst->GetSingleWordInOperand(0u)) != |
611 | 255k | spv::StorageClass::Function) { |
612 | 0 | return false; |
613 | 0 | } |
614 | | |
615 | 255k | if (!CheckTypeAnnotations(get_def_use_mgr()->GetDef(varInst->type_id()))) { |
616 | 18 | return false; |
617 | 18 | } |
618 | | |
619 | 255k | const Instruction* typeInst = GetStorageType(varInst); |
620 | 255k | if (!CheckType(typeInst)) { |
621 | 229k | return false; |
622 | 229k | } |
623 | | |
624 | 26.3k | if (!CheckAnnotations(varInst)) { |
625 | 1.50k | return false; |
626 | 1.50k | } |
627 | | |
628 | 24.8k | if (!CheckUses(varInst)) { |
629 | 9.26k | return false; |
630 | 9.26k | } |
631 | | |
632 | 15.6k | return true; |
633 | 24.8k | } |
634 | | |
635 | 255k | bool ScalarReplacementPass::CheckType(const Instruction* typeInst) const { |
636 | 255k | if (!CheckTypeAnnotations(typeInst)) { |
637 | 1.28k | return false; |
638 | 1.28k | } |
639 | | |
640 | 254k | switch (typeInst->opcode()) { |
641 | 18.4k | case spv::Op::OpTypeStruct: |
642 | | // Don't bother with empty structs or very large structs. |
643 | 18.4k | if (typeInst->NumInOperands() == 0 || |
644 | 18.4k | IsLargerThanSizeLimit(typeInst->NumInOperands())) { |
645 | 6 | return false; |
646 | 6 | } |
647 | 18.4k | return true; |
648 | 10.0k | case spv::Op::OpTypeArray: |
649 | 10.0k | if (IsSpecConstant(typeInst->GetSingleWordInOperand(1u))) { |
650 | 237 | return false; |
651 | 237 | } |
652 | 9.82k | if (IsLargerThanSizeLimit(GetArrayLength(typeInst))) { |
653 | 1.91k | return false; |
654 | 1.91k | } |
655 | 7.90k | return true; |
656 | | // TODO(alanbaker): Develop some heuristics for when this should be |
657 | | // re-enabled. |
658 | | //// Specifically including matrix and vector in an attempt to reduce the |
659 | | //// number of vector registers required. |
660 | | // case spv::Op::OpTypeMatrix: |
661 | | // case spv::Op::OpTypeVector: |
662 | | // if (IsLargerThanSizeLimit(GetNumElements(typeInst))) return false; |
663 | | // return true; |
664 | | |
665 | 1 | case spv::Op::OpTypeRuntimeArray: |
666 | 226k | default: |
667 | 226k | return false; |
668 | 254k | } |
669 | 254k | } |
670 | | |
671 | | bool ScalarReplacementPass::CheckTypeAnnotations( |
672 | 511k | const Instruction* typeInst) const { |
673 | 511k | for (auto inst : |
674 | 511k | get_decoration_mgr()->GetDecorationsFor(typeInst->result_id(), false)) { |
675 | 17.0k | uint32_t decoration; |
676 | 17.0k | if (inst->opcode() == spv::Op::OpDecorate || |
677 | 17.0k | inst->opcode() == spv::Op::OpDecorateId) { |
678 | 62 | decoration = inst->GetSingleWordInOperand(1u); |
679 | 16.9k | } else { |
680 | 16.9k | assert(inst->opcode() == spv::Op::OpMemberDecorate); |
681 | 16.9k | decoration = inst->GetSingleWordInOperand(2u); |
682 | 16.9k | } |
683 | | |
684 | 17.0k | switch (spv::Decoration(decoration)) { |
685 | 239 | case spv::Decoration::RowMajor: |
686 | 240 | case spv::Decoration::ColMajor: |
687 | 241 | case spv::Decoration::ArrayStride: |
688 | 249 | case spv::Decoration::MatrixStride: |
689 | 249 | case spv::Decoration::CPacked: |
690 | 547 | case spv::Decoration::Invariant: |
691 | 598 | case spv::Decoration::Restrict: |
692 | 608 | case spv::Decoration::Offset: |
693 | 608 | case spv::Decoration::Alignment: |
694 | 608 | case spv::Decoration::AlignmentId: |
695 | 608 | case spv::Decoration::MaxByteOffset: |
696 | 15.7k | case spv::Decoration::RelaxedPrecision: |
697 | 15.7k | case spv::Decoration::AliasedPointer: |
698 | 15.7k | case spv::Decoration::RestrictPointer: |
699 | 15.7k | break; |
700 | 1.30k | default: |
701 | 1.30k | return false; |
702 | 17.0k | } |
703 | 17.0k | } |
704 | | |
705 | 510k | return true; |
706 | 511k | } |
707 | | |
708 | 26.3k | bool ScalarReplacementPass::CheckAnnotations(const Instruction* varInst) const { |
709 | 26.3k | for (auto inst : |
710 | 26.3k | get_decoration_mgr()->GetDecorationsFor(varInst->result_id(), false)) { |
711 | 1.51k | assert(inst->opcode() == spv::Op::OpDecorate); |
712 | 1.51k | auto decoration = spv::Decoration(inst->GetSingleWordInOperand(1u)); |
713 | 1.51k | switch (decoration) { |
714 | 4 | case spv::Decoration::Invariant: |
715 | 9 | case spv::Decoration::Restrict: |
716 | 9 | case spv::Decoration::Alignment: |
717 | 9 | case spv::Decoration::AlignmentId: |
718 | 9 | case spv::Decoration::MaxByteOffset: |
719 | 9 | case spv::Decoration::AliasedPointer: |
720 | 9 | case spv::Decoration::RestrictPointer: |
721 | 9 | break; |
722 | 1.50k | default: |
723 | 1.50k | return false; |
724 | 1.51k | } |
725 | 1.51k | } |
726 | | |
727 | 24.8k | return true; |
728 | 26.3k | } |
729 | | |
730 | 24.8k | bool ScalarReplacementPass::CheckUses(const Instruction* inst) const { |
731 | 24.8k | VariableStats stats = {0, 0}; |
732 | 24.8k | bool ok = CheckUses(inst, &stats); |
733 | | |
734 | | // TODO(alanbaker/greg-lunarg): Add some meaningful heuristics about when |
735 | | // SRoA is costly, such as when the structure has many (unaccessed?) |
736 | | // members. |
737 | | |
738 | 24.8k | return ok; |
739 | 24.8k | } |
740 | | |
741 | | bool ScalarReplacementPass::CheckUses(const Instruction* inst, |
742 | 24.8k | VariableStats* stats) const { |
743 | 24.8k | uint64_t max_legal_index = GetMaxLegalIndex(inst); |
744 | | |
745 | 24.8k | bool ok = true; |
746 | 24.8k | get_def_use_mgr()->ForEachUse(inst, [this, max_legal_index, stats, &ok]( |
747 | 24.8k | const Instruction* user, |
748 | 281k | uint32_t index) { |
749 | 281k | if (user->GetCommonDebugOpcode() == CommonDebugInfoDebugDeclare || |
750 | 281k | user->GetCommonDebugOpcode() == CommonDebugInfoDebugValue) { |
751 | | // TODO: include num_partial_accesses if it uses Fragment operation or |
752 | | // DebugValue has Indexes operand. |
753 | 0 | stats->num_full_accesses++; |
754 | 0 | return; |
755 | 0 | } |
756 | | |
757 | | // Annotations are check as a group separately. |
758 | 281k | if (!IsAnnotationInst(user->opcode())) { |
759 | 281k | switch (user->opcode()) { |
760 | 174k | case spv::Op::OpAccessChain: |
761 | 178k | case spv::Op::OpInBoundsAccessChain: |
762 | 178k | if (index == 2u && user->NumInOperands() > 1) { |
763 | 178k | uint32_t id = user->GetSingleWordInOperand(1u); |
764 | 178k | const Instruction* opInst = get_def_use_mgr()->GetDef(id); |
765 | 178k | const auto* constant = |
766 | 178k | context()->get_constant_mgr()->GetConstantFromInst(opInst); |
767 | 178k | if (!constant) { |
768 | 54.8k | ok = false; |
769 | 123k | } else if (constant->GetZeroExtendedValue() >= max_legal_index) { |
770 | 36.0k | ok = false; |
771 | 87.6k | } else { |
772 | 87.6k | if (!CheckUsesRelaxed(user)) ok = false; |
773 | 87.6k | } |
774 | 178k | stats->num_partial_accesses++; |
775 | 178k | } else { |
776 | 0 | ok = false; |
777 | 0 | } |
778 | 178k | break; |
779 | 14.4k | case spv::Op::OpLoad: |
780 | 14.4k | if (!CheckLoad(user, index)) ok = false; |
781 | 14.4k | stats->num_full_accesses++; |
782 | 14.4k | break; |
783 | 82.2k | case spv::Op::OpStore: |
784 | 82.2k | if (!CheckStore(user, index)) ok = false; |
785 | 82.2k | stats->num_full_accesses++; |
786 | 82.2k | break; |
787 | 2.54k | case spv::Op::OpName: |
788 | 2.54k | case spv::Op::OpMemberName: |
789 | 2.54k | break; |
790 | 3.91k | default: |
791 | 3.91k | ok = false; |
792 | 3.91k | break; |
793 | 281k | } |
794 | 281k | } |
795 | 281k | }); |
796 | | |
797 | 24.8k | return ok; |
798 | 24.8k | } |
799 | | |
800 | 87.6k | bool ScalarReplacementPass::CheckUsesRelaxed(const Instruction* inst) const { |
801 | 87.6k | bool ok = true; |
802 | 87.6k | get_def_use_mgr()->ForEachUse( |
803 | 91.0k | inst, [this, &ok](const Instruction* user, uint32_t index) { |
804 | 91.0k | switch (user->opcode()) { |
805 | 0 | case spv::Op::OpAccessChain: |
806 | 0 | case spv::Op::OpInBoundsAccessChain: |
807 | 0 | if (index != 2u) { |
808 | 0 | ok = false; |
809 | 0 | } else { |
810 | 0 | if (!CheckUsesRelaxed(user)) ok = false; |
811 | 0 | } |
812 | 0 | break; |
813 | 56.4k | case spv::Op::OpLoad: |
814 | 56.4k | if (!CheckLoad(user, index)) ok = false; |
815 | 56.4k | break; |
816 | 34.3k | case spv::Op::OpStore: |
817 | 34.3k | if (!CheckStore(user, index)) ok = false; |
818 | 34.3k | break; |
819 | 0 | case spv::Op::OpImageTexelPointer: |
820 | 0 | if (!CheckImageTexelPointer(index)) ok = false; |
821 | 0 | break; |
822 | 0 | case spv::Op::OpExtInst: |
823 | 0 | if (user->GetCommonDebugOpcode() != CommonDebugInfoDebugDeclare || |
824 | 0 | !CheckDebugDeclare(index)) |
825 | 0 | ok = false; |
826 | 0 | break; |
827 | 293 | default: |
828 | 293 | ok = false; |
829 | 293 | break; |
830 | 91.0k | } |
831 | 91.0k | }); |
832 | | |
833 | 87.6k | return ok; |
834 | 87.6k | } |
835 | | |
836 | 0 | bool ScalarReplacementPass::CheckImageTexelPointer(uint32_t index) const { |
837 | 0 | return index == 2u; |
838 | 0 | } |
839 | | |
840 | | bool ScalarReplacementPass::CheckLoad(const Instruction* inst, |
841 | 70.9k | uint32_t index) const { |
842 | 70.9k | if (index != 2u) return false; |
843 | 70.9k | if (inst->NumInOperands() >= 2 && |
844 | 70.9k | inst->GetSingleWordInOperand(1u) & |
845 | 218 | uint32_t(spv::MemoryAccessMask::Volatile)) |
846 | 163 | return false; |
847 | 70.7k | return true; |
848 | 70.9k | } |
849 | | |
850 | | bool ScalarReplacementPass::CheckStore(const Instruction* inst, |
851 | 116k | uint32_t index) const { |
852 | 116k | if (index != 0u) return false; |
853 | 116k | if (inst->NumInOperands() >= 3 && |
854 | 116k | inst->GetSingleWordInOperand(2u) & |
855 | 483 | uint32_t(spv::MemoryAccessMask::Volatile)) |
856 | 270 | return false; |
857 | 116k | return true; |
858 | 116k | } |
859 | | |
860 | 0 | bool ScalarReplacementPass::CheckDebugDeclare(uint32_t index) const { |
861 | 0 | if (index != kDebugDeclareOperandVariableIndex) return false; |
862 | 0 | return true; |
863 | 0 | } |
864 | | |
865 | 28.2k | bool ScalarReplacementPass::IsLargerThanSizeLimit(uint64_t length) const { |
866 | 28.2k | if (max_num_elements_ == 0) { |
867 | 0 | return false; |
868 | 0 | } |
869 | 28.2k | return length > max_num_elements_; |
870 | 28.2k | } |
871 | | |
872 | | std::unique_ptr<std::unordered_set<int64_t>> |
873 | 15.6k | ScalarReplacementPass::GetUsedComponents(Instruction* inst) { |
874 | 15.6k | std::unique_ptr<std::unordered_set<int64_t>> result( |
875 | 15.6k | new std::unordered_set<int64_t>()); |
876 | | |
877 | 15.6k | analysis::DefUseManager* def_use_mgr = context()->get_def_use_mgr(); |
878 | | |
879 | 15.6k | def_use_mgr->WhileEachUser(inst, [&result, def_use_mgr, |
880 | 66.4k | this](Instruction* use) { |
881 | 66.4k | switch (use->opcode()) { |
882 | 9.45k | case spv::Op::OpLoad: { |
883 | | // Look for extract from the load. |
884 | 9.45k | std::vector<uint32_t> t; |
885 | 22.2k | if (def_use_mgr->WhileEachUser(use, [&t](Instruction* use2) { |
886 | 22.2k | if (use2->opcode() != spv::Op::OpCompositeExtract || |
887 | 22.2k | use2->NumInOperands() <= 1) { |
888 | 8.98k | return false; |
889 | 8.98k | } |
890 | 13.3k | t.push_back(use2->GetSingleWordInOperand(1)); |
891 | 13.3k | return true; |
892 | 22.2k | })) { |
893 | 470 | result->insert(t.begin(), t.end()); |
894 | 470 | return true; |
895 | 8.98k | } else { |
896 | 8.98k | result.reset(nullptr); |
897 | 8.98k | return false; |
898 | 8.98k | } |
899 | 9.45k | } |
900 | 1.09k | case spv::Op::OpName: |
901 | 1.09k | case spv::Op::OpMemberName: |
902 | 27.0k | case spv::Op::OpStore: |
903 | | // No components are used. |
904 | 27.0k | return true; |
905 | 28.1k | case spv::Op::OpAccessChain: |
906 | 29.8k | case spv::Op::OpInBoundsAccessChain: { |
907 | | // Add the first index it if is a constant. |
908 | | // TODO: Could be improved by checking if the address is used in a load. |
909 | 29.8k | analysis::ConstantManager* const_mgr = context()->get_constant_mgr(); |
910 | 29.8k | uint32_t index_id = use->GetSingleWordInOperand(1); |
911 | 29.8k | const analysis::Constant* index_const = |
912 | 29.8k | const_mgr->FindDeclaredConstant(index_id); |
913 | 29.8k | if (index_const) { |
914 | 29.8k | result->insert(index_const->GetSignExtendedValue()); |
915 | 29.8k | return true; |
916 | 29.8k | } else { |
917 | | // Could be any element. Assuming all are used. |
918 | 0 | result.reset(nullptr); |
919 | 0 | return false; |
920 | 0 | } |
921 | 29.8k | } |
922 | 7 | default: |
923 | | // We do not know what is happening. Have to assume the worst. |
924 | 7 | result.reset(nullptr); |
925 | 7 | return false; |
926 | 66.4k | } |
927 | 66.4k | }); |
928 | | |
929 | 15.6k | return result; |
930 | 15.6k | } |
931 | | |
932 | | uint64_t ScalarReplacementPass::GetMaxLegalIndex( |
933 | 24.8k | const Instruction* var_inst) const { |
934 | 24.8k | assert(var_inst->opcode() == spv::Op::OpVariable && |
935 | 24.8k | "|var_inst| must be a variable instruction."); |
936 | 24.8k | Instruction* type = GetStorageType(var_inst); |
937 | 24.8k | switch (type->opcode()) { |
938 | 17.9k | case spv::Op::OpTypeStruct: |
939 | 17.9k | return type->NumInOperands(); |
940 | 6.95k | case spv::Op::OpTypeArray: |
941 | 6.95k | return GetArrayLength(type); |
942 | 0 | case spv::Op::OpTypeMatrix: |
943 | 0 | case spv::Op::OpTypeVector: |
944 | 0 | return GetNumElements(type); |
945 | 0 | default: |
946 | 0 | return 0; |
947 | 24.8k | } |
948 | 0 | return 0; |
949 | 24.8k | } |
950 | | |
951 | | void ScalarReplacementPass::CopyDecorationsToVariable(Instruction* from, |
952 | | Instruction* to, |
953 | 79.0k | uint32_t member_index) { |
954 | 79.0k | CopyPointerDecorationsToVariable(from, to); |
955 | 79.0k | CopyNecessaryMemberDecorationsToVariable(from, to, member_index); |
956 | 79.0k | } |
957 | | |
958 | | void ScalarReplacementPass::CopyPointerDecorationsToVariable(Instruction* from, |
959 | 79.0k | Instruction* to) { |
960 | | // The RestrictPointer and AliasedPointer decorations are copied to all |
961 | | // members even if the new variable does not contain a pointer. It does |
962 | | // not hurt to do so. |
963 | 79.0k | for (auto dec_inst : |
964 | 79.0k | get_decoration_mgr()->GetDecorationsFor(from->result_id(), false)) { |
965 | 195 | uint32_t decoration; |
966 | 195 | decoration = dec_inst->GetSingleWordInOperand(1u); |
967 | 195 | switch (spv::Decoration(decoration)) { |
968 | 0 | case spv::Decoration::AliasedPointer: |
969 | 0 | case spv::Decoration::RestrictPointer: { |
970 | 0 | std::unique_ptr<Instruction> new_dec_inst(dec_inst->Clone(context())); |
971 | 0 | new_dec_inst->SetInOperand(0, {to->result_id()}); |
972 | 0 | context()->AddAnnotationInst(std::move(new_dec_inst)); |
973 | 0 | } break; |
974 | 195 | default: |
975 | 195 | break; |
976 | 195 | } |
977 | 195 | } |
978 | 79.0k | } |
979 | | |
980 | | void ScalarReplacementPass::CopyNecessaryMemberDecorationsToVariable( |
981 | 79.0k | Instruction* from, Instruction* to, uint32_t member_index) { |
982 | 79.0k | Instruction* type_inst = GetStorageType(from); |
983 | 79.0k | for (auto dec_inst : |
984 | 79.0k | get_decoration_mgr()->GetDecorationsFor(type_inst->result_id(), false)) { |
985 | 33.6k | uint32_t decoration; |
986 | 33.6k | if (dec_inst->opcode() == spv::Op::OpMemberDecorate) { |
987 | 33.6k | if (dec_inst->GetSingleWordInOperand(1) != member_index) { |
988 | 22.2k | continue; |
989 | 22.2k | } |
990 | | |
991 | 11.3k | decoration = dec_inst->GetSingleWordInOperand(2u); |
992 | 11.3k | switch (spv::Decoration(decoration)) { |
993 | 0 | case spv::Decoration::ArrayStride: |
994 | 0 | case spv::Decoration::Alignment: |
995 | 0 | case spv::Decoration::AlignmentId: |
996 | 0 | case spv::Decoration::MaxByteOffset: |
997 | 0 | case spv::Decoration::MaxByteOffsetId: |
998 | 11.0k | case spv::Decoration::RelaxedPrecision: { |
999 | 11.0k | std::unique_ptr<Instruction> new_dec_inst( |
1000 | 11.0k | new Instruction(context(), spv::Op::OpDecorate, 0, 0, {})); |
1001 | 11.0k | new_dec_inst->AddOperand( |
1002 | 11.0k | Operand(SPV_OPERAND_TYPE_ID, {to->result_id()})); |
1003 | 22.1k | for (uint32_t i = 2; i < dec_inst->NumInOperandWords(); ++i) { |
1004 | 11.0k | new_dec_inst->AddOperand(Operand(dec_inst->GetInOperand(i))); |
1005 | 11.0k | } |
1006 | 11.0k | context()->AddAnnotationInst(std::move(new_dec_inst)); |
1007 | 11.0k | } break; |
1008 | 331 | default: |
1009 | 331 | break; |
1010 | 11.3k | } |
1011 | 11.3k | } |
1012 | 33.6k | } |
1013 | 79.0k | } |
1014 | | |
1015 | | } // namespace opt |
1016 | | } // namespace spvtools |