Coverage Report

Created: 2025-07-23 06:18

/src/spirv-tools/source/opt/copy_prop_arrays.cpp
Line
Count
Source (jump to first uncovered line)
1
// Copyright (c) 2018 Google LLC.
2
//
3
// Licensed under the Apache License, Version 2.0 (the "License");
4
// you may not use this file except in compliance with the License.
5
// You may obtain a copy of the License at
6
//
7
//     http://www.apache.org/licenses/LICENSE-2.0
8
//
9
// Unless required by applicable law or agreed to in writing, software
10
// distributed under the License is distributed on an "AS IS" BASIS,
11
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
// See the License for the specific language governing permissions and
13
// limitations under the License.
14
15
#include "source/opt/copy_prop_arrays.h"
16
17
#include <utility>
18
19
#include "source/opt/ir_builder.h"
20
21
namespace spvtools {
22
namespace opt {
23
namespace {
24
25
constexpr uint32_t kLoadPointerInOperand = 0;
26
constexpr uint32_t kStorePointerInOperand = 0;
27
constexpr uint32_t kStoreObjectInOperand = 1;
28
constexpr uint32_t kCompositeExtractObjectInOperand = 0;
29
constexpr uint32_t kTypePointerStorageClassInIdx = 0;
30
constexpr uint32_t kTypePointerPointeeInIdx = 1;
31
constexpr uint32_t kExtInstSetInIdx = 0;
32
constexpr uint32_t kExtInstOpInIdx = 1;
33
constexpr uint32_t kInterpolantInIdx = 2;
34
35
4.35k
bool IsDebugDeclareOrValue(Instruction* di) {
36
4.35k
  auto dbg_opcode = di->GetCommonDebugOpcode();
37
4.35k
  return dbg_opcode == CommonDebugInfoDebugDeclare ||
38
4.35k
         dbg_opcode == CommonDebugInfoDebugValue;
39
4.35k
}
40
41
// Returns the number of members in |type|.  If |type| is not a composite type
42
// or the number of components is not known at compile time, the return value
43
// will be 0.
44
77
uint32_t GetNumberOfMembers(const analysis::Type* type, IRContext* context) {
45
77
  if (const analysis::Struct* struct_type = type->AsStruct()) {
46
5
    return static_cast<uint32_t>(struct_type->element_types().size());
47
72
  } else if (const analysis::Array* array_type = type->AsArray()) {
48
34
    const analysis::Constant* length_const =
49
34
        context->get_constant_mgr()->FindDeclaredConstant(
50
34
            array_type->LengthId());
51
52
34
    if (length_const == nullptr) {
53
      // This can happen if the length is an OpSpecConstant.
54
6
      return 0;
55
6
    }
56
28
    assert(length_const->type()->AsInteger());
57
28
    return length_const->GetU32();
58
38
  } else if (const analysis::Vector* vector_type = type->AsVector()) {
59
38
    return vector_type->element_count();
60
38
  } else if (const analysis::Matrix* matrix_type = type->AsMatrix()) {
61
0
    return matrix_type->element_count();
62
0
  } else {
63
0
    return 0;
64
0
  }
65
77
}
66
67
}  // namespace
68
69
14.8k
Pass::Status CopyPropagateArrays::Process() {
70
14.8k
  bool modified = false;
71
14.8k
  for (Function& function : *get_module()) {
72
14.1k
    if (function.IsDeclaration()) {
73
0
      continue;
74
0
    }
75
76
14.1k
    BasicBlock* entry_bb = &*function.begin();
77
78
14.1k
    for (auto var_inst = entry_bb->begin();
79
48.9k
         var_inst->opcode() == spv::Op::OpVariable; ++var_inst) {
80
34.7k
      worklist_.push(&*var_inst);
81
34.7k
    }
82
14.1k
  }
83
84
49.6k
  while (!worklist_.empty()) {
85
34.8k
    Instruction* var_inst = worklist_.front();
86
34.8k
    worklist_.pop();
87
88
    // Find the only store to the entire memory location, if it exists.
89
34.8k
    Instruction* store_inst = FindStoreInstruction(&*var_inst);
90
91
34.8k
    if (!store_inst) {
92
16.5k
      continue;
93
16.5k
    }
94
95
18.2k
    std::unique_ptr<MemoryObject> source_object =
96
18.2k
        FindSourceObjectIfPossible(&*var_inst, store_inst);
97
98
18.2k
    if (source_object == nullptr) {
99
16.0k
      continue;
100
16.0k
    }
101
102
2.23k
    if (!IsPointerToArrayType(var_inst->type_id()) &&
103
2.23k
        source_object->GetStorageClass() != spv::StorageClass::Input) {
104
1.97k
      continue;
105
1.97k
    }
106
107
260
    if (CanUpdateUses(&*var_inst, source_object->GetPointerTypeId(this))) {
108
260
      modified = true;
109
110
260
      PropagateObject(&*var_inst, source_object.get(), store_inst);
111
260
    }
112
260
  }
113
114
14.8k
  return (modified ? Status::SuccessWithChange : Status::SuccessWithoutChange);
115
14.8k
}
116
117
std::unique_ptr<CopyPropagateArrays::MemoryObject>
118
CopyPropagateArrays::FindSourceObjectIfPossible(Instruction* var_inst,
119
18.2k
                                                Instruction* store_inst) {
120
18.2k
  assert(var_inst->opcode() == spv::Op::OpVariable && "Expecting a variable.");
121
122
  // Check that the variable is a composite object where |store_inst|
123
  // dominates all of its loads.
124
18.2k
  if (!store_inst) {
125
0
    return nullptr;
126
0
  }
127
128
  // Look at the loads to ensure they are dominated by the store.
129
18.2k
  if (!HasValidReferencesOnly(var_inst, store_inst)) {
130
8.08k
    return nullptr;
131
8.08k
  }
132
133
  // If so, look at the store to see if it is the copy of an object.
134
10.1k
  std::unique_ptr<MemoryObject> source = GetSourceObjectIfAny(
135
10.1k
      store_inst->GetSingleWordInOperand(kStoreObjectInOperand));
136
137
10.1k
  if (!source) {
138
5.40k
    return nullptr;
139
5.40k
  }
140
141
  // Ensure that |source| does not change between the point at which it is
142
  // loaded, and the position in which |var_inst| is loaded.
143
  //
144
  // For now we will go with the easy to implement approach, and check that the
145
  // entire variable (not just the specific component) is never written to.
146
147
4.76k
  if (!HasNoStores(source->GetVariable())) {
148
2.53k
    return nullptr;
149
2.53k
  }
150
2.23k
  return source;
151
4.76k
}
152
153
Instruction* CopyPropagateArrays::FindStoreInstruction(
154
34.8k
    const Instruction* var_inst) const {
155
34.8k
  Instruction* store_inst = nullptr;
156
34.8k
  get_def_use_mgr()->WhileEachUser(
157
110k
      var_inst, [&store_inst, var_inst](Instruction* use) {
158
110k
        if (use->opcode() == spv::Op::OpStore &&
159
110k
            use->GetSingleWordInOperand(kStorePointerInOperand) ==
160
25.4k
                var_inst->result_id()) {
161
25.4k
          if (store_inst == nullptr) {
162
21.8k
            store_inst = use;
163
21.8k
          } else {
164
3.60k
            store_inst = nullptr;
165
3.60k
            return false;
166
3.60k
          }
167
25.4k
        }
168
106k
        return true;
169
110k
      });
170
34.8k
  return store_inst;
171
34.8k
}
172
173
void CopyPropagateArrays::PropagateObject(Instruction* var_inst,
174
                                          MemoryObject* source,
175
260
                                          Instruction* insertion_point) {
176
260
  assert(var_inst->opcode() == spv::Op::OpVariable &&
177
260
         "This function propagates variables.");
178
179
260
  Instruction* new_access_chain = BuildNewAccessChain(insertion_point, source);
180
260
  context()->KillNamesAndDecorates(var_inst);
181
260
  UpdateUses(var_inst, new_access_chain);
182
260
}
183
184
Instruction* CopyPropagateArrays::BuildNewAccessChain(
185
    Instruction* insertion_point,
186
260
    CopyPropagateArrays::MemoryObject* source) const {
187
260
  InstructionBuilder builder(
188
260
      context(), insertion_point,
189
260
      IRContext::kAnalysisDefUse | IRContext::kAnalysisInstrToBlockMapping);
190
191
260
  if (source->AccessChain().size() == 0) {
192
162
    return source->GetVariable();
193
162
  }
194
195
98
  source->BuildConstants();
196
98
  std::vector<uint32_t> access_ids(source->AccessChain().size());
197
98
  std::transform(
198
98
      source->AccessChain().cbegin(), source->AccessChain().cend(),
199
98
      access_ids.begin(), [](const AccessChainEntry& entry) {
200
98
        assert(entry.is_result_id && "Constants needs to be built first.");
201
98
        return entry.result_id;
202
98
      });
203
204
98
  return builder.AddAccessChain(source->GetPointerTypeId(this),
205
98
                                source->GetVariable()->result_id(), access_ids);
206
260
}
207
208
8.98k
bool CopyPropagateArrays::HasNoStores(Instruction* ptr_inst) {
209
20.5k
  return get_def_use_mgr()->WhileEachUser(ptr_inst, [this](Instruction* use) {
210
20.5k
    if (use->opcode() == spv::Op::OpLoad) {
211
11.1k
      return true;
212
11.1k
    } else if (use->opcode() == spv::Op::OpAccessChain) {
213
4.21k
      return HasNoStores(use);
214
5.27k
    } else if (use->IsDecoration() || use->opcode() == spv::Op::OpName) {
215
2.67k
      return true;
216
2.67k
    } else if (use->opcode() == spv::Op::OpStore) {
217
2.28k
      return false;
218
2.28k
    } else if (use->opcode() == spv::Op::OpImageTexelPointer) {
219
0
      return true;
220
327
    } else if (use->opcode() == spv::Op::OpEntryPoint) {
221
77
      return true;
222
250
    } else if (IsInterpolationInstruction(use)) {
223
0
      return true;
224
250
    } else if (use->IsCommonDebugInstr()) {
225
0
      return true;
226
0
    }
227
    // Some other instruction.  Be conservative.
228
250
    return false;
229
20.5k
  });
230
8.98k
}
231
232
bool CopyPropagateArrays::HasValidReferencesOnly(Instruction* ptr_inst,
233
30.1k
                                                 Instruction* store_inst) {
234
30.1k
  BasicBlock* store_block = context()->get_instr_block(store_inst);
235
30.1k
  DominatorAnalysis* dominator_analysis =
236
30.1k
      context()->GetDominatorAnalysis(store_block->GetParent());
237
238
30.1k
  return get_def_use_mgr()->WhileEachUser(
239
30.1k
      ptr_inst,
240
52.1k
      [this, store_inst, dominator_analysis, ptr_inst](Instruction* use) {
241
52.1k
        if (use->opcode() == spv::Op::OpLoad ||
242
52.1k
            use->opcode() == spv::Op::OpImageTexelPointer) {
243
          // TODO: If there are many load in the same BB as |store_inst| the
244
          // time to do the multiple traverses can add up.  Consider collecting
245
          // those loads and doing a single traversal.
246
18.5k
          return dominator_analysis->Dominates(store_inst, use);
247
33.5k
        } else if (IsInterpolationInstruction(use)) {
248
          // GLSL InterpolateAt* instructions work similarly to loads
249
0
          uint32_t interpolant = use->GetSingleWordInOperand(kInterpolantInIdx);
250
0
          if (interpolant !=
251
0
              store_inst->GetSingleWordInOperand(kStorePointerInOperand))
252
0
            return false;
253
0
          return dominator_analysis->Dominates(store_inst, use);
254
33.5k
        } else if (use->opcode() == spv::Op::OpAccessChain) {
255
11.8k
          return HasValidReferencesOnly(use, store_inst);
256
21.7k
        } else if (use->IsDecoration() || use->opcode() == spv::Op::OpName) {
257
2.27k
          return true;
258
19.4k
        } else if (use->opcode() == spv::Op::OpStore) {
259
          // If we are storing to part of the object it is not a candidate.
260
15.8k
          return ptr_inst->opcode() == spv::Op::OpVariable &&
261
15.8k
                 store_inst->GetSingleWordInOperand(kStorePointerInOperand) ==
262
15.2k
                     ptr_inst->result_id();
263
15.8k
        } else if (IsDebugDeclareOrValue(use)) {
264
          // The store does not have to dominate debug instructions. We do not
265
          // want debugging info to stop the transformation. It will be fixed
266
          // up later.
267
0
          return true;
268
0
        }
269
        // Some other instruction.  Be conservative.
270
3.58k
        return false;
271
52.1k
      });
272
30.1k
}
273
274
std::unique_ptr<CopyPropagateArrays::MemoryObject>
275
15.0k
CopyPropagateArrays::GetSourceObjectIfAny(uint32_t result) {
276
15.0k
  Instruction* result_inst = context()->get_def_use_mgr()->GetDef(result);
277
278
15.0k
  switch (result_inst->opcode()) {
279
4.91k
    case spv::Op::OpLoad:
280
4.91k
      return BuildMemoryObjectFromLoad(result_inst);
281
4.25k
    case spv::Op::OpCompositeExtract:
282
4.25k
      return BuildMemoryObjectFromExtract(result_inst);
283
532
    case spv::Op::OpCompositeConstruct:
284
532
      return BuildMemoryObjectFromCompositeConstruct(result_inst);
285
0
    case spv::Op::OpCopyObject:
286
0
    case spv::Op::OpCopyLogical:
287
0
      return GetSourceObjectIfAny(result_inst->GetSingleWordInOperand(0));
288
39
    case spv::Op::OpCompositeInsert:
289
39
      return BuildMemoryObjectFromInsert(result_inst);
290
5.26k
    default:
291
5.26k
      return nullptr;
292
15.0k
  }
293
15.0k
}
294
295
std::unique_ptr<CopyPropagateArrays::MemoryObject>
296
4.91k
CopyPropagateArrays::BuildMemoryObjectFromLoad(Instruction* load_inst) {
297
4.91k
  std::vector<uint32_t> components_in_reverse;
298
4.91k
  analysis::DefUseManager* def_use_mgr = context()->get_def_use_mgr();
299
300
4.91k
  Instruction* current_inst = def_use_mgr->GetDef(
301
4.91k
      load_inst->GetSingleWordInOperand(kLoadPointerInOperand));
302
303
  // Build the access chain for the memory object by collecting the indices used
304
  // in the OpAccessChain instructions.  If we find a variable index, then
305
  // return |nullptr| because we cannot know for sure which memory location is
306
  // used.
307
  //
308
  // It is built in reverse order because the different |OpAccessChain|
309
  // instructions are visited in reverse order from which they are applied.
310
5.30k
  while (current_inst->opcode() == spv::Op::OpAccessChain) {
311
794
    for (uint32_t i = current_inst->NumInOperands() - 1; i >= 1; --i) {
312
400
      uint32_t element_index_id = current_inst->GetSingleWordInOperand(i);
313
400
      components_in_reverse.push_back(element_index_id);
314
400
    }
315
394
    current_inst = def_use_mgr->GetDef(current_inst->GetSingleWordInOperand(0));
316
394
  }
317
318
  // If the address in the load is not constructed from an |OpVariable|
319
  // instruction followed by a series of |OpAccessChain| instructions, then
320
  // return |nullptr| because we cannot identify the owner or access chain
321
  // exactly.
322
4.91k
  if (current_inst->opcode() != spv::Op::OpVariable) {
323
25
    return nullptr;
324
25
  }
325
326
  // Build the memory object.  Use |rbegin| and |rend| to put the access chain
327
  // back in the correct order.
328
4.88k
  return std::unique_ptr<CopyPropagateArrays::MemoryObject>(
329
4.88k
      new MemoryObject(current_inst, components_in_reverse.rbegin(),
330
4.88k
                       components_in_reverse.rend()));
331
4.91k
}
332
333
std::unique_ptr<CopyPropagateArrays::MemoryObject>
334
4.25k
CopyPropagateArrays::BuildMemoryObjectFromExtract(Instruction* extract_inst) {
335
4.25k
  assert(extract_inst->opcode() == spv::Op::OpCompositeExtract &&
336
4.25k
         "Expecting an OpCompositeExtract instruction.");
337
4.25k
  std::unique_ptr<MemoryObject> result = GetSourceObjectIfAny(
338
4.25k
      extract_inst->GetSingleWordInOperand(kCompositeExtractObjectInOperand));
339
340
4.25k
  if (!result) {
341
741
    return nullptr;
342
741
  }
343
344
  // Copy the indices of the extract instruction to |OpAccessChain| indices.
345
3.51k
  std::vector<AccessChainEntry> components;
346
7.02k
  for (uint32_t i = 1; i < extract_inst->NumInOperands(); ++i) {
347
3.51k
    components.push_back({false, {extract_inst->GetSingleWordInOperand(i)}});
348
3.51k
  }
349
3.51k
  result->PushIndirection(components);
350
3.51k
  return result;
351
4.25k
}
352
353
std::unique_ptr<CopyPropagateArrays::MemoryObject>
354
CopyPropagateArrays::BuildMemoryObjectFromCompositeConstruct(
355
532
    Instruction* conststruct_inst) {
356
532
  assert(conststruct_inst->opcode() == spv::Op::OpCompositeConstruct &&
357
532
         "Expecting an OpCompositeConstruct instruction.");
358
359
  // If every operand in the instruction are part of the same memory object, and
360
  // are being combined in the same order, then the result is the same as the
361
  // parent.
362
363
532
  std::unique_ptr<MemoryObject> memory_object =
364
532
      GetSourceObjectIfAny(conststruct_inst->GetSingleWordInOperand(0));
365
366
532
  if (!memory_object) {
367
426
    return nullptr;
368
426
  }
369
370
106
  if (!memory_object->IsMember()) {
371
21
    return nullptr;
372
21
  }
373
374
85
  AccessChainEntry last_access = memory_object->AccessChain().back();
375
85
  if (!IsAccessChainIndexValidAndEqualTo(last_access, 0)) {
376
47
    return nullptr;
377
47
  }
378
379
38
  memory_object->PopIndirection();
380
38
  if (memory_object->GetNumberOfMembers() !=
381
38
      conststruct_inst->NumInOperands()) {
382
4
    return nullptr;
383
4
  }
384
385
34
  for (uint32_t i = 1; i < conststruct_inst->NumInOperands(); ++i) {
386
34
    std::unique_ptr<MemoryObject> member_object =
387
34
        GetSourceObjectIfAny(conststruct_inst->GetSingleWordInOperand(i));
388
389
34
    if (!member_object) {
390
17
      return nullptr;
391
17
    }
392
393
17
    if (!member_object->IsMember()) {
394
0
      return nullptr;
395
0
    }
396
397
17
    if (!memory_object->Contains(member_object.get())) {
398
13
      return nullptr;
399
13
    }
400
401
4
    last_access = member_object->AccessChain().back();
402
4
    if (!IsAccessChainIndexValidAndEqualTo(last_access, i)) {
403
4
      return nullptr;
404
4
    }
405
4
  }
406
0
  return memory_object;
407
34
}
408
409
std::unique_ptr<CopyPropagateArrays::MemoryObject>
410
39
CopyPropagateArrays::BuildMemoryObjectFromInsert(Instruction* insert_inst) {
411
39
  assert(insert_inst->opcode() == spv::Op::OpCompositeInsert &&
412
39
         "Expecting an OpCompositeInsert instruction.");
413
414
39
  analysis::DefUseManager* def_use_mgr = context()->get_def_use_mgr();
415
39
  analysis::TypeManager* type_mgr = context()->get_type_mgr();
416
39
  const analysis::Type* result_type = type_mgr->GetType(insert_inst->type_id());
417
418
39
  uint32_t number_of_elements = GetNumberOfMembers(result_type, context());
419
420
39
  if (number_of_elements == 0) {
421
6
    return nullptr;
422
6
  }
423
424
33
  if (insert_inst->NumInOperands() != 3) {
425
5
    return nullptr;
426
5
  }
427
428
28
  if (insert_inst->GetSingleWordInOperand(2) != number_of_elements - 1) {
429
19
    return nullptr;
430
19
  }
431
432
9
  std::unique_ptr<MemoryObject> memory_object =
433
9
      GetSourceObjectIfAny(insert_inst->GetSingleWordInOperand(0));
434
435
9
  if (!memory_object) {
436
9
    return nullptr;
437
9
  }
438
439
0
  if (!memory_object->IsMember()) {
440
0
    return nullptr;
441
0
  }
442
443
0
  AccessChainEntry last_access = memory_object->AccessChain().back();
444
0
  if (!IsAccessChainIndexValidAndEqualTo(last_access, number_of_elements - 1)) {
445
0
    return nullptr;
446
0
  }
447
448
0
  memory_object->PopIndirection();
449
450
0
  Instruction* current_insert =
451
0
      def_use_mgr->GetDef(insert_inst->GetSingleWordInOperand(1));
452
0
  for (uint32_t i = number_of_elements - 1; i > 0; --i) {
453
0
    if (current_insert->opcode() != spv::Op::OpCompositeInsert) {
454
0
      return nullptr;
455
0
    }
456
457
0
    if (current_insert->NumInOperands() != 3) {
458
0
      return nullptr;
459
0
    }
460
461
0
    if (current_insert->GetSingleWordInOperand(2) != i - 1) {
462
0
      return nullptr;
463
0
    }
464
465
0
    std::unique_ptr<MemoryObject> current_memory_object =
466
0
        GetSourceObjectIfAny(current_insert->GetSingleWordInOperand(0));
467
468
0
    if (!current_memory_object) {
469
0
      return nullptr;
470
0
    }
471
472
0
    if (!current_memory_object->IsMember()) {
473
0
      return nullptr;
474
0
    }
475
476
0
    if (memory_object->AccessChain().size() + 1 !=
477
0
        current_memory_object->AccessChain().size()) {
478
0
      return nullptr;
479
0
    }
480
481
0
    if (!memory_object->Contains(current_memory_object.get())) {
482
0
      return nullptr;
483
0
    }
484
485
0
    AccessChainEntry current_last_access =
486
0
        current_memory_object->AccessChain().back();
487
0
    if (!IsAccessChainIndexValidAndEqualTo(current_last_access, i - 1)) {
488
0
      return nullptr;
489
0
    }
490
0
    current_insert =
491
0
        def_use_mgr->GetDef(current_insert->GetSingleWordInOperand(1));
492
0
  }
493
494
0
  return memory_object;
495
0
}
496
497
bool CopyPropagateArrays::IsAccessChainIndexValidAndEqualTo(
498
89
    const AccessChainEntry& entry, uint32_t value) const {
499
89
  if (!entry.is_result_id) {
500
4
    return entry.immediate == value;
501
4
  }
502
503
85
  analysis::ConstantManager* const_mgr = context()->get_constant_mgr();
504
85
  const analysis::Constant* constant =
505
85
      const_mgr->FindDeclaredConstant(entry.result_id);
506
85
  if (!constant || !constant->type()->AsInteger()) {
507
17
    return false;
508
17
  }
509
68
  return constant->GetU32() == value;
510
85
}
511
512
2.23k
bool CopyPropagateArrays::IsPointerToArrayType(uint32_t type_id) {
513
2.23k
  analysis::TypeManager* type_mgr = context()->get_type_mgr();
514
2.23k
  analysis::Pointer* pointer_type = type_mgr->GetType(type_id)->AsPointer();
515
2.23k
  if (pointer_type) {
516
2.23k
    return pointer_type->pointee_type()->kind() == analysis::Type::kArray ||
517
2.23k
           pointer_type->pointee_type()->kind() == analysis::Type::kImage;
518
2.23k
  }
519
0
  return false;
520
2.23k
}
521
522
33.8k
bool CopyPropagateArrays::IsInterpolationInstruction(Instruction* inst) {
523
33.8k
  if (inst->opcode() == spv::Op::OpExtInst &&
524
33.8k
      inst->GetSingleWordInOperand(kExtInstSetInIdx) ==
525
0
          context()->get_feature_mgr()->GetExtInstImportId_GLSLstd450()) {
526
0
    uint32_t ext_inst = inst->GetSingleWordInOperand(kExtInstOpInIdx);
527
0
    switch (ext_inst) {
528
0
      case GLSLstd450InterpolateAtCentroid:
529
0
      case GLSLstd450InterpolateAtOffset:
530
0
      case GLSLstd450InterpolateAtSample:
531
0
        return true;
532
0
    }
533
0
  }
534
33.8k
  return false;
535
33.8k
}
536
537
bool CopyPropagateArrays::CanUpdateUses(Instruction* original_ptr_inst,
538
264
                                        uint32_t type_id) {
539
264
  analysis::TypeManager* type_mgr = context()->get_type_mgr();
540
264
  analysis::ConstantManager* const_mgr = context()->get_constant_mgr();
541
264
  analysis::DefUseManager* def_use_mgr = context()->get_def_use_mgr();
542
543
264
  analysis::Type* type = type_mgr->GetType(type_id);
544
264
  if (type->AsRuntimeArray()) {
545
0
    return false;
546
0
  }
547
548
264
  if (!type->AsStruct() && !type->AsArray() && !type->AsPointer()) {
549
    // If the type is not an aggregate, then the desired type must be the
550
    // same as the current type.  No work to do, and we can do that.
551
0
    return true;
552
0
  }
553
554
264
  return def_use_mgr->WhileEachUse(original_ptr_inst, [this, type_mgr,
555
264
                                                       const_mgr,
556
264
                                                       type](Instruction* use,
557
774
                                                             uint32_t) {
558
774
    if (IsDebugDeclareOrValue(use)) return true;
559
560
774
    switch (use->opcode()) {
561
56
      case spv::Op::OpLoad: {
562
56
        analysis::Pointer* pointer_type = type->AsPointer();
563
56
        uint32_t new_type_id = type_mgr->GetId(pointer_type->pointee_type());
564
565
56
        if (new_type_id != use->type_id()) {
566
0
          return CanUpdateUses(use, new_type_id);
567
0
        }
568
56
        return true;
569
56
      }
570
0
      case spv::Op::OpExtInst:
571
0
        if (IsInterpolationInstruction(use)) {
572
0
          return true;
573
0
        }
574
0
        return false;
575
303
      case spv::Op::OpAccessChain: {
576
303
        analysis::Pointer* pointer_type = type->AsPointer();
577
303
        const analysis::Type* pointee_type = pointer_type->pointee_type();
578
579
303
        std::vector<uint32_t> access_chain;
580
606
        for (uint32_t i = 1; i < use->NumInOperands(); ++i) {
581
303
          const analysis::Constant* index_const =
582
303
              const_mgr->FindDeclaredConstant(use->GetSingleWordInOperand(i));
583
303
          if (index_const) {
584
54
            access_chain.push_back(index_const->GetU32());
585
249
          } else {
586
            // Variable index means the type is a type where every element
587
            // is the same type.  Use element 0 to get the type.
588
249
            access_chain.push_back(0);
589
590
            // We are trying to access a struct with variable indices.
591
            // This cannot happen.
592
249
            if (pointee_type->kind() == analysis::Type::kStruct) {
593
0
              return false;
594
0
            }
595
249
          }
596
303
        }
597
598
303
        const analysis::Type* new_pointee_type =
599
303
            type_mgr->GetMemberType(pointee_type, access_chain);
600
303
        analysis::Pointer pointerTy(new_pointee_type,
601
303
                                    pointer_type->storage_class());
602
303
        uint32_t new_pointer_type_id =
603
303
            context()->get_type_mgr()->GetTypeInstruction(&pointerTy);
604
303
        if (new_pointer_type_id == 0) {
605
0
          return false;
606
0
        }
607
608
303
        if (new_pointer_type_id != use->type_id()) {
609
4
          return CanUpdateUses(use, new_pointer_type_id);
610
4
        }
611
299
        return true;
612
303
      }
613
0
      case spv::Op::OpCompositeExtract: {
614
0
        std::vector<uint32_t> access_chain;
615
0
        for (uint32_t i = 1; i < use->NumInOperands(); ++i) {
616
0
          access_chain.push_back(use->GetSingleWordInOperand(i));
617
0
        }
618
619
0
        const analysis::Type* new_type =
620
0
            type_mgr->GetMemberType(type, access_chain);
621
0
        uint32_t new_type_id = type_mgr->GetTypeInstruction(new_type);
622
0
        if (new_type_id == 0) {
623
0
          return false;
624
0
        }
625
626
0
        if (new_type_id != use->type_id()) {
627
0
          return CanUpdateUses(use, new_type_id);
628
0
        }
629
0
        return true;
630
0
      }
631
260
      case spv::Op::OpStore:
632
        // If needed, we can create an element-by-element copy to change the
633
        // type of the value being stored.  This way we can always handled
634
        // stores.
635
260
        return true;
636
0
      case spv::Op::OpImageTexelPointer:
637
53
      case spv::Op::OpName:
638
53
        return true;
639
102
      default:
640
102
        return use->IsDecoration();
641
774
    }
642
774
  });
643
264
}
644
645
void CopyPropagateArrays::UpdateUses(Instruction* original_ptr_inst,
646
264
                                     Instruction* new_ptr_inst) {
647
264
  analysis::TypeManager* type_mgr = context()->get_type_mgr();
648
264
  analysis::ConstantManager* const_mgr = context()->get_constant_mgr();
649
264
  analysis::DefUseManager* def_use_mgr = context()->get_def_use_mgr();
650
651
264
  std::vector<std::pair<Instruction*, uint32_t> > uses;
652
264
  def_use_mgr->ForEachUse(original_ptr_inst,
653
620
                          [&uses](Instruction* use, uint32_t index) {
654
620
                            uses.push_back({use, index});
655
620
                          });
656
657
620
  for (auto pair : uses) {
658
620
    Instruction* use = pair.first;
659
620
    uint32_t index = pair.second;
660
661
620
    if (use->IsCommonDebugInstr()) {
662
      // It is possible that the debug instructions are not dominated by
663
      // `new_ptr_inst`. If not, move the debug instruction to just after
664
      // `new_ptr_inst`.
665
0
      BasicBlock* store_block = context()->get_instr_block(new_ptr_inst);
666
0
      if (store_block) {
667
0
        Function* function = store_block->GetParent();
668
0
        DominatorAnalysis* dominator_analysis =
669
0
            context()->GetDominatorAnalysis(function);
670
0
        if (!dominator_analysis->Dominates(new_ptr_inst, use)) {
671
0
          assert(dominator_analysis->Dominates(use, new_ptr_inst));
672
0
          use->InsertAfter(new_ptr_inst);
673
0
          context()->set_instr_block(use,
674
0
                                     context()->get_instr_block(new_ptr_inst));
675
0
        }
676
0
      }
677
678
0
      switch (use->GetCommonDebugOpcode()) {
679
0
        case CommonDebugInfoDebugDeclare: {
680
0
          if (new_ptr_inst->opcode() == spv::Op::OpVariable ||
681
0
              new_ptr_inst->opcode() == spv::Op::OpFunctionParameter) {
682
0
            context()->ForgetUses(use);
683
0
            use->SetOperand(index, {new_ptr_inst->result_id()});
684
0
            context()->AnalyzeUses(use);
685
0
          } else {
686
            // Based on the spec, we cannot use a pointer other than OpVariable
687
            // or OpFunctionParameter for DebugDeclare. We have to use
688
            // DebugValue with Deref.
689
690
0
            context()->ForgetUses(use);
691
692
            // Change DebugDeclare to DebugValue.
693
0
            use->SetOperand(index - 2,
694
0
                            {static_cast<uint32_t>(CommonDebugInfoDebugValue)});
695
0
            use->SetOperand(index, {new_ptr_inst->result_id()});
696
697
            // Add Deref operation.
698
0
            Instruction* dbg_expr =
699
0
                def_use_mgr->GetDef(use->GetSingleWordOperand(index + 1));
700
0
            auto* deref_expr_instr =
701
0
                context()->get_debug_info_mgr()->DerefDebugExpression(dbg_expr);
702
0
            use->SetOperand(index + 1, {deref_expr_instr->result_id()});
703
704
0
            context()->AnalyzeUses(deref_expr_instr);
705
0
            context()->AnalyzeUses(use);
706
0
          }
707
0
          break;
708
0
        }
709
0
        case CommonDebugInfoDebugValue:
710
0
          context()->ForgetUses(use);
711
0
          use->SetOperand(index, {new_ptr_inst->result_id()});
712
0
          context()->AnalyzeUses(use);
713
0
          break;
714
0
        default:
715
0
          assert(false && "Don't know how to rewrite instruction");
716
0
          break;
717
0
      }
718
0
      continue;
719
0
    }
720
721
620
    switch (use->opcode()) {
722
56
      case spv::Op::OpLoad: {
723
        // Replace the actual use.
724
56
        context()->ForgetUses(use);
725
56
        use->SetOperand(index, {new_ptr_inst->result_id()});
726
727
        // Update the type.
728
56
        Instruction* pointer_type_inst =
729
56
            def_use_mgr->GetDef(new_ptr_inst->type_id());
730
56
        uint32_t new_type_id =
731
56
            pointer_type_inst->GetSingleWordInOperand(kTypePointerPointeeInIdx);
732
56
        if (new_type_id != use->type_id()) {
733
0
          use->SetResultType(new_type_id);
734
0
          context()->AnalyzeUses(use);
735
0
          UpdateUses(use, use);
736
56
        } else {
737
56
          context()->AnalyzeUses(use);
738
56
        }
739
740
56
        AddUsesToWorklist(use);
741
56
      } break;
742
0
      case spv::Op::OpExtInst: {
743
0
        if (IsInterpolationInstruction(use)) {
744
          // Replace the actual use.
745
0
          context()->ForgetUses(use);
746
0
          use->SetOperand(index, {new_ptr_inst->result_id()});
747
0
          context()->AnalyzeUses(use);
748
0
        } else {
749
0
          assert(false && "Don't know how to rewrite instruction");
750
0
        }
751
0
      } break;
752
303
      case spv::Op::OpAccessChain: {
753
        // Update the actual use.
754
303
        context()->ForgetUses(use);
755
303
        use->SetOperand(index, {new_ptr_inst->result_id()});
756
757
        // Convert the ids on the OpAccessChain to indices that can be used to
758
        // get the specific member.
759
303
        std::vector<uint32_t> access_chain;
760
606
        for (uint32_t i = 1; i < use->NumInOperands(); ++i) {
761
303
          const analysis::Constant* index_const =
762
303
              const_mgr->FindDeclaredConstant(use->GetSingleWordInOperand(i));
763
303
          if (index_const) {
764
54
            access_chain.push_back(index_const->GetU32());
765
249
          } else {
766
            // Variable index means the type is an type where every element
767
            // is the same type.  Use element 0 to get the type.
768
249
            access_chain.push_back(0);
769
249
          }
770
303
        }
771
772
303
        Instruction* pointer_type_inst =
773
303
            get_def_use_mgr()->GetDef(new_ptr_inst->type_id());
774
775
303
        uint32_t new_pointee_type_id = GetMemberTypeId(
776
303
            pointer_type_inst->GetSingleWordInOperand(kTypePointerPointeeInIdx),
777
303
            access_chain);
778
779
303
        spv::StorageClass storage_class = static_cast<spv::StorageClass>(
780
303
            pointer_type_inst->GetSingleWordInOperand(
781
303
                kTypePointerStorageClassInIdx));
782
783
303
        uint32_t new_pointer_type_id =
784
303
            type_mgr->FindPointerToType(new_pointee_type_id, storage_class);
785
786
303
        if (new_pointer_type_id != use->type_id()) {
787
4
          use->SetResultType(new_pointer_type_id);
788
4
          context()->AnalyzeUses(use);
789
4
          UpdateUses(use, use);
790
299
        } else {
791
299
          context()->AnalyzeUses(use);
792
299
        }
793
303
      } break;
794
0
      case spv::Op::OpCompositeExtract: {
795
        // Update the actual use.
796
0
        context()->ForgetUses(use);
797
0
        use->SetOperand(index, {new_ptr_inst->result_id()});
798
799
0
        uint32_t new_type_id = new_ptr_inst->type_id();
800
0
        std::vector<uint32_t> access_chain;
801
0
        for (uint32_t i = 1; i < use->NumInOperands(); ++i) {
802
0
          access_chain.push_back(use->GetSingleWordInOperand(i));
803
0
        }
804
805
0
        new_type_id = GetMemberTypeId(new_type_id, access_chain);
806
807
0
        if (new_type_id != use->type_id()) {
808
0
          use->SetResultType(new_type_id);
809
0
          context()->AnalyzeUses(use);
810
0
          UpdateUses(use, use);
811
0
        } else {
812
0
          context()->AnalyzeUses(use);
813
0
        }
814
0
      } break;
815
260
      case spv::Op::OpStore:
816
        // If the use is the pointer, then it is the single store to that
817
        // variable.  We do not want to replace it.  Instead, it will become
818
        // dead after all of the loads are removed, and ADCE will get rid of it.
819
        //
820
        // If the use is the object being stored, we will create a copy of the
821
        // object turning it into the correct type. The copy is done by
822
        // decomposing the object into the base type, which must be the same,
823
        // and then rebuilding them.
824
260
        if (index == 1) {
825
0
          Instruction* target_pointer = def_use_mgr->GetDef(
826
0
              use->GetSingleWordInOperand(kStorePointerInOperand));
827
0
          Instruction* pointer_type =
828
0
              def_use_mgr->GetDef(target_pointer->type_id());
829
0
          uint32_t pointee_type_id =
830
0
              pointer_type->GetSingleWordInOperand(kTypePointerPointeeInIdx);
831
0
          uint32_t copy = GenerateCopy(original_ptr_inst, pointee_type_id, use);
832
0
          assert(copy != 0 &&
833
0
                 "Should not be updating uses unless we know it can be done.");
834
835
0
          context()->ForgetUses(use);
836
0
          use->SetInOperand(index, {copy});
837
0
          context()->AnalyzeUses(use);
838
0
        }
839
260
        break;
840
260
      case spv::Op::OpDecorate:
841
      // We treat an OpImageTexelPointer as a load.  The result type should
842
      // always have the Image storage class, and should not need to be
843
      // updated.
844
1
      case spv::Op::OpImageTexelPointer:
845
        // Replace the actual use.
846
1
        context()->ForgetUses(use);
847
1
        use->SetOperand(index, {new_ptr_inst->result_id()});
848
1
        context()->AnalyzeUses(use);
849
1
        break;
850
0
      default:
851
0
        assert(false && "Don't know how to rewrite instruction");
852
0
        break;
853
620
    }
854
620
  }
855
264
}
856
857
uint32_t CopyPropagateArrays::GetMemberTypeId(
858
661
    uint32_t id, const std::vector<uint32_t>& access_chain) const {
859
661
  for (uint32_t element_index : access_chain) {
860
499
    Instruction* type_inst = get_def_use_mgr()->GetDef(id);
861
499
    switch (type_inst->opcode()) {
862
299
      case spv::Op::OpTypeArray:
863
299
      case spv::Op::OpTypeRuntimeArray:
864
299
      case spv::Op::OpTypeMatrix:
865
431
      case spv::Op::OpTypeVector:
866
431
        id = type_inst->GetSingleWordInOperand(0);
867
431
        break;
868
68
      case spv::Op::OpTypeStruct:
869
68
        id = type_inst->GetSingleWordInOperand(element_index);
870
68
        break;
871
0
      default:
872
0
        break;
873
499
    }
874
499
    assert(id != 0 &&
875
499
           "Tried to extract from an object where it cannot be done.");
876
499
  }
877
661
  return id;
878
661
}
879
880
56
void CopyPropagateArrays::AddUsesToWorklist(Instruction* inst) {
881
56
  analysis::DefUseManager* def_use_mgr = context()->get_def_use_mgr();
882
883
339
  def_use_mgr->ForEachUse(inst, [this](Instruction* use, uint32_t) {
884
339
    if (use->opcode() == spv::Op::OpStore) {
885
17
      uint32_t var_id;
886
17
      Instruction* target_pointer = GetPtr(use, &var_id);
887
17
      if (target_pointer->opcode() != spv::Op::OpVariable) {
888
0
        return;
889
0
      }
890
891
17
      worklist_.push(target_pointer);
892
17
    }
893
339
  });
894
56
}
895
896
void CopyPropagateArrays::MemoryObject::PushIndirection(
897
3.51k
    const std::vector<AccessChainEntry>& access_chain) {
898
3.51k
  access_chain_.insert(access_chain_.end(), access_chain.begin(),
899
3.51k
                       access_chain.end());
900
3.51k
}
901
902
38
uint32_t CopyPropagateArrays::MemoryObject::GetNumberOfMembers() {
903
38
  IRContext* context = variable_inst_->context();
904
38
  analysis::TypeManager* type_mgr = context->get_type_mgr();
905
906
38
  const analysis::Type* type = type_mgr->GetType(variable_inst_->type_id());
907
38
  type = type->AsPointer()->pointee_type();
908
909
38
  std::vector<uint32_t> access_indices = GetAccessIds();
910
38
  type = type_mgr->GetMemberType(type, access_indices);
911
912
38
  return opt::GetNumberOfMembers(type, context);
913
38
}
914
template <class iterator>
915
CopyPropagateArrays::MemoryObject::MemoryObject(Instruction* var_inst,
916
                                                iterator begin, iterator end)
917
4.88k
    : variable_inst_(var_inst) {
918
4.88k
  std::transform(begin, end, std::back_inserter(access_chain_),
919
4.88k
                 [](uint32_t id) { return AccessChainEntry{true, {id}}; });
920
4.88k
}
921
922
396
std::vector<uint32_t> CopyPropagateArrays::MemoryObject::GetAccessIds() const {
923
396
  analysis::ConstantManager* const_mgr =
924
396
      variable_inst_->context()->get_constant_mgr();
925
926
396
  std::vector<uint32_t> indices(AccessChain().size());
927
396
  std::transform(AccessChain().cbegin(), AccessChain().cend(), indices.begin(),
928
396
                 [&const_mgr](const AccessChainEntry& entry) {
929
196
                   if (entry.is_result_id) {
930
162
                     const analysis::Constant* constant =
931
162
                         const_mgr->FindDeclaredConstant(entry.result_id);
932
162
                     return constant == nullptr ? 0 : constant->GetU32();
933
162
                   }
934
935
34
                   return entry.immediate;
936
196
                 });
937
396
  return indices;
938
396
}
939
940
bool CopyPropagateArrays::MemoryObject::Contains(
941
17
    CopyPropagateArrays::MemoryObject* other) {
942
17
  if (this->GetVariable() != other->GetVariable()) {
943
13
    return false;
944
13
  }
945
946
4
  if (AccessChain().size() > other->AccessChain().size()) {
947
0
    return false;
948
0
  }
949
950
4
  for (uint32_t i = 0; i < AccessChain().size(); i++) {
951
0
    if (AccessChain()[i] != other->AccessChain()[i]) {
952
0
      return false;
953
0
    }
954
0
  }
955
4
  return true;
956
4
}
957
958
98
void CopyPropagateArrays::MemoryObject::BuildConstants() {
959
98
  for (auto& entry : access_chain_) {
960
98
    if (entry.is_result_id) {
961
64
      continue;
962
64
    }
963
964
34
    auto context = variable_inst_->context();
965
34
    analysis::Integer int_type(32, false);
966
34
    const analysis::Type* uint32_type =
967
34
        context->get_type_mgr()->GetRegisteredType(&int_type);
968
34
    analysis::ConstantManager* const_mgr = context->get_constant_mgr();
969
34
    const analysis::Constant* index_const =
970
34
        const_mgr->GetConstant(uint32_type, {entry.immediate});
971
34
    entry.result_id =
972
34
        const_mgr->GetDefiningInstruction(index_const)->result_id();
973
34
    entry.is_result_id = true;
974
34
  }
975
98
}
976
977
}  // namespace opt
978
}  // namespace spvtools