/src/spirv-tools/source/opt/loop_fusion.cpp
Line | Count | Source (jump to first uncovered line) |
1 | | // Copyright (c) 2018 Google LLC. |
2 | | // |
3 | | // Licensed under the Apache License, Version 2.0 (the "License"); |
4 | | // you may not use this file except in compliance with the License. |
5 | | // You may obtain a copy of the License at |
6 | | // |
7 | | // http://www.apache.org/licenses/LICENSE-2.0 |
8 | | // |
9 | | // Unless required by applicable law or agreed to in writing, software |
10 | | // distributed under the License is distributed on an "AS IS" BASIS, |
11 | | // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. |
12 | | // See the License for the specific language governing permissions and |
13 | | // limitations under the License. |
14 | | |
15 | | #include "source/opt/loop_fusion.h" |
16 | | |
17 | | #include <algorithm> |
18 | | #include <vector> |
19 | | |
20 | | #include "source/opt/ir_context.h" |
21 | | #include "source/opt/loop_dependence.h" |
22 | | #include "source/opt/loop_descriptor.h" |
23 | | |
24 | | namespace spvtools { |
25 | | namespace opt { |
26 | | namespace { |
27 | | |
28 | | // Append all the loops nested in |loop| to |loops|. |
29 | 0 | void CollectChildren(Loop* loop, std::vector<const Loop*>* loops) { |
30 | 0 | for (auto child : *loop) { |
31 | 0 | loops->push_back(child); |
32 | 0 | if (child->NumImmediateChildren() != 0) { |
33 | 0 | CollectChildren(child, loops); |
34 | 0 | } |
35 | 0 | } |
36 | 0 | } |
37 | | |
38 | | // Return the set of locations accessed by |stores| and |loads|. |
39 | | std::set<Instruction*> GetLocationsAccessed( |
40 | | const std::map<Instruction*, std::vector<Instruction*>>& stores, |
41 | 0 | const std::map<Instruction*, std::vector<Instruction*>>& loads) { |
42 | 0 | std::set<Instruction*> locations{}; |
43 | |
|
44 | 0 | for (const auto& kv : stores) { |
45 | 0 | locations.insert(std::get<0>(kv)); |
46 | 0 | } |
47 | |
|
48 | 0 | for (const auto& kv : loads) { |
49 | 0 | locations.insert(std::get<0>(kv)); |
50 | 0 | } |
51 | |
|
52 | 0 | return locations; |
53 | 0 | } |
54 | | |
55 | | // Append all dependences from |sources| to |destinations| to |dependences|. |
56 | | void GetDependences(std::vector<DistanceVector>* dependences, |
57 | | LoopDependenceAnalysis* analysis, |
58 | | const std::vector<Instruction*>& sources, |
59 | | const std::vector<Instruction*>& destinations, |
60 | 0 | size_t num_entries) { |
61 | 0 | for (auto source : sources) { |
62 | 0 | for (auto destination : destinations) { |
63 | 0 | DistanceVector dist(num_entries); |
64 | 0 | if (!analysis->GetDependence(source, destination, &dist)) { |
65 | 0 | dependences->push_back(dist); |
66 | 0 | } |
67 | 0 | } |
68 | 0 | } |
69 | 0 | } |
70 | | |
71 | | // Apped all instructions in |block| to |instructions|. |
72 | | void AddInstructionsInBlock(std::vector<Instruction*>* instructions, |
73 | 0 | BasicBlock* block) { |
74 | 0 | for (auto& inst : *block) { |
75 | 0 | instructions->push_back(&inst); |
76 | 0 | } |
77 | |
|
78 | 0 | instructions->push_back(block->GetLabelInst()); |
79 | 0 | } |
80 | | |
81 | | } // namespace |
82 | | |
83 | | bool LoopFusion::UsedInContinueOrConditionBlock(Instruction* phi_instruction, |
84 | 0 | Loop* loop) { |
85 | 0 | auto condition_block = loop->FindConditionBlock()->id(); |
86 | 0 | auto continue_block = loop->GetContinueBlock()->id(); |
87 | 0 | auto not_used = context_->get_def_use_mgr()->WhileEachUser( |
88 | 0 | phi_instruction, |
89 | 0 | [this, condition_block, continue_block](Instruction* instruction) { |
90 | 0 | auto block_id = context_->get_instr_block(instruction)->id(); |
91 | 0 | return block_id != condition_block && block_id != continue_block; |
92 | 0 | }); |
93 | |
|
94 | 0 | return !not_used; |
95 | 0 | } |
96 | | |
97 | | void LoopFusion::RemoveIfNotUsedContinueOrConditionBlock( |
98 | 0 | std::vector<Instruction*>* instructions, Loop* loop) { |
99 | 0 | instructions->erase( |
100 | 0 | std::remove_if(std::begin(*instructions), std::end(*instructions), |
101 | 0 | [this, loop](Instruction* instruction) { |
102 | 0 | return !UsedInContinueOrConditionBlock(instruction, |
103 | 0 | loop); |
104 | 0 | }), |
105 | 0 | std::end(*instructions)); |
106 | 0 | } |
107 | | |
108 | 0 | bool LoopFusion::AreCompatible() { |
109 | | // Check that the loops are in the same function. |
110 | 0 | if (loop_0_->GetHeaderBlock()->GetParent() != |
111 | 0 | loop_1_->GetHeaderBlock()->GetParent()) { |
112 | 0 | return false; |
113 | 0 | } |
114 | | |
115 | | // Check that both loops have pre-header blocks. |
116 | 0 | if (!loop_0_->GetPreHeaderBlock() || !loop_1_->GetPreHeaderBlock()) { |
117 | 0 | return false; |
118 | 0 | } |
119 | | |
120 | | // Check there are no breaks. |
121 | 0 | if (context_->cfg()->preds(loop_0_->GetMergeBlock()->id()).size() != 1 || |
122 | 0 | context_->cfg()->preds(loop_1_->GetMergeBlock()->id()).size() != 1) { |
123 | 0 | return false; |
124 | 0 | } |
125 | | |
126 | | // Check there are no continues. |
127 | 0 | if (context_->cfg()->preds(loop_0_->GetContinueBlock()->id()).size() != 1 || |
128 | 0 | context_->cfg()->preds(loop_1_->GetContinueBlock()->id()).size() != 1) { |
129 | 0 | return false; |
130 | 0 | } |
131 | | |
132 | | // |GetInductionVariables| returns all OpPhi in the header. Check that both |
133 | | // loops have exactly one that is used in the continue and condition blocks. |
134 | 0 | std::vector<Instruction*> inductions_0{}, inductions_1{}; |
135 | 0 | loop_0_->GetInductionVariables(inductions_0); |
136 | 0 | RemoveIfNotUsedContinueOrConditionBlock(&inductions_0, loop_0_); |
137 | |
|
138 | 0 | if (inductions_0.size() != 1) { |
139 | 0 | return false; |
140 | 0 | } |
141 | | |
142 | 0 | induction_0_ = inductions_0.front(); |
143 | |
|
144 | 0 | loop_1_->GetInductionVariables(inductions_1); |
145 | 0 | RemoveIfNotUsedContinueOrConditionBlock(&inductions_1, loop_1_); |
146 | |
|
147 | 0 | if (inductions_1.size() != 1) { |
148 | 0 | return false; |
149 | 0 | } |
150 | | |
151 | 0 | induction_1_ = inductions_1.front(); |
152 | |
|
153 | 0 | if (!CheckInit()) { |
154 | 0 | return false; |
155 | 0 | } |
156 | | |
157 | 0 | if (!CheckCondition()) { |
158 | 0 | return false; |
159 | 0 | } |
160 | | |
161 | 0 | if (!CheckStep()) { |
162 | 0 | return false; |
163 | 0 | } |
164 | | |
165 | | // Check adjacency, |loop_0_| should come just before |loop_1_|. |
166 | | // There is always at least one block between loops, even if it's empty. |
167 | | // We'll check at most 2 preceding blocks. |
168 | | |
169 | 0 | auto pre_header_1 = loop_1_->GetPreHeaderBlock(); |
170 | |
|
171 | 0 | std::vector<BasicBlock*> block_to_check{}; |
172 | 0 | block_to_check.push_back(pre_header_1); |
173 | |
|
174 | 0 | if (loop_0_->GetMergeBlock() != loop_1_->GetPreHeaderBlock()) { |
175 | | // Follow CFG for one more block. |
176 | 0 | auto preds = context_->cfg()->preds(pre_header_1->id()); |
177 | 0 | if (preds.size() == 1) { |
178 | 0 | auto block = &*containing_function_->FindBlock(preds.front()); |
179 | 0 | if (block == loop_0_->GetMergeBlock()) { |
180 | 0 | block_to_check.push_back(block); |
181 | 0 | } else { |
182 | 0 | return false; |
183 | 0 | } |
184 | 0 | } else { |
185 | 0 | return false; |
186 | 0 | } |
187 | 0 | } |
188 | | |
189 | | // Check that the separating blocks are either empty or only contains a store |
190 | | // to a local variable that is never read (left behind by |
191 | | // '--eliminate-local-multi-store'). Also allow OpPhi, since the loop could be |
192 | | // in LCSSA form. |
193 | 0 | for (auto block : block_to_check) { |
194 | 0 | for (auto& inst : *block) { |
195 | 0 | if (inst.opcode() == spv::Op::OpStore) { |
196 | | // Get the definition of the target to check it's function scope so |
197 | | // there are no observable side effects. |
198 | 0 | auto variable = |
199 | 0 | context_->get_def_use_mgr()->GetDef(inst.GetSingleWordInOperand(0)); |
200 | |
|
201 | 0 | if (variable->opcode() != spv::Op::OpVariable || |
202 | 0 | spv::StorageClass(variable->GetSingleWordInOperand(0)) != |
203 | 0 | spv::StorageClass::Function) { |
204 | 0 | return false; |
205 | 0 | } |
206 | | |
207 | | // Check the target is never loaded. |
208 | 0 | auto is_used = false; |
209 | 0 | context_->get_def_use_mgr()->ForEachUse( |
210 | 0 | inst.GetSingleWordInOperand(0), |
211 | 0 | [&is_used](Instruction* use_inst, uint32_t) { |
212 | 0 | if (use_inst->opcode() == spv::Op::OpLoad) { |
213 | 0 | is_used = true; |
214 | 0 | } |
215 | 0 | }); |
216 | |
|
217 | 0 | if (is_used) { |
218 | 0 | return false; |
219 | 0 | } |
220 | 0 | } else if (inst.opcode() == spv::Op::OpPhi) { |
221 | 0 | if (inst.NumInOperands() != 2) { |
222 | 0 | return false; |
223 | 0 | } |
224 | 0 | } else if (inst.opcode() != spv::Op::OpBranch) { |
225 | 0 | return false; |
226 | 0 | } |
227 | 0 | } |
228 | 0 | } |
229 | | |
230 | 0 | return true; |
231 | 0 | } // namespace opt |
232 | | |
233 | 0 | bool LoopFusion::ContainsBarriersOrFunctionCalls(Loop* loop) { |
234 | 0 | for (const auto& block : loop->GetBlocks()) { |
235 | 0 | for (const auto& inst : *containing_function_->FindBlock(block)) { |
236 | 0 | auto opcode = inst.opcode(); |
237 | 0 | if (opcode == spv::Op::OpFunctionCall || |
238 | 0 | opcode == spv::Op::OpControlBarrier || |
239 | 0 | opcode == spv::Op::OpMemoryBarrier || |
240 | 0 | opcode == spv::Op::OpTypeNamedBarrier || |
241 | 0 | opcode == spv::Op::OpNamedBarrierInitialize || |
242 | 0 | opcode == spv::Op::OpMemoryNamedBarrier) { |
243 | 0 | return true; |
244 | 0 | } |
245 | 0 | } |
246 | 0 | } |
247 | | |
248 | 0 | return false; |
249 | 0 | } |
250 | | |
251 | 0 | bool LoopFusion::CheckInit() { |
252 | 0 | int64_t loop_0_init; |
253 | 0 | if (!loop_0_->GetInductionInitValue(induction_0_, &loop_0_init)) { |
254 | 0 | return false; |
255 | 0 | } |
256 | | |
257 | 0 | int64_t loop_1_init; |
258 | 0 | if (!loop_1_->GetInductionInitValue(induction_1_, &loop_1_init)) { |
259 | 0 | return false; |
260 | 0 | } |
261 | | |
262 | 0 | if (loop_0_init != loop_1_init) { |
263 | 0 | return false; |
264 | 0 | } |
265 | | |
266 | 0 | return true; |
267 | 0 | } |
268 | | |
269 | 0 | bool LoopFusion::CheckCondition() { |
270 | 0 | auto condition_0 = loop_0_->GetConditionInst(); |
271 | 0 | auto condition_1 = loop_1_->GetConditionInst(); |
272 | |
|
273 | 0 | if (!loop_0_->IsSupportedCondition(condition_0->opcode()) || |
274 | 0 | !loop_1_->IsSupportedCondition(condition_1->opcode())) { |
275 | 0 | return false; |
276 | 0 | } |
277 | | |
278 | 0 | if (condition_0->opcode() != condition_1->opcode()) { |
279 | 0 | return false; |
280 | 0 | } |
281 | | |
282 | 0 | for (uint32_t i = 0; i < condition_0->NumInOperandWords(); ++i) { |
283 | 0 | auto arg_0 = context_->get_def_use_mgr()->GetDef( |
284 | 0 | condition_0->GetSingleWordInOperand(i)); |
285 | 0 | auto arg_1 = context_->get_def_use_mgr()->GetDef( |
286 | 0 | condition_1->GetSingleWordInOperand(i)); |
287 | |
|
288 | 0 | if (arg_0 == induction_0_ && arg_1 == induction_1_) { |
289 | 0 | continue; |
290 | 0 | } |
291 | | |
292 | 0 | if (arg_0 == induction_0_ && arg_1 != induction_1_) { |
293 | 0 | return false; |
294 | 0 | } |
295 | | |
296 | 0 | if (arg_1 == induction_1_ && arg_0 != induction_0_) { |
297 | 0 | return false; |
298 | 0 | } |
299 | | |
300 | 0 | if (arg_0 != arg_1) { |
301 | 0 | return false; |
302 | 0 | } |
303 | 0 | } |
304 | | |
305 | 0 | return true; |
306 | 0 | } |
307 | | |
308 | 0 | bool LoopFusion::CheckStep() { |
309 | 0 | auto scalar_analysis = context_->GetScalarEvolutionAnalysis(); |
310 | 0 | SENode* induction_node_0 = scalar_analysis->SimplifyExpression( |
311 | 0 | scalar_analysis->AnalyzeInstruction(induction_0_)); |
312 | 0 | if (!induction_node_0->AsSERecurrentNode()) { |
313 | 0 | return false; |
314 | 0 | } |
315 | | |
316 | 0 | SENode* induction_step_0 = |
317 | 0 | induction_node_0->AsSERecurrentNode()->GetCoefficient(); |
318 | 0 | if (!induction_step_0->AsSEConstantNode()) { |
319 | 0 | return false; |
320 | 0 | } |
321 | | |
322 | 0 | SENode* induction_node_1 = scalar_analysis->SimplifyExpression( |
323 | 0 | scalar_analysis->AnalyzeInstruction(induction_1_)); |
324 | 0 | if (!induction_node_1->AsSERecurrentNode()) { |
325 | 0 | return false; |
326 | 0 | } |
327 | | |
328 | 0 | SENode* induction_step_1 = |
329 | 0 | induction_node_1->AsSERecurrentNode()->GetCoefficient(); |
330 | 0 | if (!induction_step_1->AsSEConstantNode()) { |
331 | 0 | return false; |
332 | 0 | } |
333 | | |
334 | 0 | if (*induction_step_0 != *induction_step_1) { |
335 | 0 | return false; |
336 | 0 | } |
337 | | |
338 | 0 | return true; |
339 | 0 | } |
340 | | |
341 | | std::map<Instruction*, std::vector<Instruction*>> LoopFusion::LocationToMemOps( |
342 | 0 | const std::vector<Instruction*>& mem_ops) { |
343 | 0 | std::map<Instruction*, std::vector<Instruction*>> location_map{}; |
344 | |
|
345 | 0 | for (auto instruction : mem_ops) { |
346 | 0 | auto access_location = context_->get_def_use_mgr()->GetDef( |
347 | 0 | instruction->GetSingleWordInOperand(0)); |
348 | |
|
349 | 0 | while (access_location->opcode() == spv::Op::OpAccessChain) { |
350 | 0 | access_location = context_->get_def_use_mgr()->GetDef( |
351 | 0 | access_location->GetSingleWordInOperand(0)); |
352 | 0 | } |
353 | |
|
354 | 0 | location_map[access_location].push_back(instruction); |
355 | 0 | } |
356 | |
|
357 | 0 | return location_map; |
358 | 0 | } |
359 | | |
360 | | std::pair<std::vector<Instruction*>, std::vector<Instruction*>> |
361 | 0 | LoopFusion::GetLoadsAndStoresInLoop(Loop* loop) { |
362 | 0 | std::vector<Instruction*> loads{}; |
363 | 0 | std::vector<Instruction*> stores{}; |
364 | |
|
365 | 0 | for (auto block_id : loop->GetBlocks()) { |
366 | 0 | if (block_id == loop->GetContinueBlock()->id()) { |
367 | 0 | continue; |
368 | 0 | } |
369 | | |
370 | 0 | for (auto& instruction : *containing_function_->FindBlock(block_id)) { |
371 | 0 | if (instruction.opcode() == spv::Op::OpLoad) { |
372 | 0 | loads.push_back(&instruction); |
373 | 0 | } else if (instruction.opcode() == spv::Op::OpStore) { |
374 | 0 | stores.push_back(&instruction); |
375 | 0 | } |
376 | 0 | } |
377 | 0 | } |
378 | |
|
379 | 0 | return std::make_pair(loads, stores); |
380 | 0 | } |
381 | | |
382 | 0 | bool LoopFusion::IsUsedInLoop(Instruction* instruction, Loop* loop) { |
383 | 0 | auto not_used = context_->get_def_use_mgr()->WhileEachUser( |
384 | 0 | instruction, [this, loop](Instruction* user) { |
385 | 0 | auto block_id = context_->get_instr_block(user)->id(); |
386 | 0 | return !loop->IsInsideLoop(block_id); |
387 | 0 | }); |
388 | |
|
389 | 0 | return !not_used; |
390 | 0 | } |
391 | | |
392 | 0 | bool LoopFusion::IsLegal() { |
393 | 0 | assert(AreCompatible() && "Fusion can't be legal, loops are not compatible."); |
394 | | |
395 | | // Bail out if there are function calls as they could have side-effects that |
396 | | // cause dependencies or if there are any barriers. |
397 | 0 | if (ContainsBarriersOrFunctionCalls(loop_0_) || |
398 | 0 | ContainsBarriersOrFunctionCalls(loop_1_)) { |
399 | 0 | return false; |
400 | 0 | } |
401 | | |
402 | 0 | std::vector<Instruction*> phi_instructions{}; |
403 | 0 | loop_0_->GetInductionVariables(phi_instructions); |
404 | | |
405 | | // Check no OpPhi in |loop_0_| is used in |loop_1_|. |
406 | 0 | for (auto phi_instruction : phi_instructions) { |
407 | 0 | if (IsUsedInLoop(phi_instruction, loop_1_)) { |
408 | 0 | return false; |
409 | 0 | } |
410 | 0 | } |
411 | | |
412 | | // Check no LCSSA OpPhi in merge block of |loop_0_| is used in |loop_1_|. |
413 | 0 | auto phi_used = false; |
414 | 0 | loop_0_->GetMergeBlock()->ForEachPhiInst( |
415 | 0 | [this, &phi_used](Instruction* phi_instruction) { |
416 | 0 | phi_used |= IsUsedInLoop(phi_instruction, loop_1_); |
417 | 0 | }); |
418 | |
|
419 | 0 | if (phi_used) { |
420 | 0 | return false; |
421 | 0 | } |
422 | | |
423 | | // Grab loads & stores from both loops. |
424 | 0 | auto loads_stores_0 = GetLoadsAndStoresInLoop(loop_0_); |
425 | 0 | auto loads_stores_1 = GetLoadsAndStoresInLoop(loop_1_); |
426 | | |
427 | | // Build memory location to operation maps. |
428 | 0 | auto load_locs_0 = LocationToMemOps(std::get<0>(loads_stores_0)); |
429 | 0 | auto store_locs_0 = LocationToMemOps(std::get<1>(loads_stores_0)); |
430 | |
|
431 | 0 | auto load_locs_1 = LocationToMemOps(std::get<0>(loads_stores_1)); |
432 | 0 | auto store_locs_1 = LocationToMemOps(std::get<1>(loads_stores_1)); |
433 | | |
434 | | // Get the locations accessed in both loops. |
435 | 0 | auto locations_0 = GetLocationsAccessed(store_locs_0, load_locs_0); |
436 | 0 | auto locations_1 = GetLocationsAccessed(store_locs_1, load_locs_1); |
437 | |
|
438 | 0 | std::vector<Instruction*> potential_clashes{}; |
439 | |
|
440 | 0 | std::set_intersection(std::begin(locations_0), std::end(locations_0), |
441 | 0 | std::begin(locations_1), std::end(locations_1), |
442 | 0 | std::back_inserter(potential_clashes)); |
443 | | |
444 | | // If the loops don't access the same variables, the fusion is legal. |
445 | 0 | if (potential_clashes.empty()) { |
446 | 0 | return true; |
447 | 0 | } |
448 | | |
449 | | // Find variables that have at least one store. |
450 | 0 | std::vector<Instruction*> potential_clashes_with_stores{}; |
451 | 0 | for (auto location : potential_clashes) { |
452 | 0 | if (store_locs_0.find(location) != std::end(store_locs_0) || |
453 | 0 | store_locs_1.find(location) != std::end(store_locs_1)) { |
454 | 0 | potential_clashes_with_stores.push_back(location); |
455 | 0 | } |
456 | 0 | } |
457 | | |
458 | | // If there are only loads to the same variables, the fusion is legal. |
459 | 0 | if (potential_clashes_with_stores.empty()) { |
460 | 0 | return true; |
461 | 0 | } |
462 | | |
463 | | // Else if loads and at least one store (across loops) to the same variable |
464 | | // there is a potential dependence and we need to check the dependence |
465 | | // distance. |
466 | | |
467 | | // Find all the loops in this loop nest for the dependency analysis. |
468 | 0 | std::vector<const Loop*> loops{}; |
469 | | |
470 | | // Find the parents. |
471 | 0 | for (auto current_loop = loop_0_; current_loop != nullptr; |
472 | 0 | current_loop = current_loop->GetParent()) { |
473 | 0 | loops.push_back(current_loop); |
474 | 0 | } |
475 | |
|
476 | 0 | auto this_loop_position = loops.size() - 1; |
477 | 0 | std::reverse(std::begin(loops), std::end(loops)); |
478 | | |
479 | | // Find the children. |
480 | 0 | CollectChildren(loop_0_, &loops); |
481 | 0 | CollectChildren(loop_1_, &loops); |
482 | | |
483 | | // Check that any dependes created are legal. That means the fused loops do |
484 | | // not have any dependencies with dependence distance greater than 0 that did |
485 | | // not exist in the original loops. |
486 | |
|
487 | 0 | LoopDependenceAnalysis analysis(context_, loops); |
488 | |
|
489 | 0 | analysis.GetScalarEvolution()->AddLoopsToPretendAreTheSame( |
490 | 0 | {loop_0_, loop_1_}); |
491 | |
|
492 | 0 | for (auto location : potential_clashes_with_stores) { |
493 | | // Analyse dependences from |loop_0_| to |loop_1_|. |
494 | 0 | std::vector<DistanceVector> dependences; |
495 | | // Read-After-Write. |
496 | 0 | GetDependences(&dependences, &analysis, store_locs_0[location], |
497 | 0 | load_locs_1[location], loops.size()); |
498 | | // Write-After-Read. |
499 | 0 | GetDependences(&dependences, &analysis, load_locs_0[location], |
500 | 0 | store_locs_1[location], loops.size()); |
501 | | // Write-After-Write. |
502 | 0 | GetDependences(&dependences, &analysis, store_locs_0[location], |
503 | 0 | store_locs_1[location], loops.size()); |
504 | | |
505 | | // Check that the induction variables either don't appear in the subscripts |
506 | | // or the dependence distance is negative. |
507 | 0 | for (const auto& dependence : dependences) { |
508 | 0 | const auto& entry = dependence.GetEntries()[this_loop_position]; |
509 | 0 | if ((entry.dependence_information == |
510 | 0 | DistanceEntry::DependenceInformation::DISTANCE && |
511 | 0 | entry.distance < 1) || |
512 | 0 | (entry.dependence_information == |
513 | 0 | DistanceEntry::DependenceInformation::IRRELEVANT)) { |
514 | 0 | continue; |
515 | 0 | } else { |
516 | 0 | return false; |
517 | 0 | } |
518 | 0 | } |
519 | 0 | } |
520 | | |
521 | 0 | return true; |
522 | 0 | } |
523 | | |
524 | | void ReplacePhiParentWith(Instruction* inst, uint32_t orig_block, |
525 | 0 | uint32_t new_block) { |
526 | 0 | if (inst->GetSingleWordInOperand(1) == orig_block) { |
527 | 0 | inst->SetInOperand(1, {new_block}); |
528 | 0 | } else { |
529 | 0 | inst->SetInOperand(3, {new_block}); |
530 | 0 | } |
531 | 0 | } |
532 | | |
533 | 0 | void LoopFusion::Fuse() { |
534 | 0 | assert(AreCompatible() && "Can't fuse, loops aren't compatible"); |
535 | 0 | assert(IsLegal() && "Can't fuse, illegal"); |
536 | | |
537 | | // Save the pointers/ids, won't be found in the middle of doing modifications. |
538 | 0 | auto header_1 = loop_1_->GetHeaderBlock()->id(); |
539 | 0 | auto condition_1 = loop_1_->FindConditionBlock()->id(); |
540 | 0 | auto continue_1 = loop_1_->GetContinueBlock()->id(); |
541 | 0 | auto continue_0 = loop_0_->GetContinueBlock()->id(); |
542 | 0 | auto condition_block_of_0 = loop_0_->FindConditionBlock(); |
543 | | |
544 | | // Find the blocks whose branches need updating. |
545 | 0 | auto first_block_of_1 = &*(++containing_function_->FindBlock(condition_1)); |
546 | 0 | auto last_block_of_1 = &*(--containing_function_->FindBlock(continue_1)); |
547 | 0 | auto last_block_of_0 = &*(--containing_function_->FindBlock(continue_0)); |
548 | | |
549 | | // Update the branch for |last_block_of_loop_0| to go to |first_block_of_1|. |
550 | 0 | last_block_of_0->ForEachSuccessorLabel( |
551 | 0 | [first_block_of_1](uint32_t* succ) { *succ = first_block_of_1->id(); }); |
552 | | |
553 | | // Update the branch for the |last_block_of_loop_1| to go to the continue |
554 | | // block of |loop_0_|. |
555 | 0 | last_block_of_1->ForEachSuccessorLabel( |
556 | 0 | [this](uint32_t* succ) { *succ = loop_0_->GetContinueBlock()->id(); }); |
557 | | |
558 | | // Update merge block id in the header of |loop_0_| to the merge block of |
559 | | // |loop_1_|. |
560 | 0 | loop_0_->GetHeaderBlock()->ForEachInst([this](Instruction* inst) { |
561 | 0 | if (inst->opcode() == spv::Op::OpLoopMerge) { |
562 | 0 | inst->SetInOperand(0, {loop_1_->GetMergeBlock()->id()}); |
563 | 0 | } |
564 | 0 | }); |
565 | | |
566 | | // Update condition branch target in |loop_0_| to the merge block of |
567 | | // |loop_1_|. |
568 | 0 | condition_block_of_0->ForEachInst([this](Instruction* inst) { |
569 | 0 | if (inst->opcode() == spv::Op::OpBranchConditional) { |
570 | 0 | auto loop_0_merge_block_id = loop_0_->GetMergeBlock()->id(); |
571 | |
|
572 | 0 | if (inst->GetSingleWordInOperand(1) == loop_0_merge_block_id) { |
573 | 0 | inst->SetInOperand(1, {loop_1_->GetMergeBlock()->id()}); |
574 | 0 | } else { |
575 | 0 | inst->SetInOperand(2, {loop_1_->GetMergeBlock()->id()}); |
576 | 0 | } |
577 | 0 | } |
578 | 0 | }); |
579 | | |
580 | | // Move OpPhi instructions not corresponding to the induction variable from |
581 | | // the header of |loop_1_| to the header of |loop_0_|. |
582 | 0 | std::vector<Instruction*> instructions_to_move{}; |
583 | 0 | for (auto& instruction : *loop_1_->GetHeaderBlock()) { |
584 | 0 | if (instruction.opcode() == spv::Op::OpPhi && |
585 | 0 | &instruction != induction_1_) { |
586 | 0 | instructions_to_move.push_back(&instruction); |
587 | 0 | } |
588 | 0 | } |
589 | |
|
590 | 0 | for (auto& it : instructions_to_move) { |
591 | 0 | it->RemoveFromList(); |
592 | 0 | it->InsertBefore(induction_0_); |
593 | 0 | } |
594 | | |
595 | | // Update the OpPhi parents to the correct blocks in |loop_0_|. |
596 | 0 | loop_0_->GetHeaderBlock()->ForEachPhiInst([this](Instruction* i) { |
597 | 0 | ReplacePhiParentWith(i, loop_1_->GetPreHeaderBlock()->id(), |
598 | 0 | loop_0_->GetPreHeaderBlock()->id()); |
599 | |
|
600 | 0 | ReplacePhiParentWith(i, loop_1_->GetContinueBlock()->id(), |
601 | 0 | loop_0_->GetContinueBlock()->id()); |
602 | 0 | }); |
603 | | |
604 | | // Update instruction to block mapping & DefUseManager. |
605 | 0 | for (auto& phi_instruction : instructions_to_move) { |
606 | 0 | context_->set_instr_block(phi_instruction, loop_0_->GetHeaderBlock()); |
607 | 0 | context_->get_def_use_mgr()->AnalyzeInstUse(phi_instruction); |
608 | 0 | } |
609 | | |
610 | | // Replace the uses of the induction variable of |loop_1_| with that the |
611 | | // induction variable of |loop_0_|. |
612 | 0 | context_->ReplaceAllUsesWith(induction_1_->result_id(), |
613 | 0 | induction_0_->result_id()); |
614 | | |
615 | | // Replace LCSSA OpPhi in merge block of |loop_0_|. |
616 | 0 | loop_0_->GetMergeBlock()->ForEachPhiInst([this](Instruction* instruction) { |
617 | 0 | context_->ReplaceAllUsesWith(instruction->result_id(), |
618 | 0 | instruction->GetSingleWordInOperand(0)); |
619 | 0 | }); |
620 | | |
621 | | // Update LCSSA OpPhi in merge block of |loop_1_|. |
622 | 0 | loop_1_->GetMergeBlock()->ForEachPhiInst( |
623 | 0 | [condition_block_of_0](Instruction* instruction) { |
624 | 0 | instruction->SetInOperand(1, {condition_block_of_0->id()}); |
625 | 0 | }); |
626 | | |
627 | | // Move the continue block of |loop_0_| after the last block of |loop_1_|. |
628 | 0 | containing_function_->MoveBasicBlockToAfter(continue_0, last_block_of_1); |
629 | | |
630 | | // Gather all instructions to be killed from |loop_1_| (induction variable |
631 | | // initialisation, header, condition and continue blocks). |
632 | 0 | std::vector<Instruction*> instr_to_delete{}; |
633 | 0 | AddInstructionsInBlock(&instr_to_delete, loop_1_->GetPreHeaderBlock()); |
634 | 0 | AddInstructionsInBlock(&instr_to_delete, loop_1_->GetHeaderBlock()); |
635 | 0 | AddInstructionsInBlock(&instr_to_delete, loop_1_->FindConditionBlock()); |
636 | 0 | AddInstructionsInBlock(&instr_to_delete, loop_1_->GetContinueBlock()); |
637 | | |
638 | | // There was an additional empty block between the loops, kill that too. |
639 | 0 | if (loop_0_->GetMergeBlock() != loop_1_->GetPreHeaderBlock()) { |
640 | 0 | AddInstructionsInBlock(&instr_to_delete, loop_0_->GetMergeBlock()); |
641 | 0 | } |
642 | | |
643 | | // Update the CFG, so it wouldn't need invalidating. |
644 | 0 | auto cfg = context_->cfg(); |
645 | |
|
646 | 0 | cfg->ForgetBlock(loop_1_->GetPreHeaderBlock()); |
647 | 0 | cfg->ForgetBlock(loop_1_->GetHeaderBlock()); |
648 | 0 | cfg->ForgetBlock(loop_1_->FindConditionBlock()); |
649 | 0 | cfg->ForgetBlock(loop_1_->GetContinueBlock()); |
650 | |
|
651 | 0 | if (loop_0_->GetMergeBlock() != loop_1_->GetPreHeaderBlock()) { |
652 | 0 | cfg->ForgetBlock(loop_0_->GetMergeBlock()); |
653 | 0 | } |
654 | |
|
655 | 0 | cfg->RemoveEdge(last_block_of_0->id(), loop_0_->GetContinueBlock()->id()); |
656 | 0 | cfg->AddEdge(last_block_of_0->id(), first_block_of_1->id()); |
657 | |
|
658 | 0 | cfg->AddEdge(last_block_of_1->id(), loop_0_->GetContinueBlock()->id()); |
659 | |
|
660 | 0 | cfg->AddEdge(loop_0_->GetContinueBlock()->id(), |
661 | 0 | loop_1_->GetHeaderBlock()->id()); |
662 | |
|
663 | 0 | cfg->AddEdge(condition_block_of_0->id(), loop_1_->GetMergeBlock()->id()); |
664 | | |
665 | | // Update DefUseManager. |
666 | 0 | auto def_use_mgr = context_->get_def_use_mgr(); |
667 | | |
668 | | // Uses of labels that are in updated branches need analysing. |
669 | 0 | def_use_mgr->AnalyzeInstUse(last_block_of_0->terminator()); |
670 | 0 | def_use_mgr->AnalyzeInstUse(last_block_of_1->terminator()); |
671 | 0 | def_use_mgr->AnalyzeInstUse(loop_0_->GetHeaderBlock()->GetLoopMergeInst()); |
672 | 0 | def_use_mgr->AnalyzeInstUse(condition_block_of_0->terminator()); |
673 | | |
674 | | // Update the LoopDescriptor, so it wouldn't need invalidating. |
675 | 0 | auto ld = context_->GetLoopDescriptor(containing_function_); |
676 | | |
677 | | // Create a copy, so the iterator wouldn't be invalidated. |
678 | 0 | std::vector<Loop*> loops_to_add_remove{}; |
679 | 0 | for (auto child_loop : *loop_1_) { |
680 | 0 | loops_to_add_remove.push_back(child_loop); |
681 | 0 | } |
682 | |
|
683 | 0 | for (auto child_loop : loops_to_add_remove) { |
684 | 0 | loop_1_->RemoveChildLoop(child_loop); |
685 | 0 | loop_0_->AddNestedLoop(child_loop); |
686 | 0 | } |
687 | |
|
688 | 0 | auto loop_1_blocks = loop_1_->GetBlocks(); |
689 | |
|
690 | 0 | for (auto block : loop_1_blocks) { |
691 | 0 | loop_1_->RemoveBasicBlock(block); |
692 | 0 | if (block != header_1 && block != condition_1 && block != continue_1) { |
693 | 0 | loop_0_->AddBasicBlock(block); |
694 | 0 | if ((*ld)[block] == loop_1_) { |
695 | 0 | ld->SetBasicBlockToLoop(block, loop_0_); |
696 | 0 | } |
697 | 0 | } |
698 | |
|
699 | 0 | if ((*ld)[block] == loop_1_) { |
700 | 0 | ld->ForgetBasicBlock(block); |
701 | 0 | } |
702 | 0 | } |
703 | |
|
704 | 0 | loop_1_->RemoveBasicBlock(loop_1_->GetPreHeaderBlock()->id()); |
705 | 0 | ld->ForgetBasicBlock(loop_1_->GetPreHeaderBlock()->id()); |
706 | |
|
707 | 0 | if (loop_0_->GetMergeBlock() != loop_1_->GetPreHeaderBlock()) { |
708 | 0 | loop_0_->RemoveBasicBlock(loop_0_->GetMergeBlock()->id()); |
709 | 0 | ld->ForgetBasicBlock(loop_0_->GetMergeBlock()->id()); |
710 | 0 | } |
711 | |
|
712 | 0 | loop_0_->SetMergeBlock(loop_1_->GetMergeBlock()); |
713 | |
|
714 | 0 | loop_1_->ClearBlocks(); |
715 | |
|
716 | 0 | ld->RemoveLoop(loop_1_); |
717 | | |
718 | | // Kill unnecessary instructions and remove all empty blocks. |
719 | 0 | for (auto inst : instr_to_delete) { |
720 | 0 | context_->KillInst(inst); |
721 | 0 | } |
722 | |
|
723 | 0 | containing_function_->RemoveEmptyBlocks(); |
724 | | |
725 | | // Invalidate analyses. |
726 | 0 | context_->InvalidateAnalysesExceptFor( |
727 | 0 | IRContext::Analysis::kAnalysisInstrToBlockMapping | |
728 | 0 | IRContext::Analysis::kAnalysisLoopAnalysis | |
729 | 0 | IRContext::Analysis::kAnalysisDefUse | IRContext::Analysis::kAnalysisCFG); |
730 | 0 | } |
731 | | |
732 | | } // namespace opt |
733 | | } // namespace spvtools |