/src/spirv-tools/source/opt/amd_ext_to_khr.cpp
Line | Count | Source |
1 | | // Copyright (c) 2019 Google LLC. |
2 | | // |
3 | | // Licensed under the Apache License, Version 2.0 (the "License"); |
4 | | // you may not use this file except in compliance with the License. |
5 | | // You may obtain a copy of the License at |
6 | | // |
7 | | // http://www.apache.org/licenses/LICENSE-2.0 |
8 | | // |
9 | | // Unless required by applicable law or agreed to in writing, software |
10 | | // distributed under the License is distributed on an "AS IS" BASIS, |
11 | | // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. |
12 | | // See the License for the specific language governing permissions and |
13 | | // limitations under the License. |
14 | | |
15 | | #include "source/opt/amd_ext_to_khr.h" |
16 | | |
17 | | #include <set> |
18 | | #include <string> |
19 | | |
20 | | #include "ir_builder.h" |
21 | | #include "source/opt/ir_context.h" |
22 | | #include "type_manager.h" |
23 | | |
24 | | namespace spvtools { |
25 | | namespace opt { |
26 | | namespace { |
27 | | |
28 | | enum AmdShaderBallotExtOpcodes { |
29 | | AmdShaderBallotSwizzleInvocationsAMD = 1, |
30 | | AmdShaderBallotSwizzleInvocationsMaskedAMD = 2, |
31 | | AmdShaderBallotWriteInvocationAMD = 3, |
32 | | AmdShaderBallotMbcntAMD = 4 |
33 | | }; |
34 | | |
35 | | enum AmdShaderTrinaryMinMaxExtOpCodes { |
36 | | FMin3AMD = 1, |
37 | | UMin3AMD = 2, |
38 | | SMin3AMD = 3, |
39 | | FMax3AMD = 4, |
40 | | UMax3AMD = 5, |
41 | | SMax3AMD = 6, |
42 | | FMid3AMD = 7, |
43 | | UMid3AMD = 8, |
44 | | SMid3AMD = 9 |
45 | | }; |
46 | | |
47 | | enum AmdGcnShader { CubeFaceCoordAMD = 2, CubeFaceIndexAMD = 1, TimeAMD = 3 }; |
48 | | |
49 | 0 | analysis::Type* GetUIntType(IRContext* ctx) { |
50 | 0 | analysis::Integer int_type(32, false); |
51 | 0 | return ctx->get_type_mgr()->GetRegisteredType(&int_type); |
52 | 0 | } |
53 | | |
54 | | // Returns a folding rule that replaces |op(a,b,c)| by |op(op(a,b),c)|, where |
55 | | // |op| is either min or max. |opcode| is the binary opcode in the GLSLstd450 |
56 | | // extended instruction set that corresponds to the trinary instruction being |
57 | | // replaced. |
58 | | template <GLSLstd450 opcode> |
59 | | bool ReplaceTrinaryMinMax(IRContext* ctx, Instruction* inst, |
60 | 0 | const std::vector<const analysis::Constant*>&) { |
61 | 0 | uint32_t glsl405_ext_inst_id = |
62 | 0 | ctx->get_feature_mgr()->GetExtInstImportId_GLSLstd450(); |
63 | 0 | if (glsl405_ext_inst_id == 0) { |
64 | 0 | ctx->AddExtInstImport("GLSL.std.450"); |
65 | 0 | glsl405_ext_inst_id = |
66 | 0 | ctx->get_feature_mgr()->GetExtInstImportId_GLSLstd450(); |
67 | 0 | } |
68 | |
|
69 | 0 | InstructionBuilder ir_builder( |
70 | 0 | ctx, inst, |
71 | 0 | IRContext::kAnalysisDefUse | IRContext::kAnalysisInstrToBlockMapping); |
72 | |
|
73 | 0 | uint32_t op1 = inst->GetSingleWordInOperand(2); |
74 | 0 | uint32_t op2 = inst->GetSingleWordInOperand(3); |
75 | 0 | uint32_t op3 = inst->GetSingleWordInOperand(4); |
76 | |
|
77 | 0 | Instruction* temp = ir_builder.AddNaryExtendedInstruction( |
78 | 0 | inst->type_id(), glsl405_ext_inst_id, opcode, {op1, op2}); |
79 | 0 | if (temp == nullptr) return false; |
80 | | |
81 | 0 | Instruction::OperandList new_operands; |
82 | 0 | new_operands.push_back({SPV_OPERAND_TYPE_ID, {glsl405_ext_inst_id}}); |
83 | 0 | new_operands.push_back({SPV_OPERAND_TYPE_EXTENSION_INSTRUCTION_NUMBER, |
84 | 0 | {static_cast<uint32_t>(opcode)}}); |
85 | 0 | new_operands.push_back({SPV_OPERAND_TYPE_ID, {temp->result_id()}}); |
86 | 0 | new_operands.push_back({SPV_OPERAND_TYPE_ID, {op3}}); |
87 | |
|
88 | 0 | inst->SetInOperands(std::move(new_operands)); |
89 | 0 | ctx->UpdateDefUse(inst); |
90 | 0 | return true; |
91 | 0 | } Unexecuted instantiation: amd_ext_to_khr.cpp:bool spvtools::opt::(anonymous namespace)::ReplaceTrinaryMinMax<(GLSLstd450)37>(spvtools::opt::IRContext*, spvtools::opt::Instruction*, std::__1::vector<spvtools::opt::analysis::Constant const*, std::__1::allocator<spvtools::opt::analysis::Constant const*> > const&) Unexecuted instantiation: amd_ext_to_khr.cpp:bool spvtools::opt::(anonymous namespace)::ReplaceTrinaryMinMax<(GLSLstd450)38>(spvtools::opt::IRContext*, spvtools::opt::Instruction*, std::__1::vector<spvtools::opt::analysis::Constant const*, std::__1::allocator<spvtools::opt::analysis::Constant const*> > const&) Unexecuted instantiation: amd_ext_to_khr.cpp:bool spvtools::opt::(anonymous namespace)::ReplaceTrinaryMinMax<(GLSLstd450)39>(spvtools::opt::IRContext*, spvtools::opt::Instruction*, std::__1::vector<spvtools::opt::analysis::Constant const*, std::__1::allocator<spvtools::opt::analysis::Constant const*> > const&) Unexecuted instantiation: amd_ext_to_khr.cpp:bool spvtools::opt::(anonymous namespace)::ReplaceTrinaryMinMax<(GLSLstd450)40>(spvtools::opt::IRContext*, spvtools::opt::Instruction*, std::__1::vector<spvtools::opt::analysis::Constant const*, std::__1::allocator<spvtools::opt::analysis::Constant const*> > const&) Unexecuted instantiation: amd_ext_to_khr.cpp:bool spvtools::opt::(anonymous namespace)::ReplaceTrinaryMinMax<(GLSLstd450)41>(spvtools::opt::IRContext*, spvtools::opt::Instruction*, std::__1::vector<spvtools::opt::analysis::Constant const*, std::__1::allocator<spvtools::opt::analysis::Constant const*> > const&) Unexecuted instantiation: amd_ext_to_khr.cpp:bool spvtools::opt::(anonymous namespace)::ReplaceTrinaryMinMax<(GLSLstd450)42>(spvtools::opt::IRContext*, spvtools::opt::Instruction*, std::__1::vector<spvtools::opt::analysis::Constant const*, std::__1::allocator<spvtools::opt::analysis::Constant const*> > const&) |
92 | | |
93 | | // Returns a folding rule that replaces |mid(a,b,c)| by |clamp(a, min(b,c), |
94 | | // max(b,c)|. The three parameters are the opcode that correspond to the min, |
95 | | // max, and clamp operations for the type of the instruction being replaced. |
96 | | template <GLSLstd450 min_opcode, GLSLstd450 max_opcode, GLSLstd450 clamp_opcode> |
97 | | bool ReplaceTrinaryMid(IRContext* ctx, Instruction* inst, |
98 | 0 | const std::vector<const analysis::Constant*>&) { |
99 | 0 | uint32_t glsl405_ext_inst_id = |
100 | 0 | ctx->get_feature_mgr()->GetExtInstImportId_GLSLstd450(); |
101 | 0 | if (glsl405_ext_inst_id == 0) { |
102 | 0 | ctx->AddExtInstImport("GLSL.std.450"); |
103 | 0 | glsl405_ext_inst_id = |
104 | 0 | ctx->get_feature_mgr()->GetExtInstImportId_GLSLstd450(); |
105 | 0 | } |
106 | |
|
107 | 0 | InstructionBuilder ir_builder( |
108 | 0 | ctx, inst, |
109 | 0 | IRContext::kAnalysisDefUse | IRContext::kAnalysisInstrToBlockMapping); |
110 | |
|
111 | 0 | uint32_t op1 = inst->GetSingleWordInOperand(2); |
112 | 0 | uint32_t op2 = inst->GetSingleWordInOperand(3); |
113 | 0 | uint32_t op3 = inst->GetSingleWordInOperand(4); |
114 | |
|
115 | 0 | Instruction* min = ir_builder.AddNaryExtendedInstruction( |
116 | 0 | inst->type_id(), glsl405_ext_inst_id, static_cast<uint32_t>(min_opcode), |
117 | 0 | {op2, op3}); |
118 | 0 | if (min == nullptr) return false; |
119 | | |
120 | 0 | Instruction* max = ir_builder.AddNaryExtendedInstruction( |
121 | 0 | inst->type_id(), glsl405_ext_inst_id, static_cast<uint32_t>(max_opcode), |
122 | 0 | {op2, op3}); |
123 | 0 | if (max == nullptr) return false; |
124 | | |
125 | 0 | Instruction::OperandList new_operands; |
126 | 0 | new_operands.push_back({SPV_OPERAND_TYPE_ID, {glsl405_ext_inst_id}}); |
127 | 0 | new_operands.push_back({SPV_OPERAND_TYPE_EXTENSION_INSTRUCTION_NUMBER, |
128 | 0 | {static_cast<uint32_t>(clamp_opcode)}}); |
129 | 0 | new_operands.push_back({SPV_OPERAND_TYPE_ID, {op1}}); |
130 | 0 | new_operands.push_back({SPV_OPERAND_TYPE_ID, {min->result_id()}}); |
131 | 0 | new_operands.push_back({SPV_OPERAND_TYPE_ID, {max->result_id()}}); |
132 | |
|
133 | 0 | inst->SetInOperands(std::move(new_operands)); |
134 | 0 | ctx->UpdateDefUse(inst); |
135 | 0 | return true; |
136 | 0 | } Unexecuted instantiation: amd_ext_to_khr.cpp:bool spvtools::opt::(anonymous namespace)::ReplaceTrinaryMid<(GLSLstd450)37, (GLSLstd450)40, (GLSLstd450)43>(spvtools::opt::IRContext*, spvtools::opt::Instruction*, std::__1::vector<spvtools::opt::analysis::Constant const*, std::__1::allocator<spvtools::opt::analysis::Constant const*> > const&) Unexecuted instantiation: amd_ext_to_khr.cpp:bool spvtools::opt::(anonymous namespace)::ReplaceTrinaryMid<(GLSLstd450)38, (GLSLstd450)41, (GLSLstd450)44>(spvtools::opt::IRContext*, spvtools::opt::Instruction*, std::__1::vector<spvtools::opt::analysis::Constant const*, std::__1::allocator<spvtools::opt::analysis::Constant const*> > const&) Unexecuted instantiation: amd_ext_to_khr.cpp:bool spvtools::opt::(anonymous namespace)::ReplaceTrinaryMid<(GLSLstd450)39, (GLSLstd450)42, (GLSLstd450)45>(spvtools::opt::IRContext*, spvtools::opt::Instruction*, std::__1::vector<spvtools::opt::analysis::Constant const*, std::__1::allocator<spvtools::opt::analysis::Constant const*> > const&) |
137 | | |
138 | | // Returns a folding rule that will replace the opcode with |opcode| and add |
139 | | // the capabilities required. The folding rule assumes it is folding an |
140 | | // OpGroup*NonUniformAMD instruction from the SPV_AMD_shader_ballot extension. |
141 | | template <spv::Op new_opcode> |
142 | | bool ReplaceGroupNonuniformOperationOpCode( |
143 | | IRContext* ctx, Instruction* inst, |
144 | 0 | const std::vector<const analysis::Constant*>&) { |
145 | 0 | switch (new_opcode) { |
146 | 0 | case spv::Op::OpGroupNonUniformIAdd: |
147 | 0 | case spv::Op::OpGroupNonUniformFAdd: |
148 | 0 | case spv::Op::OpGroupNonUniformUMin: |
149 | 0 | case spv::Op::OpGroupNonUniformSMin: |
150 | 0 | case spv::Op::OpGroupNonUniformFMin: |
151 | 0 | case spv::Op::OpGroupNonUniformUMax: |
152 | 0 | case spv::Op::OpGroupNonUniformSMax: |
153 | 0 | case spv::Op::OpGroupNonUniformFMax: |
154 | 0 | break; |
155 | 0 | default: |
156 | 0 | assert( |
157 | 0 | false && |
158 | 0 | "Should be replacing with a group non uniform arithmetic operation."); |
159 | 0 | } |
160 | | |
161 | 0 | switch (inst->opcode()) { |
162 | 0 | case spv::Op::OpGroupIAddNonUniformAMD: |
163 | 0 | case spv::Op::OpGroupFAddNonUniformAMD: |
164 | 0 | case spv::Op::OpGroupUMinNonUniformAMD: |
165 | 0 | case spv::Op::OpGroupSMinNonUniformAMD: |
166 | 0 | case spv::Op::OpGroupFMinNonUniformAMD: |
167 | 0 | case spv::Op::OpGroupUMaxNonUniformAMD: |
168 | 0 | case spv::Op::OpGroupSMaxNonUniformAMD: |
169 | 0 | case spv::Op::OpGroupFMaxNonUniformAMD: |
170 | 0 | break; |
171 | 0 | default: |
172 | 0 | assert(false && |
173 | 0 | "Should be replacing a group non uniform arithmetic operation."); |
174 | 0 | } |
175 | | |
176 | 0 | ctx->AddCapability(spv::Capability::GroupNonUniformArithmetic); |
177 | 0 | inst->SetOpcode(new_opcode); |
178 | 0 | return true; |
179 | 0 | } Unexecuted instantiation: amd_ext_to_khr.cpp:bool spvtools::opt::(anonymous namespace)::ReplaceGroupNonuniformOperationOpCode<(spv::Op)349>(spvtools::opt::IRContext*, spvtools::opt::Instruction*, std::__1::vector<spvtools::opt::analysis::Constant const*, std::__1::allocator<spvtools::opt::analysis::Constant const*> > const&) Unexecuted instantiation: amd_ext_to_khr.cpp:bool spvtools::opt::(anonymous namespace)::ReplaceGroupNonuniformOperationOpCode<(spv::Op)350>(spvtools::opt::IRContext*, spvtools::opt::Instruction*, std::__1::vector<spvtools::opt::analysis::Constant const*, std::__1::allocator<spvtools::opt::analysis::Constant const*> > const&) Unexecuted instantiation: amd_ext_to_khr.cpp:bool spvtools::opt::(anonymous namespace)::ReplaceGroupNonuniformOperationOpCode<(spv::Op)354>(spvtools::opt::IRContext*, spvtools::opt::Instruction*, std::__1::vector<spvtools::opt::analysis::Constant const*, std::__1::allocator<spvtools::opt::analysis::Constant const*> > const&) Unexecuted instantiation: amd_ext_to_khr.cpp:bool spvtools::opt::(anonymous namespace)::ReplaceGroupNonuniformOperationOpCode<(spv::Op)353>(spvtools::opt::IRContext*, spvtools::opt::Instruction*, std::__1::vector<spvtools::opt::analysis::Constant const*, std::__1::allocator<spvtools::opt::analysis::Constant const*> > const&) Unexecuted instantiation: amd_ext_to_khr.cpp:bool spvtools::opt::(anonymous namespace)::ReplaceGroupNonuniformOperationOpCode<(spv::Op)355>(spvtools::opt::IRContext*, spvtools::opt::Instruction*, std::__1::vector<spvtools::opt::analysis::Constant const*, std::__1::allocator<spvtools::opt::analysis::Constant const*> > const&) Unexecuted instantiation: amd_ext_to_khr.cpp:bool spvtools::opt::(anonymous namespace)::ReplaceGroupNonuniformOperationOpCode<(spv::Op)357>(spvtools::opt::IRContext*, spvtools::opt::Instruction*, std::__1::vector<spvtools::opt::analysis::Constant const*, std::__1::allocator<spvtools::opt::analysis::Constant const*> > const&) Unexecuted instantiation: amd_ext_to_khr.cpp:bool spvtools::opt::(anonymous namespace)::ReplaceGroupNonuniformOperationOpCode<(spv::Op)356>(spvtools::opt::IRContext*, spvtools::opt::Instruction*, std::__1::vector<spvtools::opt::analysis::Constant const*, std::__1::allocator<spvtools::opt::analysis::Constant const*> > const&) Unexecuted instantiation: amd_ext_to_khr.cpp:bool spvtools::opt::(anonymous namespace)::ReplaceGroupNonuniformOperationOpCode<(spv::Op)358>(spvtools::opt::IRContext*, spvtools::opt::Instruction*, std::__1::vector<spvtools::opt::analysis::Constant const*, std::__1::allocator<spvtools::opt::analysis::Constant const*> > const&) |
180 | | |
181 | | // Returns a folding rule that will replace the SwizzleInvocationsAMD extended |
182 | | // instruction in the SPV_AMD_shader_ballot extension. |
183 | | // |
184 | | // The instruction |
185 | | // |
186 | | // %offset = OpConstantComposite %v3uint %x %y %z %w |
187 | | // %result = OpExtInst %type %1 SwizzleInvocationsAMD %data %offset |
188 | | // |
189 | | // is replaced with |
190 | | // |
191 | | // potentially new constants and types |
192 | | // |
193 | | // clang-format off |
194 | | // %uint_max = OpConstant %uint 0xFFFFFFFF |
195 | | // %v4uint = OpTypeVector %uint 4 |
196 | | // %ballot_value = OpConstantComposite %v4uint %uint_max %uint_max %uint_max %uint_max |
197 | | // %null = OpConstantNull %type |
198 | | // clang-format on |
199 | | // |
200 | | // and the following code in the function body |
201 | | // |
202 | | // clang-format off |
203 | | // %id = OpLoad %uint %SubgroupLocalInvocationId |
204 | | // %quad_idx = OpBitwiseAnd %uint %id %uint_3 |
205 | | // %quad_ldr = OpBitwiseXor %uint %id %quad_idx |
206 | | // %my_offset = OpVectorExtractDynamic %uint %offset %quad_idx |
207 | | // %target_inv = OpIAdd %uint %quad_ldr %my_offset |
208 | | // %is_active = OpGroupNonUniformBallotBitExtract %bool %uint_3 %ballot_value %target_inv |
209 | | // %shuffle = OpGroupNonUniformShuffle %type %uint_3 %data %target_inv |
210 | | // %result = OpSelect %type %is_active %shuffle %null |
211 | | // clang-format on |
212 | | // |
213 | | // Also adding the capabilities and builtins that are needed. |
214 | | bool ReplaceSwizzleInvocations(IRContext* ctx, Instruction* inst, |
215 | 0 | const std::vector<const analysis::Constant*>&) { |
216 | 0 | analysis::TypeManager* type_mgr = ctx->get_type_mgr(); |
217 | 0 | analysis::ConstantManager* const_mgr = ctx->get_constant_mgr(); |
218 | |
|
219 | 0 | ctx->AddExtension("SPV_KHR_shader_ballot"); |
220 | 0 | ctx->AddCapability(spv::Capability::GroupNonUniformBallot); |
221 | 0 | ctx->AddCapability(spv::Capability::GroupNonUniformShuffle); |
222 | |
|
223 | 0 | InstructionBuilder ir_builder( |
224 | 0 | ctx, inst, |
225 | 0 | IRContext::kAnalysisDefUse | IRContext::kAnalysisInstrToBlockMapping); |
226 | |
|
227 | 0 | uint32_t data_id = inst->GetSingleWordInOperand(2); |
228 | 0 | uint32_t offset_id = inst->GetSingleWordInOperand(3); |
229 | | |
230 | | // Get the subgroup invocation id. |
231 | 0 | uint32_t var_id = ctx->GetBuiltinInputVarId( |
232 | 0 | uint32_t(spv::BuiltIn::SubgroupLocalInvocationId)); |
233 | 0 | if (var_id == 0) return false; |
234 | 0 | Instruction* var_inst = ctx->get_def_use_mgr()->GetDef(var_id); |
235 | 0 | if (var_inst == nullptr) return false; |
236 | 0 | Instruction* var_ptr_type = |
237 | 0 | ctx->get_def_use_mgr()->GetDef(var_inst->type_id()); |
238 | 0 | if (var_ptr_type == nullptr) return false; |
239 | 0 | uint32_t uint_type_id = var_ptr_type->GetSingleWordInOperand(1); |
240 | 0 | if (uint_type_id == 0) return false; |
241 | | |
242 | 0 | Instruction* id = ir_builder.AddLoad(uint_type_id, var_id); |
243 | 0 | if (id == nullptr) return false; |
244 | | |
245 | 0 | uint32_t quad_mask = ir_builder.GetUintConstantId(3); |
246 | 0 | if (quad_mask == 0) return false; |
247 | | |
248 | | // This gives the offset in the group of 4 of this invocation. |
249 | 0 | Instruction* quad_idx = ir_builder.AddBinaryOp( |
250 | 0 | uint_type_id, spv::Op::OpBitwiseAnd, id->result_id(), quad_mask); |
251 | 0 | if (quad_idx == nullptr) return false; |
252 | | |
253 | | // Get the invocation id of the first invocation in the group of 4. |
254 | 0 | Instruction* quad_ldr = |
255 | 0 | ir_builder.AddBinaryOp(uint_type_id, spv::Op::OpBitwiseXor, |
256 | 0 | id->result_id(), quad_idx->result_id()); |
257 | 0 | if (quad_ldr == nullptr) return false; |
258 | | |
259 | | // Get the offset of the target invocation from the offset vector. |
260 | 0 | Instruction* my_offset = |
261 | 0 | ir_builder.AddBinaryOp(uint_type_id, spv::Op::OpVectorExtractDynamic, |
262 | 0 | offset_id, quad_idx->result_id()); |
263 | 0 | if (my_offset == nullptr) return false; |
264 | | |
265 | | // Determine the index of the invocation to read from. |
266 | 0 | Instruction* target_inv = |
267 | 0 | ir_builder.AddBinaryOp(uint_type_id, spv::Op::OpIAdd, |
268 | 0 | quad_ldr->result_id(), my_offset->result_id()); |
269 | 0 | if (target_inv == nullptr) return false; |
270 | | |
271 | | // Do the group operations |
272 | 0 | uint32_t uint_max_id = ir_builder.GetUintConstantId(0xFFFFFFFF); |
273 | 0 | if (uint_max_id == 0) return false; |
274 | 0 | uint32_t subgroup_scope = |
275 | 0 | ir_builder.GetUintConstantId(uint32_t(spv::Scope::Subgroup)); |
276 | 0 | if (subgroup_scope == 0) return false; |
277 | 0 | const auto* vec_type = type_mgr->GetUIntVectorType(4); |
278 | 0 | if (vec_type == nullptr) return false; |
279 | 0 | const auto* ballot_value_const = const_mgr->GetConstant( |
280 | 0 | vec_type, {uint_max_id, uint_max_id, uint_max_id, uint_max_id}); |
281 | 0 | if (ballot_value_const == nullptr) return false; |
282 | 0 | Instruction* ballot_value = |
283 | 0 | const_mgr->GetDefiningInstruction(ballot_value_const); |
284 | 0 | if (ballot_value == nullptr) return false; |
285 | 0 | uint32_t bool_type_id = type_mgr->GetBoolTypeId(); |
286 | 0 | if (bool_type_id == 0) return false; |
287 | 0 | Instruction* is_active = ir_builder.AddNaryOp( |
288 | 0 | bool_type_id, spv::Op::OpGroupNonUniformBallotBitExtract, |
289 | 0 | {subgroup_scope, ballot_value->result_id(), target_inv->result_id()}); |
290 | 0 | if (is_active == nullptr) return false; |
291 | 0 | Instruction* shuffle = |
292 | 0 | ir_builder.AddNaryOp(inst->type_id(), spv::Op::OpGroupNonUniformShuffle, |
293 | 0 | {subgroup_scope, data_id, target_inv->result_id()}); |
294 | 0 | if (shuffle == nullptr) return false; |
295 | | |
296 | | // Create the null constant to use in the select. |
297 | 0 | const auto* result_type = type_mgr->GetType(inst->type_id()); |
298 | 0 | if (result_type == nullptr) return false; |
299 | 0 | const auto* null = |
300 | 0 | const_mgr->GetConstant(result_type, std::vector<uint32_t>()); |
301 | 0 | if (null == nullptr) { |
302 | 0 | return false; |
303 | 0 | } |
304 | 0 | Instruction* null_inst = const_mgr->GetDefiningInstruction(null); |
305 | 0 | if (null_inst == nullptr) { |
306 | 0 | return false; |
307 | 0 | } |
308 | | |
309 | | // Build the select. |
310 | 0 | inst->SetOpcode(spv::Op::OpSelect); |
311 | 0 | Instruction::OperandList new_operands; |
312 | 0 | new_operands.push_back({SPV_OPERAND_TYPE_ID, {is_active->result_id()}}); |
313 | 0 | new_operands.push_back({SPV_OPERAND_TYPE_ID, {shuffle->result_id()}}); |
314 | 0 | new_operands.push_back({SPV_OPERAND_TYPE_ID, {null_inst->result_id()}}); |
315 | |
|
316 | 0 | inst->SetInOperands(std::move(new_operands)); |
317 | 0 | ctx->UpdateDefUse(inst); |
318 | 0 | return true; |
319 | 0 | } |
320 | | |
321 | | // Returns a folding rule that will replace the SwizzleInvocationsMaskedAMD |
322 | | // extended instruction in the SPV_AMD_shader_ballot extension. |
323 | | // |
324 | | // The instruction |
325 | | // |
326 | | // %mask = OpConstantComposite %v3uint %uint_x %uint_y %uint_z |
327 | | // %result = OpExtInst %uint %1 SwizzleInvocationsMaskedAMD %data %mask |
328 | | // |
329 | | // is replaced with |
330 | | // |
331 | | // potentially new constants and types |
332 | | // |
333 | | // clang-format off |
334 | | // %uint_mask_extend = OpConstant %uint 0xFFFFFFE0 |
335 | | // %uint_max = OpConstant %uint 0xFFFFFFFF |
336 | | // %v4uint = OpTypeVector %uint 4 |
337 | | // %ballot_value = OpConstantComposite %v4uint %uint_max %uint_max %uint_max %uint_max |
338 | | // clang-format on |
339 | | // |
340 | | // and the following code in the function body |
341 | | // |
342 | | // clang-format off |
343 | | // %id = OpLoad %uint %SubgroupLocalInvocationId |
344 | | // %and_mask = OpBitwiseOr %uint %uint_x %uint_mask_extend |
345 | | // %and = OpBitwiseAnd %uint %id %and_mask |
346 | | // %or = OpBitwiseOr %uint %and %uint_y |
347 | | // %target_inv = OpBitwiseXor %uint %or %uint_z |
348 | | // %is_active = OpGroupNonUniformBallotBitExtract %bool %uint_3 %ballot_value %target_inv |
349 | | // %shuffle = OpGroupNonUniformShuffle %type %uint_3 %data %target_inv |
350 | | // %result = OpSelect %type %is_active %shuffle %uint_0 |
351 | | // clang-format on |
352 | | // |
353 | | // Also adding the capabilities and builtins that are needed. |
354 | | bool ReplaceSwizzleInvocationsMasked( |
355 | | IRContext* ctx, Instruction* inst, |
356 | 0 | const std::vector<const analysis::Constant*>&) { |
357 | 0 | analysis::TypeManager* type_mgr = ctx->get_type_mgr(); |
358 | 0 | analysis::DefUseManager* def_use_mgr = ctx->get_def_use_mgr(); |
359 | 0 | analysis::ConstantManager* const_mgr = ctx->get_constant_mgr(); |
360 | |
|
361 | 0 | ctx->AddCapability(spv::Capability::GroupNonUniformBallot); |
362 | 0 | ctx->AddCapability(spv::Capability::GroupNonUniformShuffle); |
363 | |
|
364 | 0 | InstructionBuilder ir_builder( |
365 | 0 | ctx, inst, |
366 | 0 | IRContext::kAnalysisDefUse | IRContext::kAnalysisInstrToBlockMapping); |
367 | | |
368 | | // Get the operands to inst, and the components of the mask |
369 | 0 | uint32_t data_id = inst->GetSingleWordInOperand(2); |
370 | |
|
371 | 0 | Instruction* mask_inst = def_use_mgr->GetDef(inst->GetSingleWordInOperand(3)); |
372 | 0 | if (mask_inst == nullptr) return false; |
373 | 0 | assert(mask_inst->opcode() == spv::Op::OpConstantComposite && |
374 | 0 | "The mask is suppose to be a vector constant."); |
375 | 0 | assert(mask_inst->NumInOperands() == 3 && |
376 | 0 | "The mask is suppose to have 3 components."); |
377 | | |
378 | 0 | uint32_t uint_x = mask_inst->GetSingleWordInOperand(0); |
379 | 0 | if (uint_x == 0) return false; |
380 | 0 | uint32_t uint_y = mask_inst->GetSingleWordInOperand(1); |
381 | 0 | if (uint_y == 0) return false; |
382 | 0 | uint32_t uint_z = mask_inst->GetSingleWordInOperand(2); |
383 | 0 | if (uint_z == 0) return false; |
384 | | |
385 | | // Get the subgroup invocation id. |
386 | 0 | uint32_t var_id = ctx->GetBuiltinInputVarId( |
387 | 0 | uint32_t(spv::BuiltIn::SubgroupLocalInvocationId)); |
388 | 0 | if (var_id == 0) return false; |
389 | 0 | ctx->AddExtension("SPV_KHR_shader_ballot"); |
390 | 0 | Instruction* var_inst = ctx->get_def_use_mgr()->GetDef(var_id); |
391 | 0 | if (var_inst == nullptr) return false; |
392 | 0 | Instruction* var_ptr_type = |
393 | 0 | ctx->get_def_use_mgr()->GetDef(var_inst->type_id()); |
394 | 0 | if (var_ptr_type == nullptr) return false; |
395 | 0 | uint32_t uint_type_id = var_ptr_type->GetSingleWordInOperand(1); |
396 | 0 | if (uint_type_id == 0) return false; |
397 | | |
398 | 0 | Instruction* id = ir_builder.AddLoad(uint_type_id, var_id); |
399 | 0 | if (id == nullptr) return false; |
400 | | |
401 | | // Do the bitwise operations. |
402 | 0 | uint32_t mask_extended = ir_builder.GetUintConstantId(0xFFFFFFE0); |
403 | 0 | if (mask_extended == 0) return false; |
404 | 0 | Instruction* and_mask = ir_builder.AddBinaryOp( |
405 | 0 | uint_type_id, spv::Op::OpBitwiseOr, uint_x, mask_extended); |
406 | 0 | if (and_mask == nullptr) return false; |
407 | 0 | Instruction* and_result = |
408 | 0 | ir_builder.AddBinaryOp(uint_type_id, spv::Op::OpBitwiseAnd, |
409 | 0 | id->result_id(), and_mask->result_id()); |
410 | 0 | if (and_result == nullptr) return false; |
411 | 0 | Instruction* or_result = ir_builder.AddBinaryOp( |
412 | 0 | uint_type_id, spv::Op::OpBitwiseOr, and_result->result_id(), uint_y); |
413 | 0 | if (or_result == nullptr) return false; |
414 | 0 | Instruction* target_inv = ir_builder.AddBinaryOp( |
415 | 0 | uint_type_id, spv::Op::OpBitwiseXor, or_result->result_id(), uint_z); |
416 | 0 | if (target_inv == nullptr) return false; |
417 | | |
418 | | // Do the group operations |
419 | 0 | uint32_t uint_max_id = ir_builder.GetUintConstantId(0xFFFFFFFF); |
420 | 0 | if (uint_max_id == 0) return false; |
421 | 0 | uint32_t subgroup_scope = |
422 | 0 | ir_builder.GetUintConstantId(uint32_t(spv::Scope::Subgroup)); |
423 | 0 | if (subgroup_scope == 0) return false; |
424 | 0 | const auto* vec_type = type_mgr->GetUIntVectorType(4); |
425 | 0 | if (vec_type == nullptr) return false; |
426 | 0 | const auto* ballot_value_const = const_mgr->GetConstant( |
427 | 0 | vec_type, {uint_max_id, uint_max_id, uint_max_id, uint_max_id}); |
428 | 0 | if (ballot_value_const == nullptr) return false; |
429 | 0 | Instruction* ballot_value = |
430 | 0 | const_mgr->GetDefiningInstruction(ballot_value_const); |
431 | 0 | if (ballot_value == nullptr) return false; |
432 | 0 | uint32_t bool_type_id = type_mgr->GetBoolTypeId(); |
433 | 0 | if (bool_type_id == 0) return false; |
434 | 0 | Instruction* is_active = ir_builder.AddNaryOp( |
435 | 0 | bool_type_id, spv::Op::OpGroupNonUniformBallotBitExtract, |
436 | 0 | {subgroup_scope, ballot_value->result_id(), target_inv->result_id()}); |
437 | 0 | if (is_active == nullptr) return false; |
438 | 0 | Instruction* shuffle = |
439 | 0 | ir_builder.AddNaryOp(inst->type_id(), spv::Op::OpGroupNonUniformShuffle, |
440 | 0 | {subgroup_scope, data_id, target_inv->result_id()}); |
441 | 0 | if (shuffle == nullptr) return false; |
442 | | |
443 | | // Create the null constant to use in the select. |
444 | 0 | const auto* result_type = type_mgr->GetType(inst->type_id()); |
445 | 0 | if (result_type == nullptr) return false; |
446 | 0 | const auto* null = |
447 | 0 | const_mgr->GetConstant(result_type, std::vector<uint32_t>()); |
448 | 0 | if (null == nullptr) return false; |
449 | 0 | Instruction* null_inst = const_mgr->GetDefiningInstruction(null); |
450 | 0 | if (null_inst == nullptr) return false; |
451 | | |
452 | | // Build the select. |
453 | 0 | inst->SetOpcode(spv::Op::OpSelect); |
454 | 0 | Instruction::OperandList new_operands; |
455 | 0 | new_operands.push_back({SPV_OPERAND_TYPE_ID, {is_active->result_id()}}); |
456 | 0 | new_operands.push_back({SPV_OPERAND_TYPE_ID, {shuffle->result_id()}}); |
457 | 0 | new_operands.push_back({SPV_OPERAND_TYPE_ID, {null_inst->result_id()}}); |
458 | |
|
459 | 0 | inst->SetInOperands(std::move(new_operands)); |
460 | 0 | ctx->UpdateDefUse(inst); |
461 | 0 | return true; |
462 | 0 | } |
463 | | |
464 | | // Returns a folding rule that will replace the WriteInvocationAMD extended |
465 | | // instruction in the SPV_AMD_shader_ballot extension. |
466 | | // |
467 | | // The instruction |
468 | | // |
469 | | // clang-format off |
470 | | // %result = OpExtInst %type %1 WriteInvocationAMD %input_value %write_value %invocation_index |
471 | | // clang-format on |
472 | | // |
473 | | // with |
474 | | // |
475 | | // %id = OpLoad %uint %SubgroupLocalInvocationId |
476 | | // %cmp = OpIEqual %bool %id %invocation_index |
477 | | // %result = OpSelect %type %cmp %write_value %input_value |
478 | | // |
479 | | // Also adding the capabilities and builtins that are needed. |
480 | | bool ReplaceWriteInvocation(IRContext* ctx, Instruction* inst, |
481 | 0 | const std::vector<const analysis::Constant*>&) { |
482 | 0 | uint32_t var_id = ctx->GetBuiltinInputVarId( |
483 | 0 | uint32_t(spv::BuiltIn::SubgroupLocalInvocationId)); |
484 | 0 | if (var_id == 0) return false; |
485 | 0 | Instruction* var_inst = ctx->get_def_use_mgr()->GetDef(var_id); |
486 | 0 | if (var_inst == nullptr) return false; |
487 | 0 | Instruction* var_ptr_type = |
488 | 0 | ctx->get_def_use_mgr()->GetDef(var_inst->type_id()); |
489 | 0 | if (var_ptr_type == nullptr) return false; |
490 | 0 | ctx->AddCapability(spv::Capability::SubgroupBallotKHR); |
491 | 0 | ctx->AddExtension("SPV_KHR_shader_ballot"); |
492 | |
|
493 | 0 | InstructionBuilder ir_builder( |
494 | 0 | ctx, inst, |
495 | 0 | IRContext::kAnalysisDefUse | IRContext::kAnalysisInstrToBlockMapping); |
496 | 0 | Instruction* t = |
497 | 0 | ir_builder.AddLoad(var_ptr_type->GetSingleWordInOperand(1), var_id); |
498 | 0 | if (t == nullptr) return false; |
499 | 0 | analysis::Bool bool_type; |
500 | 0 | uint32_t bool_type_id = ctx->get_type_mgr()->GetTypeInstruction(&bool_type); |
501 | 0 | if (bool_type_id == 0) return false; |
502 | 0 | Instruction* cmp = |
503 | 0 | ir_builder.AddBinaryOp(bool_type_id, spv::Op::OpIEqual, t->result_id(), |
504 | 0 | inst->GetSingleWordInOperand(4)); |
505 | 0 | if (cmp == nullptr) return false; |
506 | | |
507 | | // Build a select. |
508 | 0 | inst->SetOpcode(spv::Op::OpSelect); |
509 | 0 | Instruction::OperandList new_operands; |
510 | 0 | new_operands.push_back({SPV_OPERAND_TYPE_ID, {cmp->result_id()}}); |
511 | 0 | new_operands.push_back(inst->GetInOperand(3)); |
512 | 0 | new_operands.push_back(inst->GetInOperand(2)); |
513 | |
|
514 | 0 | inst->SetInOperands(std::move(new_operands)); |
515 | 0 | ctx->UpdateDefUse(inst); |
516 | 0 | return true; |
517 | 0 | } |
518 | | |
519 | | // Returns a folding rule that will replace the MbcntAMD extended instruction in |
520 | | // the SPV_AMD_shader_ballot extension. |
521 | | // |
522 | | // The instruction |
523 | | // |
524 | | // %result = OpExtInst %uint %1 MbcntAMD %mask |
525 | | // |
526 | | // with |
527 | | // |
528 | | // Get SubgroupLtMask and convert the first 64-bits into a uint64_t because |
529 | | // AMD's shader compiler expects a 64-bit integer mask. |
530 | | // |
531 | | // %var = OpLoad %v4uint %SubgroupLtMaskKHR |
532 | | // %shuffle = OpVectorShuffle %v2uint %var %var 0 1 |
533 | | // %cast = OpBitcast %ulong %shuffle |
534 | | // |
535 | | // Perform the mask and count the bits. |
536 | | // |
537 | | // %and = OpBitwiseAnd %ulong %cast %mask |
538 | | // %result = OpBitCount %uint %and |
539 | | // |
540 | | // Also adding the capabilities and builtins that are needed. |
541 | | bool ReplaceMbcnt(IRContext* context, Instruction* inst, |
542 | 0 | const std::vector<const analysis::Constant*>&) { |
543 | 0 | analysis::TypeManager* type_mgr = context->get_type_mgr(); |
544 | 0 | analysis::DefUseManager* def_use_mgr = context->get_def_use_mgr(); |
545 | |
|
546 | 0 | uint32_t var_id = |
547 | 0 | context->GetBuiltinInputVarId(uint32_t(spv::BuiltIn::SubgroupLtMask)); |
548 | 0 | if (var_id == 0) return false; |
549 | | |
550 | 0 | context->AddCapability(spv::Capability::GroupNonUniformBallot); |
551 | 0 | Instruction* var_inst = def_use_mgr->GetDef(var_id); |
552 | 0 | Instruction* var_ptr_type = def_use_mgr->GetDef(var_inst->type_id()); |
553 | 0 | Instruction* var_type = |
554 | 0 | def_use_mgr->GetDef(var_ptr_type->GetSingleWordInOperand(1)); |
555 | 0 | assert(var_type->opcode() == spv::Op::OpTypeVector && |
556 | 0 | "Variable is suppose to be a vector of 4 ints"); |
557 | | |
558 | | // Get the type for the shuffle. |
559 | 0 | analysis::Vector temp_type(GetUIntType(context), 2); |
560 | 0 | const analysis::Type* shuffle_type = |
561 | 0 | context->get_type_mgr()->GetRegisteredType(&temp_type); |
562 | 0 | if (shuffle_type == nullptr) return false; |
563 | 0 | uint32_t shuffle_type_id = type_mgr->GetTypeInstruction(shuffle_type); |
564 | |
|
565 | 0 | uint32_t mask_id = inst->GetSingleWordInOperand(2); |
566 | 0 | Instruction* mask_inst = def_use_mgr->GetDef(mask_id); |
567 | | |
568 | | // Testing with amd's shader compiler shows that a 64-bit mask is expected. |
569 | 0 | assert(type_mgr->GetType(mask_inst->type_id())->AsInteger() != nullptr); |
570 | 0 | assert(type_mgr->GetType(mask_inst->type_id())->AsInteger()->width() == 64); |
571 | | |
572 | 0 | InstructionBuilder ir_builder( |
573 | 0 | context, inst, |
574 | 0 | IRContext::kAnalysisDefUse | IRContext::kAnalysisInstrToBlockMapping); |
575 | 0 | Instruction* load = ir_builder.AddLoad(var_type->result_id(), var_id); |
576 | 0 | if (load == nullptr) return false; |
577 | 0 | Instruction* shuffle = ir_builder.AddVectorShuffle( |
578 | 0 | shuffle_type_id, load->result_id(), load->result_id(), {0, 1}); |
579 | 0 | if (shuffle == nullptr) return false; |
580 | 0 | Instruction* bitcast = ir_builder.AddUnaryOp( |
581 | 0 | mask_inst->type_id(), spv::Op::OpBitcast, shuffle->result_id()); |
582 | 0 | if (bitcast == nullptr) return false; |
583 | 0 | Instruction* t = |
584 | 0 | ir_builder.AddBinaryOp(mask_inst->type_id(), spv::Op::OpBitwiseAnd, |
585 | 0 | bitcast->result_id(), mask_id); |
586 | 0 | if (t == nullptr) return false; |
587 | | |
588 | 0 | inst->SetOpcode(spv::Op::OpBitCount); |
589 | 0 | inst->SetInOperands({{SPV_OPERAND_TYPE_ID, {t->result_id()}}}); |
590 | 0 | context->UpdateDefUse(inst); |
591 | 0 | return true; |
592 | 0 | } |
593 | | |
594 | | // A folding rule that will replace the CubeFaceCoordAMD extended |
595 | | // instruction in the SPV_AMD_gcn_shader_ballot. Returns true if the folding is |
596 | | // successful. |
597 | | // |
598 | | // The instruction |
599 | | // |
600 | | // %result = OpExtInst %v2float %1 CubeFaceCoordAMD %input |
601 | | // |
602 | | // with |
603 | | // |
604 | | // %x = OpCompositeExtract %float %input 0 |
605 | | // %y = OpCompositeExtract %float %input 1 |
606 | | // %z = OpCompositeExtract %float %input 2 |
607 | | // %nx = OpFNegate %float %x |
608 | | // %ny = OpFNegate %float %y |
609 | | // %nz = OpFNegate %float %z |
610 | | // %ax = OpExtInst %float %n_1 FAbs %x |
611 | | // %ay = OpExtInst %float %n_1 FAbs %y |
612 | | // %az = OpExtInst %float %n_1 FAbs %z |
613 | | // %amax_x_y = OpExtInst %float %n_1 FMax %ay %ax |
614 | | // %amax = OpExtInst %float %n_1 FMax %az %amax_x_y |
615 | | // %cubema = OpFMul %float %float_2 %amax |
616 | | // %is_z_max = OpFOrdGreaterThanEqual %bool %az %amax_x_y |
617 | | // %not_is_z_max = OpLogicalNot %bool %is_z_max |
618 | | // %y_gt_x = OpFOrdGreaterThanEqual %bool %ay %ax |
619 | | // %is_y_max = OpLogicalAnd %bool %not_is_z_max %y_gt_x |
620 | | // %is_z_neg = OpFOrdLessThan %bool %z %float_0 |
621 | | // %cubesc_case_1 = OpSelect %float %is_z_neg %nx %x |
622 | | // %is_x_neg = OpFOrdLessThan %bool %x %float_0 |
623 | | // %cubesc_case_2 = OpSelect %float %is_x_neg %z %nz |
624 | | // %sel = OpSelect %float %is_y_max %x %cubesc_case_2 |
625 | | // %cubesc = OpSelect %float %is_z_max %cubesc_case_1 %sel |
626 | | // %is_y_neg = OpFOrdLessThan %bool %y %float_0 |
627 | | // %cubetc_case_1 = OpSelect %float %is_y_neg %nz %z |
628 | | // %cubetc = OpSelect %float %is_y_max %cubetc_case_1 %ny |
629 | | // %cube = OpCompositeConstruct %v2float %cubesc %cubetc |
630 | | // %denom = OpCompositeConstruct %v2float %cubema %cubema |
631 | | // %div = OpFDiv %v2float %cube %denom |
632 | | // %result = OpFAdd %v2float %div %const |
633 | | // |
634 | | // Also adding the capabilities and builtins that are needed. |
635 | | bool ReplaceCubeFaceCoord(IRContext* ctx, Instruction* inst, |
636 | 0 | const std::vector<const analysis::Constant*>&) { |
637 | 0 | analysis::TypeManager* type_mgr = ctx->get_type_mgr(); |
638 | 0 | analysis::ConstantManager* const_mgr = ctx->get_constant_mgr(); |
639 | |
|
640 | 0 | uint32_t float_type_id = type_mgr->GetFloatTypeId(); |
641 | 0 | if (float_type_id == 0) return false; |
642 | 0 | const analysis::Type* v2_float_type = type_mgr->GetFloatVectorType(2); |
643 | 0 | if (v2_float_type == nullptr) return false; |
644 | 0 | uint32_t v2_float_type_id = type_mgr->GetId(v2_float_type); |
645 | 0 | if (v2_float_type_id == 0) return false; |
646 | 0 | uint32_t bool_id = type_mgr->GetBoolTypeId(); |
647 | 0 | if (bool_id == 0) return false; |
648 | | |
649 | 0 | InstructionBuilder ir_builder( |
650 | 0 | ctx, inst, |
651 | 0 | IRContext::kAnalysisDefUse | IRContext::kAnalysisInstrToBlockMapping); |
652 | |
|
653 | 0 | uint32_t input_id = inst->GetSingleWordInOperand(2); |
654 | 0 | uint32_t glsl405_ext_inst_id = |
655 | 0 | ctx->get_feature_mgr()->GetExtInstImportId_GLSLstd450(); |
656 | 0 | if (glsl405_ext_inst_id == 0) { |
657 | 0 | ctx->AddExtInstImport("GLSL.std.450"); |
658 | 0 | glsl405_ext_inst_id = |
659 | 0 | ctx->get_feature_mgr()->GetExtInstImportId_GLSLstd450(); |
660 | 0 | } |
661 | 0 | if (glsl405_ext_inst_id == 0) return false; |
662 | | |
663 | | // Get the constants that will be used. |
664 | 0 | uint32_t f0_const_id = const_mgr->GetFloatConstId(0.0); |
665 | 0 | if (f0_const_id == 0) return false; |
666 | 0 | uint32_t f2_const_id = const_mgr->GetFloatConstId(2.0); |
667 | 0 | if (f2_const_id == 0) return false; |
668 | 0 | uint32_t f0_5_const_id = const_mgr->GetFloatConstId(0.5); |
669 | 0 | if (f0_5_const_id == 0) return false; |
670 | 0 | const analysis::Constant* vec_const = |
671 | 0 | const_mgr->GetConstant(v2_float_type, {f0_5_const_id, f0_5_const_id}); |
672 | 0 | if (vec_const == nullptr) return false; |
673 | 0 | Instruction* vec_const_inst = const_mgr->GetDefiningInstruction(vec_const); |
674 | 0 | if (vec_const_inst == nullptr) return false; |
675 | 0 | uint32_t vec_const_id = vec_const_inst->result_id(); |
676 | | |
677 | | // Extract the input values. |
678 | 0 | Instruction* x = ir_builder.AddCompositeExtract(float_type_id, input_id, {0}); |
679 | 0 | if (x == nullptr) return false; |
680 | 0 | Instruction* y = ir_builder.AddCompositeExtract(float_type_id, input_id, {1}); |
681 | 0 | if (y == nullptr) return false; |
682 | 0 | Instruction* z = ir_builder.AddCompositeExtract(float_type_id, input_id, {2}); |
683 | 0 | if (z == nullptr) return false; |
684 | | |
685 | | // Negate the input values. |
686 | 0 | Instruction* nx = |
687 | 0 | ir_builder.AddUnaryOp(float_type_id, spv::Op::OpFNegate, x->result_id()); |
688 | 0 | Instruction* ny = |
689 | 0 | ir_builder.AddUnaryOp(float_type_id, spv::Op::OpFNegate, y->result_id()); |
690 | 0 | Instruction* nz = |
691 | 0 | ir_builder.AddUnaryOp(float_type_id, spv::Op::OpFNegate, z->result_id()); |
692 | 0 | if (nx == nullptr) return false; |
693 | 0 | if (ny == nullptr) return false; |
694 | 0 | if (nz == nullptr) return false; |
695 | | |
696 | | // Get the abolsute values of the inputs. |
697 | 0 | Instruction* ax = ir_builder.AddNaryExtendedInstruction( |
698 | 0 | float_type_id, glsl405_ext_inst_id, GLSLstd450FAbs, {x->result_id()}); |
699 | 0 | if (ax == nullptr) return false; |
700 | 0 | Instruction* ay = ir_builder.AddNaryExtendedInstruction( |
701 | 0 | float_type_id, glsl405_ext_inst_id, GLSLstd450FAbs, {y->result_id()}); |
702 | 0 | if (ay == nullptr) return false; |
703 | 0 | Instruction* az = ir_builder.AddNaryExtendedInstruction( |
704 | 0 | float_type_id, glsl405_ext_inst_id, GLSLstd450FAbs, {z->result_id()}); |
705 | 0 | if (az == nullptr) return false; |
706 | | |
707 | | // Find which values are negative. Used in later computations. |
708 | 0 | Instruction* is_z_neg = ir_builder.AddBinaryOp( |
709 | 0 | bool_id, spv::Op::OpFOrdLessThan, z->result_id(), f0_const_id); |
710 | 0 | if (is_z_neg == nullptr) return false; |
711 | 0 | Instruction* is_y_neg = ir_builder.AddBinaryOp( |
712 | 0 | bool_id, spv::Op::OpFOrdLessThan, y->result_id(), f0_const_id); |
713 | 0 | if (is_y_neg == nullptr) return false; |
714 | 0 | Instruction* is_x_neg = ir_builder.AddBinaryOp( |
715 | 0 | bool_id, spv::Op::OpFOrdLessThan, x->result_id(), f0_const_id); |
716 | 0 | if (is_x_neg == nullptr) return false; |
717 | | |
718 | | // Compute cubema |
719 | 0 | Instruction* amax_x_y = ir_builder.AddNaryExtendedInstruction( |
720 | 0 | float_type_id, glsl405_ext_inst_id, GLSLstd450FMax, |
721 | 0 | {ax->result_id(), ay->result_id()}); |
722 | 0 | if (amax_x_y == nullptr) return false; |
723 | 0 | Instruction* amax = ir_builder.AddNaryExtendedInstruction( |
724 | 0 | float_type_id, glsl405_ext_inst_id, GLSLstd450FMax, |
725 | 0 | {az->result_id(), amax_x_y->result_id()}); |
726 | 0 | if (amax == nullptr) return false; |
727 | 0 | Instruction* cubema = ir_builder.AddBinaryOp(float_type_id, spv::Op::OpFMul, |
728 | 0 | f2_const_id, amax->result_id()); |
729 | 0 | if (cubema == nullptr) return false; |
730 | | |
731 | | // Do the comparisons needed for computing cubesc and cubetc. |
732 | 0 | Instruction* is_z_max = |
733 | 0 | ir_builder.AddBinaryOp(bool_id, spv::Op::OpFOrdGreaterThanEqual, |
734 | 0 | az->result_id(), amax_x_y->result_id()); |
735 | 0 | if (is_z_max == nullptr) return false; |
736 | 0 | Instruction* not_is_z_max = ir_builder.AddUnaryOp( |
737 | 0 | bool_id, spv::Op::OpLogicalNot, is_z_max->result_id()); |
738 | 0 | if (not_is_z_max == nullptr) return false; |
739 | 0 | Instruction* y_gr_x = |
740 | 0 | ir_builder.AddBinaryOp(bool_id, spv::Op::OpFOrdGreaterThanEqual, |
741 | 0 | ay->result_id(), ax->result_id()); |
742 | 0 | if (y_gr_x == nullptr) return false; |
743 | 0 | Instruction* is_y_max = |
744 | 0 | ir_builder.AddBinaryOp(bool_id, spv::Op::OpLogicalAnd, |
745 | 0 | not_is_z_max->result_id(), y_gr_x->result_id()); |
746 | 0 | if (is_y_max == nullptr) return false; |
747 | | |
748 | | // Select the correct value for cubesc. |
749 | 0 | Instruction* cubesc_case_1 = ir_builder.AddSelect( |
750 | 0 | float_type_id, is_z_neg->result_id(), nx->result_id(), x->result_id()); |
751 | 0 | if (cubesc_case_1 == nullptr) return false; |
752 | 0 | Instruction* cubesc_case_2 = ir_builder.AddSelect( |
753 | 0 | float_type_id, is_x_neg->result_id(), z->result_id(), nz->result_id()); |
754 | 0 | if (cubesc_case_2 == nullptr) return false; |
755 | 0 | Instruction* sel = |
756 | 0 | ir_builder.AddSelect(float_type_id, is_y_max->result_id(), x->result_id(), |
757 | 0 | cubesc_case_2->result_id()); |
758 | 0 | if (sel == nullptr) return false; |
759 | 0 | Instruction* cubesc = |
760 | 0 | ir_builder.AddSelect(float_type_id, is_z_max->result_id(), |
761 | 0 | cubesc_case_1->result_id(), sel->result_id()); |
762 | 0 | if (cubesc == nullptr) return false; |
763 | | |
764 | | // Select the correct value for cubetc. |
765 | 0 | Instruction* cubetc_case_1 = ir_builder.AddSelect( |
766 | 0 | float_type_id, is_y_neg->result_id(), nz->result_id(), z->result_id()); |
767 | 0 | if (cubetc_case_1 == nullptr) return false; |
768 | 0 | Instruction* cubetc = |
769 | 0 | ir_builder.AddSelect(float_type_id, is_y_max->result_id(), |
770 | 0 | cubetc_case_1->result_id(), ny->result_id()); |
771 | 0 | if (cubetc == nullptr) return false; |
772 | | |
773 | | // Do the division |
774 | 0 | Instruction* cube = ir_builder.AddCompositeConstruct( |
775 | 0 | v2_float_type_id, {cubesc->result_id(), cubetc->result_id()}); |
776 | 0 | if (cube == nullptr) return false; |
777 | 0 | Instruction* denom = ir_builder.AddCompositeConstruct( |
778 | 0 | v2_float_type_id, {cubema->result_id(), cubema->result_id()}); |
779 | 0 | if (denom == nullptr) return false; |
780 | 0 | Instruction* div = ir_builder.AddBinaryOp( |
781 | 0 | v2_float_type_id, spv::Op::OpFDiv, cube->result_id(), denom->result_id()); |
782 | 0 | if (div == nullptr) return false; |
783 | | |
784 | | // Get the final result by adding 0.5 to |div|. |
785 | 0 | inst->SetOpcode(spv::Op::OpFAdd); |
786 | 0 | Instruction::OperandList new_operands; |
787 | 0 | new_operands.push_back({SPV_OPERAND_TYPE_ID, {div->result_id()}}); |
788 | 0 | new_operands.push_back({SPV_OPERAND_TYPE_ID, {vec_const_id}}); |
789 | |
|
790 | 0 | inst->SetInOperands(std::move(new_operands)); |
791 | 0 | ctx->UpdateDefUse(inst); |
792 | 0 | return true; |
793 | 0 | } |
794 | | |
795 | | // A folding rule that will replace the CubeFaceIndexAMD extended |
796 | | // instruction in the SPV_AMD_gcn_shader_ballot. Returns true if the folding |
797 | | // is successful. |
798 | | // |
799 | | // The instruction |
800 | | // |
801 | | // %result = OpExtInst %float %1 CubeFaceIndexAMD %input |
802 | | // |
803 | | // with |
804 | | // |
805 | | // %x = OpCompositeExtract %float %input 0 |
806 | | // %y = OpCompositeExtract %float %input 1 |
807 | | // %z = OpCompositeExtract %float %input 2 |
808 | | // %ax = OpExtInst %float %n_1 FAbs %x |
809 | | // %ay = OpExtInst %float %n_1 FAbs %y |
810 | | // %az = OpExtInst %float %n_1 FAbs %z |
811 | | // %is_z_neg = OpFOrdLessThan %bool %z %float_0 |
812 | | // %is_y_neg = OpFOrdLessThan %bool %y %float_0 |
813 | | // %is_x_neg = OpFOrdLessThan %bool %x %float_0 |
814 | | // %amax_x_y = OpExtInst %float %n_1 FMax %ax %ay |
815 | | // %is_z_max = OpFOrdGreaterThanEqual %bool %az %amax_x_y |
816 | | // %y_gt_x = OpFOrdGreaterThanEqual %bool %ay %ax |
817 | | // %case_z = OpSelect %float %is_z_neg %float_5 %float4 |
818 | | // %case_y = OpSelect %float %is_y_neg %float_3 %float2 |
819 | | // %case_x = OpSelect %float %is_x_neg %float_1 %float0 |
820 | | // %sel = OpSelect %float %y_gt_x %case_y %case_x |
821 | | // %result = OpSelect %float %is_z_max %case_z %sel |
822 | | // |
823 | | // Also adding the capabilities and builtins that are needed. |
824 | | bool ReplaceCubeFaceIndex(IRContext* ctx, Instruction* inst, |
825 | 0 | const std::vector<const analysis::Constant*>&) { |
826 | 0 | analysis::TypeManager* type_mgr = ctx->get_type_mgr(); |
827 | 0 | analysis::ConstantManager* const_mgr = ctx->get_constant_mgr(); |
828 | |
|
829 | 0 | uint32_t float_type_id = type_mgr->GetFloatTypeId(); |
830 | 0 | if (float_type_id == 0) return false; |
831 | 0 | uint32_t bool_id = type_mgr->GetBoolTypeId(); |
832 | 0 | if (bool_id == 0) return false; |
833 | | |
834 | 0 | InstructionBuilder ir_builder( |
835 | 0 | ctx, inst, |
836 | 0 | IRContext::kAnalysisDefUse | IRContext::kAnalysisInstrToBlockMapping); |
837 | |
|
838 | 0 | uint32_t input_id = inst->GetSingleWordInOperand(2); |
839 | 0 | uint32_t glsl405_ext_inst_id = |
840 | 0 | ctx->get_feature_mgr()->GetExtInstImportId_GLSLstd450(); |
841 | 0 | if (glsl405_ext_inst_id == 0) { |
842 | 0 | ctx->AddExtInstImport("GLSL.std.450"); |
843 | 0 | glsl405_ext_inst_id = |
844 | 0 | ctx->get_feature_mgr()->GetExtInstImportId_GLSLstd450(); |
845 | 0 | } |
846 | | |
847 | | // Get the constants that will be used. |
848 | 0 | uint32_t f0_const_id = const_mgr->GetFloatConstId(0.0); |
849 | 0 | if (f0_const_id == 0) return false; |
850 | 0 | uint32_t f1_const_id = const_mgr->GetFloatConstId(1.0); |
851 | 0 | if (f1_const_id == 0) return false; |
852 | 0 | uint32_t f2_const_id = const_mgr->GetFloatConstId(2.0); |
853 | 0 | if (f2_const_id == 0) return false; |
854 | 0 | uint32_t f3_const_id = const_mgr->GetFloatConstId(3.0); |
855 | 0 | if (f3_const_id == 0) return false; |
856 | 0 | uint32_t f4_const_id = const_mgr->GetFloatConstId(4.0); |
857 | 0 | if (f4_const_id == 0) return false; |
858 | 0 | uint32_t f5_const_id = const_mgr->GetFloatConstId(5.0); |
859 | 0 | if (f5_const_id == 0) return false; |
860 | | |
861 | | // Extract the input values. |
862 | 0 | Instruction* x = ir_builder.AddCompositeExtract(float_type_id, input_id, {0}); |
863 | 0 | if (x == nullptr) return false; |
864 | 0 | Instruction* y = ir_builder.AddCompositeExtract(float_type_id, input_id, {1}); |
865 | 0 | if (y == nullptr) return false; |
866 | | // TODO(1-841): Handle id overflow. |
867 | 0 | Instruction* z = ir_builder.AddCompositeExtract(float_type_id, input_id, {2}); |
868 | 0 | if (z == nullptr) return false; |
869 | | |
870 | | // Get the absolute values of the inputs. |
871 | 0 | Instruction* ax = ir_builder.AddNaryExtendedInstruction( |
872 | 0 | float_type_id, glsl405_ext_inst_id, GLSLstd450FAbs, {x->result_id()}); |
873 | 0 | if (ax == nullptr) return false; |
874 | 0 | Instruction* ay = ir_builder.AddNaryExtendedInstruction( |
875 | 0 | float_type_id, glsl405_ext_inst_id, GLSLstd450FAbs, {y->result_id()}); |
876 | 0 | if (ay == nullptr) return false; |
877 | 0 | Instruction* az = ir_builder.AddNaryExtendedInstruction( |
878 | 0 | float_type_id, glsl405_ext_inst_id, GLSLstd450FAbs, {z->result_id()}); |
879 | 0 | if (az == nullptr) return false; |
880 | | |
881 | | // Find which values are negative. Used in later computations. |
882 | 0 | Instruction* is_z_neg = ir_builder.AddBinaryOp( |
883 | 0 | bool_id, spv::Op::OpFOrdLessThan, z->result_id(), f0_const_id); |
884 | 0 | if (is_z_neg == nullptr) return false; |
885 | 0 | Instruction* is_y_neg = ir_builder.AddBinaryOp( |
886 | 0 | bool_id, spv::Op::OpFOrdLessThan, y->result_id(), f0_const_id); |
887 | 0 | if (is_y_neg == nullptr) return false; |
888 | 0 | Instruction* is_x_neg = ir_builder.AddBinaryOp( |
889 | 0 | bool_id, spv::Op::OpFOrdLessThan, x->result_id(), f0_const_id); |
890 | 0 | if (is_x_neg == nullptr) return false; |
891 | | |
892 | | // Find the max value. |
893 | 0 | Instruction* amax_x_y = ir_builder.AddNaryExtendedInstruction( |
894 | 0 | float_type_id, glsl405_ext_inst_id, GLSLstd450FMax, |
895 | 0 | {ax->result_id(), ay->result_id()}); |
896 | 0 | if (amax_x_y == nullptr) return false; |
897 | 0 | Instruction* is_z_max = |
898 | 0 | ir_builder.AddBinaryOp(bool_id, spv::Op::OpFOrdGreaterThanEqual, |
899 | 0 | az->result_id(), amax_x_y->result_id()); |
900 | 0 | if (is_z_max == nullptr) return false; |
901 | 0 | Instruction* y_gr_x = |
902 | 0 | ir_builder.AddBinaryOp(bool_id, spv::Op::OpFOrdGreaterThanEqual, |
903 | 0 | ay->result_id(), ax->result_id()); |
904 | 0 | if (y_gr_x == nullptr) return false; |
905 | | |
906 | | // Get the value for each case. |
907 | 0 | Instruction* case_z = ir_builder.AddSelect( |
908 | 0 | float_type_id, is_z_neg->result_id(), f5_const_id, f4_const_id); |
909 | 0 | if (case_z == nullptr) return false; |
910 | 0 | Instruction* case_y = ir_builder.AddSelect( |
911 | 0 | float_type_id, is_y_neg->result_id(), f3_const_id, f2_const_id); |
912 | 0 | if (case_y == nullptr) return false; |
913 | 0 | Instruction* case_x = ir_builder.AddSelect( |
914 | 0 | float_type_id, is_x_neg->result_id(), f1_const_id, f0_const_id); |
915 | 0 | if (case_x == nullptr) return false; |
916 | | |
917 | | // Select the correct case. |
918 | 0 | Instruction* sel = |
919 | 0 | ir_builder.AddSelect(float_type_id, y_gr_x->result_id(), |
920 | 0 | case_y->result_id(), case_x->result_id()); |
921 | 0 | if (sel == nullptr) return false; |
922 | | |
923 | | // Get the final result by adding 0.5 to |div|. |
924 | 0 | inst->SetOpcode(spv::Op::OpSelect); |
925 | 0 | Instruction::OperandList new_operands; |
926 | 0 | new_operands.push_back({SPV_OPERAND_TYPE_ID, {is_z_max->result_id()}}); |
927 | 0 | new_operands.push_back({SPV_OPERAND_TYPE_ID, {case_z->result_id()}}); |
928 | 0 | new_operands.push_back({SPV_OPERAND_TYPE_ID, {sel->result_id()}}); |
929 | |
|
930 | 0 | inst->SetInOperands(std::move(new_operands)); |
931 | 0 | ctx->UpdateDefUse(inst); |
932 | 0 | return true; |
933 | 0 | } |
934 | | |
935 | | // A folding rule that will replace the TimeAMD extended instruction in the |
936 | | // SPV_AMD_gcn_shader_ballot. It returns true if the folding is successful. |
937 | | // It returns False, otherwise. |
938 | | // |
939 | | // The instruction |
940 | | // |
941 | | // %result = OpExtInst %uint64 %1 TimeAMD |
942 | | // |
943 | | // with |
944 | | // |
945 | | // %result = OpReadClockKHR %uint64 %uint_3 |
946 | | // |
947 | | // NOTE: TimeAMD uses subgroup scope (it is not a real time clock). |
948 | | bool ReplaceTimeAMD(IRContext* ctx, Instruction* inst, |
949 | 0 | const std::vector<const analysis::Constant*>&) { |
950 | 0 | InstructionBuilder ir_builder( |
951 | 0 | ctx, inst, |
952 | 0 | IRContext::kAnalysisDefUse | IRContext::kAnalysisInstrToBlockMapping); |
953 | 0 | ctx->AddExtension("SPV_KHR_shader_clock"); |
954 | 0 | ctx->AddCapability(spv::Capability::ShaderClockKHR); |
955 | |
|
956 | 0 | inst->SetOpcode(spv::Op::OpReadClockKHR); |
957 | 0 | Instruction::OperandList args; |
958 | 0 | uint32_t subgroup_scope_id = |
959 | 0 | ir_builder.GetUintConstantId(uint32_t(spv::Scope::Subgroup)); |
960 | 0 | args.push_back({SPV_OPERAND_TYPE_ID, {subgroup_scope_id}}); |
961 | 0 | inst->SetInOperands(std::move(args)); |
962 | 0 | ctx->UpdateDefUse(inst); |
963 | |
|
964 | 0 | return true; |
965 | 0 | } |
966 | | |
967 | | class AmdExtFoldingRules : public FoldingRules { |
968 | | public: |
969 | 0 | explicit AmdExtFoldingRules(IRContext* ctx) : FoldingRules(ctx) {} |
970 | | |
971 | | protected: |
972 | 0 | virtual void AddFoldingRules() override { |
973 | 0 | rules_[spv::Op::OpGroupIAddNonUniformAMD].push_back( |
974 | 0 | ReplaceGroupNonuniformOperationOpCode<spv::Op::OpGroupNonUniformIAdd>); |
975 | 0 | rules_[spv::Op::OpGroupFAddNonUniformAMD].push_back( |
976 | 0 | ReplaceGroupNonuniformOperationOpCode<spv::Op::OpGroupNonUniformFAdd>); |
977 | 0 | rules_[spv::Op::OpGroupUMinNonUniformAMD].push_back( |
978 | 0 | ReplaceGroupNonuniformOperationOpCode<spv::Op::OpGroupNonUniformUMin>); |
979 | 0 | rules_[spv::Op::OpGroupSMinNonUniformAMD].push_back( |
980 | 0 | ReplaceGroupNonuniformOperationOpCode<spv::Op::OpGroupNonUniformSMin>); |
981 | 0 | rules_[spv::Op::OpGroupFMinNonUniformAMD].push_back( |
982 | 0 | ReplaceGroupNonuniformOperationOpCode<spv::Op::OpGroupNonUniformFMin>); |
983 | 0 | rules_[spv::Op::OpGroupUMaxNonUniformAMD].push_back( |
984 | 0 | ReplaceGroupNonuniformOperationOpCode<spv::Op::OpGroupNonUniformUMax>); |
985 | 0 | rules_[spv::Op::OpGroupSMaxNonUniformAMD].push_back( |
986 | 0 | ReplaceGroupNonuniformOperationOpCode<spv::Op::OpGroupNonUniformSMax>); |
987 | 0 | rules_[spv::Op::OpGroupFMaxNonUniformAMD].push_back( |
988 | 0 | ReplaceGroupNonuniformOperationOpCode<spv::Op::OpGroupNonUniformFMax>); |
989 | |
|
990 | 0 | uint32_t extension_id = |
991 | 0 | context()->module()->GetExtInstImportId("SPV_AMD_shader_ballot"); |
992 | |
|
993 | 0 | if (extension_id != 0) { |
994 | 0 | ext_rules_[{extension_id, AmdShaderBallotSwizzleInvocationsAMD}] |
995 | 0 | .push_back(ReplaceSwizzleInvocations); |
996 | 0 | ext_rules_[{extension_id, AmdShaderBallotSwizzleInvocationsMaskedAMD}] |
997 | 0 | .push_back(ReplaceSwizzleInvocationsMasked); |
998 | 0 | ext_rules_[{extension_id, AmdShaderBallotWriteInvocationAMD}].push_back( |
999 | 0 | ReplaceWriteInvocation); |
1000 | 0 | ext_rules_[{extension_id, AmdShaderBallotMbcntAMD}].push_back( |
1001 | 0 | ReplaceMbcnt); |
1002 | 0 | } |
1003 | |
|
1004 | 0 | extension_id = context()->module()->GetExtInstImportId( |
1005 | 0 | "SPV_AMD_shader_trinary_minmax"); |
1006 | |
|
1007 | 0 | if (extension_id != 0) { |
1008 | 0 | ext_rules_[{extension_id, FMin3AMD}].push_back( |
1009 | 0 | ReplaceTrinaryMinMax<GLSLstd450FMin>); |
1010 | 0 | ext_rules_[{extension_id, UMin3AMD}].push_back( |
1011 | 0 | ReplaceTrinaryMinMax<GLSLstd450UMin>); |
1012 | 0 | ext_rules_[{extension_id, SMin3AMD}].push_back( |
1013 | 0 | ReplaceTrinaryMinMax<GLSLstd450SMin>); |
1014 | 0 | ext_rules_[{extension_id, FMax3AMD}].push_back( |
1015 | 0 | ReplaceTrinaryMinMax<GLSLstd450FMax>); |
1016 | 0 | ext_rules_[{extension_id, UMax3AMD}].push_back( |
1017 | 0 | ReplaceTrinaryMinMax<GLSLstd450UMax>); |
1018 | 0 | ext_rules_[{extension_id, SMax3AMD}].push_back( |
1019 | 0 | ReplaceTrinaryMinMax<GLSLstd450SMax>); |
1020 | 0 | ext_rules_[{extension_id, FMid3AMD}].push_back( |
1021 | 0 | ReplaceTrinaryMid<GLSLstd450FMin, GLSLstd450FMax, GLSLstd450FClamp>); |
1022 | 0 | ext_rules_[{extension_id, UMid3AMD}].push_back( |
1023 | 0 | ReplaceTrinaryMid<GLSLstd450UMin, GLSLstd450UMax, GLSLstd450UClamp>); |
1024 | 0 | ext_rules_[{extension_id, SMid3AMD}].push_back( |
1025 | 0 | ReplaceTrinaryMid<GLSLstd450SMin, GLSLstd450SMax, GLSLstd450SClamp>); |
1026 | 0 | } |
1027 | |
|
1028 | 0 | extension_id = |
1029 | 0 | context()->module()->GetExtInstImportId("SPV_AMD_gcn_shader"); |
1030 | |
|
1031 | 0 | if (extension_id != 0) { |
1032 | 0 | ext_rules_[{extension_id, CubeFaceCoordAMD}].push_back( |
1033 | 0 | ReplaceCubeFaceCoord); |
1034 | 0 | ext_rules_[{extension_id, CubeFaceIndexAMD}].push_back( |
1035 | 0 | ReplaceCubeFaceIndex); |
1036 | 0 | ext_rules_[{extension_id, TimeAMD}].push_back(ReplaceTimeAMD); |
1037 | 0 | } |
1038 | 0 | } |
1039 | | }; |
1040 | | |
1041 | | class AmdExtConstFoldingRules : public ConstantFoldingRules { |
1042 | | public: |
1043 | 0 | AmdExtConstFoldingRules(IRContext* ctx) : ConstantFoldingRules(ctx) {} |
1044 | | |
1045 | | protected: |
1046 | 0 | virtual void AddFoldingRules() override {} |
1047 | | }; |
1048 | | |
1049 | | } // namespace |
1050 | | |
1051 | 0 | Pass::Status AmdExtensionToKhrPass::Process() { |
1052 | 0 | bool changed = false; |
1053 | | |
1054 | | // Traverse the body of the functions to replace instructions that require |
1055 | | // the extensions. |
1056 | 0 | InstructionFolder folder( |
1057 | 0 | context(), |
1058 | 0 | std::unique_ptr<AmdExtFoldingRules>(new AmdExtFoldingRules(context())), |
1059 | 0 | MakeUnique<AmdExtConstFoldingRules>(context())); |
1060 | 0 | for (Function& func : *get_module()) { |
1061 | 0 | bool failed = |
1062 | 0 | !func.WhileEachInst([&changed, &folder, this](Instruction* inst) { |
1063 | 0 | if (folder.FoldInstruction(inst)) { |
1064 | 0 | changed = true; |
1065 | 0 | return true; |
1066 | 0 | } else if (context()->id_overflow()) { |
1067 | 0 | return false; |
1068 | 0 | } |
1069 | 0 | return true; |
1070 | 0 | }); |
1071 | |
|
1072 | 0 | if (failed) return Status::Failure; |
1073 | 0 | } |
1074 | | |
1075 | | // Now that instruction that require the extensions have been removed, we can |
1076 | | // remove the extension instructions. |
1077 | 0 | std::set<std::string> ext_to_remove = {"SPV_AMD_shader_ballot", |
1078 | 0 | "SPV_AMD_shader_trinary_minmax", |
1079 | 0 | "SPV_AMD_gcn_shader"}; |
1080 | |
|
1081 | 0 | std::vector<Instruction*> to_be_killed; |
1082 | 0 | for (Instruction& inst : context()->module()->extensions()) { |
1083 | 0 | if (inst.opcode() == spv::Op::OpExtension) { |
1084 | 0 | if (ext_to_remove.count(inst.GetInOperand(0).AsString()) != 0) { |
1085 | 0 | to_be_killed.push_back(&inst); |
1086 | 0 | } |
1087 | 0 | } |
1088 | 0 | } |
1089 | |
|
1090 | 0 | for (Instruction& inst : context()->ext_inst_imports()) { |
1091 | 0 | if (inst.opcode() == spv::Op::OpExtInstImport) { |
1092 | 0 | if (ext_to_remove.count(inst.GetInOperand(0).AsString()) != 0) { |
1093 | 0 | to_be_killed.push_back(&inst); |
1094 | 0 | } |
1095 | 0 | } |
1096 | 0 | } |
1097 | |
|
1098 | 0 | for (Instruction* inst : to_be_killed) { |
1099 | 0 | context()->KillInst(inst); |
1100 | 0 | changed = true; |
1101 | 0 | } |
1102 | | |
1103 | | // The replacements that take place use instructions that are missing before |
1104 | | // SPIR-V 1.3. If we changed something, we will have to make sure the version |
1105 | | // is at least SPIR-V 1.3 to make sure those instruction can be used. |
1106 | 0 | if (changed) { |
1107 | 0 | uint32_t version = get_module()->version(); |
1108 | 0 | if (version < 0x00010300 /*1.3*/) { |
1109 | 0 | get_module()->set_version(0x00010300); |
1110 | 0 | } |
1111 | 0 | } |
1112 | 0 | return changed ? Status::SuccessWithChange : Status::SuccessWithoutChange; |
1113 | 0 | } |
1114 | | |
1115 | | } // namespace opt |
1116 | | } // namespace spvtools |