Coverage Report

Created: 2023-03-01 07:33

/src/spirv-tools/source/opt/mem_pass.cpp
Line
Count
Source (jump to first uncovered line)
1
// Copyright (c) 2017 The Khronos Group Inc.
2
// Copyright (c) 2017 Valve Corporation
3
// Copyright (c) 2017 LunarG Inc.
4
//
5
// Licensed under the Apache License, Version 2.0 (the "License");
6
// you may not use this file except in compliance with the License.
7
// You may obtain a copy of the License at
8
//
9
//     http://www.apache.org/licenses/LICENSE-2.0
10
//
11
// Unless required by applicable law or agreed to in writing, software
12
// distributed under the License is distributed on an "AS IS" BASIS,
13
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14
// See the License for the specific language governing permissions and
15
// limitations under the License.
16
17
#include "source/opt/mem_pass.h"
18
19
#include <memory>
20
#include <set>
21
#include <vector>
22
23
#include "source/cfa.h"
24
#include "source/opt/basic_block.h"
25
#include "source/opt/dominator_analysis.h"
26
#include "source/opt/ir_context.h"
27
#include "source/opt/iterator.h"
28
29
namespace spvtools {
30
namespace opt {
31
namespace {
32
constexpr uint32_t kCopyObjectOperandInIdx = 0;
33
constexpr uint32_t kTypePointerStorageClassInIdx = 0;
34
constexpr uint32_t kTypePointerTypeIdInIdx = 1;
35
}  // namespace
36
37
403k
bool MemPass::IsBaseTargetType(const Instruction* typeInst) const {
38
403k
  switch (typeInst->opcode()) {
39
279k
    case spv::Op::OpTypeInt:
40
301k
    case spv::Op::OpTypeFloat:
41
333k
    case spv::Op::OpTypeBool:
42
357k
    case spv::Op::OpTypeVector:
43
367k
    case spv::Op::OpTypeMatrix:
44
367k
    case spv::Op::OpTypeImage:
45
367k
    case spv::Op::OpTypeSampler:
46
367k
    case spv::Op::OpTypeSampledImage:
47
367k
    case spv::Op::OpTypePointer:
48
367k
      return true;
49
36.0k
    default:
50
36.0k
      break;
51
403k
  }
52
36.0k
  return false;
53
403k
}
54
55
403k
bool MemPass::IsTargetType(const Instruction* typeInst) const {
56
403k
  if (IsBaseTargetType(typeInst)) return true;
57
36.0k
  if (typeInst->opcode() == spv::Op::OpTypeArray) {
58
18.1k
    if (!IsTargetType(
59
18.1k
            get_def_use_mgr()->GetDef(typeInst->GetSingleWordOperand(1)))) {
60
0
      return false;
61
0
    }
62
18.1k
    return true;
63
18.1k
  }
64
17.9k
  if (typeInst->opcode() != spv::Op::OpTypeStruct) return false;
65
  // All struct members must be math type
66
53.5k
  return typeInst->WhileEachInId([this](const uint32_t* tid) {
67
53.5k
    Instruction* compTypeInst = get_def_use_mgr()->GetDef(*tid);
68
53.5k
    if (!IsTargetType(compTypeInst)) return false;
69
53.5k
    return true;
70
53.5k
  });
71
17.9k
}
72
73
3.19M
bool MemPass::IsNonPtrAccessChain(const spv::Op opcode) const {
74
3.19M
  return opcode == spv::Op::OpAccessChain ||
75
3.19M
         opcode == spv::Op::OpInBoundsAccessChain;
76
3.19M
}
77
78
1.03M
bool MemPass::IsPtr(uint32_t ptrId) {
79
1.03M
  uint32_t varId = ptrId;
80
1.03M
  Instruction* ptrInst = get_def_use_mgr()->GetDef(varId);
81
1.03M
  while (ptrInst->opcode() == spv::Op::OpCopyObject) {
82
3
    varId = ptrInst->GetSingleWordInOperand(kCopyObjectOperandInIdx);
83
3
    ptrInst = get_def_use_mgr()->GetDef(varId);
84
3
  }
85
1.03M
  const spv::Op op = ptrInst->opcode();
86
1.03M
  if (op == spv::Op::OpVariable || IsNonPtrAccessChain(op)) return true;
87
29.8k
  const uint32_t varTypeId = ptrInst->type_id();
88
29.8k
  if (varTypeId == 0) return false;
89
29.8k
  const Instruction* varTypeInst = get_def_use_mgr()->GetDef(varTypeId);
90
29.8k
  return varTypeInst->opcode() == spv::Op::OpTypePointer;
91
29.8k
}
92
93
5.15M
Instruction* MemPass::GetPtr(uint32_t ptrId, uint32_t* varId) {
94
5.15M
  *varId = ptrId;
95
5.15M
  Instruction* ptrInst = get_def_use_mgr()->GetDef(*varId);
96
5.15M
  Instruction* varInst;
97
98
5.15M
  if (ptrInst->opcode() == spv::Op::OpConstantNull) {
99
0
    *varId = 0;
100
0
    return ptrInst;
101
0
  }
102
103
5.15M
  if (ptrInst->opcode() != spv::Op::OpVariable &&
104
5.15M
      ptrInst->opcode() != spv::Op::OpFunctionParameter) {
105
1.95M
    varInst = ptrInst->GetBaseAddress();
106
3.19M
  } else {
107
3.19M
    varInst = ptrInst;
108
3.19M
  }
109
5.15M
  if (varInst->opcode() == spv::Op::OpVariable) {
110
5.12M
    *varId = varInst->result_id();
111
5.12M
  } else {
112
30.4k
    *varId = 0;
113
30.4k
  }
114
115
5.15M
  while (ptrInst->opcode() == spv::Op::OpCopyObject) {
116
513
    uint32_t temp = ptrInst->GetSingleWordInOperand(0);
117
513
    ptrInst = get_def_use_mgr()->GetDef(temp);
118
513
  }
119
120
5.15M
  return ptrInst;
121
5.15M
}
122
123
4.17M
Instruction* MemPass::GetPtr(Instruction* ip, uint32_t* varId) {
124
4.17M
  assert(ip->opcode() == spv::Op::OpStore || ip->opcode() == spv::Op::OpLoad ||
125
4.17M
         ip->opcode() == spv::Op::OpImageTexelPointer ||
126
4.17M
         ip->IsAtomicWithLoad());
127
128
  // All of these opcode place the pointer in position 0.
129
0
  const uint32_t ptrId = ip->GetSingleWordInOperand(0);
130
4.17M
  return GetPtr(ptrId, varId);
131
4.17M
}
132
133
66.9k
bool MemPass::HasOnlyNamesAndDecorates(uint32_t id) const {
134
66.9k
  return get_def_use_mgr()->WhileEachUser(id, [this](Instruction* user) {
135
50.5k
    spv::Op op = user->opcode();
136
50.5k
    if (op != spv::Op::OpName && !IsNonTypeDecorate(op)) {
137
48.8k
      return false;
138
48.8k
    }
139
1.76k
    return true;
140
50.5k
  });
141
66.9k
}
142
143
70.1k
void MemPass::KillAllInsts(BasicBlock* bp, bool killLabel) {
144
70.1k
  bp->KillAllInsts(killLabel);
145
70.1k
}
146
147
1
bool MemPass::HasLoads(uint32_t varId) const {
148
3
  return !get_def_use_mgr()->WhileEachUser(varId, [this](Instruction* user) {
149
3
    spv::Op op = user->opcode();
150
    // TODO(): The following is slightly conservative. Could be
151
    // better handling of non-store/name.
152
3
    if (IsNonPtrAccessChain(op) || op == spv::Op::OpCopyObject) {
153
0
      if (HasLoads(user->result_id())) {
154
0
        return false;
155
0
      }
156
3
    } else if (op != spv::Op::OpStore && op != spv::Op::OpName &&
157
3
               !IsNonTypeDecorate(op)) {
158
1
      return false;
159
1
    }
160
2
    return true;
161
3
  });
162
1
}
163
164
11
bool MemPass::IsLiveVar(uint32_t varId) const {
165
11
  const Instruction* varInst = get_def_use_mgr()->GetDef(varId);
166
  // assume live if not a variable eg. function parameter
167
11
  if (varInst->opcode() != spv::Op::OpVariable) return true;
168
  // non-function scope vars are live
169
11
  const uint32_t varTypeId = varInst->type_id();
170
11
  const Instruction* varTypeInst = get_def_use_mgr()->GetDef(varTypeId);
171
11
  if (spv::StorageClass(varTypeInst->GetSingleWordInOperand(
172
11
          kTypePointerStorageClassInIdx)) != spv::StorageClass::Function)
173
10
    return true;
174
  // test if variable is loaded from
175
1
  return HasLoads(varId);
176
11
}
177
178
0
void MemPass::AddStores(uint32_t ptr_id, std::queue<Instruction*>* insts) {
179
0
  get_def_use_mgr()->ForEachUser(ptr_id, [this, insts](Instruction* user) {
180
0
    spv::Op op = user->opcode();
181
0
    if (IsNonPtrAccessChain(op)) {
182
0
      AddStores(user->result_id(), insts);
183
0
    } else if (op == spv::Op::OpStore) {
184
0
      insts->push(user);
185
0
    }
186
0
  });
187
0
}
188
189
void MemPass::DCEInst(Instruction* inst,
190
16.7k
                      const std::function<void(Instruction*)>& call_back) {
191
16.7k
  std::queue<Instruction*> deadInsts;
192
16.7k
  deadInsts.push(inst);
193
51.6k
  while (!deadInsts.empty()) {
194
34.8k
    Instruction* di = deadInsts.front();
195
    // Don't delete labels
196
34.8k
    if (di->opcode() == spv::Op::OpLabel) {
197
0
      deadInsts.pop();
198
0
      continue;
199
0
    }
200
    // Remember operands
201
34.8k
    std::set<uint32_t> ids;
202
66.9k
    di->ForEachInId([&ids](uint32_t* iid) { ids.insert(*iid); });
203
34.8k
    uint32_t varId = 0;
204
    // Remember variable if dead load
205
34.8k
    if (di->opcode() == spv::Op::OpLoad) (void)GetPtr(di, &varId);
206
34.8k
    if (call_back) {
207
34.8k
      call_back(di);
208
34.8k
    }
209
34.8k
    context()->KillInst(di);
210
    // For all operands with no remaining uses, add their instruction
211
    // to the dead instruction queue.
212
34.8k
    for (auto id : ids)
213
66.9k
      if (HasOnlyNamesAndDecorates(id)) {
214
18.1k
        Instruction* odi = get_def_use_mgr()->GetDef(id);
215
18.1k
        if (context()->IsCombinatorInstruction(odi)) deadInsts.push(odi);
216
18.1k
      }
217
    // if a load was deleted and it was the variable's
218
    // last load, add all its stores to dead queue
219
34.8k
    if (varId != 0 && !IsLiveVar(varId)) AddStores(varId, &deadInsts);
220
34.8k
    deadInsts.pop();
221
34.8k
  }
222
16.7k
}
223
224
427k
MemPass::MemPass() {}
225
226
318k
bool MemPass::HasOnlySupportedRefs(uint32_t varId) {
227
3.98M
  return get_def_use_mgr()->WhileEachUser(varId, [this](Instruction* user) {
228
3.98M
    auto dbg_op = user->GetCommonDebugOpcode();
229
3.98M
    if (dbg_op == CommonDebugInfoDebugDeclare ||
230
3.98M
        dbg_op == CommonDebugInfoDebugValue) {
231
0
      return true;
232
0
    }
233
3.98M
    spv::Op op = user->opcode();
234
3.98M
    if (op != spv::Op::OpStore && op != spv::Op::OpLoad &&
235
3.98M
        op != spv::Op::OpName && !IsNonTypeDecorate(op)) {
236
18.5k
      return false;
237
18.5k
    }
238
3.96M
    return true;
239
3.98M
  });
240
318k
}
241
242
208k
uint32_t MemPass::Type2Undef(uint32_t type_id) {
243
208k
  const auto uitr = type2undefs_.find(type_id);
244
208k
  if (uitr != type2undefs_.end()) return uitr->second;
245
7.76k
  const uint32_t undefId = TakeNextId();
246
7.76k
  if (undefId == 0) {
247
0
    return 0;
248
0
  }
249
250
7.76k
  std::unique_ptr<Instruction> undef_inst(
251
7.76k
      new Instruction(context(), spv::Op::OpUndef, type_id, undefId, {}));
252
7.76k
  get_def_use_mgr()->AnalyzeInstDefUse(&*undef_inst);
253
7.76k
  get_module()->AddGlobalValue(std::move(undef_inst));
254
7.76k
  type2undefs_[type_id] = undefId;
255
7.76k
  return undefId;
256
7.76k
}
257
258
3.06M
bool MemPass::IsTargetVar(uint32_t varId) {
259
3.06M
  if (varId == 0) {
260
112k
    return false;
261
112k
  }
262
263
2.95M
  if (seen_non_target_vars_.find(varId) != seen_non_target_vars_.end())
264
1.04M
    return false;
265
1.91M
  if (seen_target_vars_.find(varId) != seen_target_vars_.end()) return true;
266
390k
  const Instruction* varInst = get_def_use_mgr()->GetDef(varId);
267
390k
  if (varInst->opcode() != spv::Op::OpVariable) return false;
268
390k
  const uint32_t varTypeId = varInst->type_id();
269
390k
  const Instruction* varTypeInst = get_def_use_mgr()->GetDef(varTypeId);
270
390k
  if (spv::StorageClass(varTypeInst->GetSingleWordInOperand(
271
390k
          kTypePointerStorageClassInIdx)) != spv::StorageClass::Function) {
272
58.0k
    seen_non_target_vars_.insert(varId);
273
58.0k
    return false;
274
58.0k
  }
275
332k
  const uint32_t varPteTypeId =
276
332k
      varTypeInst->GetSingleWordInOperand(kTypePointerTypeIdInIdx);
277
332k
  Instruction* varPteTypeInst = get_def_use_mgr()->GetDef(varPteTypeId);
278
332k
  if (!IsTargetType(varPteTypeInst)) {
279
0
    seen_non_target_vars_.insert(varId);
280
0
    return false;
281
0
  }
282
332k
  seen_target_vars_.insert(varId);
283
332k
  return true;
284
332k
}
285
286
// Remove all |phi| operands coming from unreachable blocks (i.e., blocks not in
287
// |reachable_blocks|).  There are two types of removal that this function can
288
// perform:
289
//
290
// 1- Any operand that comes directly from an unreachable block is completely
291
//    removed.  Since the block is unreachable, the edge between the unreachable
292
//    block and the block holding |phi| has been removed.
293
//
294
// 2- Any operand that comes via a live block and was defined at an unreachable
295
//    block gets its value replaced with an OpUndef value. Since the argument
296
//    was generated in an unreachable block, it no longer exists, so it cannot
297
//    be referenced.  However, since the value does not reach |phi| directly
298
//    from the unreachable block, the operand cannot be removed from |phi|.
299
//    Therefore, we replace the argument value with OpUndef.
300
//
301
// For example, in the switch() below, assume that we want to remove the
302
// argument with value %11 coming from block %41.
303
//
304
//          [ ... ]
305
//          %41 = OpLabel                    <--- Unreachable block
306
//          %11 = OpLoad %int %y
307
//          [ ... ]
308
//                OpSelectionMerge %16 None
309
//                OpSwitch %12 %16 10 %13 13 %14 18 %15
310
//          %13 = OpLabel
311
//                OpBranch %16
312
//          %14 = OpLabel
313
//                OpStore %outparm %int_14
314
//                OpBranch %16
315
//          %15 = OpLabel
316
//                OpStore %outparm %int_15
317
//                OpBranch %16
318
//          %16 = OpLabel
319
//          %30 = OpPhi %int %11 %41 %int_42 %13 %11 %14 %11 %15
320
//
321
// Since %41 is now an unreachable block, the first operand of |phi| needs to
322
// be removed completely.  But the operands (%11 %14) and (%11 %15) cannot be
323
// removed because %14 and %15 are reachable blocks.  Since %11 no longer exist,
324
// in those arguments, we replace all references to %11 with an OpUndef value.
325
// This results in |phi| looking like:
326
//
327
//           %50 = OpUndef %int
328
//           [ ... ]
329
//           %30 = OpPhi %int %int_42 %13 %50 %14 %50 %15
330
void MemPass::RemovePhiOperands(
331
374k
    Instruction* phi, const std::unordered_set<BasicBlock*>& reachable_blocks) {
332
374k
  std::vector<Operand> keep_operands;
333
374k
  uint32_t type_id = 0;
334
  // The id of an undefined value we've generated.
335
374k
  uint32_t undef_id = 0;
336
337
  // Traverse all the operands in |phi|. Build the new operand vector by adding
338
  // all the original operands from |phi| except the unwanted ones.
339
1.94M
  for (uint32_t i = 0; i < phi->NumOperands();) {
340
1.57M
    if (i < 2) {
341
      // The first two arguments are always preserved.
342
748k
      keep_operands.push_back(phi->GetOperand(i));
343
748k
      ++i;
344
748k
      continue;
345
748k
    }
346
347
    // The remaining Phi arguments come in pairs. Index 'i' contains the
348
    // variable id, index 'i + 1' is the originating block id.
349
824k
    assert(i % 2 == 0 && i < phi->NumOperands() - 1 &&
350
824k
           "malformed Phi arguments");
351
352
0
    BasicBlock* in_block = cfg()->block(phi->GetSingleWordOperand(i + 1));
353
824k
    if (reachable_blocks.find(in_block) == reachable_blocks.end()) {
354
      // If the incoming block is unreachable, remove both operands as this
355
      // means that the |phi| has lost an incoming edge.
356
0
      i += 2;
357
0
      continue;
358
0
    }
359
360
    // In all other cases, the operand must be kept but may need to be changed.
361
824k
    uint32_t arg_id = phi->GetSingleWordOperand(i);
362
824k
    Instruction* arg_def_instr = get_def_use_mgr()->GetDef(arg_id);
363
824k
    BasicBlock* def_block = context()->get_instr_block(arg_def_instr);
364
824k
    if (def_block &&
365
824k
        reachable_blocks.find(def_block) == reachable_blocks.end()) {
366
      // If the current |phi| argument was defined in an unreachable block, it
367
      // means that this |phi| argument is no longer defined. Replace it with
368
      // |undef_id|.
369
0
      if (!undef_id) {
370
0
        type_id = arg_def_instr->type_id();
371
0
        undef_id = Type2Undef(type_id);
372
0
      }
373
0
      keep_operands.push_back(
374
0
          Operand(spv_operand_type_t::SPV_OPERAND_TYPE_ID, {undef_id}));
375
824k
    } else {
376
      // Otherwise, the argument comes from a reachable block or from no block
377
      // at all (meaning that it was defined in the global section of the
378
      // program).  In both cases, keep the argument intact.
379
824k
      keep_operands.push_back(phi->GetOperand(i));
380
824k
    }
381
382
824k
    keep_operands.push_back(phi->GetOperand(i + 1));
383
384
824k
    i += 2;
385
824k
  }
386
387
374k
  context()->ForgetUses(phi);
388
374k
  phi->ReplaceOperands(keep_operands);
389
374k
  context()->AnalyzeUses(phi);
390
374k
}
391
392
79.0k
void MemPass::RemoveBlock(Function::iterator* bi) {
393
79.0k
  auto& rm_block = **bi;
394
395
  // Remove instructions from the block.
396
298k
  rm_block.ForEachInst([&rm_block, this](Instruction* inst) {
397
    // Note that we do not kill the block label instruction here. The label
398
    // instruction is needed to identify the block, which is needed by the
399
    // removal of phi operands.
400
298k
    if (inst != rm_block.GetLabelInst()) {
401
219k
      context()->KillInst(inst);
402
219k
    }
403
298k
  });
404
405
  // Remove the label instruction last.
406
79.0k
  auto label = rm_block.GetLabelInst();
407
79.0k
  context()->KillInst(label);
408
409
79.0k
  *bi = bi->Erase();
410
79.0k
}
411
412
51.8k
bool MemPass::RemoveUnreachableBlocks(Function* func) {
413
51.8k
  bool modified = false;
414
415
  // Mark reachable all blocks reachable from the function's entry block.
416
51.8k
  std::unordered_set<BasicBlock*> reachable_blocks;
417
51.8k
  std::unordered_set<BasicBlock*> visited_blocks;
418
51.8k
  std::queue<BasicBlock*> worklist;
419
51.8k
  reachable_blocks.insert(func->entry().get());
420
421
  // Initially mark the function entry point as reachable.
422
51.8k
  worklist.push(func->entry().get());
423
424
51.8k
  auto mark_reachable = [&reachable_blocks, &visited_blocks, &worklist,
425
3.23M
                         this](uint32_t label_id) {
426
3.23M
    auto successor = cfg()->block(label_id);
427
3.23M
    if (visited_blocks.count(successor) == 0) {
428
2.14M
      reachable_blocks.insert(successor);
429
2.14M
      worklist.push(successor);
430
2.14M
      visited_blocks.insert(successor);
431
2.14M
    }
432
3.23M
  };
433
434
  // Transitively mark all blocks reachable from the entry as reachable.
435
2.24M
  while (!worklist.empty()) {
436
2.19M
    BasicBlock* block = worklist.front();
437
2.19M
    worklist.pop();
438
439
    // All the successors of a live block are also live.
440
2.19M
    static_cast<const BasicBlock*>(block)->ForEachSuccessorLabel(
441
2.19M
        mark_reachable);
442
443
    // All the Merge and ContinueTarget blocks of a live block are also live.
444
2.19M
    block->ForMergeAndContinueLabel(mark_reachable);
445
2.19M
  }
446
447
  // Update operands of Phi nodes that reference unreachable blocks.
448
2.27M
  for (auto& block : *func) {
449
    // If the block is about to be removed, don't bother updating its
450
    // Phi instructions.
451
2.27M
    if (reachable_blocks.count(&block) == 0) {
452
79.0k
      continue;
453
79.0k
    }
454
455
    // If the block is reachable and has Phi instructions, remove all
456
    // operands from its Phi instructions that reference unreachable blocks.
457
    // If the block has no Phi instructions, this is a no-op.
458
2.19M
    block.ForEachPhiInst([&reachable_blocks, this](Instruction* phi) {
459
374k
      RemovePhiOperands(phi, reachable_blocks);
460
374k
    });
461
2.19M
  }
462
463
  // Erase unreachable blocks.
464
2.32M
  for (auto ebi = func->begin(); ebi != func->end();) {
465
2.27M
    if (reachable_blocks.count(&*ebi) == 0) {
466
79.0k
      RemoveBlock(&ebi);
467
79.0k
      modified = true;
468
2.19M
    } else {
469
2.19M
      ++ebi;
470
2.19M
    }
471
2.27M
  }
472
473
51.8k
  return modified;
474
51.8k
}
475
476
51.8k
bool MemPass::CFGCleanup(Function* func) {
477
51.8k
  bool modified = false;
478
51.8k
  modified |= RemoveUnreachableBlocks(func);
479
51.8k
  return modified;
480
51.8k
}
481
482
16.4k
void MemPass::CollectTargetVars(Function* func) {
483
16.4k
  seen_target_vars_.clear();
484
16.4k
  seen_non_target_vars_.clear();
485
16.4k
  type2undefs_.clear();
486
487
  // Collect target (and non-) variable sets. Remove variables with
488
  // non-load/store refs from target variable set
489
403k
  for (auto& blk : *func) {
490
2.18M
    for (auto& inst : blk) {
491
2.18M
      switch (inst.opcode()) {
492
261k
        case spv::Op::OpStore:
493
625k
        case spv::Op::OpLoad: {
494
625k
          uint32_t varId;
495
625k
          (void)GetPtr(&inst, &varId);
496
625k
          if (!IsTargetVar(varId)) break;
497
318k
          if (HasOnlySupportedRefs(varId)) break;
498
18.5k
          seen_non_target_vars_.insert(varId);
499
18.5k
          seen_target_vars_.erase(varId);
500
18.5k
        } break;
501
1.55M
        default:
502
1.55M
          break;
503
2.18M
      }
504
2.18M
    }
505
403k
  }
506
16.4k
}
507
508
}  // namespace opt
509
}  // namespace spvtools