Coverage Report

Created: 2025-06-13 06:49

/src/spirv-tools/source/opt/code_sink.cpp
Line
Count
Source (jump to first uncovered line)
1
// Copyright (c) 2019 Google LLC
2
//
3
// Licensed under the Apache License, Version 2.0 (the "License");
4
// you may not use this file except in compliance with the License.
5
// You may obtain a copy of the License at
6
//
7
//     http://www.apache.org/licenses/LICENSE-2.0
8
//
9
// Unless required by applicable law or agreed to in writing, software
10
// distributed under the License is distributed on an "AS IS" BASIS,
11
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
// See the License for the specific language governing permissions and
13
// limitations under the License.
14
15
#include "code_sink.h"
16
17
#include <vector>
18
19
#include "source/opt/instruction.h"
20
#include "source/opt/ir_context.h"
21
#include "source/util/bit_vector.h"
22
23
namespace spvtools {
24
namespace opt {
25
26
0
Pass::Status CodeSinkingPass::Process() {
27
0
  bool modified = false;
28
0
  for (Function& function : *get_module()) {
29
0
    cfg()->ForEachBlockInPostOrder(function.entry().get(),
30
0
                                   [&modified, this](BasicBlock* bb) {
31
0
                                     if (SinkInstructionsInBB(bb)) {
32
0
                                       modified = true;
33
0
                                     }
34
0
                                   });
35
0
  }
36
0
  return modified ? Status::SuccessWithChange : Status::SuccessWithoutChange;
37
0
}
38
39
0
bool CodeSinkingPass::SinkInstructionsInBB(BasicBlock* bb) {
40
0
  bool modified = false;
41
0
  for (auto inst = bb->rbegin(); inst != bb->rend(); ++inst) {
42
0
    if (SinkInstruction(&*inst)) {
43
0
      inst = bb->rbegin();
44
0
      modified = true;
45
0
    }
46
0
  }
47
0
  return modified;
48
0
}
49
50
0
bool CodeSinkingPass::SinkInstruction(Instruction* inst) {
51
0
  if (inst->opcode() != spv::Op::OpLoad &&
52
0
      inst->opcode() != spv::Op::OpAccessChain) {
53
0
    return false;
54
0
  }
55
56
0
  if (ReferencesMutableMemory(inst)) {
57
0
    return false;
58
0
  }
59
60
0
  if (BasicBlock* target_bb = FindNewBasicBlockFor(inst)) {
61
0
    Instruction* pos = &*target_bb->begin();
62
0
    while (pos->opcode() == spv::Op::OpPhi) {
63
0
      pos = pos->NextNode();
64
0
    }
65
66
0
    inst->InsertBefore(pos);
67
0
    context()->set_instr_block(inst, target_bb);
68
0
    return true;
69
0
  }
70
0
  return false;
71
0
}
72
73
0
BasicBlock* CodeSinkingPass::FindNewBasicBlockFor(Instruction* inst) {
74
0
  assert(inst->result_id() != 0 && "Instruction should have a result.");
75
0
  BasicBlock* original_bb = context()->get_instr_block(inst);
76
0
  BasicBlock* bb = original_bb;
77
78
0
  std::unordered_set<uint32_t> bbs_with_uses;
79
0
  get_def_use_mgr()->ForEachUse(
80
0
      inst, [&bbs_with_uses, this](Instruction* use, uint32_t idx) {
81
0
        if (use->opcode() != spv::Op::OpPhi) {
82
0
          BasicBlock* use_bb = context()->get_instr_block(use);
83
0
          if (use_bb) {
84
0
            bbs_with_uses.insert(use_bb->id());
85
0
          }
86
0
        } else {
87
0
          bbs_with_uses.insert(use->GetSingleWordOperand(idx + 1));
88
0
        }
89
0
      });
90
91
0
  while (true) {
92
    // If |inst| is used in |bb|, then |inst| cannot be moved any further.
93
0
    if (bbs_with_uses.count(bb->id())) {
94
0
      break;
95
0
    }
96
97
    // If |bb| has one successor (succ_bb), and |bb| is the only predecessor
98
    // of succ_bb, then |inst| can be moved to succ_bb.  If succ_bb, has move
99
    // then one predecessor, then moving |inst| into succ_bb could cause it to
100
    // be executed more often, so the search has to stop.
101
0
    if (bb->terminator()->opcode() == spv::Op::OpBranch) {
102
0
      uint32_t succ_bb_id = bb->terminator()->GetSingleWordInOperand(0);
103
0
      if (cfg()->preds(succ_bb_id).size() == 1) {
104
0
        bb = context()->get_instr_block(succ_bb_id);
105
0
        continue;
106
0
      } else {
107
0
        break;
108
0
      }
109
0
    }
110
111
    // The remaining checks need to know the merge node.  If there is no merge
112
    // instruction or an OpLoopMerge, then it is a break or continue.  We could
113
    // figure it out, but not worth doing it now.
114
0
    Instruction* merge_inst = bb->GetMergeInst();
115
0
    if (merge_inst == nullptr ||
116
0
        merge_inst->opcode() != spv::Op::OpSelectionMerge) {
117
0
      break;
118
0
    }
119
120
    // Check all of the successors of |bb| it see which lead to a use of |inst|
121
    // before reaching the merge node.
122
0
    bool used_in_multiple_blocks = false;
123
0
    uint32_t bb_used_in = 0;
124
0
    bb->ForEachSuccessorLabel([this, bb, &bb_used_in, &used_in_multiple_blocks,
125
0
                               &bbs_with_uses](uint32_t* succ_bb_id) {
126
0
      if (IntersectsPath(*succ_bb_id, bb->MergeBlockIdIfAny(), bbs_with_uses)) {
127
0
        if (bb_used_in == 0) {
128
0
          bb_used_in = *succ_bb_id;
129
0
        } else {
130
0
          used_in_multiple_blocks = true;
131
0
        }
132
0
      }
133
0
    });
134
135
    // If more than one successor, which is not the merge block, uses |inst|
136
    // then we have to leave |inst| in bb because there is none of the
137
    // successors dominate all uses of |inst|.
138
0
    if (used_in_multiple_blocks) {
139
0
      break;
140
0
    }
141
142
0
    if (bb_used_in == 0) {
143
      // If |inst| is not used before reaching the merge node, then we can move
144
      // |inst| to the merge node.
145
0
      bb = context()->get_instr_block(bb->MergeBlockIdIfAny());
146
0
    } else {
147
      // If the only successor that leads to a used of |inst| has more than 1
148
      // predecessor, then moving |inst| could cause it to be executed more
149
      // often, so we cannot move it.
150
0
      if (cfg()->preds(bb_used_in).size() != 1) {
151
0
        break;
152
0
      }
153
154
      // If |inst| is used after the merge block, then |bb_used_in| does not
155
      // dominate all of the uses.  So we cannot move |inst| any further.
156
0
      if (IntersectsPath(bb->MergeBlockIdIfAny(), original_bb->id(),
157
0
                         bbs_with_uses)) {
158
0
        break;
159
0
      }
160
161
      // Otherwise, |bb_used_in| dominates all uses, so move |inst| into that
162
      // block.
163
0
      bb = context()->get_instr_block(bb_used_in);
164
0
    }
165
0
    continue;
166
0
  }
167
0
  return (bb != original_bb ? bb : nullptr);
168
0
}
169
170
0
bool CodeSinkingPass::ReferencesMutableMemory(Instruction* inst) {
171
0
  if (!inst->IsLoad()) {
172
0
    return false;
173
0
  }
174
175
0
  Instruction* base_ptr = inst->GetBaseAddress();
176
0
  if (base_ptr->opcode() != spv::Op::OpVariable) {
177
0
    return true;
178
0
  }
179
180
0
  if (base_ptr->IsReadOnlyPointer()) {
181
0
    return false;
182
0
  }
183
184
0
  if (HasUniformMemorySync()) {
185
0
    return true;
186
0
  }
187
188
0
  if (spv::StorageClass(base_ptr->GetSingleWordInOperand(0)) !=
189
0
      spv::StorageClass::Uniform) {
190
0
    return true;
191
0
  }
192
193
0
  return HasPossibleStore(base_ptr);
194
0
}
195
196
0
bool CodeSinkingPass::HasUniformMemorySync() {
197
0
  if (checked_for_uniform_sync_) {
198
0
    return has_uniform_sync_;
199
0
  }
200
201
0
  bool has_sync = false;
202
0
  get_module()->ForEachInst([this, &has_sync](Instruction* inst) {
203
0
    switch (inst->opcode()) {
204
0
      case spv::Op::OpMemoryBarrier: {
205
0
        uint32_t mem_semantics_id = inst->GetSingleWordInOperand(1);
206
0
        if (IsSyncOnUniform(mem_semantics_id)) {
207
0
          has_sync = true;
208
0
        }
209
0
        break;
210
0
      }
211
0
      case spv::Op::OpControlBarrier:
212
0
      case spv::Op::OpAtomicLoad:
213
0
      case spv::Op::OpAtomicStore:
214
0
      case spv::Op::OpAtomicExchange:
215
0
      case spv::Op::OpAtomicIIncrement:
216
0
      case spv::Op::OpAtomicIDecrement:
217
0
      case spv::Op::OpAtomicIAdd:
218
0
      case spv::Op::OpAtomicFAddEXT:
219
0
      case spv::Op::OpAtomicISub:
220
0
      case spv::Op::OpAtomicSMin:
221
0
      case spv::Op::OpAtomicUMin:
222
0
      case spv::Op::OpAtomicFMinEXT:
223
0
      case spv::Op::OpAtomicSMax:
224
0
      case spv::Op::OpAtomicUMax:
225
0
      case spv::Op::OpAtomicFMaxEXT:
226
0
      case spv::Op::OpAtomicAnd:
227
0
      case spv::Op::OpAtomicOr:
228
0
      case spv::Op::OpAtomicXor:
229
0
      case spv::Op::OpAtomicFlagTestAndSet:
230
0
      case spv::Op::OpAtomicFlagClear: {
231
0
        uint32_t mem_semantics_id = inst->GetSingleWordInOperand(2);
232
0
        if (IsSyncOnUniform(mem_semantics_id)) {
233
0
          has_sync = true;
234
0
        }
235
0
        break;
236
0
      }
237
0
      case spv::Op::OpAtomicCompareExchange:
238
0
      case spv::Op::OpAtomicCompareExchangeWeak:
239
0
        if (IsSyncOnUniform(inst->GetSingleWordInOperand(2)) ||
240
0
            IsSyncOnUniform(inst->GetSingleWordInOperand(3))) {
241
0
          has_sync = true;
242
0
        }
243
0
        break;
244
0
      default:
245
0
        break;
246
0
    }
247
0
  });
248
0
  has_uniform_sync_ = has_sync;
249
0
  return has_sync;
250
0
}
251
252
0
bool CodeSinkingPass::IsSyncOnUniform(uint32_t mem_semantics_id) const {
253
0
  const analysis::Constant* mem_semantics_const =
254
0
      context()->get_constant_mgr()->FindDeclaredConstant(mem_semantics_id);
255
0
  assert(mem_semantics_const != nullptr &&
256
0
         "Expecting memory semantics id to be a constant.");
257
0
  assert(mem_semantics_const->AsIntConstant() &&
258
0
         "Memory semantics should be an integer.");
259
0
  uint32_t mem_semantics_int = mem_semantics_const->GetU32();
260
261
  // If it does not affect uniform memory, then it is does not apply to uniform
262
  // memory.
263
0
  if ((mem_semantics_int & uint32_t(spv::MemorySemanticsMask::UniformMemory)) ==
264
0
      0) {
265
0
    return false;
266
0
  }
267
268
  // Check if there is an acquire or release.  If so not, this it does not add
269
  // any memory constraints.
270
0
  return (mem_semantics_int &
271
0
          uint32_t(spv::MemorySemanticsMask::Acquire |
272
0
                   spv::MemorySemanticsMask::AcquireRelease |
273
0
                   spv::MemorySemanticsMask::Release)) != 0;
274
0
}
275
276
0
bool CodeSinkingPass::HasPossibleStore(Instruction* var_inst) {
277
0
  assert(var_inst->opcode() == spv::Op::OpVariable ||
278
0
         var_inst->opcode() == spv::Op::OpAccessChain ||
279
0
         var_inst->opcode() == spv::Op::OpPtrAccessChain);
280
281
0
  return get_def_use_mgr()->WhileEachUser(var_inst, [this](Instruction* use) {
282
0
    switch (use->opcode()) {
283
0
      case spv::Op::OpStore:
284
0
        return true;
285
0
      case spv::Op::OpAccessChain:
286
0
      case spv::Op::OpPtrAccessChain:
287
0
        return HasPossibleStore(use);
288
0
      default:
289
0
        return false;
290
0
    }
291
0
  });
292
0
}
293
294
bool CodeSinkingPass::IntersectsPath(uint32_t start, uint32_t end,
295
0
                                     const std::unordered_set<uint32_t>& set) {
296
0
  std::vector<uint32_t> worklist;
297
0
  worklist.push_back(start);
298
0
  std::unordered_set<uint32_t> already_done;
299
0
  already_done.insert(start);
300
301
0
  while (!worklist.empty()) {
302
0
    BasicBlock* bb = context()->get_instr_block(worklist.back());
303
0
    worklist.pop_back();
304
305
0
    if (bb->id() == end) {
306
0
      continue;
307
0
    }
308
309
0
    if (set.count(bb->id())) {
310
0
      return true;
311
0
    }
312
313
0
    bb->ForEachSuccessorLabel([&already_done, &worklist](uint32_t* succ_bb_id) {
314
0
      if (already_done.insert(*succ_bb_id).second) {
315
0
        worklist.push_back(*succ_bb_id);
316
0
      }
317
0
    });
318
0
  }
319
0
  return false;
320
0
}
321
322
// namespace opt
323
324
}  // namespace opt
325
}  // namespace spvtools