Coverage Report

Created: 2026-01-16 06:48

next uncovered line (L), next uncovered region (R), next uncovered branch (B)
/src/spirv-tools/source/opt/amd_ext_to_khr.cpp
Line
Count
Source
1
// Copyright (c) 2019 Google LLC.
2
//
3
// Licensed under the Apache License, Version 2.0 (the "License");
4
// you may not use this file except in compliance with the License.
5
// You may obtain a copy of the License at
6
//
7
//     http://www.apache.org/licenses/LICENSE-2.0
8
//
9
// Unless required by applicable law or agreed to in writing, software
10
// distributed under the License is distributed on an "AS IS" BASIS,
11
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
// See the License for the specific language governing permissions and
13
// limitations under the License.
14
15
#include "source/opt/amd_ext_to_khr.h"
16
17
#include <set>
18
#include <string>
19
20
#include "ir_builder.h"
21
#include "source/opt/ir_context.h"
22
#include "type_manager.h"
23
24
namespace spvtools {
25
namespace opt {
26
namespace {
27
28
enum AmdShaderBallotExtOpcodes {
29
  AmdShaderBallotSwizzleInvocationsAMD = 1,
30
  AmdShaderBallotSwizzleInvocationsMaskedAMD = 2,
31
  AmdShaderBallotWriteInvocationAMD = 3,
32
  AmdShaderBallotMbcntAMD = 4
33
};
34
35
enum AmdShaderTrinaryMinMaxExtOpCodes {
36
  FMin3AMD = 1,
37
  UMin3AMD = 2,
38
  SMin3AMD = 3,
39
  FMax3AMD = 4,
40
  UMax3AMD = 5,
41
  SMax3AMD = 6,
42
  FMid3AMD = 7,
43
  UMid3AMD = 8,
44
  SMid3AMD = 9
45
};
46
47
enum AmdGcnShader { CubeFaceCoordAMD = 2, CubeFaceIndexAMD = 1, TimeAMD = 3 };
48
49
0
analysis::Type* GetUIntType(IRContext* ctx) {
50
0
  analysis::Integer int_type(32, false);
51
0
  return ctx->get_type_mgr()->GetRegisteredType(&int_type);
52
0
}
53
54
// Returns a folding rule that replaces |op(a,b,c)| by |op(op(a,b),c)|, where
55
// |op| is either min or max. |opcode| is the binary opcode in the GLSLstd450
56
// extended instruction set that corresponds to the trinary instruction being
57
// replaced.
58
template <GLSLstd450 opcode>
59
bool ReplaceTrinaryMinMax(IRContext* ctx, Instruction* inst,
60
0
                          const std::vector<const analysis::Constant*>&) {
61
0
  uint32_t glsl405_ext_inst_id =
62
0
      ctx->get_feature_mgr()->GetExtInstImportId_GLSLstd450();
63
0
  if (glsl405_ext_inst_id == 0) {
64
0
    ctx->AddExtInstImport("GLSL.std.450");
65
0
    glsl405_ext_inst_id =
66
0
        ctx->get_feature_mgr()->GetExtInstImportId_GLSLstd450();
67
0
  }
68
69
0
  InstructionBuilder ir_builder(
70
0
      ctx, inst,
71
0
      IRContext::kAnalysisDefUse | IRContext::kAnalysisInstrToBlockMapping);
72
73
0
  uint32_t op1 = inst->GetSingleWordInOperand(2);
74
0
  uint32_t op2 = inst->GetSingleWordInOperand(3);
75
0
  uint32_t op3 = inst->GetSingleWordInOperand(4);
76
77
0
  Instruction* temp = ir_builder.AddNaryExtendedInstruction(
78
0
      inst->type_id(), glsl405_ext_inst_id, opcode, {op1, op2});
79
0
  if (temp == nullptr) return false;
80
81
0
  Instruction::OperandList new_operands;
82
0
  new_operands.push_back({SPV_OPERAND_TYPE_ID, {glsl405_ext_inst_id}});
83
0
  new_operands.push_back({SPV_OPERAND_TYPE_EXTENSION_INSTRUCTION_NUMBER,
84
0
                          {static_cast<uint32_t>(opcode)}});
85
0
  new_operands.push_back({SPV_OPERAND_TYPE_ID, {temp->result_id()}});
86
0
  new_operands.push_back({SPV_OPERAND_TYPE_ID, {op3}});
87
88
0
  inst->SetInOperands(std::move(new_operands));
89
0
  ctx->UpdateDefUse(inst);
90
0
  return true;
91
0
}
Unexecuted instantiation: amd_ext_to_khr.cpp:bool spvtools::opt::(anonymous namespace)::ReplaceTrinaryMinMax<(GLSLstd450)37>(spvtools::opt::IRContext*, spvtools::opt::Instruction*, std::__1::vector<spvtools::opt::analysis::Constant const*, std::__1::allocator<spvtools::opt::analysis::Constant const*> > const&)
Unexecuted instantiation: amd_ext_to_khr.cpp:bool spvtools::opt::(anonymous namespace)::ReplaceTrinaryMinMax<(GLSLstd450)38>(spvtools::opt::IRContext*, spvtools::opt::Instruction*, std::__1::vector<spvtools::opt::analysis::Constant const*, std::__1::allocator<spvtools::opt::analysis::Constant const*> > const&)
Unexecuted instantiation: amd_ext_to_khr.cpp:bool spvtools::opt::(anonymous namespace)::ReplaceTrinaryMinMax<(GLSLstd450)39>(spvtools::opt::IRContext*, spvtools::opt::Instruction*, std::__1::vector<spvtools::opt::analysis::Constant const*, std::__1::allocator<spvtools::opt::analysis::Constant const*> > const&)
Unexecuted instantiation: amd_ext_to_khr.cpp:bool spvtools::opt::(anonymous namespace)::ReplaceTrinaryMinMax<(GLSLstd450)40>(spvtools::opt::IRContext*, spvtools::opt::Instruction*, std::__1::vector<spvtools::opt::analysis::Constant const*, std::__1::allocator<spvtools::opt::analysis::Constant const*> > const&)
Unexecuted instantiation: amd_ext_to_khr.cpp:bool spvtools::opt::(anonymous namespace)::ReplaceTrinaryMinMax<(GLSLstd450)41>(spvtools::opt::IRContext*, spvtools::opt::Instruction*, std::__1::vector<spvtools::opt::analysis::Constant const*, std::__1::allocator<spvtools::opt::analysis::Constant const*> > const&)
Unexecuted instantiation: amd_ext_to_khr.cpp:bool spvtools::opt::(anonymous namespace)::ReplaceTrinaryMinMax<(GLSLstd450)42>(spvtools::opt::IRContext*, spvtools::opt::Instruction*, std::__1::vector<spvtools::opt::analysis::Constant const*, std::__1::allocator<spvtools::opt::analysis::Constant const*> > const&)
92
93
// Returns a folding rule that replaces |mid(a,b,c)| by |clamp(a, min(b,c),
94
// max(b,c)|. The three parameters are the opcode that correspond to the min,
95
// max, and clamp operations for the type of the instruction being replaced.
96
template <GLSLstd450 min_opcode, GLSLstd450 max_opcode, GLSLstd450 clamp_opcode>
97
bool ReplaceTrinaryMid(IRContext* ctx, Instruction* inst,
98
0
                       const std::vector<const analysis::Constant*>&) {
99
0
  uint32_t glsl405_ext_inst_id =
100
0
      ctx->get_feature_mgr()->GetExtInstImportId_GLSLstd450();
101
0
  if (glsl405_ext_inst_id == 0) {
102
0
    ctx->AddExtInstImport("GLSL.std.450");
103
0
    glsl405_ext_inst_id =
104
0
        ctx->get_feature_mgr()->GetExtInstImportId_GLSLstd450();
105
0
  }
106
107
0
  InstructionBuilder ir_builder(
108
0
      ctx, inst,
109
0
      IRContext::kAnalysisDefUse | IRContext::kAnalysisInstrToBlockMapping);
110
111
0
  uint32_t op1 = inst->GetSingleWordInOperand(2);
112
0
  uint32_t op2 = inst->GetSingleWordInOperand(3);
113
0
  uint32_t op3 = inst->GetSingleWordInOperand(4);
114
115
0
  Instruction* min = ir_builder.AddNaryExtendedInstruction(
116
0
      inst->type_id(), glsl405_ext_inst_id, static_cast<uint32_t>(min_opcode),
117
0
      {op2, op3});
118
0
  if (min == nullptr) return false;
119
120
0
  Instruction* max = ir_builder.AddNaryExtendedInstruction(
121
0
      inst->type_id(), glsl405_ext_inst_id, static_cast<uint32_t>(max_opcode),
122
0
      {op2, op3});
123
0
  if (max == nullptr) return false;
124
125
0
  Instruction::OperandList new_operands;
126
0
  new_operands.push_back({SPV_OPERAND_TYPE_ID, {glsl405_ext_inst_id}});
127
0
  new_operands.push_back({SPV_OPERAND_TYPE_EXTENSION_INSTRUCTION_NUMBER,
128
0
                          {static_cast<uint32_t>(clamp_opcode)}});
129
0
  new_operands.push_back({SPV_OPERAND_TYPE_ID, {op1}});
130
0
  new_operands.push_back({SPV_OPERAND_TYPE_ID, {min->result_id()}});
131
0
  new_operands.push_back({SPV_OPERAND_TYPE_ID, {max->result_id()}});
132
133
0
  inst->SetInOperands(std::move(new_operands));
134
0
  ctx->UpdateDefUse(inst);
135
0
  return true;
136
0
}
Unexecuted instantiation: amd_ext_to_khr.cpp:bool spvtools::opt::(anonymous namespace)::ReplaceTrinaryMid<(GLSLstd450)37, (GLSLstd450)40, (GLSLstd450)43>(spvtools::opt::IRContext*, spvtools::opt::Instruction*, std::__1::vector<spvtools::opt::analysis::Constant const*, std::__1::allocator<spvtools::opt::analysis::Constant const*> > const&)
Unexecuted instantiation: amd_ext_to_khr.cpp:bool spvtools::opt::(anonymous namespace)::ReplaceTrinaryMid<(GLSLstd450)38, (GLSLstd450)41, (GLSLstd450)44>(spvtools::opt::IRContext*, spvtools::opt::Instruction*, std::__1::vector<spvtools::opt::analysis::Constant const*, std::__1::allocator<spvtools::opt::analysis::Constant const*> > const&)
Unexecuted instantiation: amd_ext_to_khr.cpp:bool spvtools::opt::(anonymous namespace)::ReplaceTrinaryMid<(GLSLstd450)39, (GLSLstd450)42, (GLSLstd450)45>(spvtools::opt::IRContext*, spvtools::opt::Instruction*, std::__1::vector<spvtools::opt::analysis::Constant const*, std::__1::allocator<spvtools::opt::analysis::Constant const*> > const&)
137
138
// Returns a folding rule that will replace the opcode with |opcode| and add
139
// the capabilities required.  The folding rule assumes it is folding an
140
// OpGroup*NonUniformAMD instruction from the SPV_AMD_shader_ballot extension.
141
template <spv::Op new_opcode>
142
bool ReplaceGroupNonuniformOperationOpCode(
143
    IRContext* ctx, Instruction* inst,
144
0
    const std::vector<const analysis::Constant*>&) {
145
0
  switch (new_opcode) {
146
0
    case spv::Op::OpGroupNonUniformIAdd:
147
0
    case spv::Op::OpGroupNonUniformFAdd:
148
0
    case spv::Op::OpGroupNonUniformUMin:
149
0
    case spv::Op::OpGroupNonUniformSMin:
150
0
    case spv::Op::OpGroupNonUniformFMin:
151
0
    case spv::Op::OpGroupNonUniformUMax:
152
0
    case spv::Op::OpGroupNonUniformSMax:
153
0
    case spv::Op::OpGroupNonUniformFMax:
154
0
      break;
155
0
    default:
156
0
      assert(
157
0
          false &&
158
0
          "Should be replacing with a group non uniform arithmetic operation.");
159
0
  }
160
161
0
  switch (inst->opcode()) {
162
0
    case spv::Op::OpGroupIAddNonUniformAMD:
163
0
    case spv::Op::OpGroupFAddNonUniformAMD:
164
0
    case spv::Op::OpGroupUMinNonUniformAMD:
165
0
    case spv::Op::OpGroupSMinNonUniformAMD:
166
0
    case spv::Op::OpGroupFMinNonUniformAMD:
167
0
    case spv::Op::OpGroupUMaxNonUniformAMD:
168
0
    case spv::Op::OpGroupSMaxNonUniformAMD:
169
0
    case spv::Op::OpGroupFMaxNonUniformAMD:
170
0
      break;
171
0
    default:
172
0
      assert(false &&
173
0
             "Should be replacing a group non uniform arithmetic operation.");
174
0
  }
175
176
0
  ctx->AddCapability(spv::Capability::GroupNonUniformArithmetic);
177
0
  inst->SetOpcode(new_opcode);
178
0
  return true;
179
0
}
Unexecuted instantiation: amd_ext_to_khr.cpp:bool spvtools::opt::(anonymous namespace)::ReplaceGroupNonuniformOperationOpCode<(spv::Op)349>(spvtools::opt::IRContext*, spvtools::opt::Instruction*, std::__1::vector<spvtools::opt::analysis::Constant const*, std::__1::allocator<spvtools::opt::analysis::Constant const*> > const&)
Unexecuted instantiation: amd_ext_to_khr.cpp:bool spvtools::opt::(anonymous namespace)::ReplaceGroupNonuniformOperationOpCode<(spv::Op)350>(spvtools::opt::IRContext*, spvtools::opt::Instruction*, std::__1::vector<spvtools::opt::analysis::Constant const*, std::__1::allocator<spvtools::opt::analysis::Constant const*> > const&)
Unexecuted instantiation: amd_ext_to_khr.cpp:bool spvtools::opt::(anonymous namespace)::ReplaceGroupNonuniformOperationOpCode<(spv::Op)354>(spvtools::opt::IRContext*, spvtools::opt::Instruction*, std::__1::vector<spvtools::opt::analysis::Constant const*, std::__1::allocator<spvtools::opt::analysis::Constant const*> > const&)
Unexecuted instantiation: amd_ext_to_khr.cpp:bool spvtools::opt::(anonymous namespace)::ReplaceGroupNonuniformOperationOpCode<(spv::Op)353>(spvtools::opt::IRContext*, spvtools::opt::Instruction*, std::__1::vector<spvtools::opt::analysis::Constant const*, std::__1::allocator<spvtools::opt::analysis::Constant const*> > const&)
Unexecuted instantiation: amd_ext_to_khr.cpp:bool spvtools::opt::(anonymous namespace)::ReplaceGroupNonuniformOperationOpCode<(spv::Op)355>(spvtools::opt::IRContext*, spvtools::opt::Instruction*, std::__1::vector<spvtools::opt::analysis::Constant const*, std::__1::allocator<spvtools::opt::analysis::Constant const*> > const&)
Unexecuted instantiation: amd_ext_to_khr.cpp:bool spvtools::opt::(anonymous namespace)::ReplaceGroupNonuniformOperationOpCode<(spv::Op)357>(spvtools::opt::IRContext*, spvtools::opt::Instruction*, std::__1::vector<spvtools::opt::analysis::Constant const*, std::__1::allocator<spvtools::opt::analysis::Constant const*> > const&)
Unexecuted instantiation: amd_ext_to_khr.cpp:bool spvtools::opt::(anonymous namespace)::ReplaceGroupNonuniformOperationOpCode<(spv::Op)356>(spvtools::opt::IRContext*, spvtools::opt::Instruction*, std::__1::vector<spvtools::opt::analysis::Constant const*, std::__1::allocator<spvtools::opt::analysis::Constant const*> > const&)
Unexecuted instantiation: amd_ext_to_khr.cpp:bool spvtools::opt::(anonymous namespace)::ReplaceGroupNonuniformOperationOpCode<(spv::Op)358>(spvtools::opt::IRContext*, spvtools::opt::Instruction*, std::__1::vector<spvtools::opt::analysis::Constant const*, std::__1::allocator<spvtools::opt::analysis::Constant const*> > const&)
180
181
// Returns a folding rule that will replace the SwizzleInvocationsAMD extended
182
// instruction in the SPV_AMD_shader_ballot extension.
183
//
184
// The instruction
185
//
186
//  %offset = OpConstantComposite %v3uint %x %y %z %w
187
//  %result = OpExtInst %type %1 SwizzleInvocationsAMD %data %offset
188
//
189
// is replaced with
190
//
191
// potentially new constants and types
192
//
193
// clang-format off
194
//         %uint_max = OpConstant %uint 0xFFFFFFFF
195
//           %v4uint = OpTypeVector %uint 4
196
//     %ballot_value = OpConstantComposite %v4uint %uint_max %uint_max %uint_max %uint_max
197
//             %null = OpConstantNull %type
198
// clang-format on
199
//
200
// and the following code in the function body
201
//
202
// clang-format off
203
//         %id = OpLoad %uint %SubgroupLocalInvocationId
204
//   %quad_idx = OpBitwiseAnd %uint %id %uint_3
205
//   %quad_ldr = OpBitwiseXor %uint %id %quad_idx
206
//  %my_offset = OpVectorExtractDynamic %uint %offset %quad_idx
207
// %target_inv = OpIAdd %uint %quad_ldr %my_offset
208
//  %is_active = OpGroupNonUniformBallotBitExtract %bool %uint_3 %ballot_value %target_inv
209
//    %shuffle = OpGroupNonUniformShuffle %type %uint_3 %data %target_inv
210
//     %result = OpSelect %type %is_active %shuffle %null
211
// clang-format on
212
//
213
// Also adding the capabilities and builtins that are needed.
214
bool ReplaceSwizzleInvocations(IRContext* ctx, Instruction* inst,
215
0
                               const std::vector<const analysis::Constant*>&) {
216
0
  analysis::TypeManager* type_mgr = ctx->get_type_mgr();
217
0
  analysis::ConstantManager* const_mgr = ctx->get_constant_mgr();
218
219
0
  ctx->AddExtension("SPV_KHR_shader_ballot");
220
0
  ctx->AddCapability(spv::Capability::GroupNonUniformBallot);
221
0
  ctx->AddCapability(spv::Capability::GroupNonUniformShuffle);
222
223
0
  InstructionBuilder ir_builder(
224
0
      ctx, inst,
225
0
      IRContext::kAnalysisDefUse | IRContext::kAnalysisInstrToBlockMapping);
226
227
0
  uint32_t data_id = inst->GetSingleWordInOperand(2);
228
0
  uint32_t offset_id = inst->GetSingleWordInOperand(3);
229
230
  // Get the subgroup invocation id.
231
0
  uint32_t var_id = ctx->GetBuiltinInputVarId(
232
0
      uint32_t(spv::BuiltIn::SubgroupLocalInvocationId));
233
0
  if (var_id == 0) return false;
234
0
  Instruction* var_inst = ctx->get_def_use_mgr()->GetDef(var_id);
235
0
  if (var_inst == nullptr) return false;
236
0
  Instruction* var_ptr_type =
237
0
      ctx->get_def_use_mgr()->GetDef(var_inst->type_id());
238
0
  if (var_ptr_type == nullptr) return false;
239
0
  uint32_t uint_type_id = var_ptr_type->GetSingleWordInOperand(1);
240
0
  if (uint_type_id == 0) return false;
241
242
0
  Instruction* id = ir_builder.AddLoad(uint_type_id, var_id);
243
0
  if (id == nullptr) return false;
244
245
0
  uint32_t quad_mask = ir_builder.GetUintConstantId(3);
246
0
  if (quad_mask == 0) return false;
247
248
  // This gives the offset in the group of 4 of this invocation.
249
0
  Instruction* quad_idx = ir_builder.AddBinaryOp(
250
0
      uint_type_id, spv::Op::OpBitwiseAnd, id->result_id(), quad_mask);
251
0
  if (quad_idx == nullptr) return false;
252
253
  // Get the invocation id of the first invocation in the group of 4.
254
0
  Instruction* quad_ldr =
255
0
      ir_builder.AddBinaryOp(uint_type_id, spv::Op::OpBitwiseXor,
256
0
                             id->result_id(), quad_idx->result_id());
257
0
  if (quad_ldr == nullptr) return false;
258
259
  // Get the offset of the target invocation from the offset vector.
260
0
  Instruction* my_offset =
261
0
      ir_builder.AddBinaryOp(uint_type_id, spv::Op::OpVectorExtractDynamic,
262
0
                             offset_id, quad_idx->result_id());
263
0
  if (my_offset == nullptr) return false;
264
265
  // Determine the index of the invocation to read from.
266
0
  Instruction* target_inv =
267
0
      ir_builder.AddBinaryOp(uint_type_id, spv::Op::OpIAdd,
268
0
                             quad_ldr->result_id(), my_offset->result_id());
269
0
  if (target_inv == nullptr) return false;
270
271
  // Do the group operations
272
0
  uint32_t uint_max_id = ir_builder.GetUintConstantId(0xFFFFFFFF);
273
0
  if (uint_max_id == 0) return false;
274
0
  uint32_t subgroup_scope =
275
0
      ir_builder.GetUintConstantId(uint32_t(spv::Scope::Subgroup));
276
0
  if (subgroup_scope == 0) return false;
277
0
  const auto* vec_type = type_mgr->GetUIntVectorType(4);
278
0
  if (vec_type == nullptr) return false;
279
0
  const auto* ballot_value_const = const_mgr->GetConstant(
280
0
      vec_type, {uint_max_id, uint_max_id, uint_max_id, uint_max_id});
281
0
  if (ballot_value_const == nullptr) return false;
282
0
  Instruction* ballot_value =
283
0
      const_mgr->GetDefiningInstruction(ballot_value_const);
284
0
  if (ballot_value == nullptr) return false;
285
0
  uint32_t bool_type_id = type_mgr->GetBoolTypeId();
286
0
  if (bool_type_id == 0) return false;
287
0
  Instruction* is_active = ir_builder.AddNaryOp(
288
0
      bool_type_id, spv::Op::OpGroupNonUniformBallotBitExtract,
289
0
      {subgroup_scope, ballot_value->result_id(), target_inv->result_id()});
290
0
  if (is_active == nullptr) return false;
291
0
  Instruction* shuffle =
292
0
      ir_builder.AddNaryOp(inst->type_id(), spv::Op::OpGroupNonUniformShuffle,
293
0
                           {subgroup_scope, data_id, target_inv->result_id()});
294
0
  if (shuffle == nullptr) return false;
295
296
  // Create the null constant to use in the select.
297
0
  const auto* result_type = type_mgr->GetType(inst->type_id());
298
0
  if (result_type == nullptr) return false;
299
0
  const auto* null =
300
0
      const_mgr->GetConstant(result_type, std::vector<uint32_t>());
301
0
  if (null == nullptr) {
302
0
    return false;
303
0
  }
304
0
  Instruction* null_inst = const_mgr->GetDefiningInstruction(null);
305
0
  if (null_inst == nullptr) {
306
0
    return false;
307
0
  }
308
309
  // Build the select.
310
0
  inst->SetOpcode(spv::Op::OpSelect);
311
0
  Instruction::OperandList new_operands;
312
0
  new_operands.push_back({SPV_OPERAND_TYPE_ID, {is_active->result_id()}});
313
0
  new_operands.push_back({SPV_OPERAND_TYPE_ID, {shuffle->result_id()}});
314
0
  new_operands.push_back({SPV_OPERAND_TYPE_ID, {null_inst->result_id()}});
315
316
0
  inst->SetInOperands(std::move(new_operands));
317
0
  ctx->UpdateDefUse(inst);
318
0
  return true;
319
0
}
320
321
// Returns a folding rule that will replace the SwizzleInvocationsMaskedAMD
322
// extended instruction in the SPV_AMD_shader_ballot extension.
323
//
324
// The instruction
325
//
326
//    %mask = OpConstantComposite %v3uint %uint_x %uint_y %uint_z
327
//  %result = OpExtInst %uint %1 SwizzleInvocationsMaskedAMD %data %mask
328
//
329
// is replaced with
330
//
331
// potentially new constants and types
332
//
333
// clang-format off
334
// %uint_mask_extend = OpConstant %uint 0xFFFFFFE0
335
//         %uint_max = OpConstant %uint 0xFFFFFFFF
336
//           %v4uint = OpTypeVector %uint 4
337
//     %ballot_value = OpConstantComposite %v4uint %uint_max %uint_max %uint_max %uint_max
338
// clang-format on
339
//
340
// and the following code in the function body
341
//
342
// clang-format off
343
//         %id = OpLoad %uint %SubgroupLocalInvocationId
344
//   %and_mask = OpBitwiseOr %uint %uint_x %uint_mask_extend
345
//        %and = OpBitwiseAnd %uint %id %and_mask
346
//         %or = OpBitwiseOr %uint %and %uint_y
347
// %target_inv = OpBitwiseXor %uint %or %uint_z
348
//  %is_active = OpGroupNonUniformBallotBitExtract %bool %uint_3 %ballot_value %target_inv
349
//    %shuffle = OpGroupNonUniformShuffle %type %uint_3 %data %target_inv
350
//     %result = OpSelect %type %is_active %shuffle %uint_0
351
// clang-format on
352
//
353
// Also adding the capabilities and builtins that are needed.
354
bool ReplaceSwizzleInvocationsMasked(
355
    IRContext* ctx, Instruction* inst,
356
0
    const std::vector<const analysis::Constant*>&) {
357
0
  analysis::TypeManager* type_mgr = ctx->get_type_mgr();
358
0
  analysis::DefUseManager* def_use_mgr = ctx->get_def_use_mgr();
359
0
  analysis::ConstantManager* const_mgr = ctx->get_constant_mgr();
360
361
0
  ctx->AddCapability(spv::Capability::GroupNonUniformBallot);
362
0
  ctx->AddCapability(spv::Capability::GroupNonUniformShuffle);
363
364
0
  InstructionBuilder ir_builder(
365
0
      ctx, inst,
366
0
      IRContext::kAnalysisDefUse | IRContext::kAnalysisInstrToBlockMapping);
367
368
  // Get the operands to inst, and the components of the mask
369
0
  uint32_t data_id = inst->GetSingleWordInOperand(2);
370
371
0
  Instruction* mask_inst = def_use_mgr->GetDef(inst->GetSingleWordInOperand(3));
372
0
  if (mask_inst == nullptr) return false;
373
0
  assert(mask_inst->opcode() == spv::Op::OpConstantComposite &&
374
0
         "The mask is suppose to be a vector constant.");
375
0
  assert(mask_inst->NumInOperands() == 3 &&
376
0
         "The mask is suppose to have 3 components.");
377
378
0
  uint32_t uint_x = mask_inst->GetSingleWordInOperand(0);
379
0
  if (uint_x == 0) return false;
380
0
  uint32_t uint_y = mask_inst->GetSingleWordInOperand(1);
381
0
  if (uint_y == 0) return false;
382
0
  uint32_t uint_z = mask_inst->GetSingleWordInOperand(2);
383
0
  if (uint_z == 0) return false;
384
385
  // Get the subgroup invocation id.
386
0
  uint32_t var_id = ctx->GetBuiltinInputVarId(
387
0
      uint32_t(spv::BuiltIn::SubgroupLocalInvocationId));
388
0
  if (var_id == 0) return false;
389
0
  ctx->AddExtension("SPV_KHR_shader_ballot");
390
0
  Instruction* var_inst = ctx->get_def_use_mgr()->GetDef(var_id);
391
0
  if (var_inst == nullptr) return false;
392
0
  Instruction* var_ptr_type =
393
0
      ctx->get_def_use_mgr()->GetDef(var_inst->type_id());
394
0
  if (var_ptr_type == nullptr) return false;
395
0
  uint32_t uint_type_id = var_ptr_type->GetSingleWordInOperand(1);
396
0
  if (uint_type_id == 0) return false;
397
398
0
  Instruction* id = ir_builder.AddLoad(uint_type_id, var_id);
399
0
  if (id == nullptr) return false;
400
401
  // Do the bitwise operations.
402
0
  uint32_t mask_extended = ir_builder.GetUintConstantId(0xFFFFFFE0);
403
0
  if (mask_extended == 0) return false;
404
0
  Instruction* and_mask = ir_builder.AddBinaryOp(
405
0
      uint_type_id, spv::Op::OpBitwiseOr, uint_x, mask_extended);
406
0
  if (and_mask == nullptr) return false;
407
0
  Instruction* and_result =
408
0
      ir_builder.AddBinaryOp(uint_type_id, spv::Op::OpBitwiseAnd,
409
0
                             id->result_id(), and_mask->result_id());
410
0
  if (and_result == nullptr) return false;
411
0
  Instruction* or_result = ir_builder.AddBinaryOp(
412
0
      uint_type_id, spv::Op::OpBitwiseOr, and_result->result_id(), uint_y);
413
0
  if (or_result == nullptr) return false;
414
0
  Instruction* target_inv = ir_builder.AddBinaryOp(
415
0
      uint_type_id, spv::Op::OpBitwiseXor, or_result->result_id(), uint_z);
416
0
  if (target_inv == nullptr) return false;
417
418
  // Do the group operations
419
0
  uint32_t uint_max_id = ir_builder.GetUintConstantId(0xFFFFFFFF);
420
0
  if (uint_max_id == 0) return false;
421
0
  uint32_t subgroup_scope =
422
0
      ir_builder.GetUintConstantId(uint32_t(spv::Scope::Subgroup));
423
0
  if (subgroup_scope == 0) return false;
424
0
  const auto* vec_type = type_mgr->GetUIntVectorType(4);
425
0
  if (vec_type == nullptr) return false;
426
0
  const auto* ballot_value_const = const_mgr->GetConstant(
427
0
      vec_type, {uint_max_id, uint_max_id, uint_max_id, uint_max_id});
428
0
  if (ballot_value_const == nullptr) return false;
429
0
  Instruction* ballot_value =
430
0
      const_mgr->GetDefiningInstruction(ballot_value_const);
431
0
  if (ballot_value == nullptr) return false;
432
0
  uint32_t bool_type_id = type_mgr->GetBoolTypeId();
433
0
  if (bool_type_id == 0) return false;
434
0
  Instruction* is_active = ir_builder.AddNaryOp(
435
0
      bool_type_id, spv::Op::OpGroupNonUniformBallotBitExtract,
436
0
      {subgroup_scope, ballot_value->result_id(), target_inv->result_id()});
437
0
  if (is_active == nullptr) return false;
438
0
  Instruction* shuffle =
439
0
      ir_builder.AddNaryOp(inst->type_id(), spv::Op::OpGroupNonUniformShuffle,
440
0
                           {subgroup_scope, data_id, target_inv->result_id()});
441
0
  if (shuffle == nullptr) return false;
442
443
  // Create the null constant to use in the select.
444
0
  const auto* result_type = type_mgr->GetType(inst->type_id());
445
0
  if (result_type == nullptr) return false;
446
0
  const auto* null =
447
0
      const_mgr->GetConstant(result_type, std::vector<uint32_t>());
448
0
  if (null == nullptr) return false;
449
0
  Instruction* null_inst = const_mgr->GetDefiningInstruction(null);
450
0
  if (null_inst == nullptr) return false;
451
452
  // Build the select.
453
0
  inst->SetOpcode(spv::Op::OpSelect);
454
0
  Instruction::OperandList new_operands;
455
0
  new_operands.push_back({SPV_OPERAND_TYPE_ID, {is_active->result_id()}});
456
0
  new_operands.push_back({SPV_OPERAND_TYPE_ID, {shuffle->result_id()}});
457
0
  new_operands.push_back({SPV_OPERAND_TYPE_ID, {null_inst->result_id()}});
458
459
0
  inst->SetInOperands(std::move(new_operands));
460
0
  ctx->UpdateDefUse(inst);
461
0
  return true;
462
0
}
463
464
// Returns a folding rule that will replace the WriteInvocationAMD extended
465
// instruction in the SPV_AMD_shader_ballot extension.
466
//
467
// The instruction
468
//
469
// clang-format off
470
//    %result = OpExtInst %type %1 WriteInvocationAMD %input_value %write_value %invocation_index
471
// clang-format on
472
//
473
// with
474
//
475
//     %id = OpLoad %uint %SubgroupLocalInvocationId
476
//    %cmp = OpIEqual %bool %id %invocation_index
477
// %result = OpSelect %type %cmp %write_value %input_value
478
//
479
// Also adding the capabilities and builtins that are needed.
480
bool ReplaceWriteInvocation(IRContext* ctx, Instruction* inst,
481
0
                            const std::vector<const analysis::Constant*>&) {
482
0
  uint32_t var_id = ctx->GetBuiltinInputVarId(
483
0
      uint32_t(spv::BuiltIn::SubgroupLocalInvocationId));
484
0
  if (var_id == 0) return false;
485
0
  Instruction* var_inst = ctx->get_def_use_mgr()->GetDef(var_id);
486
0
  if (var_inst == nullptr) return false;
487
0
  Instruction* var_ptr_type =
488
0
      ctx->get_def_use_mgr()->GetDef(var_inst->type_id());
489
0
  if (var_ptr_type == nullptr) return false;
490
0
  ctx->AddCapability(spv::Capability::SubgroupBallotKHR);
491
0
  ctx->AddExtension("SPV_KHR_shader_ballot");
492
493
0
  InstructionBuilder ir_builder(
494
0
      ctx, inst,
495
0
      IRContext::kAnalysisDefUse | IRContext::kAnalysisInstrToBlockMapping);
496
0
  Instruction* t =
497
0
      ir_builder.AddLoad(var_ptr_type->GetSingleWordInOperand(1), var_id);
498
0
  if (t == nullptr) return false;
499
0
  analysis::Bool bool_type;
500
0
  uint32_t bool_type_id = ctx->get_type_mgr()->GetTypeInstruction(&bool_type);
501
0
  if (bool_type_id == 0) return false;
502
0
  Instruction* cmp =
503
0
      ir_builder.AddBinaryOp(bool_type_id, spv::Op::OpIEqual, t->result_id(),
504
0
                             inst->GetSingleWordInOperand(4));
505
0
  if (cmp == nullptr) return false;
506
507
  // Build a select.
508
0
  inst->SetOpcode(spv::Op::OpSelect);
509
0
  Instruction::OperandList new_operands;
510
0
  new_operands.push_back({SPV_OPERAND_TYPE_ID, {cmp->result_id()}});
511
0
  new_operands.push_back(inst->GetInOperand(3));
512
0
  new_operands.push_back(inst->GetInOperand(2));
513
514
0
  inst->SetInOperands(std::move(new_operands));
515
0
  ctx->UpdateDefUse(inst);
516
0
  return true;
517
0
}
518
519
// Returns a folding rule that will replace the MbcntAMD extended instruction in
520
// the SPV_AMD_shader_ballot extension.
521
//
522
// The instruction
523
//
524
//  %result = OpExtInst %uint %1 MbcntAMD %mask
525
//
526
// with
527
//
528
// Get SubgroupLtMask and convert the first 64-bits into a uint64_t because
529
// AMD's shader compiler expects a 64-bit integer mask.
530
//
531
//     %var = OpLoad %v4uint %SubgroupLtMaskKHR
532
// %shuffle = OpVectorShuffle %v2uint %var %var 0 1
533
//    %cast = OpBitcast %ulong %shuffle
534
//
535
// Perform the mask and count the bits.
536
//
537
//     %and = OpBitwiseAnd %ulong %cast %mask
538
//  %result = OpBitCount %uint %and
539
//
540
// Also adding the capabilities and builtins that are needed.
541
bool ReplaceMbcnt(IRContext* context, Instruction* inst,
542
0
                  const std::vector<const analysis::Constant*>&) {
543
0
  analysis::TypeManager* type_mgr = context->get_type_mgr();
544
0
  analysis::DefUseManager* def_use_mgr = context->get_def_use_mgr();
545
546
0
  uint32_t var_id =
547
0
      context->GetBuiltinInputVarId(uint32_t(spv::BuiltIn::SubgroupLtMask));
548
0
  if (var_id == 0) return false;
549
550
0
  context->AddCapability(spv::Capability::GroupNonUniformBallot);
551
0
  Instruction* var_inst = def_use_mgr->GetDef(var_id);
552
0
  Instruction* var_ptr_type = def_use_mgr->GetDef(var_inst->type_id());
553
0
  Instruction* var_type =
554
0
      def_use_mgr->GetDef(var_ptr_type->GetSingleWordInOperand(1));
555
0
  assert(var_type->opcode() == spv::Op::OpTypeVector &&
556
0
         "Variable is suppose to be a vector of 4 ints");
557
558
  // Get the type for the shuffle.
559
0
  analysis::Vector temp_type(GetUIntType(context), 2);
560
0
  const analysis::Type* shuffle_type =
561
0
      context->get_type_mgr()->GetRegisteredType(&temp_type);
562
0
  if (shuffle_type == nullptr) return false;
563
0
  uint32_t shuffle_type_id = type_mgr->GetTypeInstruction(shuffle_type);
564
565
0
  uint32_t mask_id = inst->GetSingleWordInOperand(2);
566
0
  Instruction* mask_inst = def_use_mgr->GetDef(mask_id);
567
568
  // Testing with amd's shader compiler shows that a 64-bit mask is expected.
569
0
  assert(type_mgr->GetType(mask_inst->type_id())->AsInteger() != nullptr);
570
0
  assert(type_mgr->GetType(mask_inst->type_id())->AsInteger()->width() == 64);
571
572
0
  InstructionBuilder ir_builder(
573
0
      context, inst,
574
0
      IRContext::kAnalysisDefUse | IRContext::kAnalysisInstrToBlockMapping);
575
0
  Instruction* load = ir_builder.AddLoad(var_type->result_id(), var_id);
576
0
  if (load == nullptr) return false;
577
0
  Instruction* shuffle = ir_builder.AddVectorShuffle(
578
0
      shuffle_type_id, load->result_id(), load->result_id(), {0, 1});
579
0
  if (shuffle == nullptr) return false;
580
0
  Instruction* bitcast = ir_builder.AddUnaryOp(
581
0
      mask_inst->type_id(), spv::Op::OpBitcast, shuffle->result_id());
582
0
  if (bitcast == nullptr) return false;
583
0
  Instruction* t =
584
0
      ir_builder.AddBinaryOp(mask_inst->type_id(), spv::Op::OpBitwiseAnd,
585
0
                             bitcast->result_id(), mask_id);
586
0
  if (t == nullptr) return false;
587
588
0
  inst->SetOpcode(spv::Op::OpBitCount);
589
0
  inst->SetInOperands({{SPV_OPERAND_TYPE_ID, {t->result_id()}}});
590
0
  context->UpdateDefUse(inst);
591
0
  return true;
592
0
}
593
594
// A folding rule that will replace the CubeFaceCoordAMD extended
595
// instruction in the SPV_AMD_gcn_shader_ballot.  Returns true if the folding is
596
// successful.
597
//
598
// The instruction
599
//
600
//  %result = OpExtInst %v2float %1 CubeFaceCoordAMD %input
601
//
602
// with
603
//
604
//             %x = OpCompositeExtract %float %input 0
605
//             %y = OpCompositeExtract %float %input 1
606
//             %z = OpCompositeExtract %float %input 2
607
//            %nx = OpFNegate %float %x
608
//            %ny = OpFNegate %float %y
609
//            %nz = OpFNegate %float %z
610
//            %ax = OpExtInst %float %n_1 FAbs %x
611
//            %ay = OpExtInst %float %n_1 FAbs %y
612
//            %az = OpExtInst %float %n_1 FAbs %z
613
//      %amax_x_y = OpExtInst %float %n_1 FMax %ay %ax
614
//          %amax = OpExtInst %float %n_1 FMax %az %amax_x_y
615
//        %cubema = OpFMul %float %float_2 %amax
616
//      %is_z_max = OpFOrdGreaterThanEqual %bool %az %amax_x_y
617
//  %not_is_z_max = OpLogicalNot %bool %is_z_max
618
//        %y_gt_x = OpFOrdGreaterThanEqual %bool %ay %ax
619
//      %is_y_max = OpLogicalAnd %bool %not_is_z_max %y_gt_x
620
//      %is_z_neg = OpFOrdLessThan %bool %z %float_0
621
// %cubesc_case_1 = OpSelect %float %is_z_neg %nx %x
622
//      %is_x_neg = OpFOrdLessThan %bool %x %float_0
623
// %cubesc_case_2 = OpSelect %float %is_x_neg %z %nz
624
//           %sel = OpSelect %float %is_y_max %x %cubesc_case_2
625
//        %cubesc = OpSelect %float %is_z_max %cubesc_case_1 %sel
626
//      %is_y_neg = OpFOrdLessThan %bool %y %float_0
627
// %cubetc_case_1 = OpSelect %float %is_y_neg %nz %z
628
//        %cubetc = OpSelect %float %is_y_max %cubetc_case_1 %ny
629
//          %cube = OpCompositeConstruct %v2float %cubesc %cubetc
630
//         %denom = OpCompositeConstruct %v2float %cubema %cubema
631
//           %div = OpFDiv %v2float %cube %denom
632
//        %result = OpFAdd %v2float %div %const
633
//
634
// Also adding the capabilities and builtins that are needed.
635
bool ReplaceCubeFaceCoord(IRContext* ctx, Instruction* inst,
636
0
                          const std::vector<const analysis::Constant*>&) {
637
0
  analysis::TypeManager* type_mgr = ctx->get_type_mgr();
638
0
  analysis::ConstantManager* const_mgr = ctx->get_constant_mgr();
639
640
0
  uint32_t float_type_id = type_mgr->GetFloatTypeId();
641
0
  if (float_type_id == 0) return false;
642
0
  const analysis::Type* v2_float_type = type_mgr->GetFloatVectorType(2);
643
0
  if (v2_float_type == nullptr) return false;
644
0
  uint32_t v2_float_type_id = type_mgr->GetId(v2_float_type);
645
0
  if (v2_float_type_id == 0) return false;
646
0
  uint32_t bool_id = type_mgr->GetBoolTypeId();
647
0
  if (bool_id == 0) return false;
648
649
0
  InstructionBuilder ir_builder(
650
0
      ctx, inst,
651
0
      IRContext::kAnalysisDefUse | IRContext::kAnalysisInstrToBlockMapping);
652
653
0
  uint32_t input_id = inst->GetSingleWordInOperand(2);
654
0
  uint32_t glsl405_ext_inst_id =
655
0
      ctx->get_feature_mgr()->GetExtInstImportId_GLSLstd450();
656
0
  if (glsl405_ext_inst_id == 0) {
657
0
    ctx->AddExtInstImport("GLSL.std.450");
658
0
    glsl405_ext_inst_id =
659
0
        ctx->get_feature_mgr()->GetExtInstImportId_GLSLstd450();
660
0
  }
661
0
  if (glsl405_ext_inst_id == 0) return false;
662
663
  // Get the constants that will be used.
664
0
  uint32_t f0_const_id = const_mgr->GetFloatConstId(0.0);
665
0
  if (f0_const_id == 0) return false;
666
0
  uint32_t f2_const_id = const_mgr->GetFloatConstId(2.0);
667
0
  if (f2_const_id == 0) return false;
668
0
  uint32_t f0_5_const_id = const_mgr->GetFloatConstId(0.5);
669
0
  if (f0_5_const_id == 0) return false;
670
0
  const analysis::Constant* vec_const =
671
0
      const_mgr->GetConstant(v2_float_type, {f0_5_const_id, f0_5_const_id});
672
0
  if (vec_const == nullptr) return false;
673
0
  Instruction* vec_const_inst = const_mgr->GetDefiningInstruction(vec_const);
674
0
  if (vec_const_inst == nullptr) return false;
675
0
  uint32_t vec_const_id = vec_const_inst->result_id();
676
677
  // Extract the input values.
678
0
  Instruction* x = ir_builder.AddCompositeExtract(float_type_id, input_id, {0});
679
0
  if (x == nullptr) return false;
680
0
  Instruction* y = ir_builder.AddCompositeExtract(float_type_id, input_id, {1});
681
0
  if (y == nullptr) return false;
682
0
  Instruction* z = ir_builder.AddCompositeExtract(float_type_id, input_id, {2});
683
0
  if (z == nullptr) return false;
684
685
  // Negate the input values.
686
0
  Instruction* nx =
687
0
      ir_builder.AddUnaryOp(float_type_id, spv::Op::OpFNegate, x->result_id());
688
0
  Instruction* ny =
689
0
      ir_builder.AddUnaryOp(float_type_id, spv::Op::OpFNegate, y->result_id());
690
0
  Instruction* nz =
691
0
      ir_builder.AddUnaryOp(float_type_id, spv::Op::OpFNegate, z->result_id());
692
0
  if (nx == nullptr) return false;
693
0
  if (ny == nullptr) return false;
694
0
  if (nz == nullptr) return false;
695
696
  // Get the abolsute values of the inputs.
697
0
  Instruction* ax = ir_builder.AddNaryExtendedInstruction(
698
0
      float_type_id, glsl405_ext_inst_id, GLSLstd450FAbs, {x->result_id()});
699
0
  if (ax == nullptr) return false;
700
0
  Instruction* ay = ir_builder.AddNaryExtendedInstruction(
701
0
      float_type_id, glsl405_ext_inst_id, GLSLstd450FAbs, {y->result_id()});
702
0
  if (ay == nullptr) return false;
703
0
  Instruction* az = ir_builder.AddNaryExtendedInstruction(
704
0
      float_type_id, glsl405_ext_inst_id, GLSLstd450FAbs, {z->result_id()});
705
0
  if (az == nullptr) return false;
706
707
  // Find which values are negative.  Used in later computations.
708
0
  Instruction* is_z_neg = ir_builder.AddBinaryOp(
709
0
      bool_id, spv::Op::OpFOrdLessThan, z->result_id(), f0_const_id);
710
0
  if (is_z_neg == nullptr) return false;
711
0
  Instruction* is_y_neg = ir_builder.AddBinaryOp(
712
0
      bool_id, spv::Op::OpFOrdLessThan, y->result_id(), f0_const_id);
713
0
  if (is_y_neg == nullptr) return false;
714
0
  Instruction* is_x_neg = ir_builder.AddBinaryOp(
715
0
      bool_id, spv::Op::OpFOrdLessThan, x->result_id(), f0_const_id);
716
0
  if (is_x_neg == nullptr) return false;
717
718
  // Compute cubema
719
0
  Instruction* amax_x_y = ir_builder.AddNaryExtendedInstruction(
720
0
      float_type_id, glsl405_ext_inst_id, GLSLstd450FMax,
721
0
      {ax->result_id(), ay->result_id()});
722
0
  if (amax_x_y == nullptr) return false;
723
0
  Instruction* amax = ir_builder.AddNaryExtendedInstruction(
724
0
      float_type_id, glsl405_ext_inst_id, GLSLstd450FMax,
725
0
      {az->result_id(), amax_x_y->result_id()});
726
0
  if (amax == nullptr) return false;
727
0
  Instruction* cubema = ir_builder.AddBinaryOp(float_type_id, spv::Op::OpFMul,
728
0
                                               f2_const_id, amax->result_id());
729
0
  if (cubema == nullptr) return false;
730
731
  // Do the comparisons needed for computing cubesc and cubetc.
732
0
  Instruction* is_z_max =
733
0
      ir_builder.AddBinaryOp(bool_id, spv::Op::OpFOrdGreaterThanEqual,
734
0
                             az->result_id(), amax_x_y->result_id());
735
0
  if (is_z_max == nullptr) return false;
736
0
  Instruction* not_is_z_max = ir_builder.AddUnaryOp(
737
0
      bool_id, spv::Op::OpLogicalNot, is_z_max->result_id());
738
0
  if (not_is_z_max == nullptr) return false;
739
0
  Instruction* y_gr_x =
740
0
      ir_builder.AddBinaryOp(bool_id, spv::Op::OpFOrdGreaterThanEqual,
741
0
                             ay->result_id(), ax->result_id());
742
0
  if (y_gr_x == nullptr) return false;
743
0
  Instruction* is_y_max =
744
0
      ir_builder.AddBinaryOp(bool_id, spv::Op::OpLogicalAnd,
745
0
                             not_is_z_max->result_id(), y_gr_x->result_id());
746
0
  if (is_y_max == nullptr) return false;
747
748
  // Select the correct value for cubesc.
749
0
  Instruction* cubesc_case_1 = ir_builder.AddSelect(
750
0
      float_type_id, is_z_neg->result_id(), nx->result_id(), x->result_id());
751
0
  if (cubesc_case_1 == nullptr) return false;
752
0
  Instruction* cubesc_case_2 = ir_builder.AddSelect(
753
0
      float_type_id, is_x_neg->result_id(), z->result_id(), nz->result_id());
754
0
  if (cubesc_case_2 == nullptr) return false;
755
0
  Instruction* sel =
756
0
      ir_builder.AddSelect(float_type_id, is_y_max->result_id(), x->result_id(),
757
0
                           cubesc_case_2->result_id());
758
0
  if (sel == nullptr) return false;
759
0
  Instruction* cubesc =
760
0
      ir_builder.AddSelect(float_type_id, is_z_max->result_id(),
761
0
                           cubesc_case_1->result_id(), sel->result_id());
762
0
  if (cubesc == nullptr) return false;
763
764
  // Select the correct value for cubetc.
765
0
  Instruction* cubetc_case_1 = ir_builder.AddSelect(
766
0
      float_type_id, is_y_neg->result_id(), nz->result_id(), z->result_id());
767
0
  if (cubetc_case_1 == nullptr) return false;
768
0
  Instruction* cubetc =
769
0
      ir_builder.AddSelect(float_type_id, is_y_max->result_id(),
770
0
                           cubetc_case_1->result_id(), ny->result_id());
771
0
  if (cubetc == nullptr) return false;
772
773
  // Do the division
774
0
  Instruction* cube = ir_builder.AddCompositeConstruct(
775
0
      v2_float_type_id, {cubesc->result_id(), cubetc->result_id()});
776
0
  if (cube == nullptr) return false;
777
0
  Instruction* denom = ir_builder.AddCompositeConstruct(
778
0
      v2_float_type_id, {cubema->result_id(), cubema->result_id()});
779
0
  if (denom == nullptr) return false;
780
0
  Instruction* div = ir_builder.AddBinaryOp(
781
0
      v2_float_type_id, spv::Op::OpFDiv, cube->result_id(), denom->result_id());
782
0
  if (div == nullptr) return false;
783
784
  // Get the final result by adding 0.5 to |div|.
785
0
  inst->SetOpcode(spv::Op::OpFAdd);
786
0
  Instruction::OperandList new_operands;
787
0
  new_operands.push_back({SPV_OPERAND_TYPE_ID, {div->result_id()}});
788
0
  new_operands.push_back({SPV_OPERAND_TYPE_ID, {vec_const_id}});
789
790
0
  inst->SetInOperands(std::move(new_operands));
791
0
  ctx->UpdateDefUse(inst);
792
0
  return true;
793
0
}
794
795
// A folding rule that will replace the CubeFaceIndexAMD extended
796
// instruction in the SPV_AMD_gcn_shader_ballot.  Returns true if the folding
797
// is successful.
798
//
799
// The instruction
800
//
801
//  %result = OpExtInst %float %1 CubeFaceIndexAMD %input
802
//
803
// with
804
//
805
//             %x = OpCompositeExtract %float %input 0
806
//             %y = OpCompositeExtract %float %input 1
807
//             %z = OpCompositeExtract %float %input 2
808
//            %ax = OpExtInst %float %n_1 FAbs %x
809
//            %ay = OpExtInst %float %n_1 FAbs %y
810
//            %az = OpExtInst %float %n_1 FAbs %z
811
//      %is_z_neg = OpFOrdLessThan %bool %z %float_0
812
//      %is_y_neg = OpFOrdLessThan %bool %y %float_0
813
//      %is_x_neg = OpFOrdLessThan %bool %x %float_0
814
//      %amax_x_y = OpExtInst %float %n_1 FMax %ax %ay
815
//      %is_z_max = OpFOrdGreaterThanEqual %bool %az %amax_x_y
816
//        %y_gt_x = OpFOrdGreaterThanEqual %bool %ay %ax
817
//        %case_z = OpSelect %float %is_z_neg %float_5 %float4
818
//        %case_y = OpSelect %float %is_y_neg %float_3 %float2
819
//        %case_x = OpSelect %float %is_x_neg %float_1 %float0
820
//           %sel = OpSelect %float %y_gt_x %case_y %case_x
821
//        %result = OpSelect %float %is_z_max %case_z %sel
822
//
823
// Also adding the capabilities and builtins that are needed.
824
bool ReplaceCubeFaceIndex(IRContext* ctx, Instruction* inst,
825
0
                          const std::vector<const analysis::Constant*>&) {
826
0
  analysis::TypeManager* type_mgr = ctx->get_type_mgr();
827
0
  analysis::ConstantManager* const_mgr = ctx->get_constant_mgr();
828
829
0
  uint32_t float_type_id = type_mgr->GetFloatTypeId();
830
0
  if (float_type_id == 0) return false;
831
0
  uint32_t bool_id = type_mgr->GetBoolTypeId();
832
0
  if (bool_id == 0) return false;
833
834
0
  InstructionBuilder ir_builder(
835
0
      ctx, inst,
836
0
      IRContext::kAnalysisDefUse | IRContext::kAnalysisInstrToBlockMapping);
837
838
0
  uint32_t input_id = inst->GetSingleWordInOperand(2);
839
0
  uint32_t glsl405_ext_inst_id =
840
0
      ctx->get_feature_mgr()->GetExtInstImportId_GLSLstd450();
841
0
  if (glsl405_ext_inst_id == 0) {
842
0
    ctx->AddExtInstImport("GLSL.std.450");
843
0
    glsl405_ext_inst_id =
844
0
        ctx->get_feature_mgr()->GetExtInstImportId_GLSLstd450();
845
0
  }
846
847
  // Get the constants that will be used.
848
0
  uint32_t f0_const_id = const_mgr->GetFloatConstId(0.0);
849
0
  if (f0_const_id == 0) return false;
850
0
  uint32_t f1_const_id = const_mgr->GetFloatConstId(1.0);
851
0
  if (f1_const_id == 0) return false;
852
0
  uint32_t f2_const_id = const_mgr->GetFloatConstId(2.0);
853
0
  if (f2_const_id == 0) return false;
854
0
  uint32_t f3_const_id = const_mgr->GetFloatConstId(3.0);
855
0
  if (f3_const_id == 0) return false;
856
0
  uint32_t f4_const_id = const_mgr->GetFloatConstId(4.0);
857
0
  if (f4_const_id == 0) return false;
858
0
  uint32_t f5_const_id = const_mgr->GetFloatConstId(5.0);
859
0
  if (f5_const_id == 0) return false;
860
861
  // Extract the input values.
862
0
  Instruction* x = ir_builder.AddCompositeExtract(float_type_id, input_id, {0});
863
0
  if (x == nullptr) return false;
864
0
  Instruction* y = ir_builder.AddCompositeExtract(float_type_id, input_id, {1});
865
0
  if (y == nullptr) return false;
866
  // TODO(1-841): Handle id overflow.
867
0
  Instruction* z = ir_builder.AddCompositeExtract(float_type_id, input_id, {2});
868
0
  if (z == nullptr) return false;
869
870
  // Get the absolute values of the inputs.
871
0
  Instruction* ax = ir_builder.AddNaryExtendedInstruction(
872
0
      float_type_id, glsl405_ext_inst_id, GLSLstd450FAbs, {x->result_id()});
873
0
  if (ax == nullptr) return false;
874
0
  Instruction* ay = ir_builder.AddNaryExtendedInstruction(
875
0
      float_type_id, glsl405_ext_inst_id, GLSLstd450FAbs, {y->result_id()});
876
0
  if (ay == nullptr) return false;
877
0
  Instruction* az = ir_builder.AddNaryExtendedInstruction(
878
0
      float_type_id, glsl405_ext_inst_id, GLSLstd450FAbs, {z->result_id()});
879
0
  if (az == nullptr) return false;
880
881
  // Find which values are negative.  Used in later computations.
882
0
  Instruction* is_z_neg = ir_builder.AddBinaryOp(
883
0
      bool_id, spv::Op::OpFOrdLessThan, z->result_id(), f0_const_id);
884
0
  if (is_z_neg == nullptr) return false;
885
0
  Instruction* is_y_neg = ir_builder.AddBinaryOp(
886
0
      bool_id, spv::Op::OpFOrdLessThan, y->result_id(), f0_const_id);
887
0
  if (is_y_neg == nullptr) return false;
888
0
  Instruction* is_x_neg = ir_builder.AddBinaryOp(
889
0
      bool_id, spv::Op::OpFOrdLessThan, x->result_id(), f0_const_id);
890
0
  if (is_x_neg == nullptr) return false;
891
892
  // Find the max value.
893
0
  Instruction* amax_x_y = ir_builder.AddNaryExtendedInstruction(
894
0
      float_type_id, glsl405_ext_inst_id, GLSLstd450FMax,
895
0
      {ax->result_id(), ay->result_id()});
896
0
  if (amax_x_y == nullptr) return false;
897
0
  Instruction* is_z_max =
898
0
      ir_builder.AddBinaryOp(bool_id, spv::Op::OpFOrdGreaterThanEqual,
899
0
                             az->result_id(), amax_x_y->result_id());
900
0
  if (is_z_max == nullptr) return false;
901
0
  Instruction* y_gr_x =
902
0
      ir_builder.AddBinaryOp(bool_id, spv::Op::OpFOrdGreaterThanEqual,
903
0
                             ay->result_id(), ax->result_id());
904
0
  if (y_gr_x == nullptr) return false;
905
906
  // Get the value for each case.
907
0
  Instruction* case_z = ir_builder.AddSelect(
908
0
      float_type_id, is_z_neg->result_id(), f5_const_id, f4_const_id);
909
0
  if (case_z == nullptr) return false;
910
0
  Instruction* case_y = ir_builder.AddSelect(
911
0
      float_type_id, is_y_neg->result_id(), f3_const_id, f2_const_id);
912
0
  if (case_y == nullptr) return false;
913
0
  Instruction* case_x = ir_builder.AddSelect(
914
0
      float_type_id, is_x_neg->result_id(), f1_const_id, f0_const_id);
915
0
  if (case_x == nullptr) return false;
916
917
  // Select the correct case.
918
0
  Instruction* sel =
919
0
      ir_builder.AddSelect(float_type_id, y_gr_x->result_id(),
920
0
                           case_y->result_id(), case_x->result_id());
921
0
  if (sel == nullptr) return false;
922
923
  // Get the final result by adding 0.5 to |div|.
924
0
  inst->SetOpcode(spv::Op::OpSelect);
925
0
  Instruction::OperandList new_operands;
926
0
  new_operands.push_back({SPV_OPERAND_TYPE_ID, {is_z_max->result_id()}});
927
0
  new_operands.push_back({SPV_OPERAND_TYPE_ID, {case_z->result_id()}});
928
0
  new_operands.push_back({SPV_OPERAND_TYPE_ID, {sel->result_id()}});
929
930
0
  inst->SetInOperands(std::move(new_operands));
931
0
  ctx->UpdateDefUse(inst);
932
0
  return true;
933
0
}
934
935
// A folding rule that will replace the TimeAMD extended instruction in the
936
// SPV_AMD_gcn_shader_ballot.  It returns true if the folding is successful.
937
// It returns False, otherwise.
938
//
939
// The instruction
940
//
941
//  %result = OpExtInst %uint64 %1 TimeAMD
942
//
943
// with
944
//
945
//  %result = OpReadClockKHR %uint64 %uint_3
946
//
947
// NOTE: TimeAMD uses subgroup scope (it is not a real time clock).
948
bool ReplaceTimeAMD(IRContext* ctx, Instruction* inst,
949
0
                    const std::vector<const analysis::Constant*>&) {
950
0
  InstructionBuilder ir_builder(
951
0
      ctx, inst,
952
0
      IRContext::kAnalysisDefUse | IRContext::kAnalysisInstrToBlockMapping);
953
0
  ctx->AddExtension("SPV_KHR_shader_clock");
954
0
  ctx->AddCapability(spv::Capability::ShaderClockKHR);
955
956
0
  inst->SetOpcode(spv::Op::OpReadClockKHR);
957
0
  Instruction::OperandList args;
958
0
  uint32_t subgroup_scope_id =
959
0
      ir_builder.GetUintConstantId(uint32_t(spv::Scope::Subgroup));
960
0
  args.push_back({SPV_OPERAND_TYPE_ID, {subgroup_scope_id}});
961
0
  inst->SetInOperands(std::move(args));
962
0
  ctx->UpdateDefUse(inst);
963
964
0
  return true;
965
0
}
966
967
class AmdExtFoldingRules : public FoldingRules {
968
 public:
969
0
  explicit AmdExtFoldingRules(IRContext* ctx) : FoldingRules(ctx) {}
970
971
 protected:
972
0
  virtual void AddFoldingRules() override {
973
0
    rules_[spv::Op::OpGroupIAddNonUniformAMD].push_back(
974
0
        ReplaceGroupNonuniformOperationOpCode<spv::Op::OpGroupNonUniformIAdd>);
975
0
    rules_[spv::Op::OpGroupFAddNonUniformAMD].push_back(
976
0
        ReplaceGroupNonuniformOperationOpCode<spv::Op::OpGroupNonUniformFAdd>);
977
0
    rules_[spv::Op::OpGroupUMinNonUniformAMD].push_back(
978
0
        ReplaceGroupNonuniformOperationOpCode<spv::Op::OpGroupNonUniformUMin>);
979
0
    rules_[spv::Op::OpGroupSMinNonUniformAMD].push_back(
980
0
        ReplaceGroupNonuniformOperationOpCode<spv::Op::OpGroupNonUniformSMin>);
981
0
    rules_[spv::Op::OpGroupFMinNonUniformAMD].push_back(
982
0
        ReplaceGroupNonuniformOperationOpCode<spv::Op::OpGroupNonUniformFMin>);
983
0
    rules_[spv::Op::OpGroupUMaxNonUniformAMD].push_back(
984
0
        ReplaceGroupNonuniformOperationOpCode<spv::Op::OpGroupNonUniformUMax>);
985
0
    rules_[spv::Op::OpGroupSMaxNonUniformAMD].push_back(
986
0
        ReplaceGroupNonuniformOperationOpCode<spv::Op::OpGroupNonUniformSMax>);
987
0
    rules_[spv::Op::OpGroupFMaxNonUniformAMD].push_back(
988
0
        ReplaceGroupNonuniformOperationOpCode<spv::Op::OpGroupNonUniformFMax>);
989
990
0
    uint32_t extension_id =
991
0
        context()->module()->GetExtInstImportId("SPV_AMD_shader_ballot");
992
993
0
    if (extension_id != 0) {
994
0
      ext_rules_[{extension_id, AmdShaderBallotSwizzleInvocationsAMD}]
995
0
          .push_back(ReplaceSwizzleInvocations);
996
0
      ext_rules_[{extension_id, AmdShaderBallotSwizzleInvocationsMaskedAMD}]
997
0
          .push_back(ReplaceSwizzleInvocationsMasked);
998
0
      ext_rules_[{extension_id, AmdShaderBallotWriteInvocationAMD}].push_back(
999
0
          ReplaceWriteInvocation);
1000
0
      ext_rules_[{extension_id, AmdShaderBallotMbcntAMD}].push_back(
1001
0
          ReplaceMbcnt);
1002
0
    }
1003
1004
0
    extension_id = context()->module()->GetExtInstImportId(
1005
0
        "SPV_AMD_shader_trinary_minmax");
1006
1007
0
    if (extension_id != 0) {
1008
0
      ext_rules_[{extension_id, FMin3AMD}].push_back(
1009
0
          ReplaceTrinaryMinMax<GLSLstd450FMin>);
1010
0
      ext_rules_[{extension_id, UMin3AMD}].push_back(
1011
0
          ReplaceTrinaryMinMax<GLSLstd450UMin>);
1012
0
      ext_rules_[{extension_id, SMin3AMD}].push_back(
1013
0
          ReplaceTrinaryMinMax<GLSLstd450SMin>);
1014
0
      ext_rules_[{extension_id, FMax3AMD}].push_back(
1015
0
          ReplaceTrinaryMinMax<GLSLstd450FMax>);
1016
0
      ext_rules_[{extension_id, UMax3AMD}].push_back(
1017
0
          ReplaceTrinaryMinMax<GLSLstd450UMax>);
1018
0
      ext_rules_[{extension_id, SMax3AMD}].push_back(
1019
0
          ReplaceTrinaryMinMax<GLSLstd450SMax>);
1020
0
      ext_rules_[{extension_id, FMid3AMD}].push_back(
1021
0
          ReplaceTrinaryMid<GLSLstd450FMin, GLSLstd450FMax, GLSLstd450FClamp>);
1022
0
      ext_rules_[{extension_id, UMid3AMD}].push_back(
1023
0
          ReplaceTrinaryMid<GLSLstd450UMin, GLSLstd450UMax, GLSLstd450UClamp>);
1024
0
      ext_rules_[{extension_id, SMid3AMD}].push_back(
1025
0
          ReplaceTrinaryMid<GLSLstd450SMin, GLSLstd450SMax, GLSLstd450SClamp>);
1026
0
    }
1027
1028
0
    extension_id =
1029
0
        context()->module()->GetExtInstImportId("SPV_AMD_gcn_shader");
1030
1031
0
    if (extension_id != 0) {
1032
0
      ext_rules_[{extension_id, CubeFaceCoordAMD}].push_back(
1033
0
          ReplaceCubeFaceCoord);
1034
0
      ext_rules_[{extension_id, CubeFaceIndexAMD}].push_back(
1035
0
          ReplaceCubeFaceIndex);
1036
0
      ext_rules_[{extension_id, TimeAMD}].push_back(ReplaceTimeAMD);
1037
0
    }
1038
0
  }
1039
};
1040
1041
class AmdExtConstFoldingRules : public ConstantFoldingRules {
1042
 public:
1043
0
  AmdExtConstFoldingRules(IRContext* ctx) : ConstantFoldingRules(ctx) {}
1044
1045
 protected:
1046
0
  virtual void AddFoldingRules() override {}
1047
};
1048
1049
}  // namespace
1050
1051
0
Pass::Status AmdExtensionToKhrPass::Process() {
1052
0
  bool changed = false;
1053
1054
  // Traverse the body of the functions to replace instructions that require
1055
  // the extensions.
1056
0
  InstructionFolder folder(
1057
0
      context(),
1058
0
      std::unique_ptr<AmdExtFoldingRules>(new AmdExtFoldingRules(context())),
1059
0
      MakeUnique<AmdExtConstFoldingRules>(context()));
1060
0
  for (Function& func : *get_module()) {
1061
0
    bool failed =
1062
0
        !func.WhileEachInst([&changed, &folder, this](Instruction* inst) {
1063
0
          if (folder.FoldInstruction(inst)) {
1064
0
            changed = true;
1065
0
            return true;
1066
0
          } else if (context()->id_overflow()) {
1067
0
            return false;
1068
0
          }
1069
0
          return true;
1070
0
        });
1071
1072
0
    if (failed) return Status::Failure;
1073
0
  }
1074
1075
  // Now that instruction that require the extensions have been removed, we can
1076
  // remove the extension instructions.
1077
0
  std::set<std::string> ext_to_remove = {"SPV_AMD_shader_ballot",
1078
0
                                         "SPV_AMD_shader_trinary_minmax",
1079
0
                                         "SPV_AMD_gcn_shader"};
1080
1081
0
  std::vector<Instruction*> to_be_killed;
1082
0
  for (Instruction& inst : context()->module()->extensions()) {
1083
0
    if (inst.opcode() == spv::Op::OpExtension) {
1084
0
      if (ext_to_remove.count(inst.GetInOperand(0).AsString()) != 0) {
1085
0
        to_be_killed.push_back(&inst);
1086
0
      }
1087
0
    }
1088
0
  }
1089
1090
0
  for (Instruction& inst : context()->ext_inst_imports()) {
1091
0
    if (inst.opcode() == spv::Op::OpExtInstImport) {
1092
0
      if (ext_to_remove.count(inst.GetInOperand(0).AsString()) != 0) {
1093
0
        to_be_killed.push_back(&inst);
1094
0
      }
1095
0
    }
1096
0
  }
1097
1098
0
  for (Instruction* inst : to_be_killed) {
1099
0
    context()->KillInst(inst);
1100
0
    changed = true;
1101
0
  }
1102
1103
  // The replacements that take place use instructions that are missing before
1104
  // SPIR-V 1.3. If we changed something, we will have to make sure the version
1105
  // is at least SPIR-V 1.3 to make sure those instruction can be used.
1106
0
  if (changed) {
1107
0
    uint32_t version = get_module()->version();
1108
0
    if (version < 0x00010300 /*1.3*/) {
1109
0
      get_module()->set_version(0x00010300);
1110
0
    }
1111
0
  }
1112
0
  return changed ? Status::SuccessWithChange : Status::SuccessWithoutChange;
1113
0
}
1114
1115
}  // namespace opt
1116
}  // namespace spvtools