/src/spirv-tools/source/opt/private_to_local_pass.cpp
Line | Count | Source (jump to first uncovered line) |
1 | | // Copyright (c) 2017 Google Inc. |
2 | | // |
3 | | // Licensed under the Apache License, Version 2.0 (the "License"); |
4 | | // you may not use this file except in compliance with the License. |
5 | | // You may obtain a copy of the License at |
6 | | // |
7 | | // http://www.apache.org/licenses/LICENSE-2.0 |
8 | | // |
9 | | // Unless required by applicable law or agreed to in writing, software |
10 | | // distributed under the License is distributed on an "AS IS" BASIS, |
11 | | // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. |
12 | | // See the License for the specific language governing permissions and |
13 | | // limitations under the License. |
14 | | |
15 | | #include "source/opt/private_to_local_pass.h" |
16 | | |
17 | | #include <memory> |
18 | | #include <utility> |
19 | | #include <vector> |
20 | | |
21 | | #include "source/opt/ir_context.h" |
22 | | #include "source/spirv_constant.h" |
23 | | |
24 | | namespace spvtools { |
25 | | namespace opt { |
26 | | namespace { |
27 | | constexpr uint32_t kVariableStorageClassInIdx = 0; |
28 | | constexpr uint32_t kSpvTypePointerTypeIdInIdx = 1; |
29 | | } // namespace |
30 | | |
31 | 11.9k | Pass::Status PrivateToLocalPass::Process() { |
32 | 11.9k | bool modified = false; |
33 | | |
34 | | // Private variables require the shader capability. If this is not a shader, |
35 | | // there is no work to do. |
36 | 11.9k | if (context()->get_feature_mgr()->HasCapability(spv::Capability::Addresses)) |
37 | 638 | return Status::SuccessWithoutChange; |
38 | | |
39 | 11.3k | std::vector<std::pair<Instruction*, Function*>> variables_to_move; |
40 | 11.3k | std::unordered_set<uint32_t> localized_variables; |
41 | 278k | for (auto& inst : context()->types_values()) { |
42 | 278k | if (inst.opcode() != spv::Op::OpVariable) { |
43 | 255k | continue; |
44 | 255k | } |
45 | | |
46 | 22.4k | if (spv::StorageClass(inst.GetSingleWordInOperand( |
47 | 22.4k | kVariableStorageClassInIdx)) != spv::StorageClass::Private) { |
48 | 16.7k | continue; |
49 | 16.7k | } |
50 | | |
51 | 5.71k | Function* target_function = FindLocalFunction(inst); |
52 | 5.71k | if (target_function != nullptr) { |
53 | 3.13k | variables_to_move.push_back({&inst, target_function}); |
54 | 3.13k | } |
55 | 5.71k | } |
56 | | |
57 | 11.3k | modified = !variables_to_move.empty(); |
58 | 11.3k | for (auto p : variables_to_move) { |
59 | 3.13k | if (!MoveVariable(p.first, p.second)) { |
60 | 0 | return Status::Failure; |
61 | 0 | } |
62 | 3.13k | localized_variables.insert(p.first->result_id()); |
63 | 3.13k | } |
64 | | |
65 | 11.3k | if (get_module()->version() >= SPV_SPIRV_VERSION_WORD(1, 4)) { |
66 | | // In SPIR-V 1.4 and later entry points must list private storage class |
67 | | // variables that are statically used by the entry point. Go through the |
68 | | // entry points and remove any references to variables that were localized. |
69 | 0 | for (auto& entry : get_module()->entry_points()) { |
70 | 0 | std::vector<Operand> new_operands; |
71 | 0 | for (uint32_t i = 0; i < entry.NumInOperands(); ++i) { |
72 | | // Execution model, function id and name are always kept. |
73 | 0 | if (i < 3 || |
74 | 0 | !localized_variables.count(entry.GetSingleWordInOperand(i))) { |
75 | 0 | new_operands.push_back(entry.GetInOperand(i)); |
76 | 0 | } |
77 | 0 | } |
78 | 0 | if (new_operands.size() != entry.NumInOperands()) { |
79 | 0 | entry.SetInOperands(std::move(new_operands)); |
80 | 0 | context()->AnalyzeUses(&entry); |
81 | 0 | } |
82 | 0 | } |
83 | 0 | } |
84 | | |
85 | 11.3k | return (modified ? Status::SuccessWithChange : Status::SuccessWithoutChange); |
86 | 11.3k | } |
87 | | |
88 | 5.71k | Function* PrivateToLocalPass::FindLocalFunction(const Instruction& inst) const { |
89 | 5.71k | bool found_first_use = false; |
90 | 5.71k | Function* target_function = nullptr; |
91 | 5.71k | context()->get_def_use_mgr()->ForEachUser( |
92 | 5.71k | inst.result_id(), |
93 | 75.0k | [&target_function, &found_first_use, this](Instruction* use) { |
94 | 75.0k | BasicBlock* current_block = context()->get_instr_block(use); |
95 | 75.0k | if (current_block == nullptr) { |
96 | 1.28k | return; |
97 | 1.28k | } |
98 | | |
99 | 73.7k | if (!IsValidUse(use)) { |
100 | 2.13k | found_first_use = true; |
101 | 2.13k | target_function = nullptr; |
102 | 2.13k | return; |
103 | 2.13k | } |
104 | 71.6k | Function* current_function = current_block->GetParent(); |
105 | 71.6k | if (!found_first_use) { |
106 | 3.46k | found_first_use = true; |
107 | 3.46k | target_function = current_function; |
108 | 68.1k | } else if (target_function != current_function) { |
109 | 12.6k | target_function = nullptr; |
110 | 12.6k | } |
111 | 71.6k | }); |
112 | 5.71k | return target_function; |
113 | 5.71k | } // namespace opt |
114 | | |
115 | | bool PrivateToLocalPass::MoveVariable(Instruction* variable, |
116 | 3.13k | Function* function) { |
117 | | // The variable needs to be removed from the global section, and placed in the |
118 | | // header of the function. First step remove from the global list. |
119 | 3.13k | variable->RemoveFromList(); |
120 | 3.13k | std::unique_ptr<Instruction> var(variable); // Take ownership. |
121 | 3.13k | context()->ForgetUses(variable); |
122 | | |
123 | | // Update the storage class of the variable. |
124 | 3.13k | variable->SetInOperand(kVariableStorageClassInIdx, |
125 | 3.13k | {uint32_t(spv::StorageClass::Function)}); |
126 | | |
127 | | // Update the type as well. |
128 | 3.13k | uint32_t new_type_id = GetNewType(variable->type_id()); |
129 | 3.13k | if (new_type_id == 0) { |
130 | 0 | return false; |
131 | 0 | } |
132 | 3.13k | variable->SetResultType(new_type_id); |
133 | | |
134 | | // Place the variable at the start of the first basic block. |
135 | 3.13k | context()->AnalyzeUses(variable); |
136 | 3.13k | context()->set_instr_block(variable, &*function->begin()); |
137 | 3.13k | function->begin()->begin()->InsertBefore(std::move(var)); |
138 | | |
139 | | // Update uses where the type may have changed. |
140 | 3.13k | return UpdateUses(variable); |
141 | 3.13k | } |
142 | | |
143 | 42.8k | uint32_t PrivateToLocalPass::GetNewType(uint32_t old_type_id) { |
144 | 42.8k | auto type_mgr = context()->get_type_mgr(); |
145 | 42.8k | Instruction* old_type_inst = get_def_use_mgr()->GetDef(old_type_id); |
146 | 42.8k | uint32_t pointee_type_id = |
147 | 42.8k | old_type_inst->GetSingleWordInOperand(kSpvTypePointerTypeIdInIdx); |
148 | 42.8k | uint32_t new_type_id = |
149 | 42.8k | type_mgr->FindPointerToType(pointee_type_id, spv::StorageClass::Function); |
150 | 42.8k | if (new_type_id != 0) { |
151 | 42.8k | context()->UpdateDefUse(context()->get_def_use_mgr()->GetDef(new_type_id)); |
152 | 42.8k | } |
153 | 42.8k | return new_type_id; |
154 | 42.8k | } |
155 | | |
156 | 129k | bool PrivateToLocalPass::IsValidUse(const Instruction* inst) const { |
157 | | // The cases in this switch have to match the cases in |UpdateUse|. |
158 | | // If we don't know how to update it, it is not valid. |
159 | 129k | if (inst->GetCommonDebugOpcode() == CommonDebugInfoDebugGlobalVariable) { |
160 | 0 | return true; |
161 | 0 | } |
162 | 129k | switch (inst->opcode()) { |
163 | 45.2k | case spv::Op::OpLoad: |
164 | 71.1k | case spv::Op::OpStore: |
165 | 71.1k | case spv::Op::OpImageTexelPointer: // Treat like a load |
166 | 71.1k | return true; |
167 | 56.5k | case spv::Op::OpAccessChain: |
168 | 56.5k | return context()->get_def_use_mgr()->WhileEachUser( |
169 | 56.5k | inst, [this](const Instruction* user) { |
170 | 56.1k | if (!IsValidUse(user)) return false; |
171 | 56.1k | return true; |
172 | 56.1k | }); |
173 | 15 | case spv::Op::OpName: |
174 | 15 | return true; |
175 | 2.14k | default: |
176 | 2.14k | return spvOpcodeIsDecoration(inst->opcode()); |
177 | 129k | } |
178 | 129k | } |
179 | | |
180 | 95.1k | bool PrivateToLocalPass::UpdateUse(Instruction* inst, Instruction* user) { |
181 | | // The cases in this switch have to match the cases in |IsValidUse|. If we |
182 | | // don't think it is valid, the optimization will not view the variable as a |
183 | | // candidate, and therefore the use will not be updated. |
184 | 95.1k | if (inst->GetCommonDebugOpcode() == CommonDebugInfoDebugGlobalVariable) { |
185 | 0 | context()->get_debug_info_mgr()->ConvertDebugGlobalToLocalVariable(inst, |
186 | 0 | user); |
187 | 0 | return true; |
188 | 0 | } |
189 | 95.1k | switch (inst->opcode()) { |
190 | 34.0k | case spv::Op::OpLoad: |
191 | 54.2k | case spv::Op::OpStore: |
192 | 54.2k | case spv::Op::OpImageTexelPointer: // Treat like a load |
193 | | // The type is fine because it is the type pointed to, and that does not |
194 | | // change. |
195 | 54.2k | break; |
196 | 39.7k | case spv::Op::OpAccessChain: { |
197 | 39.7k | context()->ForgetUses(inst); |
198 | 39.7k | uint32_t new_type_id = GetNewType(inst->type_id()); |
199 | 39.7k | if (new_type_id == 0) { |
200 | 0 | return false; |
201 | 0 | } |
202 | 39.7k | inst->SetResultType(new_type_id); |
203 | 39.7k | context()->AnalyzeUses(inst); |
204 | | |
205 | | // Update uses where the type may have changed. |
206 | 39.7k | if (!UpdateUses(inst)) { |
207 | 0 | return false; |
208 | 0 | } |
209 | 39.7k | } break; |
210 | 39.7k | case spv::Op::OpName: |
211 | 1.04k | case spv::Op::OpEntryPoint: // entry points will be updated separately. |
212 | 1.04k | break; |
213 | 109 | default: |
214 | 109 | assert(spvOpcodeIsDecoration(inst->opcode()) && |
215 | 109 | "Do not know how to update the type for this instruction."); |
216 | 0 | break; |
217 | 95.1k | } |
218 | 95.1k | return true; |
219 | 95.1k | } |
220 | | |
221 | 42.8k | bool PrivateToLocalPass::UpdateUses(Instruction* inst) { |
222 | 42.8k | uint32_t id = inst->result_id(); |
223 | 42.8k | std::vector<Instruction*> uses; |
224 | 42.8k | context()->get_def_use_mgr()->ForEachUser( |
225 | 95.1k | id, [&uses](Instruction* use) { uses.push_back(use); }); |
226 | | |
227 | 95.1k | for (Instruction* use : uses) { |
228 | 95.1k | if (!UpdateUse(use, inst)) { |
229 | 0 | return false; |
230 | 0 | } |
231 | 95.1k | } |
232 | 42.8k | return true; |
233 | 42.8k | } |
234 | | |
235 | | } // namespace opt |
236 | | } // namespace spvtools |