Coverage Report

Created: 2025-06-13 06:49

/src/spirv-tools/source/opt/trim_capabilities_pass.h
Line
Count
Source (jump to first uncovered line)
1
// Copyright (c) 2023 Google Inc.
2
//
3
// Licensed under the Apache License, Version 2.0 (the "License");
4
// you may not use this file except in compliance with the License.
5
// You may obtain a copy of the License at
6
//
7
//     http://www.apache.org/licenses/LICENSE-2.0
8
//
9
// Unless required by applicable law or agreed to in writing, software
10
// distributed under the License is distributed on an "AS IS" BASIS,
11
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
// See the License for the specific language governing permissions and
13
// limitations under the License.
14
15
#ifndef SOURCE_OPT_TRIM_CAPABILITIES_PASS_H_
16
#define SOURCE_OPT_TRIM_CAPABILITIES_PASS_H_
17
18
#include <algorithm>
19
#include <array>
20
#include <functional>
21
#include <optional>
22
#include <unordered_map>
23
#include <unordered_set>
24
25
#include "source/enum_set.h"
26
#include "source/extensions.h"
27
#include "source/opt/ir_context.h"
28
#include "source/opt/module.h"
29
#include "source/opt/pass.h"
30
#include "source/spirv_target_env.h"
31
#include "source/table2.h"
32
33
namespace spvtools {
34
namespace opt {
35
36
// This is required for NDK build. The unordered_set/unordered_map
37
// implementation don't work with class enums.
38
struct ClassEnumHash {
39
0
  std::size_t operator()(spv::Capability value) const {
40
0
    using StoringType = typename std::underlying_type_t<spv::Capability>;
41
0
    return std::hash<StoringType>{}(static_cast<StoringType>(value));
42
0
  }
43
44
0
  std::size_t operator()(spv::Op value) const {
45
0
    using StoringType = typename std::underlying_type_t<spv::Op>;
46
0
    return std::hash<StoringType>{}(static_cast<StoringType>(value));
47
0
  }
48
};
49
50
// An opcode handler is a function which, given an instruction, returns either
51
// the required capability, or nothing.
52
// Each handler checks one case for a capability requirement.
53
//
54
// Example:
55
//  - `OpTypeImage` can have operand `A` operand which requires capability 1
56
//  - `OpTypeImage` can also have operand `B` which requires capability 2.
57
//    -> We have 2 handlers: `Handler_OpTypeImage_1` and
58
//    `Handler_OpTypeImage_2`.
59
using OpcodeHandler =
60
    std::optional<spv::Capability> (*)(const Instruction* instruction);
61
62
// This pass tried to remove superfluous capabilities declared in the module.
63
// - If all the capabilities listed by an extension are removed, the extension
64
//   is also trimmed.
65
// - If the module countains any capability listed in `kForbiddenCapabilities`,
66
//   the module is left untouched.
67
// - No capabilities listed in `kUntouchableCapabilities` are trimmed, even when
68
//   not used.
69
// - Only capabilitied listed in `kSupportedCapabilities` are supported.
70
// - If the module contains unsupported capabilities, results might be
71
//   incorrect.
72
class TrimCapabilitiesPass : public Pass {
73
 private:
74
  // All the capabilities supported by this optimization pass. If your module
75
  // contains unsupported instruction, the pass could yield bad results.
76
  static constexpr std::array kSupportedCapabilities{
77
      // clang-format off
78
      spv::Capability::ComputeDerivativeGroupLinearKHR,
79
      spv::Capability::ComputeDerivativeGroupQuadsKHR,
80
      spv::Capability::Float16,
81
      spv::Capability::Float64,
82
      spv::Capability::FragmentShaderPixelInterlockEXT,
83
      spv::Capability::FragmentShaderSampleInterlockEXT,
84
      spv::Capability::FragmentShaderShadingRateInterlockEXT,
85
      spv::Capability::GroupNonUniform,
86
      spv::Capability::GroupNonUniformArithmetic,
87
      spv::Capability::GroupNonUniformClustered,
88
      spv::Capability::GroupNonUniformPartitionedNV,
89
      spv::Capability::GroupNonUniformVote,
90
      spv::Capability::Groups,
91
      spv::Capability::ImageMSArray,
92
      spv::Capability::Int16,
93
      spv::Capability::Int64,
94
      spv::Capability::InterpolationFunction,
95
      spv::Capability::Linkage,
96
      spv::Capability::MinLod,
97
      spv::Capability::PhysicalStorageBufferAddresses,
98
      spv::Capability::RayQueryKHR,
99
      spv::Capability::RayTracingKHR,
100
      spv::Capability::RayTraversalPrimitiveCullingKHR,
101
      spv::Capability::Shader,
102
      spv::Capability::ShaderClockKHR,
103
      spv::Capability::StorageBuffer16BitAccess,
104
      spv::Capability::StorageImageReadWithoutFormat,
105
      spv::Capability::StorageImageWriteWithoutFormat,
106
      spv::Capability::StorageInputOutput16,
107
      spv::Capability::StoragePushConstant16,
108
      spv::Capability::StorageUniform16,
109
      spv::Capability::StorageUniformBufferBlock16,
110
      spv::Capability::VulkanMemoryModelDeviceScope,
111
      spv::Capability::QuadControlKHR,
112
      // clang-format on
113
  };
114
115
  // Those capabilities disable all transformation of the module.
116
  static constexpr std::array kForbiddenCapabilities{
117
      spv::Capability::Linkage,
118
  };
119
120
  // Those capabilities are never removed from a module because we cannot
121
  // guess from the SPIR-V only if they are required or not.
122
  static constexpr std::array kUntouchableCapabilities{
123
      spv::Capability::Shader,
124
  };
125
126
 public:
127
  TrimCapabilitiesPass();
128
  TrimCapabilitiesPass(const TrimCapabilitiesPass&) = delete;
129
  TrimCapabilitiesPass(TrimCapabilitiesPass&&) = delete;
130
131
 private:
132
  // Inserts every capability listed by `descriptor` this pass supports into
133
  // `output`.
134
  template <typename Descriptor>
135
  void addSupportedCapabilitiesToSet(const Descriptor* const descriptor,
136
0
                                     CapabilitySet* output) const {
137
0
    for (auto capability : descriptor->capabilities()) {
138
0
      if (supportedCapabilities_.contains(capability)) {
139
0
        output->insert(capability);
140
0
      }
141
0
    }
142
0
  }
Unexecuted instantiation: void spvtools::opt::TrimCapabilitiesPass::addSupportedCapabilitiesToSet<spvtools::InstructionDesc>(spvtools::InstructionDesc const*, spvtools::EnumSet<spv::Capability>*) const
Unexecuted instantiation: void spvtools::opt::TrimCapabilitiesPass::addSupportedCapabilitiesToSet<spvtools::OperandDesc>(spvtools::OperandDesc const*, spvtools::EnumSet<spv::Capability>*) const
Unexecuted instantiation: void spvtools::opt::TrimCapabilitiesPass::addSupportedCapabilitiesToSet<spvtools::ExtInstDesc>(spvtools::ExtInstDesc const*, spvtools::EnumSet<spv::Capability>*) const
143
144
  // Inserts every extension listed by `descriptor` required by the module into
145
  // `output`. Expects a Descriptor like spvtools::OperandDesc or
146
  // spvtools::InstructionDesc.
147
  template <class Descriptor>
148
  inline void addSupportedExtensionsToSet(const Descriptor* const descriptor,
149
0
                                          ExtensionSet* output) const {
150
0
    if (descriptor->minVersion <=
151
0
        spvVersionForTargetEnv(context()->GetTargetEnv())) {
152
0
      return;
153
0
    }
154
0
    output->insert(descriptor->extensions().begin(),
155
0
                   descriptor->extensions().end());
156
0
  }
Unexecuted instantiation: void spvtools::opt::TrimCapabilitiesPass::addSupportedExtensionsToSet<spvtools::InstructionDesc>(spvtools::InstructionDesc const*, spvtools::EnumSet<spvtools::Extension>*) const
Unexecuted instantiation: void spvtools::opt::TrimCapabilitiesPass::addSupportedExtensionsToSet<spvtools::OperandDesc>(spvtools::OperandDesc const*, spvtools::EnumSet<spvtools::Extension>*) const
157
158
  void addInstructionRequirementsForOpcode(spv::Op opcode,
159
                                           CapabilitySet* capabilities,
160
                                           ExtensionSet* extensions) const;
161
  void addInstructionRequirementsForOperand(const Operand& operand,
162
                                            CapabilitySet* capabilities,
163
                                            ExtensionSet* extensions) const;
164
165
  void addInstructionRequirementsForExtInst(Instruction* instruction,
166
                                            CapabilitySet* capabilities) const;
167
168
  // Given an `instruction`, determines the capabilities it requires, and output
169
  // them in `capabilities`. The returned capabilities form a subset of
170
  // kSupportedCapabilities.
171
  void addInstructionRequirements(Instruction* instruction,
172
                                  CapabilitySet* capabilities,
173
                                  ExtensionSet* extensions) const;
174
175
  // Given an operand `type` and `value`, adds the extensions it would require
176
  // to `extensions`.
177
  void AddExtensionsForOperand(const spv_operand_type_t type,
178
                               const uint32_t value,
179
                               ExtensionSet* extensions) const;
180
181
  // Returns the list of required capabilities and extensions for the module.
182
  // The returned capabilities form a subset of kSupportedCapabilities.
183
  std::pair<CapabilitySet, ExtensionSet>
184
  DetermineRequiredCapabilitiesAndExtensions() const;
185
186
  // Trims capabilities not listed in `required_capabilities` if possible.
187
  // Returns whether or not the module was modified.
188
  Pass::Status TrimUnrequiredCapabilities(
189
      const CapabilitySet& required_capabilities) const;
190
191
  // Trims extensions not listed in `required_extensions` if supported by this
192
  // pass. An extensions is considered supported as soon as one capability this
193
  // pass support requires it.
194
  Pass::Status TrimUnrequiredExtensions(
195
      const ExtensionSet& required_extensions) const;
196
197
  // Returns if the analyzed module contains any forbidden capability.
198
  bool HasForbiddenCapabilities() const;
199
200
 public:
201
0
  const char* name() const override { return "trim-capabilities"; }
202
  Status Process() override;
203
204
 private:
205
  const CapabilitySet supportedCapabilities_;
206
  const CapabilitySet forbiddenCapabilities_;
207
  const CapabilitySet untouchableCapabilities_;
208
  const std::unordered_multimap<spv::Op, OpcodeHandler, ClassEnumHash>
209
      opcodeHandlers_;
210
};
211
212
}  // namespace opt
213
}  // namespace spvtools
214
#endif  // SOURCE_OPT_TRIM_CAPABILITIES_H_