/src/spirv-tools/source/opt/resolve_binding_conflicts_pass.cpp
Line | Count | Source (jump to first uncovered line) |
1 | | // Copyright (c) 2025 Google LLC |
2 | | // |
3 | | // Licensed under the Apache License, Version 2.0 (the "License"); |
4 | | // you may not use this file except in compliance with the License. |
5 | | // You may obtain a copy of the License at |
6 | | // |
7 | | // http://www.apache.org/licenses/LICENSE-2.0 |
8 | | // |
9 | | // Unless required by applicable law or agreed to in writing, software |
10 | | // distributed under the License is distributed on an "AS IS" BASIS, |
11 | | // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. |
12 | | // See the License for the specific language governing permissions and |
13 | | // limitations under the License. |
14 | | |
15 | | #include "source/opt/resolve_binding_conflicts_pass.h" |
16 | | |
17 | | #include <algorithm> |
18 | | #include <unordered_map> |
19 | | #include <unordered_set> |
20 | | #include <vector> |
21 | | |
22 | | #include "source/opt/decoration_manager.h" |
23 | | #include "source/opt/def_use_manager.h" |
24 | | #include "source/opt/instruction.h" |
25 | | #include "source/opt/ir_builder.h" |
26 | | #include "source/opt/ir_context.h" |
27 | | #include "spirv/unified1/spirv.h" |
28 | | |
29 | | namespace spvtools { |
30 | | namespace opt { |
31 | | |
32 | | // A VarBindingInfo contains the binding information for a single resource |
33 | | // variable. |
34 | | // |
35 | | // Exactly one such object is created per resource variable in the |
36 | | // module. In particular, when a resource variable is statically used by |
37 | | // more than one entry point, those entry points share the same VarBindingInfo |
38 | | // object for that variable. |
39 | | struct VarBindingInfo { |
40 | | const Instruction* const var; |
41 | | const uint32_t descriptor_set; |
42 | | Instruction* const binding_decoration; |
43 | | |
44 | | // Returns the binding number. |
45 | 0 | uint32_t binding() const { |
46 | 0 | return binding_decoration->GetSingleWordInOperand(2); |
47 | 0 | } |
48 | | // Sets the binding number to 'b'. |
49 | 0 | void updateBinding(uint32_t b) { binding_decoration->SetOperand(2, {b}); } |
50 | | }; |
51 | | |
52 | | // The bindings in the same descriptor set that are used by an entry point. |
53 | | using BindingList = std::vector<VarBindingInfo*>; |
54 | | // A map from descriptor set number to the list of bindings in that descriptor |
55 | | // set, as used by a particular entry point. |
56 | | using DescriptorSets = std::unordered_map<uint32_t, BindingList>; |
57 | | |
58 | 0 | IRContext::Analysis ResolveBindingConflictsPass::GetPreservedAnalyses() { |
59 | | // All analyses are kept up to date. |
60 | | // At most this modifies the Binding numbers on variables. |
61 | 0 | return IRContext::kAnalysisDefUse | IRContext::kAnalysisInstrToBlockMapping | |
62 | 0 | IRContext::kAnalysisDecorations | IRContext::kAnalysisCombinators | |
63 | 0 | IRContext::kAnalysisCFG | IRContext::kAnalysisDominatorAnalysis | |
64 | 0 | IRContext::kAnalysisLoopAnalysis | IRContext::kAnalysisNameMap | |
65 | 0 | IRContext::kAnalysisScalarEvolution | |
66 | 0 | IRContext::kAnalysisRegisterPressure | |
67 | 0 | IRContext::kAnalysisValueNumberTable | |
68 | 0 | IRContext::kAnalysisStructuredCFG | IRContext::kAnalysisBuiltinVarId | |
69 | 0 | IRContext::kAnalysisIdToFuncMapping | IRContext::kAnalysisConstants | |
70 | 0 | IRContext::kAnalysisTypes | IRContext::kAnalysisDebugInfo | |
71 | 0 | IRContext::kAnalysisLiveness; |
72 | 0 | } |
73 | | |
74 | | // Orders variable binding info objects. |
75 | | // * The binding number is most signficant; |
76 | | // * Then a sampler-like object compares greater than non-sampler like object. |
77 | | // * Otherwise compare based on variable ID. |
78 | | // This provides a total order among bindings in a descriptor set for a valid |
79 | | // Vulkan module. |
80 | 0 | bool Less(const VarBindingInfo* const lhs, const VarBindingInfo* const rhs) { |
81 | 0 | if (lhs->binding() < rhs->binding()) return true; |
82 | 0 | if (lhs->binding() > rhs->binding()) return false; |
83 | | |
84 | | // Examine types. |
85 | | // In valid Vulkan the only conflict can occur between |
86 | | // images and samplers. We only care about a specific |
87 | | // comparison when one is a image-like thing and the other |
88 | | // is a sampler-like thing of the same shape. So unwrap |
89 | | // types until we hit one of those two. |
90 | | |
91 | 0 | auto* def_use_mgr = lhs->var->context()->get_def_use_mgr(); |
92 | | |
93 | | // Returns the type found by iteratively following pointer pointee type, |
94 | | // or array element type. |
95 | 0 | auto unwrap = [&def_use_mgr](Instruction* ty) { |
96 | 0 | bool keep_going = true; |
97 | 0 | do { |
98 | 0 | switch (ty->opcode()) { |
99 | 0 | case spv::Op::OpTypePointer: |
100 | 0 | ty = def_use_mgr->GetDef(ty->GetSingleWordInOperand(1)); |
101 | 0 | break; |
102 | 0 | case spv::Op::OpTypeArray: |
103 | 0 | case spv::Op::OpTypeRuntimeArray: |
104 | 0 | ty = def_use_mgr->GetDef(ty->GetSingleWordInOperand(0)); |
105 | 0 | break; |
106 | 0 | default: |
107 | 0 | keep_going = false; |
108 | 0 | break; |
109 | 0 | } |
110 | 0 | } while (keep_going); |
111 | 0 | return ty; |
112 | 0 | }; |
113 | |
|
114 | 0 | auto* lhs_ty = unwrap(def_use_mgr->GetDef(lhs->var->type_id())); |
115 | 0 | auto* rhs_ty = unwrap(def_use_mgr->GetDef(rhs->var->type_id())); |
116 | 0 | if (lhs_ty->opcode() == rhs_ty->opcode()) { |
117 | | // Pick based on variable ID. |
118 | 0 | return lhs->var->result_id() < rhs->var->result_id(); |
119 | 0 | } |
120 | | // A sampler is always greater than an image. |
121 | 0 | if (lhs_ty->opcode() == spv::Op::OpTypeSampler) { |
122 | 0 | return false; |
123 | 0 | } |
124 | 0 | if (rhs_ty->opcode() == spv::Op::OpTypeSampler) { |
125 | 0 | return true; |
126 | 0 | } |
127 | | // Pick based on variable ID. |
128 | 0 | return lhs->var->result_id() < rhs->var->result_id(); |
129 | 0 | } |
130 | | |
131 | | // Summarizes the caller-callee relationships between functions in a module. |
132 | | class CallGraph { |
133 | | public: |
134 | | // Returns the list of all functions statically reachable from entry points, |
135 | | // where callees precede callers. |
136 | 0 | const std::vector<uint32_t>& CalleesBeforeCallers() const { |
137 | 0 | return visit_order_; |
138 | 0 | } |
139 | | // Returns the list functions called from a given function. |
140 | 0 | const std::unordered_set<uint32_t>& Callees(uint32_t caller) { |
141 | 0 | return calls_[caller]; |
142 | 0 | } |
143 | | |
144 | 0 | CallGraph(IRContext& context) { |
145 | | // Populate calls_. |
146 | 0 | std::queue<uint32_t> callee_queue; |
147 | 0 | for (const auto& fn : *context.module()) { |
148 | 0 | auto& callees = calls_[fn.result_id()]; |
149 | 0 | context.AddCalls(&fn, &callee_queue); |
150 | 0 | while (!callee_queue.empty()) { |
151 | 0 | callees.insert(callee_queue.front()); |
152 | 0 | callee_queue.pop(); |
153 | 0 | } |
154 | 0 | } |
155 | | |
156 | | // Perform depth-first search, starting from each entry point. |
157 | | // Populates visit_order_. |
158 | 0 | for (const auto& ep : context.module()->entry_points()) { |
159 | 0 | Visit(ep.GetSingleWordInOperand(1)); |
160 | 0 | } |
161 | 0 | } |
162 | | |
163 | | private: |
164 | | // Visits a function, recursively visiting its callees. Adds this ID |
165 | | // to the visit_order after all callees have been visited. |
166 | 0 | void Visit(uint32_t func_id) { |
167 | 0 | if (visited_.count(func_id)) { |
168 | 0 | return; |
169 | 0 | } |
170 | 0 | visited_.insert(func_id); |
171 | 0 | for (auto callee_id : calls_[func_id]) { |
172 | 0 | Visit(callee_id); |
173 | 0 | } |
174 | 0 | visit_order_.push_back(func_id); |
175 | 0 | } |
176 | | |
177 | | // Maps the ID of a function to the IDs of functions it calls. |
178 | | std::unordered_map<uint32_t, std::unordered_set<uint32_t>> calls_; |
179 | | |
180 | | // IDs of visited functions; |
181 | | std::unordered_set<uint32_t> visited_; |
182 | | // IDs of functions, where callees precede callers. |
183 | | std::vector<uint32_t> visit_order_; |
184 | | }; |
185 | | |
186 | | // Returns vector binding info for all resource variables in the module. |
187 | 0 | auto GetVarBindings(IRContext& context) { |
188 | 0 | std::vector<VarBindingInfo> vars; |
189 | 0 | auto* deco_mgr = context.get_decoration_mgr(); |
190 | 0 | for (auto& inst : context.module()->types_values()) { |
191 | 0 | if (inst.opcode() == spv::Op::OpVariable) { |
192 | 0 | Instruction* descriptor_set_deco = nullptr; |
193 | 0 | Instruction* binding_deco = nullptr; |
194 | 0 | for (auto* deco : deco_mgr->GetDecorationsFor(inst.result_id(), false)) { |
195 | 0 | switch (static_cast<spv::Decoration>(deco->GetSingleWordInOperand(1))) { |
196 | 0 | case spv::Decoration::DescriptorSet: |
197 | 0 | assert(!descriptor_set_deco); |
198 | 0 | descriptor_set_deco = deco; |
199 | 0 | break; |
200 | 0 | case spv::Decoration::Binding: |
201 | 0 | assert(!binding_deco); |
202 | 0 | binding_deco = deco; |
203 | 0 | break; |
204 | 0 | default: |
205 | 0 | break; |
206 | 0 | } |
207 | 0 | } |
208 | 0 | if (descriptor_set_deco && binding_deco) { |
209 | 0 | vars.push_back({&inst, descriptor_set_deco->GetSingleWordInOperand(2), |
210 | 0 | binding_deco}); |
211 | 0 | } |
212 | 0 | } |
213 | 0 | } |
214 | 0 | return vars; |
215 | 0 | } |
216 | | |
217 | | // Merges the bindings from source into sink. Maintains order and uniqueness |
218 | | // within a list of bindings. |
219 | 0 | void Merge(DescriptorSets& sink, const DescriptorSets& source) { |
220 | 0 | for (auto index_and_bindings : source) { |
221 | 0 | const uint32_t index = index_and_bindings.first; |
222 | 0 | const BindingList& src1 = index_and_bindings.second; |
223 | 0 | const BindingList& src2 = sink[index]; |
224 | 0 | BindingList merged; |
225 | 0 | merged.resize(src1.size() + src2.size()); |
226 | 0 | auto merged_end = std::merge(src1.begin(), src1.end(), src2.begin(), |
227 | 0 | src2.end(), merged.begin(), Less); |
228 | 0 | auto unique_end = std::unique(merged.begin(), merged_end); |
229 | 0 | merged.resize(unique_end - merged.begin()); |
230 | 0 | sink[index] = std::move(merged); |
231 | 0 | } |
232 | 0 | } |
233 | | |
234 | | // Resolves conflicts within this binding list, so the binding number on an |
235 | | // item is at least one more than the binding number on the previous item. |
236 | | // When this does not yet hold, increase the binding number on the second |
237 | | // item in the pair. Returns true if any changes were applied. |
238 | 0 | bool ResolveConflicts(BindingList& bl) { |
239 | 0 | bool changed = false; |
240 | 0 | for (size_t i = 1; i < bl.size(); i++) { |
241 | 0 | const auto prev_num = bl[i - 1]->binding(); |
242 | 0 | if (prev_num >= bl[i]->binding()) { |
243 | 0 | bl[i]->updateBinding(prev_num + 1); |
244 | 0 | changed = true; |
245 | 0 | } |
246 | 0 | } |
247 | 0 | return changed; |
248 | 0 | } |
249 | | |
250 | 0 | Pass::Status ResolveBindingConflictsPass::Process() { |
251 | | // Assumes the descriptor set and binding decorations are not provided |
252 | | // via decoration groups. Decoration groups were deprecated in SPIR-V 1.3 |
253 | | // Revision 6. I have not seen any compiler generate them. --dneto |
254 | |
|
255 | 0 | auto vars = GetVarBindings(*context()); |
256 | | |
257 | | // Maps a function ID to the variables used directly or indirectly by the |
258 | | // function, organized into descriptor sets. Each descriptor set |
259 | | // consists of a BindingList of distinct variables. |
260 | 0 | std::unordered_map<uint32_t, DescriptorSets> used_vars; |
261 | | |
262 | | // Determine variables directly used by functions. |
263 | 0 | auto* def_use_mgr = context()->get_def_use_mgr(); |
264 | 0 | for (auto& var : vars) { |
265 | 0 | std::unordered_set<uint32_t> visited_functions_for_var; |
266 | 0 | def_use_mgr->ForEachUser(var.var, [&](Instruction* user) { |
267 | 0 | if (auto* block = context()->get_instr_block(user)) { |
268 | 0 | auto* fn = block->GetParent(); |
269 | 0 | assert(fn); |
270 | 0 | const auto fn_id = fn->result_id(); |
271 | 0 | if (visited_functions_for_var.insert(fn_id).second) { |
272 | 0 | used_vars[fn_id][var.descriptor_set].push_back(&var); |
273 | 0 | } |
274 | 0 | } |
275 | 0 | }); |
276 | 0 | } |
277 | | |
278 | | // Sort within a descriptor set by binding number. |
279 | 0 | for (auto& sets_for_fn : used_vars) { |
280 | 0 | for (auto& ds : sets_for_fn.second) { |
281 | 0 | BindingList& bindings = ds.second; |
282 | 0 | std::stable_sort(bindings.begin(), bindings.end(), Less); |
283 | 0 | } |
284 | 0 | } |
285 | | |
286 | | // Propagate from callees to callers. |
287 | 0 | CallGraph call_graph(*context()); |
288 | 0 | for (const uint32_t caller : call_graph.CalleesBeforeCallers()) { |
289 | 0 | DescriptorSets& caller_ds = used_vars[caller]; |
290 | 0 | for (const uint32_t callee : call_graph.Callees(caller)) { |
291 | 0 | Merge(caller_ds, used_vars[callee]); |
292 | 0 | } |
293 | 0 | } |
294 | | |
295 | | // At this point, the descriptor sets associated with each entry point |
296 | | // capture exactly the set of resource variables statically used |
297 | | // by the static call tree of that entry point. |
298 | | |
299 | | // Resolve conflicts. |
300 | | // VarBindingInfo objects may be shared between the bindings lists. |
301 | | // Updating a binding in one list can require updating another list later. |
302 | | // So repeat updates until settling. |
303 | | |
304 | | // The union of BindingLists across all entry points. |
305 | 0 | std::vector<BindingList*> ep_bindings; |
306 | |
|
307 | 0 | for (auto& ep : context()->module()->entry_points()) { |
308 | 0 | for (auto& ds : used_vars[ep.GetSingleWordInOperand(1)]) { |
309 | 0 | BindingList& bindings = ds.second; |
310 | 0 | ep_bindings.push_back(&bindings); |
311 | 0 | } |
312 | 0 | } |
313 | 0 | bool modified = false; |
314 | 0 | bool found_conflict; |
315 | 0 | do { |
316 | 0 | found_conflict = false; |
317 | 0 | for (BindingList* bl : ep_bindings) { |
318 | 0 | found_conflict |= ResolveConflicts(*bl); |
319 | 0 | } |
320 | 0 | modified |= found_conflict; |
321 | 0 | } while (found_conflict); |
322 | |
|
323 | 0 | return modified ? Pass::Status::SuccessWithChange |
324 | 0 | : Pass::Status::SuccessWithoutChange; |
325 | 0 | } |
326 | | |
327 | | } // namespace opt |
328 | | } // namespace spvtools |