/src/spirv-tools/source/opt/legalize_multidim_array_pass.cpp
Line | Count | Source |
1 | | // Copyright (c) 2026 Google LLC |
2 | | // |
3 | | // Licensed under the Apache License, Version 2.0 (the "License"); |
4 | | // you may not use this file except in compliance with the License. |
5 | | // You may obtain a copy of the License at |
6 | | // |
7 | | // http://www.apache.org/licenses/LICENSE-2.0 |
8 | | // |
9 | | // Unless required by applicable law or agreed to in writing, software |
10 | | // distributed under the License is distributed on an "AS IS" BASIS, |
11 | | // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. |
12 | | // See the License for the specific language governing permissions and |
13 | | // limitations under the License. |
14 | | |
15 | | #include "source/opt/legalize_multidim_array_pass.h" |
16 | | |
17 | | #include "source/opt/constants.h" |
18 | | #include "source/opt/desc_sroa_util.h" |
19 | | #include "source/opt/ir_builder.h" |
20 | | #include "source/opt/ir_context.h" |
21 | | #include "source/opt/type_manager.h" |
22 | | |
23 | | namespace spvtools { |
24 | | namespace opt { |
25 | | |
26 | 7.29k | Pass::Status LegalizeMultidimArrayPass::Process() { |
27 | 7.29k | std::vector<Instruction*> vars_to_legalize; |
28 | | |
29 | 139k | for (auto& var : context()->types_values()) { |
30 | 139k | if (var.opcode() != spv::Op::OpVariable) continue; |
31 | 7.80k | if (!IsMultidimArrayOfResources(&var)) continue; |
32 | 0 | if (!CanLegalize(&var)) { |
33 | 0 | context()->EmitErrorMessage("Unable to legalize multidimensional array: ", |
34 | 0 | &var); |
35 | 0 | return Status::Failure; |
36 | 0 | } |
37 | 0 | vars_to_legalize.push_back(&var); |
38 | 0 | } |
39 | | |
40 | 7.29k | if (vars_to_legalize.empty()) return Status::SuccessWithoutChange; |
41 | | |
42 | 0 | for (auto* var : vars_to_legalize) { |
43 | 0 | uint32_t old_ptr_type_id = var->type_id(); |
44 | 0 | uint32_t new_ptr_type_id = FlattenArrayType(var); |
45 | 0 | if (new_ptr_type_id == 0) return Status::Failure; |
46 | 0 | if (!RewriteAccessChains(var, old_ptr_type_id)) return Status::Failure; |
47 | 0 | } |
48 | | |
49 | 0 | return Status::SuccessWithChange; |
50 | 0 | } |
51 | | |
52 | 7.80k | bool LegalizeMultidimArrayPass::IsMultidimArrayOfResources(Instruction* var) { |
53 | 7.80k | if (!descsroautil::IsDescriptorArray(context(), var)) return false; |
54 | | |
55 | 0 | uint32_t type_id = var->type_id(); |
56 | 0 | Instruction* type_inst = context()->get_def_use_mgr()->GetDef(type_id); |
57 | 0 | uint32_t pointee_type_id = type_inst->GetSingleWordInOperand(1); |
58 | 0 | std::vector<uint32_t> dims; |
59 | 0 | uint32_t element_type_id = 0; |
60 | 0 | GetArrayDimensions(pointee_type_id, &dims, &element_type_id); |
61 | |
|
62 | 0 | return dims.size() > 1; |
63 | 7.80k | } |
64 | | |
65 | | void LegalizeMultidimArrayPass::GetArrayDimensions(uint32_t type_id, |
66 | | std::vector<uint32_t>* dims, |
67 | 0 | uint32_t* element_type_id) { |
68 | 0 | assert(dims != nullptr && "dims cannot be null."); |
69 | 0 | dims->clear(); |
70 | |
|
71 | 0 | Instruction* type_inst = context()->get_def_use_mgr()->GetDef(type_id); |
72 | 0 | while (type_inst->opcode() == spv::Op::OpTypeArray) { |
73 | 0 | uint32_t length_id = type_inst->GetSingleWordInOperand(1); |
74 | 0 | Instruction* length_inst = context()->get_def_use_mgr()->GetDef(length_id); |
75 | | // Assume OpConstant. According to the spec the length could also be an |
76 | | // OpSpecConstantOp. However, DXC will not generate that type of code. The |
77 | | // code to handle spec constants will be much more complicated. |
78 | 0 | assert(length_inst->opcode() == spv::Op::OpConstant); |
79 | 0 | uint32_t length = length_inst->GetSingleWordInOperand(0); |
80 | 0 | dims->push_back(length); |
81 | 0 | type_id = type_inst->GetSingleWordInOperand(0); |
82 | 0 | type_inst = context()->get_def_use_mgr()->GetDef(type_id); |
83 | 0 | } |
84 | 0 | *element_type_id = type_id; |
85 | 0 | } |
86 | | |
87 | 0 | uint32_t LegalizeMultidimArrayPass::FlattenArrayType(Instruction* var) { |
88 | 0 | analysis::TypeManager* type_mgr = context()->get_type_mgr(); |
89 | 0 | analysis::ConstantManager* constant_mgr = context()->get_constant_mgr(); |
90 | |
|
91 | 0 | uint32_t ptr_type_id = var->type_id(); |
92 | 0 | Instruction* ptr_type_inst = |
93 | 0 | context()->get_def_use_mgr()->GetDef(ptr_type_id); |
94 | 0 | uint32_t pointee_type_id = ptr_type_inst->GetSingleWordInOperand(1); |
95 | |
|
96 | 0 | std::vector<uint32_t> dims; |
97 | 0 | uint32_t element_type_id = 0; |
98 | 0 | GetArrayDimensions(pointee_type_id, &dims, &element_type_id); |
99 | |
|
100 | 0 | uint32_t total_elements = 1; |
101 | 0 | for (uint32_t dim : dims) { |
102 | 0 | total_elements *= dim; |
103 | 0 | } |
104 | |
|
105 | 0 | const analysis::Constant* total_elements_const = |
106 | 0 | constant_mgr->GetIntConst(total_elements, 32, false); |
107 | |
|
108 | 0 | Instruction* total_elements_inst = |
109 | 0 | constant_mgr->GetDefiningInstruction(total_elements_const); |
110 | 0 | uint32_t total_elements_id = total_elements_inst->result_id(); |
111 | | |
112 | | // Create new OpTypeArray. |
113 | 0 | analysis::Type* element_type = type_mgr->GetType(element_type_id); |
114 | 0 | analysis::Array::LengthInfo length_info = { |
115 | 0 | total_elements_id, |
116 | 0 | {analysis::Array::LengthInfo::kConstant, total_elements}}; |
117 | 0 | analysis::Array new_array_type(element_type, length_info); |
118 | 0 | uint32_t new_array_type_id = type_mgr->GetTypeInstruction(&new_array_type); |
119 | | |
120 | | // Create new OpTypePointer. |
121 | 0 | spv::StorageClass sc = |
122 | 0 | static_cast<spv::StorageClass>(ptr_type_inst->GetSingleWordInOperand(0)); |
123 | 0 | analysis::Pointer new_ptr_type(type_mgr->GetType(new_array_type_id), sc); |
124 | 0 | uint32_t new_ptr_type_id = type_mgr->GetTypeInstruction(&new_ptr_type); |
125 | |
|
126 | 0 | var->SetResultType(new_ptr_type_id); |
127 | 0 | context()->UpdateDefUse(var); |
128 | | |
129 | | // Move the var after the new pointer type to avoid a def-before-use. |
130 | 0 | var->InsertAfter(get_def_use_mgr()->GetDef(new_ptr_type_id)); |
131 | |
|
132 | 0 | return new_ptr_type_id; |
133 | 0 | } |
134 | | |
135 | | bool LegalizeMultidimArrayPass::RewriteAccessChains(Instruction* var, |
136 | 0 | uint32_t old_ptr_type_id) { |
137 | 0 | uint32_t var_id = var->result_id(); |
138 | 0 | std::vector<Instruction*> users; |
139 | | // Use a worklist to handle transitive uses (e.g. through OpCopyObject) |
140 | 0 | std::vector<Instruction*> worklist; |
141 | |
|
142 | 0 | context()->get_def_use_mgr()->ForEachUser( |
143 | 0 | var_id, [&worklist](Instruction* user) { worklist.push_back(user); }); |
144 | |
|
145 | 0 | Instruction* old_ptr_type_inst = |
146 | 0 | context()->get_def_use_mgr()->GetDef(old_ptr_type_id); |
147 | 0 | uint32_t old_pointee_type_id = old_ptr_type_inst->GetSingleWordInOperand(1); |
148 | 0 | std::vector<uint32_t> dims; |
149 | 0 | uint32_t element_type_id = 0; |
150 | 0 | GetArrayDimensions(old_pointee_type_id, &dims, &element_type_id); |
151 | 0 | assert(dims.size() != 0 && |
152 | 0 | "This variable should have been rejected earlier."); |
153 | | |
154 | | // Calculate strides once |
155 | 0 | std::vector<uint32_t> strides(dims.size()); |
156 | 0 | strides[dims.size() - 1] = 1; |
157 | 0 | for (int i = static_cast<int>(dims.size()) - 2; i >= 0; --i) { |
158 | 0 | strides[i] = strides[i + 1] * dims[i + 1]; |
159 | 0 | } |
160 | | |
161 | | // Pre-calculate uint type id |
162 | 0 | uint32_t uint_type_id = context()->get_type_mgr()->GetUIntTypeId(); |
163 | 0 | if (uint_type_id == 0) return false; |
164 | | |
165 | 0 | while (!worklist.empty()) { |
166 | 0 | Instruction* user = worklist.back(); |
167 | 0 | worklist.pop_back(); |
168 | |
|
169 | 0 | if (user->opcode() == spv::Op::OpAccessChain || |
170 | 0 | user->opcode() == spv::Op::OpInBoundsAccessChain) { |
171 | 0 | uint32_t num_indices = user->NumInOperands() - 1; |
172 | 0 | assert(num_indices >= dims.size()); |
173 | | |
174 | 0 | InstructionBuilder builder(context(), user, IRContext::kAnalysisDefUse); |
175 | |
|
176 | 0 | uint32_t linearized_idx_id = 0; |
177 | 0 | for (uint32_t i = 0; i < dims.size(); ++i) { |
178 | 0 | uint32_t idx_id = user->GetSingleWordInOperand(i + 1); |
179 | |
|
180 | 0 | uint32_t term_id = idx_id; |
181 | 0 | if (strides[i] != 1) { |
182 | 0 | const analysis::Constant* stride_const = |
183 | 0 | context()->get_constant_mgr()->GetConstant( |
184 | 0 | context()->get_type_mgr()->GetType(uint_type_id), |
185 | 0 | {strides[i]}); |
186 | 0 | Instruction* stride_inst = |
187 | 0 | context()->get_constant_mgr()->GetDefiningInstruction( |
188 | 0 | stride_const); |
189 | |
|
190 | 0 | Instruction* mul_inst = builder.AddBinaryOp( |
191 | 0 | uint_type_id, spv::Op::OpIMul, idx_id, stride_inst->result_id()); |
192 | 0 | if (mul_inst == nullptr) return false; |
193 | 0 | term_id = mul_inst->result_id(); |
194 | 0 | } |
195 | | |
196 | 0 | if (linearized_idx_id == 0) { |
197 | 0 | linearized_idx_id = term_id; |
198 | 0 | } else { |
199 | 0 | Instruction* add_inst = builder.AddBinaryOp( |
200 | 0 | uint_type_id, spv::Op::OpIAdd, linearized_idx_id, term_id); |
201 | 0 | if (add_inst == nullptr) return false; |
202 | 0 | linearized_idx_id = add_inst->result_id(); |
203 | 0 | } |
204 | 0 | } |
205 | | |
206 | | // Create new AccessChain. |
207 | 0 | Instruction::OperandList new_operands; |
208 | 0 | new_operands.push_back(user->GetInOperand(0)); |
209 | 0 | new_operands.push_back({SPV_OPERAND_TYPE_ID, {linearized_idx_id}}); |
210 | 0 | for (uint32_t i = static_cast<uint32_t>(dims.size()); i < num_indices; |
211 | 0 | ++i) { |
212 | 0 | new_operands.push_back(user->GetInOperand(i + 1)); |
213 | 0 | } |
214 | 0 | user->SetInOperands(std::move(new_operands)); |
215 | 0 | context()->UpdateDefUse(user); |
216 | 0 | } else if (user->opcode() == spv::Op::OpCopyObject) { |
217 | | // The type of the variable has changed so the result type of the |
218 | | // OpCopyObject will change as well. |
219 | |
|
220 | 0 | uint32_t operand_id = user->GetSingleWordInOperand(0); |
221 | 0 | Instruction* operand_inst = |
222 | 0 | context()->get_def_use_mgr()->GetDef(operand_id); |
223 | 0 | user->SetResultType(operand_inst->type_id()); |
224 | 0 | context()->UpdateDefUse(user); |
225 | | |
226 | | // Add users of this copy to worklist |
227 | 0 | context()->get_def_use_mgr()->ForEachUser( |
228 | 0 | user->result_id(), |
229 | 0 | [&worklist](Instruction* u) { worklist.push_back(u); }); |
230 | 0 | } |
231 | 0 | } |
232 | 0 | return true; |
233 | 0 | } |
234 | | |
235 | | bool LegalizeMultidimArrayPass::CheckUse(Instruction* inst, |
236 | 0 | uint32_t max_depth) { |
237 | 0 | if (inst->opcode() == spv::Op::OpAccessChain || |
238 | 0 | inst->opcode() == spv::Op::OpInBoundsAccessChain) { |
239 | 0 | uint32_t num_indices = inst->NumInOperands() - 1; |
240 | 0 | return num_indices >= max_depth; |
241 | 0 | } else if (inst->opcode() == spv::Op::OpCopyObject) { |
242 | 0 | bool ok = true; |
243 | 0 | return !context()->get_def_use_mgr()->WhileEachUser( |
244 | 0 | inst->result_id(), |
245 | 0 | [&](Instruction* u) { return !CheckUse(u, max_depth); }); |
246 | 0 | return ok; |
247 | 0 | } else if (inst->IsDecoration() || inst->opcode() == spv::Op::OpName || |
248 | 0 | inst->opcode() == spv::Op::OpMemberName) { |
249 | | // Metadata is fine. |
250 | 0 | return true; |
251 | 0 | } |
252 | | |
253 | | // Direct use of array or partial array without AccessChain is not allowed. |
254 | 0 | return false; |
255 | 0 | } |
256 | | |
257 | 0 | bool LegalizeMultidimArrayPass::CanLegalize(Instruction* var) { |
258 | 0 | bool ok = true; |
259 | 0 | uint32_t ptr_type_id = var->type_id(); |
260 | 0 | Instruction* ptr_type_inst = |
261 | 0 | context()->get_def_use_mgr()->GetDef(ptr_type_id); |
262 | 0 | uint32_t pointee_type_id = ptr_type_inst->GetSingleWordInOperand(1); |
263 | 0 | std::vector<uint32_t> dims; |
264 | 0 | uint32_t element_type_id = 0; |
265 | 0 | GetArrayDimensions(pointee_type_id, &dims, &element_type_id); |
266 | |
|
267 | 0 | context()->get_def_use_mgr()->ForEachUser( |
268 | 0 | var->result_id(), [&](Instruction* u) { |
269 | 0 | if (!CheckUse(u, static_cast<uint32_t>(dims.size()))) ok = false; |
270 | 0 | }); |
271 | 0 | return ok; |
272 | 0 | } |
273 | | |
274 | | } // namespace opt |
275 | | } // namespace spvtools |