Coverage Report

Created: 2026-01-16 06:48

next uncovered line (L), next uncovered region (R), next uncovered branch (B)
/src/spirv-tools/source/opt/trim_capabilities_pass.cpp
Line
Count
Source
1
// Copyright (c) 2023 Google Inc.
2
//
3
// Licensed under the Apache License, Version 2.0 (the "License");
4
// you may not use this file except in compliance with the License.
5
// You may obtain a copy of the License at
6
//
7
//     http://www.apache.org/licenses/LICENSE-2.0
8
//
9
// Unless required by applicable law or agreed to in writing, software
10
// distributed under the License is distributed on an "AS IS" BASIS,
11
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
// See the License for the specific language governing permissions and
13
// limitations under the License.
14
15
#include "source/opt/trim_capabilities_pass.h"
16
17
#include <algorithm>
18
#include <array>
19
#include <cassert>
20
#include <functional>
21
#include <optional>
22
#include <queue>
23
#include <stack>
24
#include <unordered_map>
25
#include <unordered_set>
26
#include <vector>
27
28
#include "source/enum_set.h"
29
#include "source/ext_inst.h"
30
#include "source/opt/ir_context.h"
31
#include "source/opt/reflect.h"
32
#include "source/spirv_target_env.h"
33
#include "source/table2.h"
34
#include "source/util/string_utils.h"
35
36
namespace spvtools {
37
namespace opt {
38
39
namespace {
40
constexpr uint32_t kOpTypeFloatSizeIndex = 0;
41
constexpr uint32_t kOpTypePointerStorageClassIndex = 0;
42
constexpr uint32_t kTypeArrayTypeIndex = 0;
43
constexpr uint32_t kOpTypeScalarBitWidthIndex = 0;
44
constexpr uint32_t kTypePointerTypeIdInIndex = 1;
45
constexpr uint32_t kOpTypeIntSizeIndex = 0;
46
constexpr uint32_t kOpTypeImageDimIndex = 1;
47
constexpr uint32_t kOpTypeImageArrayedIndex = kOpTypeImageDimIndex + 2;
48
constexpr uint32_t kOpTypeImageMSIndex = kOpTypeImageArrayedIndex + 1;
49
constexpr uint32_t kOpTypeImageSampledIndex = kOpTypeImageMSIndex + 1;
50
constexpr uint32_t kOpTypeImageFormatIndex = kOpTypeImageSampledIndex + 1;
51
constexpr uint32_t kOpImageReadImageIndex = 0;
52
constexpr uint32_t kOpImageWriteImageIndex = 0;
53
constexpr uint32_t kOpImageSparseReadImageIndex = 0;
54
constexpr uint32_t kOpExtInstSetInIndex = 0;
55
constexpr uint32_t kOpExtInstInstructionInIndex = 1;
56
constexpr uint32_t kOpExtInstImportNameInIndex = 0;
57
58
// DFS visit of the type defined by `instruction`.
59
// If `condition` is true, children of the current node are visited.
60
// If `condition` is false, the children of the current node are ignored.
61
template <class UnaryPredicate>
62
0
static void DFSWhile(const Instruction* instruction, UnaryPredicate condition) {
63
0
  std::stack<uint32_t> instructions_to_visit;
64
0
  std::unordered_set<uint32_t> visited_instructions;
65
0
  instructions_to_visit.push(instruction->result_id());
66
0
  const auto* def_use_mgr = instruction->context()->get_def_use_mgr();
67
68
0
  while (!instructions_to_visit.empty()) {
69
0
    const Instruction* item = def_use_mgr->GetDef(instructions_to_visit.top());
70
0
    instructions_to_visit.pop();
71
72
    // Forward references can be allowed, meaning we can have cycles
73
    // between ID uses. Need to keep track of this.
74
0
    if (visited_instructions.count(item->result_id())) continue;
75
0
    visited_instructions.insert(item->result_id());
76
77
0
    if (!condition(item)) {
78
0
      continue;
79
0
    }
80
81
0
    if (item->opcode() == spv::Op::OpTypePointer) {
82
0
      instructions_to_visit.push(
83
0
          item->GetSingleWordInOperand(kTypePointerTypeIdInIndex));
84
0
      continue;
85
0
    }
86
87
0
    if (item->opcode() == spv::Op::OpTypeMatrix ||
88
0
        item->opcode() == spv::Op::OpTypeVector ||
89
0
        item->opcode() == spv::Op::OpTypeArray ||
90
0
        item->opcode() == spv::Op::OpTypeRuntimeArray) {
91
0
      instructions_to_visit.push(
92
0
          item->GetSingleWordInOperand(kTypeArrayTypeIndex));
93
0
      continue;
94
0
    }
95
96
0
    if (item->opcode() == spv::Op::OpTypeStruct) {
97
0
      item->ForEachInOperand([&instructions_to_visit](const uint32_t* op_id) {
98
0
        instructions_to_visit.push(*op_id);
99
0
      });
Unexecuted instantiation: trim_capabilities_pass.cpp:spvtools::opt::(anonymous namespace)::DFSWhile<spvtools::opt::(anonymous namespace)::AnyTypeOf<bool (*)(spvtools::opt::Instruction const*)>(spvtools::opt::Instruction const*, bool (*)(spvtools::opt::Instruction const*))::{lambda(spvtools::opt::Instruction const*)#1}>(spvtools::opt::Instruction const*, spvtools::opt::(anonymous namespace)::AnyTypeOf<bool (*)(spvtools::opt::Instruction const*)>(spvtools::opt::Instruction const*, bool (*)(spvtools::opt::Instruction const*))::{lambda(spvtools::opt::Instruction const*)#1})::{lambda(unsigned int const*)#1}::operator()(unsigned int const*) const
Unexecuted instantiation: trim_capabilities_pass.cpp:spvtools::opt::(anonymous namespace)::DFSWhile<spvtools::opt::Handler_OpTypePointer_StorageUniform16(spvtools::opt::Instruction const*)::$_0>(spvtools::opt::Instruction const*, spvtools::opt::Handler_OpTypePointer_StorageUniform16(spvtools::opt::Instruction const*)::$_0)::{lambda(unsigned int const*)#1}::operator()(unsigned int const*) const
Unexecuted instantiation: trim_capabilities_pass.cpp:spvtools::opt::(anonymous namespace)::DFSWhile<spvtools::opt::(anonymous namespace)::AnyTypeOf<spvtools::opt::Handler_OpTypePointer_StorageUniformBufferBlock16(spvtools::opt::Instruction const*)::$_0>(spvtools::opt::Instruction const*, spvtools::opt::Handler_OpTypePointer_StorageUniformBufferBlock16(spvtools::opt::Instruction const*)::$_0)::{lambda(spvtools::opt::Instruction const*)#1}>(spvtools::opt::Instruction const*, spvtools::opt::(anonymous namespace)::AnyTypeOf<spvtools::opt::Handler_OpTypePointer_StorageUniformBufferBlock16(spvtools::opt::Instruction const*)::$_0>(spvtools::opt::Instruction const*, spvtools::opt::Handler_OpTypePointer_StorageUniformBufferBlock16(spvtools::opt::Instruction const*)::$_0)::{lambda(spvtools::opt::Instruction const*)#1})::{lambda(unsigned int const*)#1}::operator()(unsigned int const*) const
Unexecuted instantiation: trim_capabilities_pass.cpp:spvtools::opt::(anonymous namespace)::DFSWhile<spvtools::opt::(anonymous namespace)::AnyTypeOf<spvtools::opt::Handler_OpTypePointer_StorageBuffer16BitAccess(spvtools::opt::Instruction const*)::$_0>(spvtools::opt::Instruction const*, spvtools::opt::Handler_OpTypePointer_StorageBuffer16BitAccess(spvtools::opt::Instruction const*)::$_0)::{lambda(spvtools::opt::Instruction const*)#1}>(spvtools::opt::Instruction const*, spvtools::opt::(anonymous namespace)::AnyTypeOf<spvtools::opt::Handler_OpTypePointer_StorageBuffer16BitAccess(spvtools::opt::Instruction const*)::$_0>(spvtools::opt::Instruction const*, spvtools::opt::Handler_OpTypePointer_StorageBuffer16BitAccess(spvtools::opt::Instruction const*)::$_0)::{lambda(spvtools::opt::Instruction const*)#1})::{lambda(unsigned int const*)#1}::operator()(unsigned int const*) const
100
0
      continue;
101
0
    }
102
0
  }
103
0
}
Unexecuted instantiation: trim_capabilities_pass.cpp:void spvtools::opt::(anonymous namespace)::DFSWhile<spvtools::opt::(anonymous namespace)::AnyTypeOf<bool (*)(spvtools::opt::Instruction const*)>(spvtools::opt::Instruction const*, bool (*)(spvtools::opt::Instruction const*))::{lambda(spvtools::opt::Instruction const*)#1}>(spvtools::opt::Instruction const*, spvtools::opt::(anonymous namespace)::AnyTypeOf<bool (*)(spvtools::opt::Instruction const*)>(spvtools::opt::Instruction const*, bool (*)(spvtools::opt::Instruction const*))::{lambda(spvtools::opt::Instruction const*)#1})
Unexecuted instantiation: trim_capabilities_pass.cpp:void spvtools::opt::(anonymous namespace)::DFSWhile<spvtools::opt::Handler_OpTypePointer_StorageUniform16(spvtools::opt::Instruction const*)::$_0>(spvtools::opt::Instruction const*, spvtools::opt::Handler_OpTypePointer_StorageUniform16(spvtools::opt::Instruction const*)::$_0)
Unexecuted instantiation: trim_capabilities_pass.cpp:void spvtools::opt::(anonymous namespace)::DFSWhile<spvtools::opt::(anonymous namespace)::AnyTypeOf<spvtools::opt::Handler_OpTypePointer_StorageUniformBufferBlock16(spvtools::opt::Instruction const*)::$_0>(spvtools::opt::Instruction const*, spvtools::opt::Handler_OpTypePointer_StorageUniformBufferBlock16(spvtools::opt::Instruction const*)::$_0)::{lambda(spvtools::opt::Instruction const*)#1}>(spvtools::opt::Instruction const*, spvtools::opt::(anonymous namespace)::AnyTypeOf<spvtools::opt::Handler_OpTypePointer_StorageUniformBufferBlock16(spvtools::opt::Instruction const*)::$_0>(spvtools::opt::Instruction const*, spvtools::opt::Handler_OpTypePointer_StorageUniformBufferBlock16(spvtools::opt::Instruction const*)::$_0)::{lambda(spvtools::opt::Instruction const*)#1})
Unexecuted instantiation: trim_capabilities_pass.cpp:void spvtools::opt::(anonymous namespace)::DFSWhile<spvtools::opt::(anonymous namespace)::AnyTypeOf<spvtools::opt::Handler_OpTypePointer_StorageBuffer16BitAccess(spvtools::opt::Instruction const*)::$_0>(spvtools::opt::Instruction const*, spvtools::opt::Handler_OpTypePointer_StorageBuffer16BitAccess(spvtools::opt::Instruction const*)::$_0)::{lambda(spvtools::opt::Instruction const*)#1}>(spvtools::opt::Instruction const*, spvtools::opt::(anonymous namespace)::AnyTypeOf<spvtools::opt::Handler_OpTypePointer_StorageBuffer16BitAccess(spvtools::opt::Instruction const*)::$_0>(spvtools::opt::Instruction const*, spvtools::opt::Handler_OpTypePointer_StorageBuffer16BitAccess(spvtools::opt::Instruction const*)::$_0)::{lambda(spvtools::opt::Instruction const*)#1})
104
105
// Walks the type defined by `instruction` (OpType* only).
106
// Returns `true` if any call to `predicate` with the type/subtype returns true.
107
template <class UnaryPredicate>
108
static bool AnyTypeOf(const Instruction* instruction,
109
0
                      UnaryPredicate predicate) {
110
0
  assert(IsTypeInst(instruction->opcode()) &&
111
0
         "AnyTypeOf called with a non-type instruction.");
112
113
0
  bool found_one = false;
114
0
  DFSWhile(instruction, [&found_one, predicate](const Instruction* node) {
115
0
    if (found_one || predicate(node)) {
116
0
      found_one = true;
117
0
      return false;
118
0
    }
119
120
0
    return true;
121
0
  });
Unexecuted instantiation: trim_capabilities_pass.cpp:spvtools::opt::(anonymous namespace)::AnyTypeOf<bool (*)(spvtools::opt::Instruction const*)>(spvtools::opt::Instruction const*, bool (*)(spvtools::opt::Instruction const*))::{lambda(spvtools::opt::Instruction const*)#1}::operator()(spvtools::opt::Instruction const*) const
Unexecuted instantiation: trim_capabilities_pass.cpp:spvtools::opt::(anonymous namespace)::AnyTypeOf<spvtools::opt::Handler_OpTypePointer_StorageUniformBufferBlock16(spvtools::opt::Instruction const*)::$_0>(spvtools::opt::Instruction const*, spvtools::opt::Handler_OpTypePointer_StorageUniformBufferBlock16(spvtools::opt::Instruction const*)::$_0)::{lambda(spvtools::opt::Instruction const*)#1}::operator()(spvtools::opt::Instruction const*) const
Unexecuted instantiation: trim_capabilities_pass.cpp:spvtools::opt::(anonymous namespace)::AnyTypeOf<spvtools::opt::Handler_OpTypePointer_StorageBuffer16BitAccess(spvtools::opt::Instruction const*)::$_0>(spvtools::opt::Instruction const*, spvtools::opt::Handler_OpTypePointer_StorageBuffer16BitAccess(spvtools::opt::Instruction const*)::$_0)::{lambda(spvtools::opt::Instruction const*)#1}::operator()(spvtools::opt::Instruction const*) const
122
0
  return found_one;
123
0
}
Unexecuted instantiation: trim_capabilities_pass.cpp:bool spvtools::opt::(anonymous namespace)::AnyTypeOf<bool (*)(spvtools::opt::Instruction const*)>(spvtools::opt::Instruction const*, bool (*)(spvtools::opt::Instruction const*))
Unexecuted instantiation: trim_capabilities_pass.cpp:bool spvtools::opt::(anonymous namespace)::AnyTypeOf<spvtools::opt::Handler_OpTypePointer_StorageUniformBufferBlock16(spvtools::opt::Instruction const*)::$_0>(spvtools::opt::Instruction const*, spvtools::opt::Handler_OpTypePointer_StorageUniformBufferBlock16(spvtools::opt::Instruction const*)::$_0)
Unexecuted instantiation: trim_capabilities_pass.cpp:bool spvtools::opt::(anonymous namespace)::AnyTypeOf<spvtools::opt::Handler_OpTypePointer_StorageBuffer16BitAccess(spvtools::opt::Instruction const*)::$_0>(spvtools::opt::Instruction const*, spvtools::opt::Handler_OpTypePointer_StorageBuffer16BitAccess(spvtools::opt::Instruction const*)::$_0)
124
125
0
static bool is16bitType(const Instruction* instruction) {
126
0
  if (instruction->opcode() != spv::Op::OpTypeInt &&
127
0
      instruction->opcode() != spv::Op::OpTypeFloat) {
128
0
    return false;
129
0
  }
130
131
0
  return instruction->GetSingleWordInOperand(kOpTypeScalarBitWidthIndex) == 16;
132
0
}
133
134
0
static bool Has16BitCapability(const FeatureManager* feature_manager) {
135
0
  const CapabilitySet& capabilities = feature_manager->GetCapabilities();
136
0
  return capabilities.contains(spv::Capability::Float16) ||
137
0
         capabilities.contains(spv::Capability::Int16);
138
0
}
139
140
}  // namespace
141
142
// ============== Begin opcode handler implementations. =======================
143
//
144
// Adding support for a new capability should only require adding a new handler,
145
// and updating the
146
// kSupportedCapabilities/kUntouchableCapabilities/kFordiddenCapabilities lists.
147
//
148
// Handler names follow the following convention:
149
//  Handler_<Opcode>_<Capability>()
150
151
static std::optional<spv::Capability> Handler_OpTypeFloat_Float16(
152
0
    const Instruction* instruction) {
153
0
  assert(instruction->opcode() == spv::Op::OpTypeFloat &&
154
0
         "This handler only support OpTypeFloat opcodes.");
155
156
0
  const uint32_t size =
157
0
      instruction->GetSingleWordInOperand(kOpTypeFloatSizeIndex);
158
0
  return size == 16 ? std::optional(spv::Capability::Float16) : std::nullopt;
159
0
}
160
161
static std::optional<spv::Capability> Handler_OpTypeFloat_Float64(
162
0
    const Instruction* instruction) {
163
0
  assert(instruction->opcode() == spv::Op::OpTypeFloat &&
164
0
         "This handler only support OpTypeFloat opcodes.");
165
166
0
  const uint32_t size =
167
0
      instruction->GetSingleWordInOperand(kOpTypeFloatSizeIndex);
168
0
  return size == 64 ? std::optional(spv::Capability::Float64) : std::nullopt;
169
0
}
170
171
static std::optional<spv::Capability>
172
0
Handler_OpTypePointer_StorageInputOutput16(const Instruction* instruction) {
173
0
  assert(instruction->opcode() == spv::Op::OpTypePointer &&
174
0
         "This handler only support OpTypePointer opcodes.");
175
176
  // This capability is only required if the variable has an Input/Output
177
  // storage class.
178
0
  spv::StorageClass storage_class = spv::StorageClass(
179
0
      instruction->GetSingleWordInOperand(kOpTypePointerStorageClassIndex));
180
0
  if (storage_class != spv::StorageClass::Input &&
181
0
      storage_class != spv::StorageClass::Output) {
182
0
    return std::nullopt;
183
0
  }
184
185
0
  if (!Has16BitCapability(instruction->context()->get_feature_mgr())) {
186
0
    return std::nullopt;
187
0
  }
188
189
0
  return AnyTypeOf(instruction, is16bitType)
190
0
             ? std::optional(spv::Capability::StorageInputOutput16)
191
0
             : std::nullopt;
192
0
}
193
194
static std::optional<spv::Capability>
195
0
Handler_OpTypePointer_StoragePushConstant16(const Instruction* instruction) {
196
0
  assert(instruction->opcode() == spv::Op::OpTypePointer &&
197
0
         "This handler only support OpTypePointer opcodes.");
198
199
  // This capability is only required if the variable has a PushConstant storage
200
  // class.
201
0
  spv::StorageClass storage_class = spv::StorageClass(
202
0
      instruction->GetSingleWordInOperand(kOpTypePointerStorageClassIndex));
203
0
  if (storage_class != spv::StorageClass::PushConstant) {
204
0
    return std::nullopt;
205
0
  }
206
207
0
  if (!Has16BitCapability(instruction->context()->get_feature_mgr())) {
208
0
    return std::nullopt;
209
0
  }
210
211
0
  return AnyTypeOf(instruction, is16bitType)
212
0
             ? std::optional(spv::Capability::StoragePushConstant16)
213
0
             : std::nullopt;
214
0
}
215
216
static std::optional<spv::Capability>
217
Handler_OpTypePointer_StorageUniformBufferBlock16(
218
0
    const Instruction* instruction) {
219
0
  assert(instruction->opcode() == spv::Op::OpTypePointer &&
220
0
         "This handler only support OpTypePointer opcodes.");
221
222
  // This capability is only required if the variable has a Uniform storage
223
  // class.
224
0
  spv::StorageClass storage_class = spv::StorageClass(
225
0
      instruction->GetSingleWordInOperand(kOpTypePointerStorageClassIndex));
226
0
  if (storage_class != spv::StorageClass::Uniform) {
227
0
    return std::nullopt;
228
0
  }
229
230
0
  if (!Has16BitCapability(instruction->context()->get_feature_mgr())) {
231
0
    return std::nullopt;
232
0
  }
233
234
0
  const auto* decoration_mgr = instruction->context()->get_decoration_mgr();
235
0
  const bool matchesCondition =
236
0
      AnyTypeOf(instruction, [decoration_mgr](const Instruction* item) {
237
0
        if (!decoration_mgr->HasDecoration(item->result_id(),
238
0
                                           spv::Decoration::BufferBlock)) {
239
0
          return false;
240
0
        }
241
242
0
        return AnyTypeOf(item, is16bitType);
243
0
      });
244
245
0
  return matchesCondition
246
0
             ? std::optional(spv::Capability::StorageUniformBufferBlock16)
247
0
             : std::nullopt;
248
0
}
249
250
static std::optional<spv::Capability>
251
0
Handler_OpTypePointer_StorageBuffer16BitAccess(const Instruction* instruction) {
252
0
  assert(instruction->opcode() == spv::Op::OpTypePointer &&
253
0
         "This handler only support OpTypePointer opcodes.");
254
255
  // Requires StorageBuffer, ShaderRecordBufferKHR or PhysicalStorageBuffer
256
  // storage classes.
257
0
  spv::StorageClass storage_class = spv::StorageClass(
258
0
      instruction->GetSingleWordInOperand(kOpTypePointerStorageClassIndex));
259
0
  if (storage_class != spv::StorageClass::StorageBuffer &&
260
0
      storage_class != spv::StorageClass::ShaderRecordBufferKHR &&
261
0
      storage_class != spv::StorageClass::PhysicalStorageBuffer) {
262
0
    return std::nullopt;
263
0
  }
264
265
0
  const auto* decoration_mgr = instruction->context()->get_decoration_mgr();
266
0
  const bool matchesCondition =
267
0
      AnyTypeOf(instruction, [decoration_mgr](const Instruction* item) {
268
0
        if (!decoration_mgr->HasDecoration(item->result_id(),
269
0
                                           spv::Decoration::Block)) {
270
0
          return false;
271
0
        }
272
273
0
        return AnyTypeOf(item, is16bitType);
274
0
      });
275
276
0
  return matchesCondition
277
0
             ? std::optional(spv::Capability::StorageBuffer16BitAccess)
278
0
             : std::nullopt;
279
0
}
280
281
static std::optional<spv::Capability> Handler_OpTypePointer_StorageUniform16(
282
0
    const Instruction* instruction) {
283
0
  assert(instruction->opcode() == spv::Op::OpTypePointer &&
284
0
         "This handler only support OpTypePointer opcodes.");
285
286
  // This capability is only required if the variable has a Uniform storage
287
  // class.
288
0
  spv::StorageClass storage_class = spv::StorageClass(
289
0
      instruction->GetSingleWordInOperand(kOpTypePointerStorageClassIndex));
290
0
  if (storage_class != spv::StorageClass::Uniform) {
291
0
    return std::nullopt;
292
0
  }
293
294
0
  const auto* feature_manager = instruction->context()->get_feature_mgr();
295
0
  if (!Has16BitCapability(feature_manager)) {
296
0
    return std::nullopt;
297
0
  }
298
299
0
  const bool hasBufferBlockCapability =
300
0
      feature_manager->GetCapabilities().contains(
301
0
          spv::Capability::StorageUniformBufferBlock16);
302
0
  const auto* decoration_mgr = instruction->context()->get_decoration_mgr();
303
0
  bool found16bitType = false;
304
305
0
  DFSWhile(instruction, [decoration_mgr, hasBufferBlockCapability,
306
0
                         &found16bitType](const Instruction* item) {
307
0
    if (found16bitType) {
308
0
      return false;
309
0
    }
310
311
0
    if (hasBufferBlockCapability &&
312
0
        decoration_mgr->HasDecoration(item->result_id(),
313
0
                                      spv::Decoration::BufferBlock)) {
314
0
      return false;
315
0
    }
316
317
0
    if (is16bitType(item)) {
318
0
      found16bitType = true;
319
0
      return false;
320
0
    }
321
322
0
    return true;
323
0
  });
324
325
0
  return found16bitType ? std::optional(spv::Capability::StorageUniform16)
326
0
                        : std::nullopt;
327
0
}
328
329
static std::optional<spv::Capability> Handler_OpTypeInt_Int16(
330
0
    const Instruction* instruction) {
331
0
  assert(instruction->opcode() == spv::Op::OpTypeInt &&
332
0
         "This handler only support OpTypeInt opcodes.");
333
334
0
  const uint32_t size =
335
0
      instruction->GetSingleWordInOperand(kOpTypeIntSizeIndex);
336
0
  return size == 16 ? std::optional(spv::Capability::Int16) : std::nullopt;
337
0
}
338
339
static std::optional<spv::Capability> Handler_OpTypeInt_Int64(
340
0
    const Instruction* instruction) {
341
0
  assert(instruction->opcode() == spv::Op::OpTypeInt &&
342
0
         "This handler only support OpTypeInt opcodes.");
343
344
0
  const uint32_t size =
345
0
      instruction->GetSingleWordInOperand(kOpTypeIntSizeIndex);
346
0
  return size == 64 ? std::optional(spv::Capability::Int64) : std::nullopt;
347
0
}
348
349
static std::optional<spv::Capability> Handler_OpTypeImage_ImageMSArray(
350
0
    const Instruction* instruction) {
351
0
  assert(instruction->opcode() == spv::Op::OpTypeImage &&
352
0
         "This handler only support OpTypeImage opcodes.");
353
354
0
  const uint32_t arrayed =
355
0
      instruction->GetSingleWordInOperand(kOpTypeImageArrayedIndex);
356
0
  const uint32_t ms = instruction->GetSingleWordInOperand(kOpTypeImageMSIndex);
357
0
  const uint32_t sampled =
358
0
      instruction->GetSingleWordInOperand(kOpTypeImageSampledIndex);
359
360
0
  return arrayed == 1 && sampled == 2 && ms == 1
361
0
             ? std::optional(spv::Capability::ImageMSArray)
362
0
             : std::nullopt;
363
0
}
364
365
static std::optional<spv::Capability>
366
Handler_OpImageRead_StorageImageReadWithoutFormat(
367
0
    const Instruction* instruction) {
368
0
  assert(instruction->opcode() == spv::Op::OpImageRead &&
369
0
         "This handler only support OpImageRead opcodes.");
370
0
  const auto* def_use_mgr = instruction->context()->get_def_use_mgr();
371
372
0
  const uint32_t image_index =
373
0
      instruction->GetSingleWordInOperand(kOpImageReadImageIndex);
374
0
  const uint32_t type_index = def_use_mgr->GetDef(image_index)->type_id();
375
0
  const Instruction* type = def_use_mgr->GetDef(type_index);
376
0
  const uint32_t dim = type->GetSingleWordInOperand(kOpTypeImageDimIndex);
377
0
  const uint32_t format = type->GetSingleWordInOperand(kOpTypeImageFormatIndex);
378
379
  // If the Image Format is Unknown and Dim is SubpassData,
380
  // StorageImageReadWithoutFormat is required.
381
0
  const bool is_unknown = spv::ImageFormat(format) == spv::ImageFormat::Unknown;
382
0
  const bool requires_capability_for_unknown =
383
0
      spv::Dim(dim) != spv::Dim::SubpassData;
384
0
  return is_unknown && requires_capability_for_unknown
385
0
             ? std::optional(spv::Capability::StorageImageReadWithoutFormat)
386
0
             : std::nullopt;
387
0
}
388
389
static std::optional<spv::Capability>
390
Handler_OpImageWrite_StorageImageWriteWithoutFormat(
391
0
    const Instruction* instruction) {
392
0
  assert(instruction->opcode() == spv::Op::OpImageWrite &&
393
0
         "This handler only support OpImageWrite opcodes.");
394
0
  const auto* def_use_mgr = instruction->context()->get_def_use_mgr();
395
396
0
  const uint32_t image_index =
397
0
      instruction->GetSingleWordInOperand(kOpImageWriteImageIndex);
398
0
  const uint32_t type_index = def_use_mgr->GetDef(image_index)->type_id();
399
400
  // If the Image Format is Unknown, StorageImageWriteWithoutFormat is required.
401
0
  const Instruction* type = def_use_mgr->GetDef(type_index);
402
0
  const uint32_t format = type->GetSingleWordInOperand(kOpTypeImageFormatIndex);
403
0
  const bool is_unknown = spv::ImageFormat(format) == spv::ImageFormat::Unknown;
404
0
  return is_unknown
405
0
             ? std::optional(spv::Capability::StorageImageWriteWithoutFormat)
406
0
             : std::nullopt;
407
0
}
408
409
static std::optional<spv::Capability>
410
Handler_OpImageSparseRead_StorageImageReadWithoutFormat(
411
0
    const Instruction* instruction) {
412
0
  assert(instruction->opcode() == spv::Op::OpImageSparseRead &&
413
0
         "This handler only support OpImageSparseRead opcodes.");
414
0
  const auto* def_use_mgr = instruction->context()->get_def_use_mgr();
415
416
0
  const uint32_t image_index =
417
0
      instruction->GetSingleWordInOperand(kOpImageSparseReadImageIndex);
418
0
  const uint32_t type_index = def_use_mgr->GetDef(image_index)->type_id();
419
0
  const Instruction* type = def_use_mgr->GetDef(type_index);
420
0
  const uint32_t format = type->GetSingleWordInOperand(kOpTypeImageFormatIndex);
421
422
0
  return spv::ImageFormat(format) == spv::ImageFormat::Unknown
423
0
             ? std::optional(spv::Capability::StorageImageReadWithoutFormat)
424
0
             : std::nullopt;
425
0
}
426
427
// Opcode of interest to determine capabilities requirements.
428
constexpr std::array<std::pair<spv::Op, OpcodeHandler>, 14> kOpcodeHandlers{{
429
    // clang-format off
430
    {spv::Op::OpImageRead,                   Handler_OpImageRead_StorageImageReadWithoutFormat},
431
    {spv::Op::OpImageWrite,                  Handler_OpImageWrite_StorageImageWriteWithoutFormat},
432
    {spv::Op::OpImageSparseRead,             Handler_OpImageSparseRead_StorageImageReadWithoutFormat},
433
    {spv::Op::OpTypeFloat,                   Handler_OpTypeFloat_Float16 },
434
    {spv::Op::OpTypeFloat,                   Handler_OpTypeFloat_Float64 },
435
    {spv::Op::OpTypeImage,                   Handler_OpTypeImage_ImageMSArray},
436
    {spv::Op::OpTypeInt,                     Handler_OpTypeInt_Int16 },
437
    {spv::Op::OpTypeInt,                     Handler_OpTypeInt_Int64 },
438
    {spv::Op::OpTypePointer,                 Handler_OpTypePointer_StorageInputOutput16},
439
    {spv::Op::OpTypePointer,                 Handler_OpTypePointer_StoragePushConstant16},
440
    {spv::Op::OpTypePointer,                 Handler_OpTypePointer_StorageUniform16},
441
    {spv::Op::OpTypePointer,                 Handler_OpTypePointer_StorageUniform16},
442
    {spv::Op::OpTypePointer,                 Handler_OpTypePointer_StorageUniformBufferBlock16},
443
    {spv::Op::OpTypePointer,                 Handler_OpTypePointer_StorageBuffer16BitAccess},
444
    // clang-format on
445
}};
446
447
// ==============  End opcode handler implementations.  =======================
448
449
namespace {
450
0
ExtensionSet getExtensionsRelatedTo(const CapabilitySet& capabilities) {
451
0
  ExtensionSet output;
452
0
  const spvtools::OperandDesc* desc = nullptr;
453
0
  for (auto capability : capabilities) {
454
0
    if (SPV_SUCCESS !=
455
0
        spvtools::LookupOperand(SPV_OPERAND_TYPE_CAPABILITY,
456
0
                                static_cast<uint32_t>(capability), &desc)) {
457
0
      continue;
458
0
    }
459
460
0
    for (auto extension : desc->extensions()) {
461
0
      output.insert(extension);
462
0
    }
463
0
  }
464
465
0
  return output;
466
0
}
467
468
0
bool hasOpcodeConflictingCapabilities(spv::Op opcode) {
469
0
  switch (opcode) {
470
0
    case spv::Op::OpBeginInvocationInterlockEXT:
471
0
    case spv::Op::OpEndInvocationInterlockEXT:
472
0
    case spv::Op::OpGroupNonUniformIAdd:
473
0
    case spv::Op::OpGroupNonUniformFAdd:
474
0
    case spv::Op::OpGroupNonUniformIMul:
475
0
    case spv::Op::OpGroupNonUniformFMul:
476
0
    case spv::Op::OpGroupNonUniformSMin:
477
0
    case spv::Op::OpGroupNonUniformUMin:
478
0
    case spv::Op::OpGroupNonUniformFMin:
479
0
    case spv::Op::OpGroupNonUniformSMax:
480
0
    case spv::Op::OpGroupNonUniformUMax:
481
0
    case spv::Op::OpGroupNonUniformFMax:
482
0
    case spv::Op::OpGroupNonUniformBitwiseAnd:
483
0
    case spv::Op::OpGroupNonUniformBitwiseOr:
484
0
    case spv::Op::OpGroupNonUniformBitwiseXor:
485
0
    case spv::Op::OpGroupNonUniformLogicalAnd:
486
0
    case spv::Op::OpGroupNonUniformLogicalOr:
487
0
    case spv::Op::OpGroupNonUniformLogicalXor:
488
0
      return true;
489
0
    default:
490
0
      return false;
491
0
  }
492
0
}
493
494
}  // namespace
495
496
TrimCapabilitiesPass::TrimCapabilitiesPass()
497
0
    : supportedCapabilities_(
498
0
          TrimCapabilitiesPass::kSupportedCapabilities.cbegin(),
499
0
          TrimCapabilitiesPass::kSupportedCapabilities.cend()),
500
0
      forbiddenCapabilities_(
501
0
          TrimCapabilitiesPass::kForbiddenCapabilities.cbegin(),
502
0
          TrimCapabilitiesPass::kForbiddenCapabilities.cend()),
503
0
      untouchableCapabilities_(
504
0
          TrimCapabilitiesPass::kUntouchableCapabilities.cbegin(),
505
0
          TrimCapabilitiesPass::kUntouchableCapabilities.cend()),
506
0
      opcodeHandlers_(kOpcodeHandlers.cbegin(), kOpcodeHandlers.cend()) {}
507
508
void TrimCapabilitiesPass::addInstructionRequirementsForOpcode(
509
    spv::Op opcode, CapabilitySet* capabilities,
510
0
    ExtensionSet* extensions) const {
511
0
  if (hasOpcodeConflictingCapabilities(opcode)) {
512
0
    return;
513
0
  }
514
515
0
  const spvtools::InstructionDesc* desc;
516
0
  auto result = spvtools::LookupOpcode(opcode, &desc);
517
0
  if (result != SPV_SUCCESS) {
518
0
    return;
519
0
  }
520
521
0
  addSupportedCapabilitiesToSet(desc, capabilities);
522
0
  addSupportedExtensionsToSet(desc, extensions);
523
0
}
524
525
void TrimCapabilitiesPass::addInstructionRequirementsForOperand(
526
    const Operand& operand, CapabilitySet* capabilities,
527
0
    ExtensionSet* extensions) const {
528
  // No supported capability relies on a 2+-word operand.
529
0
  if (operand.words.size() != 1) {
530
0
    return;
531
0
  }
532
533
  // No supported capability relies on a literal string operand or an ID.
534
0
  if (operand.type == SPV_OPERAND_TYPE_LITERAL_STRING ||
535
0
      operand.type == SPV_OPERAND_TYPE_ID ||
536
0
      operand.type == SPV_OPERAND_TYPE_RESULT_ID) {
537
0
    return;
538
0
  }
539
540
  // If the Vulkan memory model is declared and any instruction uses Device
541
  // scope, the VulkanMemoryModelDeviceScope capability must be declared. This
542
  // rule cannot be covered by the grammar, so must be checked explicitly.
543
0
  if (operand.type == SPV_OPERAND_TYPE_SCOPE_ID) {
544
0
    const Instruction* memory_model = context()->GetMemoryModel();
545
0
    if (memory_model && memory_model->GetSingleWordInOperand(1u) ==
546
0
                            uint32_t(spv::MemoryModel::Vulkan)) {
547
0
      capabilities->insert(spv::Capability::VulkanMemoryModelDeviceScope);
548
0
    }
549
0
  }
550
551
  // case 1: Operand is a single value, can directly lookup.
552
0
  if (!spvOperandIsConcreteMask(operand.type)) {
553
0
    const spvtools::OperandDesc* desc = nullptr;
554
0
    auto result =
555
0
        spvtools::LookupOperand(operand.type, operand.words[0], &desc);
556
0
    if (result != SPV_SUCCESS) {
557
0
      return;
558
0
    }
559
0
    addSupportedCapabilitiesToSet(desc, capabilities);
560
0
    addSupportedExtensionsToSet(desc, extensions);
561
0
    return;
562
0
  }
563
564
  // case 2: operand can be a bitmask, we need to decompose the lookup.
565
0
  for (uint32_t i = 0; i < 32; i++) {
566
0
    const uint32_t mask = (1 << i) & operand.words[0];
567
0
    if (!mask) {
568
0
      continue;
569
0
    }
570
571
0
    const spvtools::OperandDesc* desc = nullptr;
572
0
    auto result = spvtools::LookupOperand(operand.type, mask, &desc);
573
0
    if (result != SPV_SUCCESS) {
574
0
      continue;
575
0
    }
576
577
0
    addSupportedCapabilitiesToSet(desc, capabilities);
578
0
    addSupportedExtensionsToSet(desc, extensions);
579
0
  }
580
0
}
581
582
void TrimCapabilitiesPass::addInstructionRequirementsForExtInst(
583
0
    Instruction* instruction, CapabilitySet* capabilities) const {
584
0
  assert(instruction->opcode() == spv::Op::OpExtInst &&
585
0
         "addInstructionRequirementsForExtInst must be passed an OpExtInst "
586
0
         "instruction");
587
588
0
  const auto* def_use_mgr = context()->get_def_use_mgr();
589
590
0
  const Instruction* extInstImport = def_use_mgr->GetDef(
591
0
      instruction->GetSingleWordInOperand(kOpExtInstSetInIndex));
592
0
  uint32_t extInstruction =
593
0
      instruction->GetSingleWordInOperand(kOpExtInstInstructionInIndex);
594
595
0
  const Operand& extInstSet =
596
0
      extInstImport->GetInOperand(kOpExtInstImportNameInIndex);
597
598
0
  spv_ext_inst_type_t instructionSet =
599
0
      spvExtInstImportTypeGet(extInstSet.AsString().c_str());
600
601
0
  const ExtInstDesc* desc = nullptr;
602
0
  auto result = LookupExtInst(instructionSet, extInstruction, &desc);
603
0
  if (result != SPV_SUCCESS) {
604
0
    return;
605
0
  }
606
607
0
  addSupportedCapabilitiesToSet(desc, capabilities);
608
0
}
609
610
void TrimCapabilitiesPass::addInstructionRequirements(
611
    Instruction* instruction, CapabilitySet* capabilities,
612
0
    ExtensionSet* extensions) const {
613
  // Ignoring OpCapability and OpExtension instructions.
614
0
  if (instruction->opcode() == spv::Op::OpCapability ||
615
0
      instruction->opcode() == spv::Op::OpConditionalCapabilityINTEL ||
616
0
      instruction->opcode() == spv::Op::OpExtension ||
617
0
      instruction->opcode() == spv::Op::OpConditionalExtensionINTEL) {
618
0
    return;
619
0
  }
620
621
0
  if (instruction->opcode() == spv::Op::OpExtInst) {
622
0
    addInstructionRequirementsForExtInst(instruction, capabilities);
623
0
  } else {
624
0
    addInstructionRequirementsForOpcode(instruction->opcode(), capabilities,
625
0
                                        extensions);
626
0
  }
627
628
  // Second case: one of the opcode operand is gated by a capability.
629
0
  const uint32_t operandCount = instruction->NumOperands();
630
0
  for (uint32_t i = 0; i < operandCount; i++) {
631
0
    addInstructionRequirementsForOperand(instruction->GetOperand(i),
632
0
                                         capabilities, extensions);
633
0
  }
634
635
  // Last case: some complex logic needs to be run to determine capabilities.
636
0
  auto [begin, end] = opcodeHandlers_.equal_range(instruction->opcode());
637
0
  for (auto it = begin; it != end; it++) {
638
0
    const OpcodeHandler handler = it->second;
639
0
    auto result = handler(instruction);
640
0
    if (!result.has_value()) {
641
0
      continue;
642
0
    }
643
644
0
    capabilities->insert(*result);
645
0
  }
646
0
}
647
648
void TrimCapabilitiesPass::AddExtensionsForOperand(
649
    const spv_operand_type_t type, const uint32_t value,
650
0
    ExtensionSet* extensions) const {
651
0
  const spvtools::OperandDesc* desc = nullptr;
652
0
  spv_result_t result = spvtools::LookupOperand(type, value, &desc);
653
0
  if (result != SPV_SUCCESS) {
654
0
    return;
655
0
  }
656
0
  addSupportedExtensionsToSet(desc, extensions);
657
0
}
658
659
std::pair<CapabilitySet, ExtensionSet>
660
0
TrimCapabilitiesPass::DetermineRequiredCapabilitiesAndExtensions() const {
661
0
  CapabilitySet required_capabilities;
662
0
  ExtensionSet required_extensions;
663
664
0
  get_module()->ForEachInst([&](Instruction* instruction) {
665
0
    addInstructionRequirements(instruction, &required_capabilities,
666
0
                               &required_extensions);
667
0
  });
668
669
0
  for (auto capability : required_capabilities) {
670
0
    AddExtensionsForOperand(SPV_OPERAND_TYPE_CAPABILITY,
671
0
                            static_cast<uint32_t>(capability),
672
0
                            &required_extensions);
673
0
  }
674
675
0
#if !defined(NDEBUG)
676
  // Debug only. We check the outputted required capabilities against the
677
  // supported capabilities list. The supported capabilities list is useful for
678
  // API users to quickly determine if they can use the pass or not. But this
679
  // list has to remain up-to-date with the pass code. If we can detect a
680
  // capability as required, but it's not listed, it means the list is
681
  // out-of-sync. This method is not ideal, but should cover most cases.
682
0
  {
683
0
    for (auto capability : required_capabilities) {
684
0
      assert(supportedCapabilities_.contains(capability) &&
685
0
             "Module is using a capability that is not listed as supported.");
686
0
    }
687
0
  }
688
0
#endif
689
690
0
  return std::make_pair(std::move(required_capabilities),
691
0
                        std::move(required_extensions));
692
0
}
693
694
Pass::Status TrimCapabilitiesPass::TrimUnrequiredCapabilities(
695
0
    const CapabilitySet& required_capabilities) const {
696
0
  const FeatureManager* feature_manager = context()->get_feature_mgr();
697
0
  CapabilitySet capabilities_to_trim;
698
0
  for (auto capability : feature_manager->GetCapabilities()) {
699
    // Some capabilities cannot be safely removed. Leaving them untouched.
700
0
    if (untouchableCapabilities_.contains(capability)) {
701
0
      continue;
702
0
    }
703
704
    // If the capability is unsupported, don't trim it.
705
0
    if (!supportedCapabilities_.contains(capability)) {
706
0
      continue;
707
0
    }
708
709
0
    if (required_capabilities.contains(capability)) {
710
0
      continue;
711
0
    }
712
713
0
    capabilities_to_trim.insert(capability);
714
0
  }
715
716
0
  for (auto capability : capabilities_to_trim) {
717
0
    context()->RemoveCapability(capability);
718
0
  }
719
720
0
  return capabilities_to_trim.size() == 0 ? Pass::Status::SuccessWithoutChange
721
0
                                          : Pass::Status::SuccessWithChange;
722
0
}
723
724
Pass::Status TrimCapabilitiesPass::TrimUnrequiredExtensions(
725
0
    const ExtensionSet& required_extensions) const {
726
0
  const auto supported_extensions =
727
0
      getExtensionsRelatedTo(supportedCapabilities_);
728
729
0
  bool modified_module = false;
730
0
  for (auto extension : supported_extensions) {
731
0
    if (required_extensions.contains(extension)) {
732
0
      continue;
733
0
    }
734
735
0
    if (context()->RemoveExtension(extension)) {
736
0
      modified_module = true;
737
0
    }
738
0
  }
739
740
0
  return modified_module ? Pass::Status::SuccessWithChange
741
0
                         : Pass::Status::SuccessWithoutChange;
742
0
}
743
744
0
bool TrimCapabilitiesPass::HasForbiddenCapabilities() const {
745
  // EnumSet.HasAnyOf returns `true` if the given set is empty.
746
0
  if (forbiddenCapabilities_.size() == 0) {
747
0
    return false;
748
0
  }
749
750
0
  const auto& capabilities = context()->get_feature_mgr()->GetCapabilities();
751
0
  return capabilities.HasAnyOf(forbiddenCapabilities_);
752
0
}
753
754
0
Pass::Status TrimCapabilitiesPass::Process() {
755
0
  if (HasForbiddenCapabilities()) {
756
0
    return Status::SuccessWithoutChange;
757
0
  }
758
759
0
  auto [required_capabilities, required_extensions] =
760
0
      DetermineRequiredCapabilitiesAndExtensions();
761
762
0
  Pass::Status capStatus = TrimUnrequiredCapabilities(required_capabilities);
763
0
  Pass::Status extStatus = TrimUnrequiredExtensions(required_extensions);
764
765
0
  return capStatus == Pass::Status::SuccessWithChange ||
766
0
                 extStatus == Pass::Status::SuccessWithChange
767
0
             ? Pass::Status::SuccessWithChange
768
0
             : Pass::Status::SuccessWithoutChange;
769
0
}
770
771
}  // namespace opt
772
}  // namespace spvtools