/src/spirv-tools/source/opt/interface_var_sroa.cpp
Line | Count | Source |
1 | | // Copyright (c) 2022 Google LLC |
2 | | // |
3 | | // Licensed under the Apache License, Version 2.0 (the "License"); |
4 | | // you may not use this file except in compliance with the License. |
5 | | // You may obtain a copy of the License at |
6 | | // |
7 | | // http://www.apache.org/licenses/LICENSE-2.0 |
8 | | // |
9 | | // Unless required by applicable law or agreed to in writing, software |
10 | | // distributed under the License is distributed on an "AS IS" BASIS, |
11 | | // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. |
12 | | // See the License for the specific language governing permissions and |
13 | | // limitations under the License. |
14 | | |
15 | | #include "source/opt/interface_var_sroa.h" |
16 | | |
17 | | #include <iostream> |
18 | | |
19 | | #include "source/opt/decoration_manager.h" |
20 | | #include "source/opt/def_use_manager.h" |
21 | | #include "source/opt/function.h" |
22 | | #include "source/opt/log.h" |
23 | | #include "source/opt/type_manager.h" |
24 | | #include "source/util/make_unique.h" |
25 | | |
26 | | namespace spvtools { |
27 | | namespace opt { |
28 | | namespace { |
29 | | constexpr uint32_t kOpDecorateDecorationInOperandIndex = 1; |
30 | | constexpr uint32_t kOpDecorateLiteralInOperandIndex = 2; |
31 | | constexpr uint32_t kOpEntryPointInOperandInterface = 3; |
32 | | constexpr uint32_t kOpVariableStorageClassInOperandIndex = 0; |
33 | | constexpr uint32_t kOpTypeArrayElemTypeInOperandIndex = 0; |
34 | | constexpr uint32_t kOpTypeArrayLengthInOperandIndex = 1; |
35 | | constexpr uint32_t kOpTypeMatrixColCountInOperandIndex = 1; |
36 | | constexpr uint32_t kOpTypeMatrixColTypeInOperandIndex = 0; |
37 | | constexpr uint32_t kOpTypePtrTypeInOperandIndex = 1; |
38 | | constexpr uint32_t kOpConstantValueInOperandIndex = 0; |
39 | | |
40 | | // Get the length of the OpTypeArray |array_type|. |
41 | | uint32_t GetArrayLength(analysis::DefUseManager* def_use_mgr, |
42 | 0 | Instruction* array_type) { |
43 | 0 | assert(array_type->opcode() == spv::Op::OpTypeArray); |
44 | 0 | uint32_t const_int_id = |
45 | 0 | array_type->GetSingleWordInOperand(kOpTypeArrayLengthInOperandIndex); |
46 | 0 | Instruction* array_length_inst = def_use_mgr->GetDef(const_int_id); |
47 | 0 | assert(array_length_inst->opcode() == spv::Op::OpConstant); |
48 | 0 | return array_length_inst->GetSingleWordInOperand( |
49 | 0 | kOpConstantValueInOperandIndex); |
50 | 0 | } |
51 | | |
52 | | // Get the element type instruction of the OpTypeArray |array_type|. |
53 | | Instruction* GetArrayElementType(analysis::DefUseManager* def_use_mgr, |
54 | 0 | Instruction* array_type) { |
55 | 0 | assert(array_type->opcode() == spv::Op::OpTypeArray); |
56 | 0 | uint32_t elem_type_id = |
57 | 0 | array_type->GetSingleWordInOperand(kOpTypeArrayElemTypeInOperandIndex); |
58 | 0 | return def_use_mgr->GetDef(elem_type_id); |
59 | 0 | } |
60 | | |
61 | | // Get the column type instruction of the OpTypeMatrix |matrix_type|. |
62 | | Instruction* GetMatrixColumnType(analysis::DefUseManager* def_use_mgr, |
63 | 0 | Instruction* matrix_type) { |
64 | 0 | assert(matrix_type->opcode() == spv::Op::OpTypeMatrix); |
65 | 0 | uint32_t column_type_id = |
66 | 0 | matrix_type->GetSingleWordInOperand(kOpTypeMatrixColTypeInOperandIndex); |
67 | 0 | return def_use_mgr->GetDef(column_type_id); |
68 | 0 | } |
69 | | |
70 | | // Traverses the component type of OpTypeArray or OpTypeMatrix. Repeats it |
71 | | // |depth_to_component| times recursively and returns the component type. |
72 | | // |type_id| is the result id of the OpTypeArray or OpTypeMatrix instruction. |
73 | | uint32_t GetComponentTypeOfArrayMatrix(analysis::DefUseManager* def_use_mgr, |
74 | | uint32_t type_id, |
75 | 0 | uint32_t depth_to_component) { |
76 | 0 | if (depth_to_component == 0) return type_id; |
77 | | |
78 | 0 | Instruction* type_inst = def_use_mgr->GetDef(type_id); |
79 | 0 | if (type_inst->opcode() == spv::Op::OpTypeArray) { |
80 | 0 | uint32_t elem_type_id = |
81 | 0 | type_inst->GetSingleWordInOperand(kOpTypeArrayElemTypeInOperandIndex); |
82 | 0 | return GetComponentTypeOfArrayMatrix(def_use_mgr, elem_type_id, |
83 | 0 | depth_to_component - 1); |
84 | 0 | } |
85 | | |
86 | 0 | assert(type_inst->opcode() == spv::Op::OpTypeMatrix); |
87 | 0 | uint32_t column_type_id = |
88 | 0 | type_inst->GetSingleWordInOperand(kOpTypeMatrixColTypeInOperandIndex); |
89 | 0 | return GetComponentTypeOfArrayMatrix(def_use_mgr, column_type_id, |
90 | 0 | depth_to_component - 1); |
91 | 0 | } |
92 | | |
93 | | // Creates an OpDecorate instruction whose Target is |var_id| and Decoration is |
94 | | // |decoration|. Adds |literal| as an extra operand of the instruction. |
95 | | void CreateDecoration(analysis::DecorationManager* decoration_mgr, |
96 | | uint32_t var_id, spv::Decoration decoration, |
97 | 0 | uint32_t literal) { |
98 | 0 | std::vector<Operand> operands({ |
99 | 0 | {spv_operand_type_t::SPV_OPERAND_TYPE_ID, {var_id}}, |
100 | 0 | {spv_operand_type_t::SPV_OPERAND_TYPE_DECORATION, |
101 | 0 | {static_cast<uint32_t>(decoration)}}, |
102 | 0 | {spv_operand_type_t::SPV_OPERAND_TYPE_LITERAL_INTEGER, {literal}}, |
103 | 0 | }); |
104 | 0 | decoration_mgr->AddDecoration(spv::Op::OpDecorate, std::move(operands)); |
105 | 0 | } |
106 | | |
107 | | // Replaces load instructions with composite construct instructions in all the |
108 | | // users of the loads. |loads_to_composites| is the mapping from each load to |
109 | | // its corresponding OpCompositeConstruct. |
110 | | void ReplaceLoadWithCompositeConstruct( |
111 | | IRContext* context, |
112 | 0 | const std::unordered_map<Instruction*, Instruction*>& loads_to_composites) { |
113 | 0 | for (const auto& load_and_composite : loads_to_composites) { |
114 | 0 | Instruction* load = load_and_composite.first; |
115 | 0 | Instruction* composite_construct = load_and_composite.second; |
116 | |
|
117 | 0 | std::vector<Instruction*> users; |
118 | 0 | context->get_def_use_mgr()->ForEachUse( |
119 | 0 | load, [&users, composite_construct](Instruction* user, uint32_t index) { |
120 | 0 | user->GetOperand(index).words[0] = composite_construct->result_id(); |
121 | 0 | users.push_back(user); |
122 | 0 | }); |
123 | |
|
124 | 0 | for (Instruction* user : users) |
125 | 0 | context->get_def_use_mgr()->AnalyzeInstUse(user); |
126 | 0 | } |
127 | 0 | } |
128 | | |
129 | | // Returns the storage class of the instruction |var|. |
130 | 0 | spv::StorageClass GetStorageClass(Instruction* var) { |
131 | 0 | return static_cast<spv::StorageClass>( |
132 | 0 | var->GetSingleWordInOperand(kOpVariableStorageClassInOperandIndex)); |
133 | 0 | } |
134 | | |
135 | | } // namespace |
136 | | |
137 | | bool InterfaceVariableScalarReplacement::HasExtraArrayness( |
138 | 0 | Instruction& entry_point, Instruction* var) { |
139 | 0 | spv::ExecutionModel execution_model = |
140 | 0 | static_cast<spv::ExecutionModel>(entry_point.GetSingleWordInOperand(0)); |
141 | 0 | if (execution_model != spv::ExecutionModel::TessellationEvaluation && |
142 | 0 | execution_model != spv::ExecutionModel::TessellationControl) { |
143 | 0 | return false; |
144 | 0 | } |
145 | 0 | if (!context()->get_decoration_mgr()->HasDecoration( |
146 | 0 | var->result_id(), uint32_t(spv::Decoration::Patch))) { |
147 | 0 | if (execution_model == spv::ExecutionModel::TessellationControl) |
148 | 0 | return true; |
149 | 0 | return GetStorageClass(var) != spv::StorageClass::Output; |
150 | 0 | } |
151 | 0 | return false; |
152 | 0 | } |
153 | | |
154 | | bool InterfaceVariableScalarReplacement:: |
155 | | CheckExtraArraynessConflictBetweenEntries(Instruction* interface_var, |
156 | 0 | bool has_extra_arrayness) { |
157 | 0 | if (has_extra_arrayness) { |
158 | 0 | return !ReportErrorIfHasNoExtraArraynessForOtherEntry(interface_var); |
159 | 0 | } |
160 | 0 | return !ReportErrorIfHasExtraArraynessForOtherEntry(interface_var); |
161 | 0 | } |
162 | | |
163 | | bool InterfaceVariableScalarReplacement::GetVariableLocation( |
164 | 0 | Instruction* var, uint32_t* location) { |
165 | 0 | return !context()->get_decoration_mgr()->WhileEachDecoration( |
166 | 0 | var->result_id(), uint32_t(spv::Decoration::Location), |
167 | 0 | [location](const Instruction& inst) { |
168 | 0 | *location = |
169 | 0 | inst.GetSingleWordInOperand(kOpDecorateLiteralInOperandIndex); |
170 | 0 | return false; |
171 | 0 | }); |
172 | 0 | } |
173 | | |
174 | | bool InterfaceVariableScalarReplacement::GetVariableComponent( |
175 | 0 | Instruction* var, uint32_t* component) { |
176 | 0 | return !context()->get_decoration_mgr()->WhileEachDecoration( |
177 | 0 | var->result_id(), uint32_t(spv::Decoration::Component), |
178 | 0 | [component](const Instruction& inst) { |
179 | 0 | *component = |
180 | 0 | inst.GetSingleWordInOperand(kOpDecorateLiteralInOperandIndex); |
181 | 0 | return false; |
182 | 0 | }); |
183 | 0 | } |
184 | | |
185 | | std::vector<Instruction*> |
186 | | InterfaceVariableScalarReplacement::CollectInterfaceVariables( |
187 | 0 | Instruction& entry_point) { |
188 | 0 | std::vector<Instruction*> interface_vars; |
189 | 0 | for (uint32_t i = kOpEntryPointInOperandInterface; |
190 | 0 | i < entry_point.NumInOperands(); ++i) { |
191 | 0 | Instruction* interface_var = context()->get_def_use_mgr()->GetDef( |
192 | 0 | entry_point.GetSingleWordInOperand(i)); |
193 | 0 | assert(interface_var->opcode() == spv::Op::OpVariable); |
194 | | |
195 | 0 | spv::StorageClass storage_class = GetStorageClass(interface_var); |
196 | 0 | if (storage_class != spv::StorageClass::Input && |
197 | 0 | storage_class != spv::StorageClass::Output) { |
198 | 0 | continue; |
199 | 0 | } |
200 | | |
201 | 0 | interface_vars.push_back(interface_var); |
202 | 0 | } |
203 | 0 | return interface_vars; |
204 | 0 | } |
205 | | |
206 | | void InterfaceVariableScalarReplacement::KillInstructionAndUsers( |
207 | 0 | Instruction* inst) { |
208 | 0 | if (inst->opcode() == spv::Op::OpEntryPoint) { |
209 | 0 | return; |
210 | 0 | } |
211 | 0 | if (inst->opcode() != spv::Op::OpAccessChain) { |
212 | 0 | context()->KillInst(inst); |
213 | 0 | return; |
214 | 0 | } |
215 | 0 | std::vector<Instruction*> users; |
216 | 0 | context()->get_def_use_mgr()->ForEachUser( |
217 | 0 | inst, [&users](Instruction* user) { users.push_back(user); }); |
218 | 0 | for (auto user : users) { |
219 | 0 | context()->KillInst(user); |
220 | 0 | } |
221 | 0 | context()->KillInst(inst); |
222 | 0 | } |
223 | | |
224 | | void InterfaceVariableScalarReplacement::KillInstructionsAndUsers( |
225 | 0 | const std::vector<Instruction*>& insts) { |
226 | 0 | for (Instruction* inst : insts) { |
227 | 0 | KillInstructionAndUsers(inst); |
228 | 0 | } |
229 | 0 | } |
230 | | |
231 | | void InterfaceVariableScalarReplacement::KillLocationAndComponentDecorations( |
232 | 0 | uint32_t var_id) { |
233 | 0 | context()->get_decoration_mgr()->RemoveDecorationsFrom( |
234 | 0 | var_id, [](const Instruction& inst) { |
235 | 0 | spv::Decoration decoration = spv::Decoration( |
236 | 0 | inst.GetSingleWordInOperand(kOpDecorateDecorationInOperandIndex)); |
237 | 0 | return decoration == spv::Decoration::Location || |
238 | 0 | decoration == spv::Decoration::Component; |
239 | 0 | }); |
240 | 0 | } |
241 | | |
242 | | Pass::Status |
243 | | InterfaceVariableScalarReplacement::ReplaceInterfaceVariableWithScalars( |
244 | | Instruction* interface_var, Instruction* interface_var_type, |
245 | 0 | uint32_t location, uint32_t component, uint32_t extra_array_length) { |
246 | 0 | std::optional<NestedCompositeComponents> scalar_interface_vars = |
247 | 0 | CreateScalarInterfaceVarsForReplacement(interface_var_type, |
248 | 0 | GetStorageClass(interface_var), |
249 | 0 | extra_array_length); |
250 | |
|
251 | 0 | if (!scalar_interface_vars) { |
252 | 0 | return Status::Failure; |
253 | 0 | } |
254 | | |
255 | 0 | AddLocationAndComponentDecorations(*scalar_interface_vars, &location, |
256 | 0 | component); |
257 | 0 | KillLocationAndComponentDecorations(interface_var->result_id()); |
258 | |
|
259 | 0 | Status status = ReplaceInterfaceVarWith(interface_var, extra_array_length, |
260 | 0 | *scalar_interface_vars); |
261 | 0 | if (status == Status::Failure) { |
262 | 0 | return status; |
263 | 0 | } |
264 | | |
265 | 0 | context()->KillInst(interface_var); |
266 | 0 | return status; |
267 | 0 | } |
268 | | |
269 | | Pass::Status InterfaceVariableScalarReplacement::ReplaceInterfaceVarWith( |
270 | | Instruction* interface_var, uint32_t extra_array_length, |
271 | 0 | const NestedCompositeComponents& scalar_interface_vars) { |
272 | 0 | std::vector<Instruction*> users; |
273 | 0 | context()->get_def_use_mgr()->ForEachUser( |
274 | 0 | interface_var, [&users](Instruction* user) { users.push_back(user); }); |
275 | |
|
276 | 0 | std::vector<uint32_t> interface_var_component_indices; |
277 | 0 | std::unordered_map<Instruction*, Instruction*> loads_to_composites; |
278 | 0 | std::unordered_map<Instruction*, Instruction*> |
279 | 0 | loads_for_access_chain_to_composites; |
280 | 0 | if (extra_array_length != 0) { |
281 | | // Note that the extra arrayness is the first dimension of the array |
282 | | // interface variable. |
283 | 0 | for (uint32_t index = 0; index < extra_array_length; ++index) { |
284 | 0 | std::unordered_map<Instruction*, Instruction*> loads_to_component_values; |
285 | 0 | Status status = ReplaceComponentsOfInterfaceVarWith( |
286 | 0 | interface_var, users, scalar_interface_vars, |
287 | 0 | interface_var_component_indices, &index, &loads_to_component_values, |
288 | 0 | &loads_for_access_chain_to_composites); |
289 | 0 | if (status == Status::Failure) { |
290 | 0 | return Status::Failure; |
291 | 0 | } |
292 | 0 | AddComponentsToCompositesForLoads(loads_to_component_values, |
293 | 0 | &loads_to_composites, 0); |
294 | 0 | } |
295 | 0 | } else { |
296 | 0 | Status status = ReplaceComponentsOfInterfaceVarWith( |
297 | 0 | interface_var, users, scalar_interface_vars, |
298 | 0 | interface_var_component_indices, nullptr, &loads_to_composites, |
299 | 0 | &loads_for_access_chain_to_composites); |
300 | 0 | if (status == Status::Failure) { |
301 | 0 | return Status::Failure; |
302 | 0 | } |
303 | 0 | } |
304 | | |
305 | 0 | ReplaceLoadWithCompositeConstruct(context(), loads_to_composites); |
306 | 0 | ReplaceLoadWithCompositeConstruct(context(), |
307 | 0 | loads_for_access_chain_to_composites); |
308 | |
|
309 | 0 | KillInstructionsAndUsers(users); |
310 | 0 | return Status::SuccessWithChange; |
311 | 0 | } |
312 | | |
313 | | void InterfaceVariableScalarReplacement::AddLocationAndComponentDecorations( |
314 | | const NestedCompositeComponents& vars, uint32_t* location, |
315 | 0 | uint32_t component) { |
316 | 0 | if (!vars.HasMultipleComponents()) { |
317 | 0 | uint32_t var_id = vars.GetComponentVariable()->result_id(); |
318 | 0 | CreateDecoration(context()->get_decoration_mgr(), var_id, |
319 | 0 | spv::Decoration::Location, *location); |
320 | 0 | CreateDecoration(context()->get_decoration_mgr(), var_id, |
321 | 0 | spv::Decoration::Component, component); |
322 | 0 | ++(*location); |
323 | 0 | return; |
324 | 0 | } |
325 | 0 | for (const auto& var : vars.GetComponents()) { |
326 | 0 | AddLocationAndComponentDecorations(var, location, component); |
327 | 0 | } |
328 | 0 | } |
329 | | |
330 | | Pass::Status |
331 | | InterfaceVariableScalarReplacement::ReplaceComponentsOfInterfaceVarWith( |
332 | | Instruction* interface_var, |
333 | | const std::vector<Instruction*>& interface_var_users, |
334 | | const NestedCompositeComponents& scalar_interface_vars, |
335 | | std::vector<uint32_t>& interface_var_component_indices, |
336 | | const uint32_t* extra_array_index, |
337 | | std::unordered_map<Instruction*, Instruction*>* loads_to_composites, |
338 | | std::unordered_map<Instruction*, Instruction*>* |
339 | 0 | loads_for_access_chain_to_composites) { |
340 | 0 | if (!scalar_interface_vars.HasMultipleComponents()) { |
341 | 0 | for (Instruction* interface_var_user : interface_var_users) { |
342 | 0 | Status status = ReplaceComponentOfInterfaceVarWith( |
343 | 0 | interface_var, interface_var_user, |
344 | 0 | scalar_interface_vars.GetComponentVariable(), |
345 | 0 | interface_var_component_indices, extra_array_index, |
346 | 0 | loads_to_composites, loads_for_access_chain_to_composites); |
347 | 0 | if (status == Status::Failure) { |
348 | 0 | return Status::Failure; |
349 | 0 | } |
350 | 0 | } |
351 | 0 | return Status::SuccessWithChange; |
352 | 0 | } |
353 | 0 | return ReplaceMultipleComponentsOfInterfaceVarWith( |
354 | 0 | interface_var, interface_var_users, scalar_interface_vars.GetComponents(), |
355 | 0 | interface_var_component_indices, extra_array_index, loads_to_composites, |
356 | 0 | loads_for_access_chain_to_composites); |
357 | 0 | } |
358 | | |
359 | | Pass::Status |
360 | | InterfaceVariableScalarReplacement::ReplaceMultipleComponentsOfInterfaceVarWith( |
361 | | Instruction* interface_var, |
362 | | const std::vector<Instruction*>& interface_var_users, |
363 | | const std::vector<NestedCompositeComponents>& components, |
364 | | std::vector<uint32_t>& interface_var_component_indices, |
365 | | const uint32_t* extra_array_index, |
366 | | std::unordered_map<Instruction*, Instruction*>* loads_to_composites, |
367 | | std::unordered_map<Instruction*, Instruction*>* |
368 | 0 | loads_for_access_chain_to_composites) { |
369 | 0 | for (uint32_t i = 0; i < components.size(); ++i) { |
370 | 0 | interface_var_component_indices.push_back(i); |
371 | 0 | std::unordered_map<Instruction*, Instruction*> loads_to_component_values; |
372 | 0 | std::unordered_map<Instruction*, Instruction*> |
373 | 0 | loads_for_access_chain_to_component_values; |
374 | 0 | Status status = ReplaceComponentsOfInterfaceVarWith( |
375 | 0 | interface_var, interface_var_users, components[i], |
376 | 0 | interface_var_component_indices, extra_array_index, |
377 | 0 | &loads_to_component_values, |
378 | 0 | &loads_for_access_chain_to_component_values); |
379 | 0 | if (status == Status::Failure) { |
380 | 0 | return Status::Failure; |
381 | 0 | } |
382 | 0 | interface_var_component_indices.pop_back(); |
383 | |
|
384 | 0 | uint32_t depth_to_component = |
385 | 0 | static_cast<uint32_t>(interface_var_component_indices.size()); |
386 | 0 | AddComponentsToCompositesForLoads( |
387 | 0 | loads_for_access_chain_to_component_values, |
388 | 0 | loads_for_access_chain_to_composites, depth_to_component); |
389 | 0 | if (extra_array_index) ++depth_to_component; |
390 | 0 | AddComponentsToCompositesForLoads(loads_to_component_values, |
391 | 0 | loads_to_composites, depth_to_component); |
392 | 0 | } |
393 | 0 | return Status::SuccessWithChange; |
394 | 0 | } |
395 | | |
396 | | Pass::Status |
397 | | InterfaceVariableScalarReplacement::ReplaceComponentOfInterfaceVarWith( |
398 | | Instruction* interface_var, Instruction* interface_var_user, |
399 | | Instruction* scalar_var, |
400 | | const std::vector<uint32_t>& interface_var_component_indices, |
401 | | const uint32_t* extra_array_index, |
402 | | std::unordered_map<Instruction*, Instruction*>* loads_to_component_values, |
403 | | std::unordered_map<Instruction*, Instruction*>* |
404 | 0 | loads_for_access_chain_to_component_values) { |
405 | 0 | spv::Op opcode = interface_var_user->opcode(); |
406 | 0 | if (opcode == spv::Op::OpStore) { |
407 | 0 | uint32_t value_id = interface_var_user->GetSingleWordInOperand(1); |
408 | 0 | StoreComponentOfValueToScalarVar(value_id, interface_var_component_indices, |
409 | 0 | scalar_var, extra_array_index, |
410 | 0 | interface_var_user); |
411 | 0 | return Status::SuccessWithChange; |
412 | 0 | } |
413 | 0 | if (opcode == spv::Op::OpLoad) { |
414 | 0 | Instruction* scalar_load = |
415 | 0 | LoadScalarVar(scalar_var, extra_array_index, interface_var_user); |
416 | 0 | if (scalar_load == nullptr) { |
417 | 0 | return Status::Failure; |
418 | 0 | } |
419 | 0 | loads_to_component_values->insert({interface_var_user, scalar_load}); |
420 | 0 | return Status::SuccessWithChange; |
421 | 0 | } |
422 | | |
423 | | // Copy OpName and annotation instructions only once. Therefore, we create |
424 | | // them only for the first element of the extra array. |
425 | 0 | if (extra_array_index && *extra_array_index != 0) |
426 | 0 | return Status::SuccessWithChange; |
427 | | |
428 | 0 | if (opcode == spv::Op::OpDecorateId || opcode == spv::Op::OpDecorateString || |
429 | 0 | opcode == spv::Op::OpDecorate) { |
430 | 0 | CloneAnnotationForVariable(interface_var_user, scalar_var->result_id()); |
431 | 0 | return Status::SuccessWithChange; |
432 | 0 | } |
433 | | |
434 | 0 | if (opcode == spv::Op::OpName) { |
435 | 0 | std::unique_ptr<Instruction> new_inst(interface_var_user->Clone(context())); |
436 | 0 | new_inst->SetInOperand(0, {scalar_var->result_id()}); |
437 | 0 | context()->AddDebug2Inst(std::move(new_inst)); |
438 | 0 | return Status::SuccessWithChange; |
439 | 0 | } |
440 | | |
441 | 0 | if (opcode == spv::Op::OpEntryPoint) { |
442 | 0 | if (ReplaceInterfaceVarInEntryPoint(interface_var, interface_var_user, |
443 | 0 | scalar_var->result_id())) { |
444 | 0 | return Status::SuccessWithChange; |
445 | 0 | } |
446 | 0 | return Status::Failure; |
447 | 0 | } |
448 | | |
449 | 0 | if (opcode == spv::Op::OpAccessChain) { |
450 | 0 | ReplaceAccessChainWith(interface_var_user, interface_var_component_indices, |
451 | 0 | scalar_var, |
452 | 0 | loads_for_access_chain_to_component_values); |
453 | 0 | return Status::SuccessWithChange; |
454 | 0 | } |
455 | | |
456 | 0 | std::string message("Unhandled instruction"); |
457 | 0 | message += "\n " + interface_var_user->PrettyPrint( |
458 | 0 | SPV_BINARY_TO_TEXT_OPTION_FRIENDLY_NAMES); |
459 | 0 | message += |
460 | 0 | "\nfor interface variable scalar replacement\n " + |
461 | 0 | interface_var->PrettyPrint(SPV_BINARY_TO_TEXT_OPTION_FRIENDLY_NAMES); |
462 | 0 | context()->consumer()(SPV_MSG_ERROR, "", {0, 0, 0}, message.c_str()); |
463 | 0 | return Status::Failure; |
464 | 0 | } |
465 | | |
466 | | void InterfaceVariableScalarReplacement::UseBaseAccessChainForAccessChain( |
467 | 0 | Instruction* access_chain, Instruction* base_access_chain) { |
468 | 0 | assert(base_access_chain->opcode() == spv::Op::OpAccessChain && |
469 | 0 | access_chain->opcode() == spv::Op::OpAccessChain && |
470 | 0 | access_chain->GetSingleWordInOperand(0) == |
471 | 0 | base_access_chain->result_id()); |
472 | 0 | Instruction::OperandList new_operands; |
473 | 0 | for (uint32_t i = 0; i < base_access_chain->NumInOperands(); ++i) { |
474 | 0 | new_operands.emplace_back(base_access_chain->GetInOperand(i)); |
475 | 0 | } |
476 | 0 | for (uint32_t i = 1; i < access_chain->NumInOperands(); ++i) { |
477 | 0 | new_operands.emplace_back(access_chain->GetInOperand(i)); |
478 | 0 | } |
479 | 0 | access_chain->SetInOperands(std::move(new_operands)); |
480 | 0 | } |
481 | | |
482 | | Instruction* InterfaceVariableScalarReplacement::CreateAccessChainToVar( |
483 | | uint32_t var_type_id, Instruction* var, |
484 | | const std::vector<uint32_t>& index_ids, Instruction* insert_before, |
485 | 0 | uint32_t* component_type_id) { |
486 | 0 | analysis::DefUseManager* def_use_mgr = context()->get_def_use_mgr(); |
487 | 0 | *component_type_id = GetComponentTypeOfArrayMatrix( |
488 | 0 | def_use_mgr, var_type_id, static_cast<uint32_t>(index_ids.size())); |
489 | |
|
490 | 0 | uint32_t ptr_type_id = |
491 | 0 | GetPointerType(*component_type_id, GetStorageClass(var)); |
492 | |
|
493 | 0 | uint32_t new_id = TakeNextId(); |
494 | 0 | if (new_id == 0) { |
495 | 0 | return nullptr; |
496 | 0 | } |
497 | 0 | std::unique_ptr<Instruction> new_access_chain( |
498 | 0 | new Instruction(context(), spv::Op::OpAccessChain, ptr_type_id, new_id, |
499 | 0 | std::initializer_list<Operand>{ |
500 | 0 | {SPV_OPERAND_TYPE_ID, {var->result_id()}}})); |
501 | 0 | for (uint32_t index_id : index_ids) { |
502 | 0 | new_access_chain->AddOperand({SPV_OPERAND_TYPE_ID, {index_id}}); |
503 | 0 | } |
504 | |
|
505 | 0 | Instruction* inst = new_access_chain.get(); |
506 | 0 | def_use_mgr->AnalyzeInstDefUse(inst); |
507 | 0 | insert_before->InsertBefore(std::move(new_access_chain)); |
508 | 0 | return inst; |
509 | 0 | } |
510 | | |
511 | | Instruction* InterfaceVariableScalarReplacement::CreateAccessChainWithIndex( |
512 | | uint32_t component_type_id, Instruction* var, uint32_t index, |
513 | 0 | Instruction* insert_before) { |
514 | 0 | uint32_t ptr_type_id = |
515 | 0 | GetPointerType(component_type_id, GetStorageClass(var)); |
516 | 0 | uint32_t index_id = context()->get_constant_mgr()->GetUIntConstId(index); |
517 | 0 | uint32_t new_id = TakeNextId(); |
518 | 0 | if (new_id == 0) { |
519 | 0 | return nullptr; |
520 | 0 | } |
521 | 0 | std::unique_ptr<Instruction> new_access_chain( |
522 | 0 | new Instruction(context(), spv::Op::OpAccessChain, ptr_type_id, new_id, |
523 | 0 | std::initializer_list<Operand>{ |
524 | 0 | {SPV_OPERAND_TYPE_ID, {var->result_id()}}, |
525 | 0 | {SPV_OPERAND_TYPE_ID, {index_id}}, |
526 | 0 | })); |
527 | 0 | Instruction* inst = new_access_chain.get(); |
528 | 0 | context()->get_def_use_mgr()->AnalyzeInstDefUse(inst); |
529 | 0 | insert_before->InsertBefore(std::move(new_access_chain)); |
530 | 0 | return inst; |
531 | 0 | } |
532 | | |
533 | | void InterfaceVariableScalarReplacement::ReplaceAccessChainWith( |
534 | | Instruction* access_chain, |
535 | | const std::vector<uint32_t>& interface_var_component_indices, |
536 | | Instruction* scalar_var, |
537 | 0 | std::unordered_map<Instruction*, Instruction*>* loads_to_component_values) { |
538 | 0 | std::vector<uint32_t> indexes; |
539 | 0 | for (uint32_t i = 1; i < access_chain->NumInOperands(); ++i) { |
540 | 0 | indexes.push_back(access_chain->GetSingleWordInOperand(i)); |
541 | 0 | } |
542 | | |
543 | | // Note that we have a strong assumption that |access_chain| has only a single |
544 | | // index that is for the extra arrayness. |
545 | 0 | context()->get_def_use_mgr()->ForEachUser( |
546 | 0 | access_chain, |
547 | 0 | [this, access_chain, &indexes, &interface_var_component_indices, |
548 | 0 | scalar_var, loads_to_component_values](Instruction* user) { |
549 | 0 | switch (user->opcode()) { |
550 | 0 | case spv::Op::OpAccessChain: { |
551 | 0 | UseBaseAccessChainForAccessChain(user, access_chain); |
552 | 0 | ReplaceAccessChainWith(user, interface_var_component_indices, |
553 | 0 | scalar_var, loads_to_component_values); |
554 | 0 | return; |
555 | 0 | } |
556 | 0 | case spv::Op::OpStore: { |
557 | 0 | uint32_t value_id = user->GetSingleWordInOperand(1); |
558 | 0 | StoreComponentOfValueToAccessChainToScalarVar( |
559 | 0 | value_id, interface_var_component_indices, scalar_var, indexes, |
560 | 0 | user); |
561 | 0 | return; |
562 | 0 | } |
563 | 0 | case spv::Op::OpLoad: { |
564 | 0 | Instruction* value = |
565 | 0 | LoadAccessChainToVar(scalar_var, indexes, user); |
566 | 0 | loads_to_component_values->insert({user, value}); |
567 | 0 | return; |
568 | 0 | } |
569 | 0 | default: |
570 | 0 | break; |
571 | 0 | } |
572 | 0 | }); |
573 | 0 | } |
574 | | |
575 | | void InterfaceVariableScalarReplacement::CloneAnnotationForVariable( |
576 | 0 | Instruction* annotation_inst, uint32_t var_id) { |
577 | 0 | assert(annotation_inst->opcode() == spv::Op::OpDecorate || |
578 | 0 | annotation_inst->opcode() == spv::Op::OpDecorateId || |
579 | 0 | annotation_inst->opcode() == spv::Op::OpDecorateString); |
580 | 0 | std::unique_ptr<Instruction> new_inst(annotation_inst->Clone(context())); |
581 | 0 | new_inst->SetInOperand(0, {var_id}); |
582 | 0 | context()->AddAnnotationInst(std::move(new_inst)); |
583 | 0 | } |
584 | | |
585 | | bool InterfaceVariableScalarReplacement::ReplaceInterfaceVarInEntryPoint( |
586 | | Instruction* interface_var, Instruction* entry_point, |
587 | 0 | uint32_t scalar_var_id) { |
588 | 0 | analysis::DefUseManager* def_use_mgr = context()->get_def_use_mgr(); |
589 | 0 | uint32_t interface_var_id = interface_var->result_id(); |
590 | 0 | if (interface_vars_removed_from_entry_point_operands_.find( |
591 | 0 | interface_var_id) != |
592 | 0 | interface_vars_removed_from_entry_point_operands_.end()) { |
593 | 0 | entry_point->AddOperand({SPV_OPERAND_TYPE_ID, {scalar_var_id}}); |
594 | 0 | def_use_mgr->AnalyzeInstUse(entry_point); |
595 | 0 | return true; |
596 | 0 | } |
597 | | |
598 | 0 | bool success = !entry_point->WhileEachInId( |
599 | 0 | [&interface_var_id, &scalar_var_id](uint32_t* id) { |
600 | 0 | if (*id == interface_var_id) { |
601 | 0 | *id = scalar_var_id; |
602 | 0 | return false; |
603 | 0 | } |
604 | 0 | return true; |
605 | 0 | }); |
606 | 0 | if (!success) { |
607 | 0 | std::string message( |
608 | 0 | "interface variable is not an operand of the entry point"); |
609 | 0 | message += "\n " + interface_var->PrettyPrint( |
610 | 0 | SPV_BINARY_TO_TEXT_OPTION_FRIENDLY_NAMES); |
611 | 0 | message += "\n " + entry_point->PrettyPrint( |
612 | 0 | SPV_BINARY_TO_TEXT_OPTION_FRIENDLY_NAMES); |
613 | 0 | context()->consumer()(SPV_MSG_ERROR, "", {0, 0, 0}, message.c_str()); |
614 | 0 | return false; |
615 | 0 | } |
616 | | |
617 | 0 | def_use_mgr->AnalyzeInstUse(entry_point); |
618 | 0 | interface_vars_removed_from_entry_point_operands_.insert(interface_var_id); |
619 | 0 | return true; |
620 | 0 | } |
621 | | |
622 | | uint32_t InterfaceVariableScalarReplacement::GetPointeeTypeIdOfVar( |
623 | 0 | Instruction* var) { |
624 | 0 | assert(var->opcode() == spv::Op::OpVariable); |
625 | | |
626 | 0 | uint32_t ptr_type_id = var->type_id(); |
627 | 0 | analysis::DefUseManager* def_use_mgr = context()->get_def_use_mgr(); |
628 | 0 | Instruction* ptr_type_inst = def_use_mgr->GetDef(ptr_type_id); |
629 | |
|
630 | 0 | assert(ptr_type_inst->opcode() == spv::Op::OpTypePointer && |
631 | 0 | "Variable must have a pointer type."); |
632 | 0 | return ptr_type_inst->GetSingleWordInOperand(kOpTypePtrTypeInOperandIndex); |
633 | 0 | } |
634 | | |
635 | | void InterfaceVariableScalarReplacement::StoreComponentOfValueToScalarVar( |
636 | | uint32_t value_id, const std::vector<uint32_t>& component_indices, |
637 | | Instruction* scalar_var, const uint32_t* extra_array_index, |
638 | 0 | Instruction* insert_before) { |
639 | 0 | uint32_t component_type_id = GetPointeeTypeIdOfVar(scalar_var); |
640 | 0 | Instruction* ptr = scalar_var; |
641 | 0 | if (extra_array_index) { |
642 | 0 | auto* ty_mgr = context()->get_type_mgr(); |
643 | 0 | analysis::Array* array_type = ty_mgr->GetType(component_type_id)->AsArray(); |
644 | 0 | assert(array_type != nullptr); |
645 | 0 | component_type_id = ty_mgr->GetTypeInstruction(array_type->element_type()); |
646 | 0 | ptr = CreateAccessChainWithIndex(component_type_id, scalar_var, |
647 | 0 | *extra_array_index, insert_before); |
648 | 0 | if (ptr == nullptr) { |
649 | 0 | return; |
650 | 0 | } |
651 | 0 | } |
652 | | |
653 | 0 | StoreComponentOfValueTo(component_type_id, value_id, component_indices, ptr, |
654 | 0 | extra_array_index, insert_before); |
655 | 0 | } |
656 | | |
657 | | Instruction* InterfaceVariableScalarReplacement::LoadScalarVar( |
658 | | Instruction* scalar_var, const uint32_t* extra_array_index, |
659 | 0 | Instruction* insert_before) { |
660 | 0 | uint32_t component_type_id = GetPointeeTypeIdOfVar(scalar_var); |
661 | 0 | Instruction* ptr = scalar_var; |
662 | 0 | if (extra_array_index) { |
663 | 0 | auto* ty_mgr = context()->get_type_mgr(); |
664 | 0 | analysis::Array* array_type = ty_mgr->GetType(component_type_id)->AsArray(); |
665 | 0 | assert(array_type != nullptr); |
666 | 0 | component_type_id = ty_mgr->GetTypeInstruction(array_type->element_type()); |
667 | 0 | ptr = CreateAccessChainWithIndex(component_type_id, scalar_var, |
668 | 0 | *extra_array_index, insert_before); |
669 | 0 | if (ptr == nullptr) { |
670 | 0 | return nullptr; |
671 | 0 | } |
672 | 0 | } |
673 | | |
674 | 0 | return CreateLoad(component_type_id, ptr, insert_before); |
675 | 0 | } |
676 | | |
677 | | Instruction* InterfaceVariableScalarReplacement::CreateLoad( |
678 | 0 | uint32_t type_id, Instruction* ptr, Instruction* insert_before) { |
679 | 0 | uint32_t new_id = TakeNextId(); |
680 | 0 | if (new_id == 0) { |
681 | 0 | return nullptr; |
682 | 0 | } |
683 | 0 | std::unique_ptr<Instruction> load( |
684 | 0 | new Instruction(context(), spv::Op::OpLoad, type_id, new_id, |
685 | 0 | std::initializer_list<Operand>{ |
686 | 0 | {SPV_OPERAND_TYPE_ID, {ptr->result_id()}}})); |
687 | 0 | Instruction* load_inst = load.get(); |
688 | 0 | context()->get_def_use_mgr()->AnalyzeInstDefUse(load_inst); |
689 | 0 | insert_before->InsertBefore(std::move(load)); |
690 | 0 | return load_inst; |
691 | 0 | } |
692 | | |
693 | | void InterfaceVariableScalarReplacement::StoreComponentOfValueTo( |
694 | | uint32_t component_type_id, uint32_t value_id, |
695 | | const std::vector<uint32_t>& component_indices, Instruction* ptr, |
696 | 0 | const uint32_t* extra_array_index, Instruction* insert_before) { |
697 | 0 | std::unique_ptr<Instruction> composite_extract(CreateCompositeExtract( |
698 | 0 | component_type_id, value_id, component_indices, extra_array_index)); |
699 | 0 | if (composite_extract == nullptr) { |
700 | 0 | return; |
701 | 0 | } |
702 | | |
703 | 0 | std::unique_ptr<Instruction> new_store( |
704 | 0 | new Instruction(context(), spv::Op::OpStore)); |
705 | 0 | new_store->AddOperand({SPV_OPERAND_TYPE_ID, {ptr->result_id()}}); |
706 | 0 | new_store->AddOperand( |
707 | 0 | {SPV_OPERAND_TYPE_ID, {composite_extract->result_id()}}); |
708 | |
|
709 | 0 | analysis::DefUseManager* def_use_mgr = context()->get_def_use_mgr(); |
710 | 0 | def_use_mgr->AnalyzeInstDefUse(composite_extract.get()); |
711 | 0 | def_use_mgr->AnalyzeInstDefUse(new_store.get()); |
712 | |
|
713 | 0 | insert_before->InsertBefore(std::move(composite_extract)); |
714 | 0 | insert_before->InsertBefore(std::move(new_store)); |
715 | 0 | } |
716 | | |
717 | | Instruction* InterfaceVariableScalarReplacement::CreateCompositeExtract( |
718 | | uint32_t type_id, uint32_t composite_id, |
719 | 0 | const std::vector<uint32_t>& indexes, const uint32_t* extra_first_index) { |
720 | 0 | uint32_t component_id = TakeNextId(); |
721 | 0 | if (component_id == 0) { |
722 | 0 | return nullptr; |
723 | 0 | } |
724 | 0 | Instruction* composite_extract = new Instruction( |
725 | 0 | context(), spv::Op::OpCompositeExtract, type_id, component_id, |
726 | 0 | std::initializer_list<Operand>{{SPV_OPERAND_TYPE_ID, {composite_id}}}); |
727 | 0 | if (extra_first_index) { |
728 | 0 | composite_extract->AddOperand( |
729 | 0 | {SPV_OPERAND_TYPE_LITERAL_INTEGER, {*extra_first_index}}); |
730 | 0 | } |
731 | 0 | for (uint32_t index : indexes) { |
732 | 0 | composite_extract->AddOperand({SPV_OPERAND_TYPE_LITERAL_INTEGER, {index}}); |
733 | 0 | } |
734 | 0 | return composite_extract; |
735 | 0 | } |
736 | | |
737 | | void InterfaceVariableScalarReplacement:: |
738 | | StoreComponentOfValueToAccessChainToScalarVar( |
739 | | uint32_t value_id, const std::vector<uint32_t>& component_indices, |
740 | | Instruction* scalar_var, |
741 | | const std::vector<uint32_t>& access_chain_indices, |
742 | 0 | Instruction* insert_before) { |
743 | 0 | uint32_t component_type_id = GetPointeeTypeIdOfVar(scalar_var); |
744 | 0 | Instruction* ptr = scalar_var; |
745 | 0 | if (!access_chain_indices.empty()) { |
746 | 0 | ptr = CreateAccessChainToVar(component_type_id, scalar_var, |
747 | 0 | access_chain_indices, insert_before, |
748 | 0 | &component_type_id); |
749 | 0 | } |
750 | |
|
751 | 0 | StoreComponentOfValueTo(component_type_id, value_id, component_indices, ptr, |
752 | 0 | nullptr, insert_before); |
753 | 0 | } |
754 | | |
755 | | Instruction* InterfaceVariableScalarReplacement::LoadAccessChainToVar( |
756 | | Instruction* var, const std::vector<uint32_t>& indexes, |
757 | 0 | Instruction* insert_before) { |
758 | 0 | uint32_t component_type_id = GetPointeeTypeIdOfVar(var); |
759 | 0 | Instruction* ptr = var; |
760 | 0 | if (!indexes.empty()) { |
761 | 0 | ptr = CreateAccessChainToVar(component_type_id, var, indexes, insert_before, |
762 | 0 | &component_type_id); |
763 | 0 | if (ptr == nullptr) { |
764 | 0 | return nullptr; |
765 | 0 | } |
766 | 0 | } |
767 | | |
768 | 0 | return CreateLoad(component_type_id, ptr, insert_before); |
769 | 0 | } |
770 | | |
771 | | Instruction* |
772 | | InterfaceVariableScalarReplacement::CreateCompositeConstructForComponentOfLoad( |
773 | 0 | Instruction* load, uint32_t depth_to_component) { |
774 | 0 | analysis::DefUseManager* def_use_mgr = context()->get_def_use_mgr(); |
775 | 0 | uint32_t type_id = load->type_id(); |
776 | 0 | if (depth_to_component != 0) { |
777 | 0 | type_id = GetComponentTypeOfArrayMatrix(def_use_mgr, load->type_id(), |
778 | 0 | depth_to_component); |
779 | 0 | } |
780 | 0 | uint32_t new_id = TakeNextId(); |
781 | 0 | if (new_id == 0) { |
782 | 0 | return nullptr; |
783 | 0 | } |
784 | 0 | std::unique_ptr<Instruction> new_composite_construct(new Instruction( |
785 | 0 | context(), spv::Op::OpCompositeConstruct, type_id, new_id, {})); |
786 | 0 | Instruction* composite_construct = new_composite_construct.get(); |
787 | 0 | def_use_mgr->AnalyzeInstDefUse(composite_construct); |
788 | | |
789 | | // Insert |new_composite_construct| after |load|. When there are multiple |
790 | | // recursive composite construct instructions for a load, we have to place the |
791 | | // composite construct with a lower depth later because it constructs the |
792 | | // composite that contains other composites with lower depths. |
793 | 0 | auto* insert_before = load->NextNode(); |
794 | 0 | while (true) { |
795 | 0 | auto itr = |
796 | 0 | composite_ids_to_component_depths.find(insert_before->result_id()); |
797 | 0 | if (itr == composite_ids_to_component_depths.end()) break; |
798 | 0 | if (itr->second <= depth_to_component) break; |
799 | 0 | insert_before = insert_before->NextNode(); |
800 | 0 | } |
801 | 0 | insert_before->InsertBefore(std::move(new_composite_construct)); |
802 | 0 | composite_ids_to_component_depths.insert({new_id, depth_to_component}); |
803 | 0 | return composite_construct; |
804 | 0 | } |
805 | | |
806 | | void InterfaceVariableScalarReplacement::AddComponentsToCompositesForLoads( |
807 | | const std::unordered_map<Instruction*, Instruction*>& |
808 | | loads_to_component_values, |
809 | | std::unordered_map<Instruction*, Instruction*>* loads_to_composites, |
810 | 0 | uint32_t depth_to_component) { |
811 | 0 | analysis::DefUseManager* def_use_mgr = context()->get_def_use_mgr(); |
812 | 0 | for (auto& load_and_component_vale : loads_to_component_values) { |
813 | 0 | Instruction* load = load_and_component_vale.first; |
814 | 0 | Instruction* component_value = load_and_component_vale.second; |
815 | 0 | Instruction* composite_construct = nullptr; |
816 | 0 | auto itr = loads_to_composites->find(load); |
817 | 0 | if (itr == loads_to_composites->end()) { |
818 | 0 | composite_construct = |
819 | 0 | CreateCompositeConstructForComponentOfLoad(load, depth_to_component); |
820 | 0 | if (composite_construct == nullptr) { |
821 | 0 | assert(false && "Could not create composite construct"); |
822 | 0 | return; |
823 | 0 | } |
824 | 0 | loads_to_composites->insert({load, composite_construct}); |
825 | 0 | } else { |
826 | 0 | composite_construct = itr->second; |
827 | 0 | } |
828 | 0 | composite_construct->AddOperand( |
829 | 0 | {SPV_OPERAND_TYPE_ID, {component_value->result_id()}}); |
830 | 0 | def_use_mgr->AnalyzeInstDefUse(composite_construct); |
831 | 0 | } |
832 | 0 | } |
833 | | |
834 | | uint32_t InterfaceVariableScalarReplacement::GetArrayType( |
835 | 0 | uint32_t elem_type_id, uint32_t array_length) { |
836 | 0 | analysis::Type* elem_type = context()->get_type_mgr()->GetType(elem_type_id); |
837 | 0 | uint32_t array_length_id = |
838 | 0 | context()->get_constant_mgr()->GetUIntConstId(array_length); |
839 | 0 | analysis::Array array_type( |
840 | 0 | elem_type, |
841 | 0 | analysis::Array::LengthInfo{array_length_id, {0, array_length}}); |
842 | 0 | return context()->get_type_mgr()->GetTypeInstruction(&array_type); |
843 | 0 | } |
844 | | |
845 | | uint32_t InterfaceVariableScalarReplacement::GetPointerType( |
846 | 0 | uint32_t type_id, spv::StorageClass storage_class) { |
847 | 0 | analysis::Type* type = context()->get_type_mgr()->GetType(type_id); |
848 | 0 | analysis::Pointer ptr_type(type, storage_class); |
849 | 0 | return context()->get_type_mgr()->GetTypeInstruction(&ptr_type); |
850 | 0 | } |
851 | | |
852 | | std::optional<InterfaceVariableScalarReplacement::NestedCompositeComponents> |
853 | | InterfaceVariableScalarReplacement::CreateScalarInterfaceVarsForArray( |
854 | | Instruction* interface_var_type, spv::StorageClass storage_class, |
855 | 0 | uint32_t extra_array_length) { |
856 | 0 | assert(interface_var_type->opcode() == spv::Op::OpTypeArray); |
857 | | |
858 | 0 | analysis::DefUseManager* def_use_mgr = context()->get_def_use_mgr(); |
859 | 0 | uint32_t array_length = GetArrayLength(def_use_mgr, interface_var_type); |
860 | 0 | Instruction* elem_type = GetArrayElementType(def_use_mgr, interface_var_type); |
861 | |
|
862 | 0 | NestedCompositeComponents scalar_vars; |
863 | 0 | while (array_length > 0) { |
864 | 0 | std::optional<NestedCompositeComponents> scalar_vars_for_element = |
865 | 0 | CreateScalarInterfaceVarsForReplacement(elem_type, storage_class, |
866 | 0 | extra_array_length); |
867 | 0 | if (!scalar_vars_for_element) { |
868 | 0 | return std::nullopt; |
869 | 0 | } |
870 | 0 | scalar_vars.AddComponent(*scalar_vars_for_element); |
871 | 0 | --array_length; |
872 | 0 | } |
873 | 0 | return scalar_vars; |
874 | 0 | } |
875 | | |
876 | | std::optional<InterfaceVariableScalarReplacement::NestedCompositeComponents> |
877 | | InterfaceVariableScalarReplacement::CreateScalarInterfaceVarsForMatrix( |
878 | | Instruction* interface_var_type, spv::StorageClass storage_class, |
879 | 0 | uint32_t extra_array_length) { |
880 | 0 | assert(interface_var_type->opcode() == spv::Op::OpTypeMatrix); |
881 | | |
882 | 0 | analysis::DefUseManager* def_use_mgr = context()->get_def_use_mgr(); |
883 | 0 | uint32_t column_count = interface_var_type->GetSingleWordInOperand( |
884 | 0 | kOpTypeMatrixColCountInOperandIndex); |
885 | 0 | Instruction* column_type = |
886 | 0 | GetMatrixColumnType(def_use_mgr, interface_var_type); |
887 | |
|
888 | 0 | NestedCompositeComponents scalar_vars; |
889 | 0 | while (column_count > 0) { |
890 | 0 | std::optional<NestedCompositeComponents> scalar_vars_for_column = |
891 | 0 | CreateScalarInterfaceVarsForReplacement(column_type, storage_class, |
892 | 0 | extra_array_length); |
893 | 0 | if (!scalar_vars_for_column) { |
894 | 0 | return std::nullopt; |
895 | 0 | } |
896 | 0 | scalar_vars.AddComponent(*scalar_vars_for_column); |
897 | 0 | --column_count; |
898 | 0 | } |
899 | 0 | return scalar_vars; |
900 | 0 | } |
901 | | |
902 | | std::optional<InterfaceVariableScalarReplacement::NestedCompositeComponents> |
903 | | InterfaceVariableScalarReplacement::CreateScalarInterfaceVarsForReplacement( |
904 | | Instruction* interface_var_type, spv::StorageClass storage_class, |
905 | 0 | uint32_t extra_array_length) { |
906 | | // Handle array case. |
907 | 0 | if (interface_var_type->opcode() == spv::Op::OpTypeArray) { |
908 | 0 | return CreateScalarInterfaceVarsForArray(interface_var_type, storage_class, |
909 | 0 | extra_array_length); |
910 | 0 | } |
911 | | |
912 | | // Handle matrix case. |
913 | 0 | if (interface_var_type->opcode() == spv::Op::OpTypeMatrix) { |
914 | 0 | return CreateScalarInterfaceVarsForMatrix(interface_var_type, storage_class, |
915 | 0 | extra_array_length); |
916 | 0 | } |
917 | | |
918 | | // Handle scalar or vector case. |
919 | 0 | NestedCompositeComponents scalar_var; |
920 | 0 | uint32_t type_id = interface_var_type->result_id(); |
921 | 0 | if (extra_array_length != 0) { |
922 | 0 | type_id = GetArrayType(type_id, extra_array_length); |
923 | 0 | } |
924 | 0 | uint32_t ptr_type_id = |
925 | 0 | context()->get_type_mgr()->FindPointerToType(type_id, storage_class); |
926 | 0 | uint32_t id = TakeNextId(); |
927 | 0 | if (id == 0) { |
928 | 0 | return std::nullopt; |
929 | 0 | } |
930 | 0 | std::unique_ptr<Instruction> variable( |
931 | 0 | new Instruction(context(), spv::Op::OpVariable, ptr_type_id, id, |
932 | 0 | std::initializer_list<Operand>{ |
933 | 0 | {SPV_OPERAND_TYPE_STORAGE_CLASS, |
934 | 0 | {static_cast<uint32_t>(storage_class)}}})); |
935 | 0 | scalar_var.SetSingleComponentVariable(variable.get()); |
936 | 0 | context()->AddGlobalValue(std::move(variable)); |
937 | 0 | return scalar_var; |
938 | 0 | } |
939 | | |
940 | | Instruction* InterfaceVariableScalarReplacement::GetTypeOfVariable( |
941 | 0 | Instruction* var) { |
942 | 0 | uint32_t pointee_type_id = GetPointeeTypeIdOfVar(var); |
943 | 0 | analysis::DefUseManager* def_use_mgr = context()->get_def_use_mgr(); |
944 | 0 | return def_use_mgr->GetDef(pointee_type_id); |
945 | 0 | } |
946 | | |
947 | 0 | Pass::Status InterfaceVariableScalarReplacement::Process() { |
948 | 0 | Pass::Status status = Status::SuccessWithoutChange; |
949 | 0 | for (Instruction& entry_point : get_module()->entry_points()) { |
950 | 0 | status = |
951 | 0 | CombineStatus(status, ReplaceInterfaceVarsWithScalars(entry_point)); |
952 | 0 | } |
953 | 0 | return status; |
954 | 0 | } |
955 | | |
956 | | bool InterfaceVariableScalarReplacement:: |
957 | 0 | ReportErrorIfHasExtraArraynessForOtherEntry(Instruction* var) { |
958 | 0 | if (vars_with_extra_arrayness.find(var) == vars_with_extra_arrayness.end()) |
959 | 0 | return false; |
960 | | |
961 | 0 | std::string message( |
962 | 0 | "A variable is arrayed for an entry point but it is not " |
963 | 0 | "arrayed for another entry point"); |
964 | 0 | message += |
965 | 0 | "\n " + var->PrettyPrint(SPV_BINARY_TO_TEXT_OPTION_FRIENDLY_NAMES); |
966 | 0 | context()->consumer()(SPV_MSG_ERROR, "", {0, 0, 0}, message.c_str()); |
967 | 0 | return true; |
968 | 0 | } |
969 | | |
970 | | bool InterfaceVariableScalarReplacement:: |
971 | 0 | ReportErrorIfHasNoExtraArraynessForOtherEntry(Instruction* var) { |
972 | 0 | if (vars_without_extra_arrayness.find(var) == |
973 | 0 | vars_without_extra_arrayness.end()) |
974 | 0 | return false; |
975 | | |
976 | 0 | std::string message( |
977 | 0 | "A variable is not arrayed for an entry point but it is " |
978 | 0 | "arrayed for another entry point"); |
979 | 0 | message += |
980 | 0 | "\n " + var->PrettyPrint(SPV_BINARY_TO_TEXT_OPTION_FRIENDLY_NAMES); |
981 | 0 | context()->consumer()(SPV_MSG_ERROR, "", {0, 0, 0}, message.c_str()); |
982 | 0 | return true; |
983 | 0 | } |
984 | | |
985 | | Pass::Status |
986 | | InterfaceVariableScalarReplacement::ReplaceInterfaceVarsWithScalars( |
987 | 0 | Instruction& entry_point) { |
988 | 0 | std::vector<Instruction*> interface_vars = |
989 | 0 | CollectInterfaceVariables(entry_point); |
990 | |
|
991 | 0 | Pass::Status status = Status::SuccessWithoutChange; |
992 | 0 | for (Instruction* interface_var : interface_vars) { |
993 | 0 | uint32_t location, component; |
994 | 0 | if (!GetVariableLocation(interface_var, &location)) continue; |
995 | 0 | if (!GetVariableComponent(interface_var, &component)) component = 0; |
996 | |
|
997 | 0 | Instruction* interface_var_type = GetTypeOfVariable(interface_var); |
998 | 0 | uint32_t extra_array_length = 0; |
999 | 0 | if (HasExtraArrayness(entry_point, interface_var)) { |
1000 | 0 | extra_array_length = |
1001 | 0 | GetArrayLength(context()->get_def_use_mgr(), interface_var_type); |
1002 | 0 | interface_var_type = |
1003 | 0 | GetArrayElementType(context()->get_def_use_mgr(), interface_var_type); |
1004 | 0 | vars_with_extra_arrayness.insert(interface_var); |
1005 | 0 | } else { |
1006 | 0 | vars_without_extra_arrayness.insert(interface_var); |
1007 | 0 | } |
1008 | |
|
1009 | 0 | if (!CheckExtraArraynessConflictBetweenEntries(interface_var, |
1010 | 0 | extra_array_length != 0)) { |
1011 | 0 | return Pass::Status::Failure; |
1012 | 0 | } |
1013 | | |
1014 | 0 | if (interface_var_type->opcode() != spv::Op::OpTypeArray && |
1015 | 0 | interface_var_type->opcode() != spv::Op::OpTypeMatrix) { |
1016 | 0 | continue; |
1017 | 0 | } |
1018 | | |
1019 | 0 | if (ReplaceInterfaceVariableWithScalars( |
1020 | 0 | interface_var, interface_var_type, location, component, |
1021 | 0 | extra_array_length) == Pass::Status::Failure) { |
1022 | 0 | return Pass::Status::Failure; |
1023 | 0 | } |
1024 | 0 | status = Pass::Status::SuccessWithChange; |
1025 | 0 | } |
1026 | | |
1027 | 0 | return status; |
1028 | 0 | } |
1029 | | |
1030 | | } // namespace opt |
1031 | | } // namespace spvtools |