Coverage Report

Created: 2025-06-13 06:49

/src/spirv-tools/source/opt/resolve_binding_conflicts_pass.cpp
Line
Count
Source (jump to first uncovered line)
1
// Copyright (c) 2025 Google LLC
2
//
3
// Licensed under the Apache License, Version 2.0 (the "License");
4
// you may not use this file except in compliance with the License.
5
// You may obtain a copy of the License at
6
//
7
//     http://www.apache.org/licenses/LICENSE-2.0
8
//
9
// Unless required by applicable law or agreed to in writing, software
10
// distributed under the License is distributed on an "AS IS" BASIS,
11
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
// See the License for the specific language governing permissions and
13
// limitations under the License.
14
15
#include "source/opt/resolve_binding_conflicts_pass.h"
16
17
#include <algorithm>
18
#include <unordered_map>
19
#include <unordered_set>
20
#include <vector>
21
22
#include "source/opt/decoration_manager.h"
23
#include "source/opt/def_use_manager.h"
24
#include "source/opt/instruction.h"
25
#include "source/opt/ir_builder.h"
26
#include "source/opt/ir_context.h"
27
#include "spirv/unified1/spirv.h"
28
29
namespace spvtools {
30
namespace opt {
31
32
// A VarBindingInfo contains the binding information for a single resource
33
// variable.
34
//
35
// Exactly one such object is created per resource variable in the
36
// module. In particular, when a resource variable is statically used by
37
// more than one entry point, those entry points share the same VarBindingInfo
38
// object for that variable.
39
struct VarBindingInfo {
40
  const Instruction* const var;
41
  const uint32_t descriptor_set;
42
  Instruction* const binding_decoration;
43
44
  // Returns the binding number.
45
0
  uint32_t binding() const {
46
0
    return binding_decoration->GetSingleWordInOperand(2);
47
0
  }
48
  // Sets the binding number to 'b'.
49
0
  void updateBinding(uint32_t b) { binding_decoration->SetOperand(2, {b}); }
50
};
51
52
// The bindings in the same descriptor set that are used by an entry point.
53
using BindingList = std::vector<VarBindingInfo*>;
54
// A map from descriptor set number to the list of bindings in that descriptor
55
// set, as used by a particular entry point.
56
using DescriptorSets = std::unordered_map<uint32_t, BindingList>;
57
58
0
IRContext::Analysis ResolveBindingConflictsPass::GetPreservedAnalyses() {
59
  // All analyses are kept up to date.
60
  // At most this modifies the Binding numbers on variables.
61
0
  return IRContext::kAnalysisDefUse | IRContext::kAnalysisInstrToBlockMapping |
62
0
         IRContext::kAnalysisDecorations | IRContext::kAnalysisCombinators |
63
0
         IRContext::kAnalysisCFG | IRContext::kAnalysisDominatorAnalysis |
64
0
         IRContext::kAnalysisLoopAnalysis | IRContext::kAnalysisNameMap |
65
0
         IRContext::kAnalysisScalarEvolution |
66
0
         IRContext::kAnalysisRegisterPressure |
67
0
         IRContext::kAnalysisValueNumberTable |
68
0
         IRContext::kAnalysisStructuredCFG | IRContext::kAnalysisBuiltinVarId |
69
0
         IRContext::kAnalysisIdToFuncMapping | IRContext::kAnalysisConstants |
70
0
         IRContext::kAnalysisTypes | IRContext::kAnalysisDebugInfo |
71
0
         IRContext::kAnalysisLiveness;
72
0
}
73
74
// Orders variable binding info objects.
75
// * The binding number is most signficant;
76
// * Then a sampler-like object compares greater than non-sampler like object.
77
// * Otherwise compare based on variable ID.
78
// This provides a total order among bindings in a descriptor set for a valid
79
// Vulkan module.
80
0
bool Less(const VarBindingInfo* const lhs, const VarBindingInfo* const rhs) {
81
0
  if (lhs->binding() < rhs->binding()) return true;
82
0
  if (lhs->binding() > rhs->binding()) return false;
83
84
  // Examine types.
85
  // In valid Vulkan the only conflict can occur between
86
  // images and samplers.  We only care about a specific
87
  // comparison when one is a image-like thing and the other
88
  // is a sampler-like thing of the same shape.  So unwrap
89
  // types until we hit one of those two.
90
91
0
  auto* def_use_mgr = lhs->var->context()->get_def_use_mgr();
92
93
  // Returns the type found by iteratively following pointer pointee type,
94
  // or array element type.
95
0
  auto unwrap = [&def_use_mgr](Instruction* ty) {
96
0
    bool keep_going = true;
97
0
    do {
98
0
      switch (ty->opcode()) {
99
0
        case spv::Op::OpTypePointer:
100
0
          ty = def_use_mgr->GetDef(ty->GetSingleWordInOperand(1));
101
0
          break;
102
0
        case spv::Op::OpTypeArray:
103
0
        case spv::Op::OpTypeRuntimeArray:
104
0
          ty = def_use_mgr->GetDef(ty->GetSingleWordInOperand(0));
105
0
          break;
106
0
        default:
107
0
          keep_going = false;
108
0
          break;
109
0
      }
110
0
    } while (keep_going);
111
0
    return ty;
112
0
  };
113
114
0
  auto* lhs_ty = unwrap(def_use_mgr->GetDef(lhs->var->type_id()));
115
0
  auto* rhs_ty = unwrap(def_use_mgr->GetDef(rhs->var->type_id()));
116
0
  if (lhs_ty->opcode() == rhs_ty->opcode()) {
117
    // Pick based on variable ID.
118
0
    return lhs->var->result_id() < rhs->var->result_id();
119
0
  }
120
  // A sampler is always greater than an image.
121
0
  if (lhs_ty->opcode() == spv::Op::OpTypeSampler) {
122
0
    return false;
123
0
  }
124
0
  if (rhs_ty->opcode() == spv::Op::OpTypeSampler) {
125
0
    return true;
126
0
  }
127
  // Pick based on variable ID.
128
0
  return lhs->var->result_id() < rhs->var->result_id();
129
0
}
130
131
// Summarizes the caller-callee relationships between functions in a module.
132
class CallGraph {
133
 public:
134
  // Returns the list of all functions statically reachable from entry points,
135
  // where callees precede callers.
136
0
  const std::vector<uint32_t>& CalleesBeforeCallers() const {
137
0
    return visit_order_;
138
0
  }
139
  // Returns the list functions called from a given function.
140
0
  const std::unordered_set<uint32_t>& Callees(uint32_t caller) {
141
0
    return calls_[caller];
142
0
  }
143
144
0
  CallGraph(IRContext& context) {
145
    // Populate calls_.
146
0
    std::queue<uint32_t> callee_queue;
147
0
    for (const auto& fn : *context.module()) {
148
0
      auto& callees = calls_[fn.result_id()];
149
0
      context.AddCalls(&fn, &callee_queue);
150
0
      while (!callee_queue.empty()) {
151
0
        callees.insert(callee_queue.front());
152
0
        callee_queue.pop();
153
0
      }
154
0
    }
155
156
    // Perform depth-first search, starting from each entry point.
157
    // Populates visit_order_.
158
0
    for (const auto& ep : context.module()->entry_points()) {
159
0
      Visit(ep.GetSingleWordInOperand(1));
160
0
    }
161
0
  }
162
163
 private:
164
  // Visits a function, recursively visiting its callees. Adds this ID
165
  // to the visit_order after all callees have been visited.
166
0
  void Visit(uint32_t func_id) {
167
0
    if (visited_.count(func_id)) {
168
0
      return;
169
0
    }
170
0
    visited_.insert(func_id);
171
0
    for (auto callee_id : calls_[func_id]) {
172
0
      Visit(callee_id);
173
0
    }
174
0
    visit_order_.push_back(func_id);
175
0
  }
176
177
  // Maps the ID of a function to the IDs of functions it calls.
178
  std::unordered_map<uint32_t, std::unordered_set<uint32_t>> calls_;
179
180
  // IDs of visited functions;
181
  std::unordered_set<uint32_t> visited_;
182
  // IDs of functions, where callees precede callers.
183
  std::vector<uint32_t> visit_order_;
184
};
185
186
// Returns vector binding info for all resource variables in the module.
187
0
auto GetVarBindings(IRContext& context) {
188
0
  std::vector<VarBindingInfo> vars;
189
0
  auto* deco_mgr = context.get_decoration_mgr();
190
0
  for (auto& inst : context.module()->types_values()) {
191
0
    if (inst.opcode() == spv::Op::OpVariable) {
192
0
      Instruction* descriptor_set_deco = nullptr;
193
0
      Instruction* binding_deco = nullptr;
194
0
      for (auto* deco : deco_mgr->GetDecorationsFor(inst.result_id(), false)) {
195
0
        switch (static_cast<spv::Decoration>(deco->GetSingleWordInOperand(1))) {
196
0
          case spv::Decoration::DescriptorSet:
197
0
            assert(!descriptor_set_deco);
198
0
            descriptor_set_deco = deco;
199
0
            break;
200
0
          case spv::Decoration::Binding:
201
0
            assert(!binding_deco);
202
0
            binding_deco = deco;
203
0
            break;
204
0
          default:
205
0
            break;
206
0
        }
207
0
      }
208
0
      if (descriptor_set_deco && binding_deco) {
209
0
        vars.push_back({&inst, descriptor_set_deco->GetSingleWordInOperand(2),
210
0
                        binding_deco});
211
0
      }
212
0
    }
213
0
  }
214
0
  return vars;
215
0
}
216
217
// Merges the bindings from source into sink. Maintains order and uniqueness
218
// within a list of bindings.
219
0
void Merge(DescriptorSets& sink, const DescriptorSets& source) {
220
0
  for (auto index_and_bindings : source) {
221
0
    const uint32_t index = index_and_bindings.first;
222
0
    const BindingList& src1 = index_and_bindings.second;
223
0
    const BindingList& src2 = sink[index];
224
0
    BindingList merged;
225
0
    merged.resize(src1.size() + src2.size());
226
0
    auto merged_end = std::merge(src1.begin(), src1.end(), src2.begin(),
227
0
                                 src2.end(), merged.begin(), Less);
228
0
    auto unique_end = std::unique(merged.begin(), merged_end);
229
0
    merged.resize(unique_end - merged.begin());
230
0
    sink[index] = std::move(merged);
231
0
  }
232
0
}
233
234
// Resolves conflicts within this binding list, so the binding number on an
235
// item is at least one more than the binding number on the previous item.
236
// When this does not yet hold, increase the binding number on the second
237
// item in the pair. Returns true if any changes were applied.
238
0
bool ResolveConflicts(BindingList& bl) {
239
0
  bool changed = false;
240
0
  for (size_t i = 1; i < bl.size(); i++) {
241
0
    const auto prev_num = bl[i - 1]->binding();
242
0
    if (prev_num >= bl[i]->binding()) {
243
0
      bl[i]->updateBinding(prev_num + 1);
244
0
      changed = true;
245
0
    }
246
0
  }
247
0
  return changed;
248
0
}
249
250
0
Pass::Status ResolveBindingConflictsPass::Process() {
251
  // Assumes the descriptor set and binding decorations are not provided
252
  // via decoration groups.  Decoration groups were deprecated in SPIR-V 1.3
253
  // Revision 6.  I have not seen any compiler generate them. --dneto
254
255
0
  auto vars = GetVarBindings(*context());
256
257
  // Maps a function ID to the variables used directly or indirectly by the
258
  // function, organized into descriptor sets. Each descriptor set
259
  // consists of a BindingList of distinct variables.
260
0
  std::unordered_map<uint32_t, DescriptorSets> used_vars;
261
262
  // Determine variables directly used by functions.
263
0
  auto* def_use_mgr = context()->get_def_use_mgr();
264
0
  for (auto& var : vars) {
265
0
    std::unordered_set<uint32_t> visited_functions_for_var;
266
0
    def_use_mgr->ForEachUser(var.var, [&](Instruction* user) {
267
0
      if (auto* block = context()->get_instr_block(user)) {
268
0
        auto* fn = block->GetParent();
269
0
        assert(fn);
270
0
        const auto fn_id = fn->result_id();
271
0
        if (visited_functions_for_var.insert(fn_id).second) {
272
0
          used_vars[fn_id][var.descriptor_set].push_back(&var);
273
0
        }
274
0
      }
275
0
    });
276
0
  }
277
278
  // Sort within a descriptor set by binding number.
279
0
  for (auto& sets_for_fn : used_vars) {
280
0
    for (auto& ds : sets_for_fn.second) {
281
0
      BindingList& bindings = ds.second;
282
0
      std::stable_sort(bindings.begin(), bindings.end(), Less);
283
0
    }
284
0
  }
285
286
  // Propagate from callees to callers.
287
0
  CallGraph call_graph(*context());
288
0
  for (const uint32_t caller : call_graph.CalleesBeforeCallers()) {
289
0
    DescriptorSets& caller_ds = used_vars[caller];
290
0
    for (const uint32_t callee : call_graph.Callees(caller)) {
291
0
      Merge(caller_ds, used_vars[callee]);
292
0
    }
293
0
  }
294
295
  // At this point, the descriptor sets associated with each entry point
296
  // capture exactly the set of resource variables statically used
297
  // by the static call tree of that entry point.
298
299
  // Resolve conflicts.
300
  // VarBindingInfo objects may be shared between the bindings lists.
301
  // Updating a binding in one list can require updating another list later.
302
  // So repeat updates until settling.
303
304
  // The union of BindingLists across all entry points.
305
0
  std::vector<BindingList*> ep_bindings;
306
307
0
  for (auto& ep : context()->module()->entry_points()) {
308
0
    for (auto& ds : used_vars[ep.GetSingleWordInOperand(1)]) {
309
0
      BindingList& bindings = ds.second;
310
0
      ep_bindings.push_back(&bindings);
311
0
    }
312
0
  }
313
0
  bool modified = false;
314
0
  bool found_conflict;
315
0
  do {
316
0
    found_conflict = false;
317
0
    for (BindingList* bl : ep_bindings) {
318
0
      found_conflict |= ResolveConflicts(*bl);
319
0
    }
320
0
    modified |= found_conflict;
321
0
  } while (found_conflict);
322
323
0
  return modified ? Pass::Status::SuccessWithChange
324
0
                  : Pass::Status::SuccessWithoutChange;
325
0
}
326
327
}  // namespace opt
328
}  // namespace spvtools