Coverage Report

Created: 2025-07-23 06:18

/src/spirv-tools/source/opt/scalar_replacement_pass.cpp
Line
Count
Source (jump to first uncovered line)
1
// Copyright (c) 2017 Google Inc.
2
// Modifications Copyright (C) 2024 Advanced Micro Devices, Inc. All rights
3
// reserved.
4
//
5
// Licensed under the Apache License, Version 2.0 (the "License");
6
// you may not use this file except in compliance with the License.
7
// You may obtain a copy of the License at
8
//
9
//     http://www.apache.org/licenses/LICENSE-2.0
10
//
11
// Unless required by applicable law or agreed to in writing, software
12
// distributed under the License is distributed on an "AS IS" BASIS,
13
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14
// See the License for the specific language governing permissions and
15
// limitations under the License.
16
17
#include "source/opt/scalar_replacement_pass.h"
18
19
#include <algorithm>
20
#include <queue>
21
#include <tuple>
22
#include <utility>
23
24
#include "source/extensions.h"
25
#include "source/opt/reflect.h"
26
#include "source/opt/types.h"
27
#include "source/util/make_unique.h"
28
29
namespace spvtools {
30
namespace opt {
31
namespace {
32
constexpr uint32_t kDebugValueOperandValueIndex = 5;
33
constexpr uint32_t kDebugValueOperandExpressionIndex = 6;
34
constexpr uint32_t kDebugDeclareOperandVariableIndex = 5;
35
}  // namespace
36
37
22.7k
Pass::Status ScalarReplacementPass::Process() {
38
22.7k
  Status status = Status::SuccessWithoutChange;
39
22.7k
  for (auto& f : *get_module()) {
40
22.0k
    if (f.IsDeclaration()) {
41
0
      continue;
42
0
    }
43
44
22.0k
    Status functionStatus = ProcessFunction(&f);
45
22.0k
    if (functionStatus == Status::Failure)
46
0
      return functionStatus;
47
22.0k
    else if (functionStatus == Status::SuccessWithChange)
48
2.63k
      status = functionStatus;
49
22.0k
  }
50
51
22.7k
  return status;
52
22.7k
}
53
54
22.0k
Pass::Status ScalarReplacementPass::ProcessFunction(Function* function) {
55
22.0k
  std::queue<Instruction*> worklist;
56
22.0k
  BasicBlock& entry = *function->begin();
57
198k
  for (auto iter = entry.begin(); iter != entry.end(); ++iter) {
58
    // Function storage class OpVariables must appear as the first instructions
59
    // of the entry block.
60
198k
    if (iter->opcode() != spv::Op::OpVariable) break;
61
62
176k
    Instruction* varInst = &*iter;
63
176k
    if (CanReplaceVariable(varInst)) {
64
15.5k
      worklist.push(varInst);
65
15.5k
    }
66
176k
  }
67
68
22.0k
  Status status = Status::SuccessWithoutChange;
69
37.6k
  while (!worklist.empty()) {
70
15.6k
    Instruction* varInst = worklist.front();
71
15.6k
    worklist.pop();
72
73
15.6k
    Status var_status = ReplaceVariable(varInst, &worklist);
74
15.6k
    if (var_status == Status::Failure)
75
0
      return var_status;
76
15.6k
    else if (var_status == Status::SuccessWithChange)
77
15.6k
      status = var_status;
78
15.6k
  }
79
80
22.0k
  return status;
81
22.0k
}
82
83
Pass::Status ScalarReplacementPass::ReplaceVariable(
84
15.6k
    Instruction* inst, std::queue<Instruction*>* worklist) {
85
15.6k
  std::vector<Instruction*> replacements;
86
15.6k
  if (!CreateReplacementVariables(inst, &replacements)) {
87
0
    return Status::Failure;
88
0
  }
89
90
15.6k
  std::vector<Instruction*> dead;
91
15.6k
  bool replaced_all_uses = get_def_use_mgr()->WhileEachUser(
92
91.9k
      inst, [this, &replacements, &dead](Instruction* user) {
93
91.9k
        if (user->GetCommonDebugOpcode() == CommonDebugInfoDebugDeclare) {
94
0
          if (ReplaceWholeDebugDeclare(user, replacements)) {
95
0
            dead.push_back(user);
96
0
            return true;
97
0
          }
98
0
          return false;
99
0
        }
100
91.9k
        if (user->GetCommonDebugOpcode() == CommonDebugInfoDebugValue) {
101
0
          if (ReplaceWholeDebugValue(user, replacements)) {
102
0
            dead.push_back(user);
103
0
            return true;
104
0
          }
105
0
          return false;
106
0
        }
107
91.9k
        if (!IsAnnotationInst(user->opcode())) {
108
91.9k
          switch (user->opcode()) {
109
9.83k
            case spv::Op::OpLoad:
110
9.83k
              if (ReplaceWholeLoad(user, replacements)) {
111
9.83k
                dead.push_back(user);
112
9.83k
              } else {
113
0
                return false;
114
0
              }
115
9.83k
              break;
116
26.3k
            case spv::Op::OpStore:
117
26.3k
              if (ReplaceWholeStore(user, replacements)) {
118
26.3k
                dead.push_back(user);
119
26.3k
              } else {
120
0
                return false;
121
0
              }
122
26.3k
              break;
123
52.8k
            case spv::Op::OpAccessChain:
124
54.6k
            case spv::Op::OpInBoundsAccessChain:
125
54.6k
              if (ReplaceAccessChain(user, replacements))
126
54.6k
                dead.push_back(user);
127
0
              else
128
0
                return false;
129
54.6k
              break;
130
54.6k
            case spv::Op::OpName:
131
1.09k
            case spv::Op::OpMemberName:
132
1.09k
              break;
133
0
            default:
134
0
              assert(false && "Unexpected opcode");
135
0
              break;
136
91.9k
          }
137
91.9k
        }
138
91.9k
        return true;
139
91.9k
      });
140
141
15.6k
  if (replaced_all_uses) {
142
15.6k
    dead.push_back(inst);
143
15.6k
  } else {
144
0
    return Status::Failure;
145
0
  }
146
147
  // If there are no dead instructions to clean up, return with no changes.
148
15.6k
  if (dead.empty()) return Status::SuccessWithoutChange;
149
150
  // Clean up some dead code.
151
122k
  while (!dead.empty()) {
152
106k
    Instruction* toKill = dead.back();
153
106k
    dead.pop_back();
154
106k
    context()->KillInst(toKill);
155
106k
  }
156
157
  // Attempt to further scalarize.
158
129k
  for (auto var : replacements) {
159
129k
    if (var->opcode() == spv::Op::OpVariable) {
160
79.0k
      if (get_def_use_mgr()->NumUsers(var) == 0) {
161
11
        context()->KillInst(var);
162
79.0k
      } else if (CanReplaceVariable(var)) {
163
57
        worklist->push(var);
164
57
      }
165
79.0k
    }
166
129k
  }
167
168
15.6k
  return Status::SuccessWithChange;
169
15.6k
}
170
171
bool ScalarReplacementPass::ReplaceWholeDebugDeclare(
172
0
    Instruction* dbg_decl, const std::vector<Instruction*>& replacements) {
173
  // Insert Deref operation to the front of the operation list of |dbg_decl|.
174
0
  Instruction* dbg_expr = context()->get_def_use_mgr()->GetDef(
175
0
      dbg_decl->GetSingleWordOperand(kDebugValueOperandExpressionIndex));
176
0
  auto* deref_expr =
177
0
      context()->get_debug_info_mgr()->DerefDebugExpression(dbg_expr);
178
179
  // Add DebugValue instruction with Indexes operand and Deref operation.
180
0
  int32_t idx = 0;
181
0
  for (const auto* var : replacements) {
182
0
    Instruction* insert_before = var->NextNode();
183
0
    while (insert_before->opcode() == spv::Op::OpVariable)
184
0
      insert_before = insert_before->NextNode();
185
0
    assert(insert_before != nullptr && "unexpected end of list");
186
0
    Instruction* added_dbg_value =
187
0
        context()->get_debug_info_mgr()->AddDebugValueForDecl(
188
0
            dbg_decl, /*value_id=*/var->result_id(),
189
0
            /*insert_before=*/insert_before, /*line=*/dbg_decl);
190
191
0
    if (added_dbg_value == nullptr) return false;
192
0
    added_dbg_value->AddOperand(
193
0
        {SPV_OPERAND_TYPE_ID,
194
0
         {context()->get_constant_mgr()->GetSIntConstId(idx)}});
195
0
    added_dbg_value->SetOperand(kDebugValueOperandExpressionIndex,
196
0
                                {deref_expr->result_id()});
197
0
    if (context()->AreAnalysesValid(IRContext::Analysis::kAnalysisDefUse)) {
198
0
      context()->get_def_use_mgr()->AnalyzeInstUse(added_dbg_value);
199
0
    }
200
0
    ++idx;
201
0
  }
202
0
  return true;
203
0
}
204
205
bool ScalarReplacementPass::ReplaceWholeDebugValue(
206
0
    Instruction* dbg_value, const std::vector<Instruction*>& replacements) {
207
0
  int32_t idx = 0;
208
0
  BasicBlock* block = context()->get_instr_block(dbg_value);
209
0
  for (auto var : replacements) {
210
    // Clone the DebugValue.
211
0
    std::unique_ptr<Instruction> new_dbg_value(dbg_value->Clone(context()));
212
0
    uint32_t new_id = TakeNextId();
213
0
    if (new_id == 0) return false;
214
0
    new_dbg_value->SetResultId(new_id);
215
    // Update 'Value' operand to the |replacements|.
216
0
    new_dbg_value->SetOperand(kDebugValueOperandValueIndex, {var->result_id()});
217
    // Append 'Indexes' operand.
218
0
    new_dbg_value->AddOperand(
219
0
        {SPV_OPERAND_TYPE_ID,
220
0
         {context()->get_constant_mgr()->GetSIntConstId(idx)}});
221
    // Insert the new DebugValue to the basic block.
222
0
    auto* added_instr = dbg_value->InsertBefore(std::move(new_dbg_value));
223
0
    get_def_use_mgr()->AnalyzeInstDefUse(added_instr);
224
0
    context()->set_instr_block(added_instr, block);
225
0
    ++idx;
226
0
  }
227
0
  return true;
228
0
}
229
230
bool ScalarReplacementPass::ReplaceWholeLoad(
231
9.83k
    Instruction* load, const std::vector<Instruction*>& replacements) {
232
  // Replaces the load of the entire composite with a load from each replacement
233
  // variable followed by a composite construction.
234
9.83k
  BasicBlock* block = context()->get_instr_block(load);
235
9.83k
  std::vector<Instruction*> loads;
236
9.83k
  loads.reserve(replacements.size());
237
9.83k
  BasicBlock::iterator where(load);
238
96.1k
  for (auto var : replacements) {
239
    // Create a load of each replacement variable.
240
96.1k
    if (var->opcode() != spv::Op::OpVariable) {
241
8.12k
      loads.push_back(var);
242
8.12k
      continue;
243
8.12k
    }
244
245
87.9k
    Instruction* type = GetStorageType(var);
246
87.9k
    uint32_t loadId = TakeNextId();
247
87.9k
    if (loadId == 0) {
248
0
      return false;
249
0
    }
250
87.9k
    std::unique_ptr<Instruction> newLoad(
251
87.9k
        new Instruction(context(), spv::Op::OpLoad, type->result_id(), loadId,
252
87.9k
                        std::initializer_list<Operand>{
253
87.9k
                            {SPV_OPERAND_TYPE_ID, {var->result_id()}}}));
254
    // Copy memory access attributes which start at index 1. Index 0 is the
255
    // pointer to load.
256
87.9k
    for (uint32_t i = 1; i < load->NumInOperands(); ++i) {
257
0
      Operand copy(load->GetInOperand(i));
258
0
      newLoad->AddOperand(std::move(copy));
259
0
    }
260
87.9k
    where = where.InsertBefore(std::move(newLoad));
261
87.9k
    get_def_use_mgr()->AnalyzeInstDefUse(&*where);
262
87.9k
    context()->set_instr_block(&*where, block);
263
87.9k
    where->UpdateDebugInfoFrom(load);
264
87.9k
    loads.push_back(&*where);
265
87.9k
  }
266
267
  // Construct a new composite.
268
9.83k
  uint32_t compositeId = TakeNextId();
269
9.83k
  if (compositeId == 0) {
270
0
    return false;
271
0
  }
272
9.83k
  where = load;
273
9.83k
  std::unique_ptr<Instruction> compositeConstruct(
274
9.83k
      new Instruction(context(), spv::Op::OpCompositeConstruct, load->type_id(),
275
9.83k
                      compositeId, {}));
276
96.1k
  for (auto l : loads) {
277
96.1k
    Operand op(SPV_OPERAND_TYPE_ID,
278
96.1k
               std::initializer_list<uint32_t>{l->result_id()});
279
96.1k
    compositeConstruct->AddOperand(std::move(op));
280
96.1k
  }
281
9.83k
  where = where.InsertBefore(std::move(compositeConstruct));
282
9.83k
  get_def_use_mgr()->AnalyzeInstDefUse(&*where);
283
9.83k
  where->UpdateDebugInfoFrom(load);
284
9.83k
  context()->set_instr_block(&*where, block);
285
9.83k
  context()->ReplaceAllUsesWith(load->result_id(), compositeId);
286
9.83k
  return true;
287
9.83k
}
288
289
bool ScalarReplacementPass::ReplaceWholeStore(
290
26.3k
    Instruction* store, const std::vector<Instruction*>& replacements) {
291
  // Replaces a store to the whole composite with a series of extract and stores
292
  // to each element.
293
26.3k
  uint32_t storeInput = store->GetSingleWordInOperand(1u);
294
26.3k
  BasicBlock* block = context()->get_instr_block(store);
295
26.3k
  BasicBlock::iterator where(store);
296
26.3k
  uint32_t elementIndex = 0;
297
225k
  for (auto var : replacements) {
298
    // Create the extract.
299
225k
    if (var->opcode() != spv::Op::OpVariable) {
300
135k
      elementIndex++;
301
135k
      continue;
302
135k
    }
303
304
89.4k
    Instruction* type = GetStorageType(var);
305
89.4k
    uint32_t extractId = TakeNextId();
306
89.4k
    if (extractId == 0) {
307
0
      return false;
308
0
    }
309
89.4k
    std::unique_ptr<Instruction> extract(new Instruction(
310
89.4k
        context(), spv::Op::OpCompositeExtract, type->result_id(), extractId,
311
89.4k
        std::initializer_list<Operand>{
312
89.4k
            {SPV_OPERAND_TYPE_ID, {storeInput}},
313
89.4k
            {SPV_OPERAND_TYPE_LITERAL_INTEGER, {elementIndex++}}}));
314
89.4k
    auto iter = where.InsertBefore(std::move(extract));
315
89.4k
    iter->UpdateDebugInfoFrom(store);
316
89.4k
    get_def_use_mgr()->AnalyzeInstDefUse(&*iter);
317
89.4k
    context()->set_instr_block(&*iter, block);
318
319
    // Create the store.
320
89.4k
    std::unique_ptr<Instruction> newStore(
321
89.4k
        new Instruction(context(), spv::Op::OpStore, 0, 0,
322
89.4k
                        std::initializer_list<Operand>{
323
89.4k
                            {SPV_OPERAND_TYPE_ID, {var->result_id()}},
324
89.4k
                            {SPV_OPERAND_TYPE_ID, {extractId}}}));
325
    // Copy memory access attributes which start at index 2. Index 0 is the
326
    // pointer and index 1 is the data.
327
89.4k
    for (uint32_t i = 2; i < store->NumInOperands(); ++i) {
328
14
      Operand copy(store->GetInOperand(i));
329
14
      newStore->AddOperand(std::move(copy));
330
14
    }
331
89.4k
    iter = where.InsertBefore(std::move(newStore));
332
89.4k
    iter->UpdateDebugInfoFrom(store);
333
89.4k
    get_def_use_mgr()->AnalyzeInstDefUse(&*iter);
334
89.4k
    context()->set_instr_block(&*iter, block);
335
89.4k
  }
336
26.3k
  return true;
337
26.3k
}
338
339
bool ScalarReplacementPass::ReplaceAccessChain(
340
54.6k
    Instruction* chain, const std::vector<Instruction*>& replacements) {
341
  // Replaces the access chain with either another access chain (with one fewer
342
  // indexes) or a direct use of the replacement variable.
343
54.6k
  uint32_t indexId = chain->GetSingleWordInOperand(1u);
344
54.6k
  const Instruction* index = get_def_use_mgr()->GetDef(indexId);
345
54.6k
  int64_t indexValue = context()
346
54.6k
                           ->get_constant_mgr()
347
54.6k
                           ->GetConstantFromInst(index)
348
54.6k
                           ->GetSignExtendedValue();
349
54.6k
  if (indexValue < 0 ||
350
54.6k
      indexValue >= static_cast<int64_t>(replacements.size())) {
351
    // Out of bounds access, this is illegal IR.  Notice that OpAccessChain
352
    // indexing is 0-based, so we should also reject index == size-of-array.
353
0
    return false;
354
54.6k
  } else {
355
54.6k
    const Instruction* var = replacements[static_cast<size_t>(indexValue)];
356
54.6k
    if (chain->NumInOperands() > 2) {
357
      // Replace input access chain with another access chain.
358
2.18k
      BasicBlock::iterator chainIter(chain);
359
2.18k
      uint32_t replacementId = TakeNextId();
360
2.18k
      if (replacementId == 0) {
361
0
        return false;
362
0
      }
363
2.18k
      std::unique_ptr<Instruction> replacementChain(new Instruction(
364
2.18k
          context(), chain->opcode(), chain->type_id(), replacementId,
365
2.18k
          std::initializer_list<Operand>{
366
2.18k
              {SPV_OPERAND_TYPE_ID, {var->result_id()}}}));
367
      // Add the remaining indexes.
368
4.36k
      for (uint32_t i = 2; i < chain->NumInOperands(); ++i) {
369
2.18k
        Operand copy(chain->GetInOperand(i));
370
2.18k
        replacementChain->AddOperand(std::move(copy));
371
2.18k
      }
372
2.18k
      replacementChain->UpdateDebugInfoFrom(chain);
373
2.18k
      auto iter = chainIter.InsertBefore(std::move(replacementChain));
374
2.18k
      get_def_use_mgr()->AnalyzeInstDefUse(&*iter);
375
2.18k
      context()->set_instr_block(&*iter, context()->get_instr_block(chain));
376
2.18k
      context()->ReplaceAllUsesWith(chain->result_id(), replacementId);
377
52.4k
    } else {
378
      // Replace with a use of the variable.
379
52.4k
      context()->ReplaceAllUsesWith(chain->result_id(), var->result_id());
380
52.4k
    }
381
54.6k
  }
382
383
54.6k
  return true;
384
54.6k
}
385
386
bool ScalarReplacementPass::CreateReplacementVariables(
387
15.6k
    Instruction* inst, std::vector<Instruction*>* replacements) {
388
15.6k
  Instruction* type = GetStorageType(inst);
389
390
15.6k
  std::unique_ptr<std::unordered_set<int64_t>> components_used =
391
15.6k
      GetUsedComponents(inst);
392
393
15.6k
  uint32_t elem = 0;
394
15.6k
  switch (type->opcode()) {
395
13.7k
    case spv::Op::OpTypeStruct:
396
13.7k
      type->ForEachInOperand(
397
40.1k
          [this, inst, &elem, replacements, &components_used](uint32_t* id) {
398
40.1k
            if (!components_used || components_used->count(elem)) {
399
32.5k
              CreateVariable(*id, inst, elem, replacements);
400
32.5k
            } else {
401
7.59k
              replacements->push_back(GetUndef(*id));
402
7.59k
            }
403
40.1k
            elem++;
404
40.1k
          });
405
13.7k
      break;
406
1.84k
    case spv::Op::OpTypeArray:
407
91.2k
      for (uint32_t i = 0; i != GetArrayLength(type); ++i) {
408
89.3k
        if (!components_used || components_used->count(i)) {
409
46.4k
          CreateVariable(type->GetSingleWordInOperand(0u), inst, i,
410
46.4k
                         replacements);
411
46.4k
        } else {
412
42.9k
          uint32_t element_type_id = type->GetSingleWordInOperand(0);
413
42.9k
          replacements->push_back(GetUndef(element_type_id));
414
42.9k
        }
415
89.3k
      }
416
1.84k
      break;
417
418
0
    case spv::Op::OpTypeMatrix:
419
0
    case spv::Op::OpTypeVector:
420
0
      for (uint32_t i = 0; i != GetNumElements(type); ++i) {
421
0
        CreateVariable(type->GetSingleWordInOperand(0u), inst, i, replacements);
422
0
      }
423
0
      break;
424
425
0
    default:
426
0
      assert(false && "Unexpected type.");
427
0
      break;
428
15.6k
  }
429
430
15.6k
  TransferAnnotations(inst, replacements);
431
15.6k
  return std::find(replacements->begin(), replacements->end(), nullptr) ==
432
15.6k
         replacements->end();
433
15.6k
}
434
435
50.4k
Instruction* ScalarReplacementPass::GetUndef(uint32_t type_id) {
436
50.4k
  return get_def_use_mgr()->GetDef(Type2Undef(type_id));
437
50.4k
}
438
439
void ScalarReplacementPass::TransferAnnotations(
440
15.6k
    const Instruction* source, std::vector<Instruction*>* replacements) {
441
  // Only transfer invariant and restrict decorations on the variable. There are
442
  // no type or member decorations that are necessary to transfer.
443
15.6k
  for (auto inst :
444
15.6k
       get_decoration_mgr()->GetDecorationsFor(source->result_id(), false)) {
445
7
    assert(inst->opcode() == spv::Op::OpDecorate);
446
7
    auto decoration = spv::Decoration(inst->GetSingleWordInOperand(1u));
447
7
    if (decoration == spv::Decoration::Invariant ||
448
7
        decoration == spv::Decoration::Restrict) {
449
195
      for (auto var : *replacements) {
450
195
        if (var == nullptr) {
451
0
          continue;
452
0
        }
453
454
195
        std::unique_ptr<Instruction> annotation(new Instruction(
455
195
            context(), spv::Op::OpDecorate, 0, 0,
456
195
            std::initializer_list<Operand>{
457
195
                {SPV_OPERAND_TYPE_ID, {var->result_id()}},
458
195
                {SPV_OPERAND_TYPE_DECORATION, {uint32_t(decoration)}}}));
459
195
        for (uint32_t i = 2; i < inst->NumInOperands(); ++i) {
460
0
          Operand copy(inst->GetInOperand(i));
461
0
          annotation->AddOperand(std::move(copy));
462
0
        }
463
195
        context()->AddAnnotationInst(std::move(annotation));
464
195
        get_def_use_mgr()->AnalyzeInstUse(&*--context()->annotation_end());
465
195
      }
466
7
    }
467
7
  }
468
15.6k
}
469
470
void ScalarReplacementPass::CreateVariable(
471
    uint32_t type_id, Instruction* var_inst, uint32_t index,
472
79.0k
    std::vector<Instruction*>* replacements) {
473
79.0k
  uint32_t ptr_id = GetOrCreatePointerType(type_id);
474
79.0k
  uint32_t id = TakeNextId();
475
476
79.0k
  if (id == 0) {
477
0
    replacements->push_back(nullptr);
478
0
  }
479
480
79.0k
  std::unique_ptr<Instruction> variable(
481
79.0k
      new Instruction(context(), spv::Op::OpVariable, ptr_id, id,
482
79.0k
                      std::initializer_list<Operand>{
483
79.0k
                          {SPV_OPERAND_TYPE_STORAGE_CLASS,
484
79.0k
                           {uint32_t(spv::StorageClass::Function)}}}));
485
486
79.0k
  BasicBlock* block = context()->get_instr_block(var_inst);
487
79.0k
  block->begin().InsertBefore(std::move(variable));
488
79.0k
  Instruction* inst = &*block->begin();
489
490
  // If varInst was initialized, make sure to initialize its replacement.
491
79.0k
  GetOrCreateInitialValue(var_inst, index, inst);
492
79.0k
  get_def_use_mgr()->AnalyzeInstDefUse(inst);
493
79.0k
  context()->set_instr_block(inst, block);
494
495
79.0k
  CopyDecorationsToVariable(var_inst, inst, index);
496
79.0k
  inst->UpdateDebugInfoFrom(var_inst);
497
498
79.0k
  replacements->push_back(inst);
499
79.0k
}
500
501
79.0k
uint32_t ScalarReplacementPass::GetOrCreatePointerType(uint32_t id) {
502
79.0k
  auto iter = pointee_to_pointer_.find(id);
503
79.0k
  if (iter != pointee_to_pointer_.end()) return iter->second;
504
505
1.95k
  analysis::TypeManager* type_mgr = context()->get_type_mgr();
506
1.95k
  uint32_t ptr_type_id =
507
1.95k
      type_mgr->FindPointerToType(id, spv::StorageClass::Function);
508
1.95k
  pointee_to_pointer_[id] = ptr_type_id;
509
1.95k
  return ptr_type_id;
510
79.0k
}
511
512
void ScalarReplacementPass::GetOrCreateInitialValue(Instruction* source,
513
                                                    uint32_t index,
514
79.0k
                                                    Instruction* newVar) {
515
79.0k
  assert(source->opcode() == spv::Op::OpVariable);
516
79.0k
  if (source->NumInOperands() < 2) return;
517
518
0
  uint32_t initId = source->GetSingleWordInOperand(1u);
519
0
  uint32_t storageId = GetStorageType(newVar)->result_id();
520
0
  Instruction* init = get_def_use_mgr()->GetDef(initId);
521
0
  uint32_t newInitId = 0;
522
  // TODO(dnovillo): Refactor this with constant propagation.
523
0
  if (init->opcode() == spv::Op::OpConstantNull) {
524
    // Initialize to appropriate NULL.
525
0
    auto iter = type_to_null_.find(storageId);
526
0
    if (iter == type_to_null_.end()) {
527
0
      newInitId = TakeNextId();
528
0
      type_to_null_[storageId] = newInitId;
529
0
      context()->AddGlobalValue(
530
0
          MakeUnique<Instruction>(context(), spv::Op::OpConstantNull, storageId,
531
0
                                  newInitId, std::initializer_list<Operand>{}));
532
0
      Instruction* newNull = &*--context()->types_values_end();
533
0
      get_def_use_mgr()->AnalyzeInstDefUse(newNull);
534
0
    } else {
535
0
      newInitId = iter->second;
536
0
    }
537
0
  } else if (IsSpecConstantInst(init->opcode())) {
538
    // Create a new constant extract.
539
0
    newInitId = TakeNextId();
540
0
    context()->AddGlobalValue(MakeUnique<Instruction>(
541
0
        context(), spv::Op::OpSpecConstantOp, storageId, newInitId,
542
0
        std::initializer_list<Operand>{
543
0
            {SPV_OPERAND_TYPE_SPEC_CONSTANT_OP_NUMBER,
544
0
             {uint32_t(spv::Op::OpCompositeExtract)}},
545
0
            {SPV_OPERAND_TYPE_ID, {init->result_id()}},
546
0
            {SPV_OPERAND_TYPE_LITERAL_INTEGER, {index}}}));
547
0
    Instruction* newSpecConst = &*--context()->types_values_end();
548
0
    get_def_use_mgr()->AnalyzeInstDefUse(newSpecConst);
549
0
  } else if (init->opcode() == spv::Op::OpConstantComposite) {
550
    // Get the appropriate index constant.
551
0
    newInitId = init->GetSingleWordInOperand(index);
552
0
    Instruction* element = get_def_use_mgr()->GetDef(newInitId);
553
0
    if (element->opcode() == spv::Op::OpUndef) {
554
      // Undef is not a valid initializer for a variable.
555
0
      newInitId = 0;
556
0
    }
557
0
  } else {
558
0
    assert(false);
559
0
  }
560
561
0
  if (newInitId != 0) {
562
0
    newVar->AddOperand({SPV_OPERAND_TYPE_ID, {newInitId}});
563
0
  }
564
0
}
565
566
uint64_t ScalarReplacementPass::GetArrayLength(
567
108k
    const Instruction* arrayType) const {
568
108k
  assert(arrayType->opcode() == spv::Op::OpTypeArray);
569
108k
  const Instruction* length =
570
108k
      get_def_use_mgr()->GetDef(arrayType->GetSingleWordInOperand(1u));
571
108k
  return context()
572
108k
      ->get_constant_mgr()
573
108k
      ->GetConstantFromInst(length)
574
108k
      ->GetZeroExtendedValue();
575
108k
}
576
577
0
uint64_t ScalarReplacementPass::GetNumElements(const Instruction* type) const {
578
0
  assert(type->opcode() == spv::Op::OpTypeVector ||
579
0
         type->opcode() == spv::Op::OpTypeMatrix);
580
0
  const Operand& op = type->GetInOperand(1u);
581
0
  assert(op.words.size() <= 2);
582
0
  uint64_t len = 0;
583
0
  for (size_t i = 0; i != op.words.size(); ++i) {
584
0
    len |= (static_cast<uint64_t>(op.words[i]) << (32ull * i));
585
0
  }
586
0
  return len;
587
0
}
588
589
10.0k
bool ScalarReplacementPass::IsSpecConstant(uint32_t id) const {
590
10.0k
  const Instruction* inst = get_def_use_mgr()->GetDef(id);
591
10.0k
  assert(inst);
592
10.0k
  return spvOpcodeIsSpecConstant(inst->opcode());
593
10.0k
}
594
595
Instruction* ScalarReplacementPass::GetStorageType(
596
552k
    const Instruction* inst) const {
597
552k
  assert(inst->opcode() == spv::Op::OpVariable);
598
599
552k
  uint32_t ptrTypeId = inst->type_id();
600
552k
  uint32_t typeId =
601
552k
      get_def_use_mgr()->GetDef(ptrTypeId)->GetSingleWordInOperand(1u);
602
552k
  return get_def_use_mgr()->GetDef(typeId);
603
552k
}
604
605
bool ScalarReplacementPass::CanReplaceVariable(
606
255k
    const Instruction* varInst) const {
607
255k
  assert(varInst->opcode() == spv::Op::OpVariable);
608
609
  // Can only replace function scope variables.
610
255k
  if (spv::StorageClass(varInst->GetSingleWordInOperand(0u)) !=
611
255k
      spv::StorageClass::Function) {
612
0
    return false;
613
0
  }
614
615
255k
  if (!CheckTypeAnnotations(get_def_use_mgr()->GetDef(varInst->type_id()))) {
616
18
    return false;
617
18
  }
618
619
255k
  const Instruction* typeInst = GetStorageType(varInst);
620
255k
  if (!CheckType(typeInst)) {
621
229k
    return false;
622
229k
  }
623
624
26.3k
  if (!CheckAnnotations(varInst)) {
625
1.50k
    return false;
626
1.50k
  }
627
628
24.8k
  if (!CheckUses(varInst)) {
629
9.26k
    return false;
630
9.26k
  }
631
632
15.6k
  return true;
633
24.8k
}
634
635
255k
bool ScalarReplacementPass::CheckType(const Instruction* typeInst) const {
636
255k
  if (!CheckTypeAnnotations(typeInst)) {
637
1.28k
    return false;
638
1.28k
  }
639
640
254k
  switch (typeInst->opcode()) {
641
18.4k
    case spv::Op::OpTypeStruct:
642
      // Don't bother with empty structs or very large structs.
643
18.4k
      if (typeInst->NumInOperands() == 0 ||
644
18.4k
          IsLargerThanSizeLimit(typeInst->NumInOperands())) {
645
6
        return false;
646
6
      }
647
18.4k
      return true;
648
10.0k
    case spv::Op::OpTypeArray:
649
10.0k
      if (IsSpecConstant(typeInst->GetSingleWordInOperand(1u))) {
650
237
        return false;
651
237
      }
652
9.82k
      if (IsLargerThanSizeLimit(GetArrayLength(typeInst))) {
653
1.91k
        return false;
654
1.91k
      }
655
7.90k
      return true;
656
      // TODO(alanbaker): Develop some heuristics for when this should be
657
      // re-enabled.
658
      //// Specifically including matrix and vector in an attempt to reduce the
659
      //// number of vector registers required.
660
      // case spv::Op::OpTypeMatrix:
661
      // case spv::Op::OpTypeVector:
662
      //  if (IsLargerThanSizeLimit(GetNumElements(typeInst))) return false;
663
      //  return true;
664
665
1
    case spv::Op::OpTypeRuntimeArray:
666
226k
    default:
667
226k
      return false;
668
254k
  }
669
254k
}
670
671
bool ScalarReplacementPass::CheckTypeAnnotations(
672
511k
    const Instruction* typeInst) const {
673
511k
  for (auto inst :
674
511k
       get_decoration_mgr()->GetDecorationsFor(typeInst->result_id(), false)) {
675
17.0k
    uint32_t decoration;
676
17.0k
    if (inst->opcode() == spv::Op::OpDecorate ||
677
17.0k
        inst->opcode() == spv::Op::OpDecorateId) {
678
62
      decoration = inst->GetSingleWordInOperand(1u);
679
16.9k
    } else {
680
16.9k
      assert(inst->opcode() == spv::Op::OpMemberDecorate);
681
16.9k
      decoration = inst->GetSingleWordInOperand(2u);
682
16.9k
    }
683
684
17.0k
    switch (spv::Decoration(decoration)) {
685
239
      case spv::Decoration::RowMajor:
686
240
      case spv::Decoration::ColMajor:
687
241
      case spv::Decoration::ArrayStride:
688
249
      case spv::Decoration::MatrixStride:
689
249
      case spv::Decoration::CPacked:
690
547
      case spv::Decoration::Invariant:
691
598
      case spv::Decoration::Restrict:
692
608
      case spv::Decoration::Offset:
693
608
      case spv::Decoration::Alignment:
694
608
      case spv::Decoration::AlignmentId:
695
608
      case spv::Decoration::MaxByteOffset:
696
15.7k
      case spv::Decoration::RelaxedPrecision:
697
15.7k
      case spv::Decoration::AliasedPointer:
698
15.7k
      case spv::Decoration::RestrictPointer:
699
15.7k
        break;
700
1.30k
      default:
701
1.30k
        return false;
702
17.0k
    }
703
17.0k
  }
704
705
510k
  return true;
706
511k
}
707
708
26.3k
bool ScalarReplacementPass::CheckAnnotations(const Instruction* varInst) const {
709
26.3k
  for (auto inst :
710
26.3k
       get_decoration_mgr()->GetDecorationsFor(varInst->result_id(), false)) {
711
1.51k
    assert(inst->opcode() == spv::Op::OpDecorate);
712
1.51k
    auto decoration = spv::Decoration(inst->GetSingleWordInOperand(1u));
713
1.51k
    switch (decoration) {
714
4
      case spv::Decoration::Invariant:
715
9
      case spv::Decoration::Restrict:
716
9
      case spv::Decoration::Alignment:
717
9
      case spv::Decoration::AlignmentId:
718
9
      case spv::Decoration::MaxByteOffset:
719
9
      case spv::Decoration::AliasedPointer:
720
9
      case spv::Decoration::RestrictPointer:
721
9
        break;
722
1.50k
      default:
723
1.50k
        return false;
724
1.51k
    }
725
1.51k
  }
726
727
24.8k
  return true;
728
26.3k
}
729
730
24.8k
bool ScalarReplacementPass::CheckUses(const Instruction* inst) const {
731
24.8k
  VariableStats stats = {0, 0};
732
24.8k
  bool ok = CheckUses(inst, &stats);
733
734
  // TODO(alanbaker/greg-lunarg): Add some meaningful heuristics about when
735
  // SRoA is costly, such as when the structure has many (unaccessed?)
736
  // members.
737
738
24.8k
  return ok;
739
24.8k
}
740
741
bool ScalarReplacementPass::CheckUses(const Instruction* inst,
742
24.8k
                                      VariableStats* stats) const {
743
24.8k
  uint64_t max_legal_index = GetMaxLegalIndex(inst);
744
745
24.8k
  bool ok = true;
746
24.8k
  get_def_use_mgr()->ForEachUse(inst, [this, max_legal_index, stats, &ok](
747
24.8k
                                          const Instruction* user,
748
281k
                                          uint32_t index) {
749
281k
    if (user->GetCommonDebugOpcode() == CommonDebugInfoDebugDeclare ||
750
281k
        user->GetCommonDebugOpcode() == CommonDebugInfoDebugValue) {
751
      // TODO: include num_partial_accesses if it uses Fragment operation or
752
      // DebugValue has Indexes operand.
753
0
      stats->num_full_accesses++;
754
0
      return;
755
0
    }
756
757
    // Annotations are check as a group separately.
758
281k
    if (!IsAnnotationInst(user->opcode())) {
759
281k
      switch (user->opcode()) {
760
174k
        case spv::Op::OpAccessChain:
761
178k
        case spv::Op::OpInBoundsAccessChain:
762
178k
          if (index == 2u && user->NumInOperands() > 1) {
763
178k
            uint32_t id = user->GetSingleWordInOperand(1u);
764
178k
            const Instruction* opInst = get_def_use_mgr()->GetDef(id);
765
178k
            const auto* constant =
766
178k
                context()->get_constant_mgr()->GetConstantFromInst(opInst);
767
178k
            if (!constant) {
768
54.8k
              ok = false;
769
123k
            } else if (constant->GetZeroExtendedValue() >= max_legal_index) {
770
36.0k
              ok = false;
771
87.6k
            } else {
772
87.6k
              if (!CheckUsesRelaxed(user)) ok = false;
773
87.6k
            }
774
178k
            stats->num_partial_accesses++;
775
178k
          } else {
776
0
            ok = false;
777
0
          }
778
178k
          break;
779
14.4k
        case spv::Op::OpLoad:
780
14.4k
          if (!CheckLoad(user, index)) ok = false;
781
14.4k
          stats->num_full_accesses++;
782
14.4k
          break;
783
82.2k
        case spv::Op::OpStore:
784
82.2k
          if (!CheckStore(user, index)) ok = false;
785
82.2k
          stats->num_full_accesses++;
786
82.2k
          break;
787
2.54k
        case spv::Op::OpName:
788
2.54k
        case spv::Op::OpMemberName:
789
2.54k
          break;
790
3.91k
        default:
791
3.91k
          ok = false;
792
3.91k
          break;
793
281k
      }
794
281k
    }
795
281k
  });
796
797
24.8k
  return ok;
798
24.8k
}
799
800
87.6k
bool ScalarReplacementPass::CheckUsesRelaxed(const Instruction* inst) const {
801
87.6k
  bool ok = true;
802
87.6k
  get_def_use_mgr()->ForEachUse(
803
91.0k
      inst, [this, &ok](const Instruction* user, uint32_t index) {
804
91.0k
        switch (user->opcode()) {
805
0
          case spv::Op::OpAccessChain:
806
0
          case spv::Op::OpInBoundsAccessChain:
807
0
            if (index != 2u) {
808
0
              ok = false;
809
0
            } else {
810
0
              if (!CheckUsesRelaxed(user)) ok = false;
811
0
            }
812
0
            break;
813
56.4k
          case spv::Op::OpLoad:
814
56.4k
            if (!CheckLoad(user, index)) ok = false;
815
56.4k
            break;
816
34.3k
          case spv::Op::OpStore:
817
34.3k
            if (!CheckStore(user, index)) ok = false;
818
34.3k
            break;
819
0
          case spv::Op::OpImageTexelPointer:
820
0
            if (!CheckImageTexelPointer(index)) ok = false;
821
0
            break;
822
0
          case spv::Op::OpExtInst:
823
0
            if (user->GetCommonDebugOpcode() != CommonDebugInfoDebugDeclare ||
824
0
                !CheckDebugDeclare(index))
825
0
              ok = false;
826
0
            break;
827
293
          default:
828
293
            ok = false;
829
293
            break;
830
91.0k
        }
831
91.0k
      });
832
833
87.6k
  return ok;
834
87.6k
}
835
836
0
bool ScalarReplacementPass::CheckImageTexelPointer(uint32_t index) const {
837
0
  return index == 2u;
838
0
}
839
840
bool ScalarReplacementPass::CheckLoad(const Instruction* inst,
841
70.9k
                                      uint32_t index) const {
842
70.9k
  if (index != 2u) return false;
843
70.9k
  if (inst->NumInOperands() >= 2 &&
844
70.9k
      inst->GetSingleWordInOperand(1u) &
845
218
          uint32_t(spv::MemoryAccessMask::Volatile))
846
163
    return false;
847
70.7k
  return true;
848
70.9k
}
849
850
bool ScalarReplacementPass::CheckStore(const Instruction* inst,
851
116k
                                       uint32_t index) const {
852
116k
  if (index != 0u) return false;
853
116k
  if (inst->NumInOperands() >= 3 &&
854
116k
      inst->GetSingleWordInOperand(2u) &
855
483
          uint32_t(spv::MemoryAccessMask::Volatile))
856
270
    return false;
857
116k
  return true;
858
116k
}
859
860
0
bool ScalarReplacementPass::CheckDebugDeclare(uint32_t index) const {
861
0
  if (index != kDebugDeclareOperandVariableIndex) return false;
862
0
  return true;
863
0
}
864
865
28.2k
bool ScalarReplacementPass::IsLargerThanSizeLimit(uint64_t length) const {
866
28.2k
  if (max_num_elements_ == 0) {
867
0
    return false;
868
0
  }
869
28.2k
  return length > max_num_elements_;
870
28.2k
}
871
872
std::unique_ptr<std::unordered_set<int64_t>>
873
15.6k
ScalarReplacementPass::GetUsedComponents(Instruction* inst) {
874
15.6k
  std::unique_ptr<std::unordered_set<int64_t>> result(
875
15.6k
      new std::unordered_set<int64_t>());
876
877
15.6k
  analysis::DefUseManager* def_use_mgr = context()->get_def_use_mgr();
878
879
15.6k
  def_use_mgr->WhileEachUser(inst, [&result, def_use_mgr,
880
66.4k
                                    this](Instruction* use) {
881
66.4k
    switch (use->opcode()) {
882
9.45k
      case spv::Op::OpLoad: {
883
        // Look for extract from the load.
884
9.45k
        std::vector<uint32_t> t;
885
22.2k
        if (def_use_mgr->WhileEachUser(use, [&t](Instruction* use2) {
886
22.2k
              if (use2->opcode() != spv::Op::OpCompositeExtract ||
887
22.2k
                  use2->NumInOperands() <= 1) {
888
8.98k
                return false;
889
8.98k
              }
890
13.3k
              t.push_back(use2->GetSingleWordInOperand(1));
891
13.3k
              return true;
892
22.2k
            })) {
893
470
          result->insert(t.begin(), t.end());
894
470
          return true;
895
8.98k
        } else {
896
8.98k
          result.reset(nullptr);
897
8.98k
          return false;
898
8.98k
        }
899
9.45k
      }
900
1.09k
      case spv::Op::OpName:
901
1.09k
      case spv::Op::OpMemberName:
902
27.0k
      case spv::Op::OpStore:
903
        // No components are used.
904
27.0k
        return true;
905
28.1k
      case spv::Op::OpAccessChain:
906
29.8k
      case spv::Op::OpInBoundsAccessChain: {
907
        // Add the first index it if is a constant.
908
        // TODO: Could be improved by checking if the address is used in a load.
909
29.8k
        analysis::ConstantManager* const_mgr = context()->get_constant_mgr();
910
29.8k
        uint32_t index_id = use->GetSingleWordInOperand(1);
911
29.8k
        const analysis::Constant* index_const =
912
29.8k
            const_mgr->FindDeclaredConstant(index_id);
913
29.8k
        if (index_const) {
914
29.8k
          result->insert(index_const->GetSignExtendedValue());
915
29.8k
          return true;
916
29.8k
        } else {
917
          // Could be any element.  Assuming all are used.
918
0
          result.reset(nullptr);
919
0
          return false;
920
0
        }
921
29.8k
      }
922
7
      default:
923
        // We do not know what is happening.  Have to assume the worst.
924
7
        result.reset(nullptr);
925
7
        return false;
926
66.4k
    }
927
66.4k
  });
928
929
15.6k
  return result;
930
15.6k
}
931
932
uint64_t ScalarReplacementPass::GetMaxLegalIndex(
933
24.8k
    const Instruction* var_inst) const {
934
24.8k
  assert(var_inst->opcode() == spv::Op::OpVariable &&
935
24.8k
         "|var_inst| must be a variable instruction.");
936
24.8k
  Instruction* type = GetStorageType(var_inst);
937
24.8k
  switch (type->opcode()) {
938
17.9k
    case spv::Op::OpTypeStruct:
939
17.9k
      return type->NumInOperands();
940
6.95k
    case spv::Op::OpTypeArray:
941
6.95k
      return GetArrayLength(type);
942
0
    case spv::Op::OpTypeMatrix:
943
0
    case spv::Op::OpTypeVector:
944
0
      return GetNumElements(type);
945
0
    default:
946
0
      return 0;
947
24.8k
  }
948
0
  return 0;
949
24.8k
}
950
951
void ScalarReplacementPass::CopyDecorationsToVariable(Instruction* from,
952
                                                      Instruction* to,
953
79.0k
                                                      uint32_t member_index) {
954
79.0k
  CopyPointerDecorationsToVariable(from, to);
955
79.0k
  CopyNecessaryMemberDecorationsToVariable(from, to, member_index);
956
79.0k
}
957
958
void ScalarReplacementPass::CopyPointerDecorationsToVariable(Instruction* from,
959
79.0k
                                                             Instruction* to) {
960
  // The RestrictPointer and AliasedPointer decorations are copied to all
961
  // members even if the new variable does not contain a pointer. It does
962
  // not hurt to do so.
963
79.0k
  for (auto dec_inst :
964
79.0k
       get_decoration_mgr()->GetDecorationsFor(from->result_id(), false)) {
965
195
    uint32_t decoration;
966
195
    decoration = dec_inst->GetSingleWordInOperand(1u);
967
195
    switch (spv::Decoration(decoration)) {
968
0
      case spv::Decoration::AliasedPointer:
969
0
      case spv::Decoration::RestrictPointer: {
970
0
        std::unique_ptr<Instruction> new_dec_inst(dec_inst->Clone(context()));
971
0
        new_dec_inst->SetInOperand(0, {to->result_id()});
972
0
        context()->AddAnnotationInst(std::move(new_dec_inst));
973
0
      } break;
974
195
      default:
975
195
        break;
976
195
    }
977
195
  }
978
79.0k
}
979
980
void ScalarReplacementPass::CopyNecessaryMemberDecorationsToVariable(
981
79.0k
    Instruction* from, Instruction* to, uint32_t member_index) {
982
79.0k
  Instruction* type_inst = GetStorageType(from);
983
79.0k
  for (auto dec_inst :
984
79.0k
       get_decoration_mgr()->GetDecorationsFor(type_inst->result_id(), false)) {
985
33.6k
    uint32_t decoration;
986
33.6k
    if (dec_inst->opcode() == spv::Op::OpMemberDecorate) {
987
33.6k
      if (dec_inst->GetSingleWordInOperand(1) != member_index) {
988
22.2k
        continue;
989
22.2k
      }
990
991
11.3k
      decoration = dec_inst->GetSingleWordInOperand(2u);
992
11.3k
      switch (spv::Decoration(decoration)) {
993
0
        case spv::Decoration::ArrayStride:
994
0
        case spv::Decoration::Alignment:
995
0
        case spv::Decoration::AlignmentId:
996
0
        case spv::Decoration::MaxByteOffset:
997
0
        case spv::Decoration::MaxByteOffsetId:
998
11.0k
        case spv::Decoration::RelaxedPrecision: {
999
11.0k
          std::unique_ptr<Instruction> new_dec_inst(
1000
11.0k
              new Instruction(context(), spv::Op::OpDecorate, 0, 0, {}));
1001
11.0k
          new_dec_inst->AddOperand(
1002
11.0k
              Operand(SPV_OPERAND_TYPE_ID, {to->result_id()}));
1003
22.1k
          for (uint32_t i = 2; i < dec_inst->NumInOperandWords(); ++i) {
1004
11.0k
            new_dec_inst->AddOperand(Operand(dec_inst->GetInOperand(i)));
1005
11.0k
          }
1006
11.0k
          context()->AddAnnotationInst(std::move(new_dec_inst));
1007
11.0k
        } break;
1008
331
        default:
1009
331
          break;
1010
11.3k
      }
1011
11.3k
    }
1012
33.6k
  }
1013
79.0k
}
1014
1015
}  // namespace opt
1016
}  // namespace spvtools