/src/spirv-tools/source/opt/wrap_opkill.cpp
Line | Count | Source (jump to first uncovered line) |
1 | | // Copyright (c) 2019 Google LLC |
2 | | // |
3 | | // Licensed under the Apache License, Version 2.0 (the "License"); |
4 | | // you may not use this file except in compliance with the License. |
5 | | // You may obtain a copy of the License at |
6 | | // |
7 | | // http://www.apache.org/licenses/LICENSE-2.0 |
8 | | // |
9 | | // Unless required by applicable law or agreed to in writing, software |
10 | | // distributed under the License is distributed on an "AS IS" BASIS, |
11 | | // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. |
12 | | // See the License for the specific language governing permissions and |
13 | | // limitations under the License. |
14 | | |
15 | | #include "source/opt/wrap_opkill.h" |
16 | | |
17 | | #include "ir_builder.h" |
18 | | |
19 | | namespace spvtools { |
20 | | namespace opt { |
21 | | |
22 | 14.8k | Pass::Status WrapOpKill::Process() { |
23 | 14.8k | bool modified = false; |
24 | | |
25 | 14.8k | auto func_to_process = |
26 | 14.8k | context()->GetStructuredCFGAnalysis()->FindFuncsCalledFromContinue(); |
27 | 14.8k | for (uint32_t func_id : func_to_process) { |
28 | 57 | Function* func = context()->GetFunction(func_id); |
29 | 1.89k | bool successful = func->WhileEachInst([this, &modified](Instruction* inst) { |
30 | 1.89k | const auto opcode = inst->opcode(); |
31 | 1.89k | if ((opcode == spv::Op::OpKill) || |
32 | 1.89k | (opcode == spv::Op::OpTerminateInvocation)) { |
33 | 10 | modified = true; |
34 | 10 | if (!ReplaceWithFunctionCall(inst)) { |
35 | 0 | return false; |
36 | 0 | } |
37 | 10 | } |
38 | 1.89k | return true; |
39 | 1.89k | }); |
40 | | |
41 | 57 | if (!successful) { |
42 | 0 | return Status::Failure; |
43 | 0 | } |
44 | 57 | } |
45 | | |
46 | 14.8k | if (opkill_function_ != nullptr) { |
47 | 10 | assert(modified && |
48 | 10 | "The function should only be generated if something was modified."); |
49 | 10 | context()->AddFunction(std::move(opkill_function_)); |
50 | 10 | } |
51 | 14.8k | if (opterminateinvocation_function_ != nullptr) { |
52 | 0 | assert(modified && |
53 | 0 | "The function should only be generated if something was modified."); |
54 | 0 | context()->AddFunction(std::move(opterminateinvocation_function_)); |
55 | 0 | } |
56 | 14.8k | return (modified ? Status::SuccessWithChange : Status::SuccessWithoutChange); |
57 | 14.8k | } |
58 | | |
59 | 10 | bool WrapOpKill::ReplaceWithFunctionCall(Instruction* inst) { |
60 | 10 | assert((inst->opcode() == spv::Op::OpKill || |
61 | 10 | inst->opcode() == spv::Op::OpTerminateInvocation) && |
62 | 10 | "|inst| must be an OpKill or OpTerminateInvocation instruction."); |
63 | 10 | InstructionBuilder ir_builder( |
64 | 10 | context(), inst, |
65 | 10 | IRContext::kAnalysisDefUse | IRContext::kAnalysisInstrToBlockMapping); |
66 | 10 | uint32_t func_id = GetKillingFuncId(inst->opcode()); |
67 | 10 | if (func_id == 0) { |
68 | 0 | return false; |
69 | 0 | } |
70 | 10 | Instruction* call_inst = |
71 | 10 | ir_builder.AddFunctionCall(GetVoidTypeId(), func_id, {}); |
72 | 10 | if (call_inst == nullptr) { |
73 | 0 | return false; |
74 | 0 | } |
75 | 10 | call_inst->UpdateDebugInfoFrom(inst); |
76 | | |
77 | 10 | Instruction* return_inst = nullptr; |
78 | 10 | uint32_t return_type_id = GetOwningFunctionsReturnType(inst); |
79 | 10 | if (return_type_id != GetVoidTypeId()) { |
80 | 10 | Instruction* undef = |
81 | 10 | ir_builder.AddNullaryOp(return_type_id, spv::Op::OpUndef); |
82 | 10 | if (undef == nullptr) { |
83 | 0 | return false; |
84 | 0 | } |
85 | 10 | return_inst = |
86 | 10 | ir_builder.AddUnaryOp(0, spv::Op::OpReturnValue, undef->result_id()); |
87 | 10 | } else { |
88 | 0 | return_inst = ir_builder.AddNullaryOp(0, spv::Op::OpReturn); |
89 | 0 | } |
90 | | |
91 | 10 | if (return_inst == nullptr) { |
92 | 0 | return false; |
93 | 0 | } |
94 | | |
95 | 10 | context()->KillInst(inst); |
96 | 10 | return true; |
97 | 10 | } |
98 | | |
99 | 30 | uint32_t WrapOpKill::GetVoidTypeId() { |
100 | 30 | if (void_type_id_ != 0) { |
101 | 20 | return void_type_id_; |
102 | 20 | } |
103 | | |
104 | 10 | analysis::TypeManager* type_mgr = context()->get_type_mgr(); |
105 | 10 | analysis::Void void_type; |
106 | 10 | void_type_id_ = type_mgr->GetTypeInstruction(&void_type); |
107 | 10 | return void_type_id_; |
108 | 30 | } |
109 | | |
110 | 10 | uint32_t WrapOpKill::GetVoidFunctionTypeId() { |
111 | 10 | analysis::TypeManager* type_mgr = context()->get_type_mgr(); |
112 | 10 | analysis::Void void_type; |
113 | 10 | const analysis::Type* registered_void_type = |
114 | 10 | type_mgr->GetRegisteredType(&void_type); |
115 | | |
116 | 10 | analysis::Function func_type(registered_void_type, {}); |
117 | 10 | return type_mgr->GetTypeInstruction(&func_type); |
118 | 10 | } |
119 | | |
120 | 10 | uint32_t WrapOpKill::GetKillingFuncId(spv::Op opcode) { |
121 | | // Parameterize by opcode |
122 | 10 | assert(opcode == spv::Op::OpKill || opcode == spv::Op::OpTerminateInvocation); |
123 | | |
124 | 10 | std::unique_ptr<Function>* const killing_func = |
125 | 10 | (opcode == spv::Op::OpKill) ? &opkill_function_ |
126 | 10 | : &opterminateinvocation_function_; |
127 | | |
128 | 10 | if (*killing_func != nullptr) { |
129 | 0 | return (*killing_func)->result_id(); |
130 | 0 | } |
131 | | |
132 | 10 | uint32_t killing_func_id = TakeNextId(); |
133 | 10 | if (killing_func_id == 0) { |
134 | 0 | return 0; |
135 | 0 | } |
136 | | |
137 | 10 | uint32_t void_type_id = GetVoidTypeId(); |
138 | 10 | if (void_type_id == 0) { |
139 | 0 | return 0; |
140 | 0 | } |
141 | | |
142 | | // Generate the function start instruction |
143 | 10 | std::unique_ptr<Instruction> func_start(new Instruction( |
144 | 10 | context(), spv::Op::OpFunction, void_type_id, killing_func_id, {})); |
145 | 10 | func_start->AddOperand({SPV_OPERAND_TYPE_FUNCTION_CONTROL, {0}}); |
146 | 10 | func_start->AddOperand({SPV_OPERAND_TYPE_ID, {GetVoidFunctionTypeId()}}); |
147 | 10 | (*killing_func).reset(new Function(std::move(func_start))); |
148 | | |
149 | | // Generate the function end instruction |
150 | 10 | std::unique_ptr<Instruction> func_end( |
151 | 10 | new Instruction(context(), spv::Op::OpFunctionEnd, 0, 0, {})); |
152 | 10 | (*killing_func)->SetFunctionEnd(std::move(func_end)); |
153 | | |
154 | | // Create the one basic block for the function. |
155 | 10 | uint32_t lab_id = TakeNextId(); |
156 | 10 | if (lab_id == 0) { |
157 | 0 | return 0; |
158 | 0 | } |
159 | 10 | std::unique_ptr<Instruction> label_inst( |
160 | 10 | new Instruction(context(), spv::Op::OpLabel, 0, lab_id, {})); |
161 | 10 | std::unique_ptr<BasicBlock> bb(new BasicBlock(std::move(label_inst))); |
162 | | |
163 | | // Add the OpKill to the basic block |
164 | 10 | std::unique_ptr<Instruction> kill_inst( |
165 | 10 | new Instruction(context(), opcode, 0, 0, {})); |
166 | 10 | bb->AddInstruction(std::move(kill_inst)); |
167 | | |
168 | | // Add the bb to the function |
169 | 10 | (*killing_func)->AddBasicBlock(std::move(bb)); |
170 | | |
171 | | // Add the function to the module. |
172 | 10 | if (context()->AreAnalysesValid(IRContext::kAnalysisDefUse)) { |
173 | 0 | (*killing_func)->ForEachInst([this](Instruction* inst) { |
174 | 0 | context()->AnalyzeDefUse(inst); |
175 | 0 | }); |
176 | 0 | } |
177 | | |
178 | 10 | if (context()->AreAnalysesValid(IRContext::kAnalysisInstrToBlockMapping)) { |
179 | 10 | for (BasicBlock& basic_block : *(*killing_func)) { |
180 | 10 | context()->set_instr_block(basic_block.GetLabelInst(), &basic_block); |
181 | 10 | for (Instruction& inst : basic_block) { |
182 | 10 | context()->set_instr_block(&inst, &basic_block); |
183 | 10 | } |
184 | 10 | } |
185 | 10 | } |
186 | | |
187 | 10 | return (*killing_func)->result_id(); |
188 | 10 | } |
189 | | |
190 | 10 | uint32_t WrapOpKill::GetOwningFunctionsReturnType(Instruction* inst) { |
191 | 10 | BasicBlock* bb = context()->get_instr_block(inst); |
192 | 10 | if (bb == nullptr) { |
193 | 0 | return 0; |
194 | 0 | } |
195 | | |
196 | 10 | Function* func = bb->GetParent(); |
197 | 10 | return func->type_id(); |
198 | 10 | } |
199 | | |
200 | | } // namespace opt |
201 | | } // namespace spvtools |