Coverage Report

Created: 2025-12-31 06:15

next uncovered line (L), next uncovered region (R), next uncovered branch (B)
/src/spirv-tools/source/opt/interface_var_sroa.cpp
Line
Count
Source
1
// Copyright (c) 2022 Google LLC
2
//
3
// Licensed under the Apache License, Version 2.0 (the "License");
4
// you may not use this file except in compliance with the License.
5
// You may obtain a copy of the License at
6
//
7
//     http://www.apache.org/licenses/LICENSE-2.0
8
//
9
// Unless required by applicable law or agreed to in writing, software
10
// distributed under the License is distributed on an "AS IS" BASIS,
11
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
// See the License for the specific language governing permissions and
13
// limitations under the License.
14
15
#include "source/opt/interface_var_sroa.h"
16
17
#include <iostream>
18
19
#include "source/opt/decoration_manager.h"
20
#include "source/opt/def_use_manager.h"
21
#include "source/opt/function.h"
22
#include "source/opt/log.h"
23
#include "source/opt/type_manager.h"
24
#include "source/util/make_unique.h"
25
26
namespace spvtools {
27
namespace opt {
28
namespace {
29
constexpr uint32_t kOpDecorateDecorationInOperandIndex = 1;
30
constexpr uint32_t kOpDecorateLiteralInOperandIndex = 2;
31
constexpr uint32_t kOpEntryPointInOperandInterface = 3;
32
constexpr uint32_t kOpVariableStorageClassInOperandIndex = 0;
33
constexpr uint32_t kOpTypeArrayElemTypeInOperandIndex = 0;
34
constexpr uint32_t kOpTypeArrayLengthInOperandIndex = 1;
35
constexpr uint32_t kOpTypeMatrixColCountInOperandIndex = 1;
36
constexpr uint32_t kOpTypeMatrixColTypeInOperandIndex = 0;
37
constexpr uint32_t kOpTypePtrTypeInOperandIndex = 1;
38
constexpr uint32_t kOpConstantValueInOperandIndex = 0;
39
40
// Get the length of the OpTypeArray |array_type|.
41
uint32_t GetArrayLength(analysis::DefUseManager* def_use_mgr,
42
0
                        Instruction* array_type) {
43
0
  assert(array_type->opcode() == spv::Op::OpTypeArray);
44
0
  uint32_t const_int_id =
45
0
      array_type->GetSingleWordInOperand(kOpTypeArrayLengthInOperandIndex);
46
0
  Instruction* array_length_inst = def_use_mgr->GetDef(const_int_id);
47
0
  assert(array_length_inst->opcode() == spv::Op::OpConstant);
48
0
  return array_length_inst->GetSingleWordInOperand(
49
0
      kOpConstantValueInOperandIndex);
50
0
}
51
52
// Get the element type instruction of the OpTypeArray |array_type|.
53
Instruction* GetArrayElementType(analysis::DefUseManager* def_use_mgr,
54
0
                                 Instruction* array_type) {
55
0
  assert(array_type->opcode() == spv::Op::OpTypeArray);
56
0
  uint32_t elem_type_id =
57
0
      array_type->GetSingleWordInOperand(kOpTypeArrayElemTypeInOperandIndex);
58
0
  return def_use_mgr->GetDef(elem_type_id);
59
0
}
60
61
// Get the column type instruction of the OpTypeMatrix |matrix_type|.
62
Instruction* GetMatrixColumnType(analysis::DefUseManager* def_use_mgr,
63
0
                                 Instruction* matrix_type) {
64
0
  assert(matrix_type->opcode() == spv::Op::OpTypeMatrix);
65
0
  uint32_t column_type_id =
66
0
      matrix_type->GetSingleWordInOperand(kOpTypeMatrixColTypeInOperandIndex);
67
0
  return def_use_mgr->GetDef(column_type_id);
68
0
}
69
70
// Traverses the component type of OpTypeArray or OpTypeMatrix. Repeats it
71
// |depth_to_component| times recursively and returns the component type.
72
// |type_id| is the result id of the OpTypeArray or OpTypeMatrix instruction.
73
uint32_t GetComponentTypeOfArrayMatrix(analysis::DefUseManager* def_use_mgr,
74
                                       uint32_t type_id,
75
0
                                       uint32_t depth_to_component) {
76
0
  if (depth_to_component == 0) return type_id;
77
78
0
  Instruction* type_inst = def_use_mgr->GetDef(type_id);
79
0
  if (type_inst->opcode() == spv::Op::OpTypeArray) {
80
0
    uint32_t elem_type_id =
81
0
        type_inst->GetSingleWordInOperand(kOpTypeArrayElemTypeInOperandIndex);
82
0
    return GetComponentTypeOfArrayMatrix(def_use_mgr, elem_type_id,
83
0
                                         depth_to_component - 1);
84
0
  }
85
86
0
  assert(type_inst->opcode() == spv::Op::OpTypeMatrix);
87
0
  uint32_t column_type_id =
88
0
      type_inst->GetSingleWordInOperand(kOpTypeMatrixColTypeInOperandIndex);
89
0
  return GetComponentTypeOfArrayMatrix(def_use_mgr, column_type_id,
90
0
                                       depth_to_component - 1);
91
0
}
92
93
// Creates an OpDecorate instruction whose Target is |var_id| and Decoration is
94
// |decoration|. Adds |literal| as an extra operand of the instruction.
95
void CreateDecoration(analysis::DecorationManager* decoration_mgr,
96
                      uint32_t var_id, spv::Decoration decoration,
97
0
                      uint32_t literal) {
98
0
  std::vector<Operand> operands({
99
0
      {spv_operand_type_t::SPV_OPERAND_TYPE_ID, {var_id}},
100
0
      {spv_operand_type_t::SPV_OPERAND_TYPE_DECORATION,
101
0
       {static_cast<uint32_t>(decoration)}},
102
0
      {spv_operand_type_t::SPV_OPERAND_TYPE_LITERAL_INTEGER, {literal}},
103
0
  });
104
0
  decoration_mgr->AddDecoration(spv::Op::OpDecorate, std::move(operands));
105
0
}
106
107
// Replaces load instructions with composite construct instructions in all the
108
// users of the loads. |loads_to_composites| is the mapping from each load to
109
// its corresponding OpCompositeConstruct.
110
void ReplaceLoadWithCompositeConstruct(
111
    IRContext* context,
112
0
    const std::unordered_map<Instruction*, Instruction*>& loads_to_composites) {
113
0
  for (const auto& load_and_composite : loads_to_composites) {
114
0
    Instruction* load = load_and_composite.first;
115
0
    Instruction* composite_construct = load_and_composite.second;
116
117
0
    std::vector<Instruction*> users;
118
0
    context->get_def_use_mgr()->ForEachUse(
119
0
        load, [&users, composite_construct](Instruction* user, uint32_t index) {
120
0
          user->GetOperand(index).words[0] = composite_construct->result_id();
121
0
          users.push_back(user);
122
0
        });
123
124
0
    for (Instruction* user : users)
125
0
      context->get_def_use_mgr()->AnalyzeInstUse(user);
126
0
  }
127
0
}
128
129
// Returns the storage class of the instruction |var|.
130
0
spv::StorageClass GetStorageClass(Instruction* var) {
131
0
  return static_cast<spv::StorageClass>(
132
0
      var->GetSingleWordInOperand(kOpVariableStorageClassInOperandIndex));
133
0
}
134
135
}  // namespace
136
137
bool InterfaceVariableScalarReplacement::HasExtraArrayness(
138
0
    Instruction& entry_point, Instruction* var) {
139
0
  spv::ExecutionModel execution_model =
140
0
      static_cast<spv::ExecutionModel>(entry_point.GetSingleWordInOperand(0));
141
0
  if (execution_model != spv::ExecutionModel::TessellationEvaluation &&
142
0
      execution_model != spv::ExecutionModel::TessellationControl) {
143
0
    return false;
144
0
  }
145
0
  if (!context()->get_decoration_mgr()->HasDecoration(
146
0
          var->result_id(), uint32_t(spv::Decoration::Patch))) {
147
0
    if (execution_model == spv::ExecutionModel::TessellationControl)
148
0
      return true;
149
0
    return GetStorageClass(var) != spv::StorageClass::Output;
150
0
  }
151
0
  return false;
152
0
}
153
154
bool InterfaceVariableScalarReplacement::
155
    CheckExtraArraynessConflictBetweenEntries(Instruction* interface_var,
156
0
                                              bool has_extra_arrayness) {
157
0
  if (has_extra_arrayness) {
158
0
    return !ReportErrorIfHasNoExtraArraynessForOtherEntry(interface_var);
159
0
  }
160
0
  return !ReportErrorIfHasExtraArraynessForOtherEntry(interface_var);
161
0
}
162
163
bool InterfaceVariableScalarReplacement::GetVariableLocation(
164
0
    Instruction* var, uint32_t* location) {
165
0
  return !context()->get_decoration_mgr()->WhileEachDecoration(
166
0
      var->result_id(), uint32_t(spv::Decoration::Location),
167
0
      [location](const Instruction& inst) {
168
0
        *location =
169
0
            inst.GetSingleWordInOperand(kOpDecorateLiteralInOperandIndex);
170
0
        return false;
171
0
      });
172
0
}
173
174
bool InterfaceVariableScalarReplacement::GetVariableComponent(
175
0
    Instruction* var, uint32_t* component) {
176
0
  return !context()->get_decoration_mgr()->WhileEachDecoration(
177
0
      var->result_id(), uint32_t(spv::Decoration::Component),
178
0
      [component](const Instruction& inst) {
179
0
        *component =
180
0
            inst.GetSingleWordInOperand(kOpDecorateLiteralInOperandIndex);
181
0
        return false;
182
0
      });
183
0
}
184
185
std::vector<Instruction*>
186
InterfaceVariableScalarReplacement::CollectInterfaceVariables(
187
0
    Instruction& entry_point) {
188
0
  std::vector<Instruction*> interface_vars;
189
0
  for (uint32_t i = kOpEntryPointInOperandInterface;
190
0
       i < entry_point.NumInOperands(); ++i) {
191
0
    Instruction* interface_var = context()->get_def_use_mgr()->GetDef(
192
0
        entry_point.GetSingleWordInOperand(i));
193
0
    assert(interface_var->opcode() == spv::Op::OpVariable);
194
195
0
    spv::StorageClass storage_class = GetStorageClass(interface_var);
196
0
    if (storage_class != spv::StorageClass::Input &&
197
0
        storage_class != spv::StorageClass::Output) {
198
0
      continue;
199
0
    }
200
201
0
    interface_vars.push_back(interface_var);
202
0
  }
203
0
  return interface_vars;
204
0
}
205
206
void InterfaceVariableScalarReplacement::KillInstructionAndUsers(
207
0
    Instruction* inst) {
208
0
  if (inst->opcode() == spv::Op::OpEntryPoint) {
209
0
    return;
210
0
  }
211
0
  if (inst->opcode() != spv::Op::OpAccessChain) {
212
0
    context()->KillInst(inst);
213
0
    return;
214
0
  }
215
0
  std::vector<Instruction*> users;
216
0
  context()->get_def_use_mgr()->ForEachUser(
217
0
      inst, [&users](Instruction* user) { users.push_back(user); });
218
0
  for (auto user : users) {
219
0
    context()->KillInst(user);
220
0
  }
221
0
  context()->KillInst(inst);
222
0
}
223
224
void InterfaceVariableScalarReplacement::KillInstructionsAndUsers(
225
0
    const std::vector<Instruction*>& insts) {
226
0
  for (Instruction* inst : insts) {
227
0
    KillInstructionAndUsers(inst);
228
0
  }
229
0
}
230
231
void InterfaceVariableScalarReplacement::KillLocationAndComponentDecorations(
232
0
    uint32_t var_id) {
233
0
  context()->get_decoration_mgr()->RemoveDecorationsFrom(
234
0
      var_id, [](const Instruction& inst) {
235
0
        spv::Decoration decoration = spv::Decoration(
236
0
            inst.GetSingleWordInOperand(kOpDecorateDecorationInOperandIndex));
237
0
        return decoration == spv::Decoration::Location ||
238
0
               decoration == spv::Decoration::Component;
239
0
      });
240
0
}
241
242
Pass::Status
243
InterfaceVariableScalarReplacement::ReplaceInterfaceVariableWithScalars(
244
    Instruction* interface_var, Instruction* interface_var_type,
245
0
    uint32_t location, uint32_t component, uint32_t extra_array_length) {
246
0
  std::optional<NestedCompositeComponents> scalar_interface_vars =
247
0
      CreateScalarInterfaceVarsForReplacement(interface_var_type,
248
0
                                              GetStorageClass(interface_var),
249
0
                                              extra_array_length);
250
251
0
  if (!scalar_interface_vars) {
252
0
    return Status::Failure;
253
0
  }
254
255
0
  AddLocationAndComponentDecorations(*scalar_interface_vars, &location,
256
0
                                     component);
257
0
  KillLocationAndComponentDecorations(interface_var->result_id());
258
259
0
  Status status = ReplaceInterfaceVarWith(interface_var, extra_array_length,
260
0
                                          *scalar_interface_vars);
261
0
  if (status == Status::Failure) {
262
0
    return status;
263
0
  }
264
265
0
  context()->KillInst(interface_var);
266
0
  return status;
267
0
}
268
269
Pass::Status InterfaceVariableScalarReplacement::ReplaceInterfaceVarWith(
270
    Instruction* interface_var, uint32_t extra_array_length,
271
0
    const NestedCompositeComponents& scalar_interface_vars) {
272
0
  std::vector<Instruction*> users;
273
0
  context()->get_def_use_mgr()->ForEachUser(
274
0
      interface_var, [&users](Instruction* user) { users.push_back(user); });
275
276
0
  std::vector<uint32_t> interface_var_component_indices;
277
0
  std::unordered_map<Instruction*, Instruction*> loads_to_composites;
278
0
  std::unordered_map<Instruction*, Instruction*>
279
0
      loads_for_access_chain_to_composites;
280
0
  if (extra_array_length != 0) {
281
    // Note that the extra arrayness is the first dimension of the array
282
    // interface variable.
283
0
    for (uint32_t index = 0; index < extra_array_length; ++index) {
284
0
      std::unordered_map<Instruction*, Instruction*> loads_to_component_values;
285
0
      Status status = ReplaceComponentsOfInterfaceVarWith(
286
0
          interface_var, users, scalar_interface_vars,
287
0
          interface_var_component_indices, &index, &loads_to_component_values,
288
0
          &loads_for_access_chain_to_composites);
289
0
      if (status == Status::Failure) {
290
0
        return Status::Failure;
291
0
      }
292
0
      AddComponentsToCompositesForLoads(loads_to_component_values,
293
0
                                        &loads_to_composites, 0);
294
0
    }
295
0
  } else {
296
0
    Status status = ReplaceComponentsOfInterfaceVarWith(
297
0
        interface_var, users, scalar_interface_vars,
298
0
        interface_var_component_indices, nullptr, &loads_to_composites,
299
0
        &loads_for_access_chain_to_composites);
300
0
    if (status == Status::Failure) {
301
0
      return Status::Failure;
302
0
    }
303
0
  }
304
305
0
  ReplaceLoadWithCompositeConstruct(context(), loads_to_composites);
306
0
  ReplaceLoadWithCompositeConstruct(context(),
307
0
                                    loads_for_access_chain_to_composites);
308
309
0
  KillInstructionsAndUsers(users);
310
0
  return Status::SuccessWithChange;
311
0
}
312
313
void InterfaceVariableScalarReplacement::AddLocationAndComponentDecorations(
314
    const NestedCompositeComponents& vars, uint32_t* location,
315
0
    uint32_t component) {
316
0
  if (!vars.HasMultipleComponents()) {
317
0
    uint32_t var_id = vars.GetComponentVariable()->result_id();
318
0
    CreateDecoration(context()->get_decoration_mgr(), var_id,
319
0
                     spv::Decoration::Location, *location);
320
0
    CreateDecoration(context()->get_decoration_mgr(), var_id,
321
0
                     spv::Decoration::Component, component);
322
0
    ++(*location);
323
0
    return;
324
0
  }
325
0
  for (const auto& var : vars.GetComponents()) {
326
0
    AddLocationAndComponentDecorations(var, location, component);
327
0
  }
328
0
}
329
330
Pass::Status
331
InterfaceVariableScalarReplacement::ReplaceComponentsOfInterfaceVarWith(
332
    Instruction* interface_var,
333
    const std::vector<Instruction*>& interface_var_users,
334
    const NestedCompositeComponents& scalar_interface_vars,
335
    std::vector<uint32_t>& interface_var_component_indices,
336
    const uint32_t* extra_array_index,
337
    std::unordered_map<Instruction*, Instruction*>* loads_to_composites,
338
    std::unordered_map<Instruction*, Instruction*>*
339
0
        loads_for_access_chain_to_composites) {
340
0
  if (!scalar_interface_vars.HasMultipleComponents()) {
341
0
    for (Instruction* interface_var_user : interface_var_users) {
342
0
      Status status = ReplaceComponentOfInterfaceVarWith(
343
0
          interface_var, interface_var_user,
344
0
          scalar_interface_vars.GetComponentVariable(),
345
0
          interface_var_component_indices, extra_array_index,
346
0
          loads_to_composites, loads_for_access_chain_to_composites);
347
0
      if (status == Status::Failure) {
348
0
        return Status::Failure;
349
0
      }
350
0
    }
351
0
    return Status::SuccessWithChange;
352
0
  }
353
0
  return ReplaceMultipleComponentsOfInterfaceVarWith(
354
0
      interface_var, interface_var_users, scalar_interface_vars.GetComponents(),
355
0
      interface_var_component_indices, extra_array_index, loads_to_composites,
356
0
      loads_for_access_chain_to_composites);
357
0
}
358
359
Pass::Status
360
InterfaceVariableScalarReplacement::ReplaceMultipleComponentsOfInterfaceVarWith(
361
    Instruction* interface_var,
362
    const std::vector<Instruction*>& interface_var_users,
363
    const std::vector<NestedCompositeComponents>& components,
364
    std::vector<uint32_t>& interface_var_component_indices,
365
    const uint32_t* extra_array_index,
366
    std::unordered_map<Instruction*, Instruction*>* loads_to_composites,
367
    std::unordered_map<Instruction*, Instruction*>*
368
0
        loads_for_access_chain_to_composites) {
369
0
  for (uint32_t i = 0; i < components.size(); ++i) {
370
0
    interface_var_component_indices.push_back(i);
371
0
    std::unordered_map<Instruction*, Instruction*> loads_to_component_values;
372
0
    std::unordered_map<Instruction*, Instruction*>
373
0
        loads_for_access_chain_to_component_values;
374
0
    Status status = ReplaceComponentsOfInterfaceVarWith(
375
0
        interface_var, interface_var_users, components[i],
376
0
        interface_var_component_indices, extra_array_index,
377
0
        &loads_to_component_values,
378
0
        &loads_for_access_chain_to_component_values);
379
0
    if (status == Status::Failure) {
380
0
      return Status::Failure;
381
0
    }
382
0
    interface_var_component_indices.pop_back();
383
384
0
    uint32_t depth_to_component =
385
0
        static_cast<uint32_t>(interface_var_component_indices.size());
386
0
    AddComponentsToCompositesForLoads(
387
0
        loads_for_access_chain_to_component_values,
388
0
        loads_for_access_chain_to_composites, depth_to_component);
389
0
    if (extra_array_index) ++depth_to_component;
390
0
    AddComponentsToCompositesForLoads(loads_to_component_values,
391
0
                                      loads_to_composites, depth_to_component);
392
0
  }
393
0
  return Status::SuccessWithChange;
394
0
}
395
396
Pass::Status
397
InterfaceVariableScalarReplacement::ReplaceComponentOfInterfaceVarWith(
398
    Instruction* interface_var, Instruction* interface_var_user,
399
    Instruction* scalar_var,
400
    const std::vector<uint32_t>& interface_var_component_indices,
401
    const uint32_t* extra_array_index,
402
    std::unordered_map<Instruction*, Instruction*>* loads_to_component_values,
403
    std::unordered_map<Instruction*, Instruction*>*
404
0
        loads_for_access_chain_to_component_values) {
405
0
  spv::Op opcode = interface_var_user->opcode();
406
0
  if (opcode == spv::Op::OpStore) {
407
0
    uint32_t value_id = interface_var_user->GetSingleWordInOperand(1);
408
0
    StoreComponentOfValueToScalarVar(value_id, interface_var_component_indices,
409
0
                                     scalar_var, extra_array_index,
410
0
                                     interface_var_user);
411
0
    return Status::SuccessWithChange;
412
0
  }
413
0
  if (opcode == spv::Op::OpLoad) {
414
0
    Instruction* scalar_load =
415
0
        LoadScalarVar(scalar_var, extra_array_index, interface_var_user);
416
0
    if (scalar_load == nullptr) {
417
0
      return Status::Failure;
418
0
    }
419
0
    loads_to_component_values->insert({interface_var_user, scalar_load});
420
0
    return Status::SuccessWithChange;
421
0
  }
422
423
  // Copy OpName and annotation instructions only once. Therefore, we create
424
  // them only for the first element of the extra array.
425
0
  if (extra_array_index && *extra_array_index != 0)
426
0
    return Status::SuccessWithChange;
427
428
0
  if (opcode == spv::Op::OpDecorateId || opcode == spv::Op::OpDecorateString ||
429
0
      opcode == spv::Op::OpDecorate) {
430
0
    CloneAnnotationForVariable(interface_var_user, scalar_var->result_id());
431
0
    return Status::SuccessWithChange;
432
0
  }
433
434
0
  if (opcode == spv::Op::OpName) {
435
0
    std::unique_ptr<Instruction> new_inst(interface_var_user->Clone(context()));
436
0
    new_inst->SetInOperand(0, {scalar_var->result_id()});
437
0
    context()->AddDebug2Inst(std::move(new_inst));
438
0
    return Status::SuccessWithChange;
439
0
  }
440
441
0
  if (opcode == spv::Op::OpEntryPoint) {
442
0
    if (ReplaceInterfaceVarInEntryPoint(interface_var, interface_var_user,
443
0
                                        scalar_var->result_id())) {
444
0
      return Status::SuccessWithChange;
445
0
    }
446
0
    return Status::Failure;
447
0
  }
448
449
0
  if (opcode == spv::Op::OpAccessChain) {
450
0
    ReplaceAccessChainWith(interface_var_user, interface_var_component_indices,
451
0
                           scalar_var,
452
0
                           loads_for_access_chain_to_component_values);
453
0
    return Status::SuccessWithChange;
454
0
  }
455
456
0
  std::string message("Unhandled instruction");
457
0
  message += "\n  " + interface_var_user->PrettyPrint(
458
0
                          SPV_BINARY_TO_TEXT_OPTION_FRIENDLY_NAMES);
459
0
  message +=
460
0
      "\nfor interface variable scalar replacement\n  " +
461
0
      interface_var->PrettyPrint(SPV_BINARY_TO_TEXT_OPTION_FRIENDLY_NAMES);
462
0
  context()->consumer()(SPV_MSG_ERROR, "", {0, 0, 0}, message.c_str());
463
0
  return Status::Failure;
464
0
}
465
466
void InterfaceVariableScalarReplacement::UseBaseAccessChainForAccessChain(
467
0
    Instruction* access_chain, Instruction* base_access_chain) {
468
0
  assert(base_access_chain->opcode() == spv::Op::OpAccessChain &&
469
0
         access_chain->opcode() == spv::Op::OpAccessChain &&
470
0
         access_chain->GetSingleWordInOperand(0) ==
471
0
             base_access_chain->result_id());
472
0
  Instruction::OperandList new_operands;
473
0
  for (uint32_t i = 0; i < base_access_chain->NumInOperands(); ++i) {
474
0
    new_operands.emplace_back(base_access_chain->GetInOperand(i));
475
0
  }
476
0
  for (uint32_t i = 1; i < access_chain->NumInOperands(); ++i) {
477
0
    new_operands.emplace_back(access_chain->GetInOperand(i));
478
0
  }
479
0
  access_chain->SetInOperands(std::move(new_operands));
480
0
}
481
482
Instruction* InterfaceVariableScalarReplacement::CreateAccessChainToVar(
483
    uint32_t var_type_id, Instruction* var,
484
    const std::vector<uint32_t>& index_ids, Instruction* insert_before,
485
0
    uint32_t* component_type_id) {
486
0
  analysis::DefUseManager* def_use_mgr = context()->get_def_use_mgr();
487
0
  *component_type_id = GetComponentTypeOfArrayMatrix(
488
0
      def_use_mgr, var_type_id, static_cast<uint32_t>(index_ids.size()));
489
490
0
  uint32_t ptr_type_id =
491
0
      GetPointerType(*component_type_id, GetStorageClass(var));
492
493
0
  uint32_t new_id = TakeNextId();
494
0
  if (new_id == 0) {
495
0
    return nullptr;
496
0
  }
497
0
  std::unique_ptr<Instruction> new_access_chain(
498
0
      new Instruction(context(), spv::Op::OpAccessChain, ptr_type_id, new_id,
499
0
                      std::initializer_list<Operand>{
500
0
                          {SPV_OPERAND_TYPE_ID, {var->result_id()}}}));
501
0
  for (uint32_t index_id : index_ids) {
502
0
    new_access_chain->AddOperand({SPV_OPERAND_TYPE_ID, {index_id}});
503
0
  }
504
505
0
  Instruction* inst = new_access_chain.get();
506
0
  def_use_mgr->AnalyzeInstDefUse(inst);
507
0
  insert_before->InsertBefore(std::move(new_access_chain));
508
0
  return inst;
509
0
}
510
511
Instruction* InterfaceVariableScalarReplacement::CreateAccessChainWithIndex(
512
    uint32_t component_type_id, Instruction* var, uint32_t index,
513
0
    Instruction* insert_before) {
514
0
  uint32_t ptr_type_id =
515
0
      GetPointerType(component_type_id, GetStorageClass(var));
516
0
  uint32_t index_id = context()->get_constant_mgr()->GetUIntConstId(index);
517
0
  uint32_t new_id = TakeNextId();
518
0
  if (new_id == 0) {
519
0
    return nullptr;
520
0
  }
521
0
  std::unique_ptr<Instruction> new_access_chain(
522
0
      new Instruction(context(), spv::Op::OpAccessChain, ptr_type_id, new_id,
523
0
                      std::initializer_list<Operand>{
524
0
                          {SPV_OPERAND_TYPE_ID, {var->result_id()}},
525
0
                          {SPV_OPERAND_TYPE_ID, {index_id}},
526
0
                      }));
527
0
  Instruction* inst = new_access_chain.get();
528
0
  context()->get_def_use_mgr()->AnalyzeInstDefUse(inst);
529
0
  insert_before->InsertBefore(std::move(new_access_chain));
530
0
  return inst;
531
0
}
532
533
void InterfaceVariableScalarReplacement::ReplaceAccessChainWith(
534
    Instruction* access_chain,
535
    const std::vector<uint32_t>& interface_var_component_indices,
536
    Instruction* scalar_var,
537
0
    std::unordered_map<Instruction*, Instruction*>* loads_to_component_values) {
538
0
  std::vector<uint32_t> indexes;
539
0
  for (uint32_t i = 1; i < access_chain->NumInOperands(); ++i) {
540
0
    indexes.push_back(access_chain->GetSingleWordInOperand(i));
541
0
  }
542
543
  // Note that we have a strong assumption that |access_chain| has only a single
544
  // index that is for the extra arrayness.
545
0
  context()->get_def_use_mgr()->ForEachUser(
546
0
      access_chain,
547
0
      [this, access_chain, &indexes, &interface_var_component_indices,
548
0
       scalar_var, loads_to_component_values](Instruction* user) {
549
0
        switch (user->opcode()) {
550
0
          case spv::Op::OpAccessChain: {
551
0
            UseBaseAccessChainForAccessChain(user, access_chain);
552
0
            ReplaceAccessChainWith(user, interface_var_component_indices,
553
0
                                   scalar_var, loads_to_component_values);
554
0
            return;
555
0
          }
556
0
          case spv::Op::OpStore: {
557
0
            uint32_t value_id = user->GetSingleWordInOperand(1);
558
0
            StoreComponentOfValueToAccessChainToScalarVar(
559
0
                value_id, interface_var_component_indices, scalar_var, indexes,
560
0
                user);
561
0
            return;
562
0
          }
563
0
          case spv::Op::OpLoad: {
564
0
            Instruction* value =
565
0
                LoadAccessChainToVar(scalar_var, indexes, user);
566
0
            loads_to_component_values->insert({user, value});
567
0
            return;
568
0
          }
569
0
          default:
570
0
            break;
571
0
        }
572
0
      });
573
0
}
574
575
void InterfaceVariableScalarReplacement::CloneAnnotationForVariable(
576
0
    Instruction* annotation_inst, uint32_t var_id) {
577
0
  assert(annotation_inst->opcode() == spv::Op::OpDecorate ||
578
0
         annotation_inst->opcode() == spv::Op::OpDecorateId ||
579
0
         annotation_inst->opcode() == spv::Op::OpDecorateString);
580
0
  std::unique_ptr<Instruction> new_inst(annotation_inst->Clone(context()));
581
0
  new_inst->SetInOperand(0, {var_id});
582
0
  context()->AddAnnotationInst(std::move(new_inst));
583
0
}
584
585
bool InterfaceVariableScalarReplacement::ReplaceInterfaceVarInEntryPoint(
586
    Instruction* interface_var, Instruction* entry_point,
587
0
    uint32_t scalar_var_id) {
588
0
  analysis::DefUseManager* def_use_mgr = context()->get_def_use_mgr();
589
0
  uint32_t interface_var_id = interface_var->result_id();
590
0
  if (interface_vars_removed_from_entry_point_operands_.find(
591
0
          interface_var_id) !=
592
0
      interface_vars_removed_from_entry_point_operands_.end()) {
593
0
    entry_point->AddOperand({SPV_OPERAND_TYPE_ID, {scalar_var_id}});
594
0
    def_use_mgr->AnalyzeInstUse(entry_point);
595
0
    return true;
596
0
  }
597
598
0
  bool success = !entry_point->WhileEachInId(
599
0
      [&interface_var_id, &scalar_var_id](uint32_t* id) {
600
0
        if (*id == interface_var_id) {
601
0
          *id = scalar_var_id;
602
0
          return false;
603
0
        }
604
0
        return true;
605
0
      });
606
0
  if (!success) {
607
0
    std::string message(
608
0
        "interface variable is not an operand of the entry point");
609
0
    message += "\n  " + interface_var->PrettyPrint(
610
0
                            SPV_BINARY_TO_TEXT_OPTION_FRIENDLY_NAMES);
611
0
    message += "\n  " + entry_point->PrettyPrint(
612
0
                            SPV_BINARY_TO_TEXT_OPTION_FRIENDLY_NAMES);
613
0
    context()->consumer()(SPV_MSG_ERROR, "", {0, 0, 0}, message.c_str());
614
0
    return false;
615
0
  }
616
617
0
  def_use_mgr->AnalyzeInstUse(entry_point);
618
0
  interface_vars_removed_from_entry_point_operands_.insert(interface_var_id);
619
0
  return true;
620
0
}
621
622
uint32_t InterfaceVariableScalarReplacement::GetPointeeTypeIdOfVar(
623
0
    Instruction* var) {
624
0
  assert(var->opcode() == spv::Op::OpVariable);
625
626
0
  uint32_t ptr_type_id = var->type_id();
627
0
  analysis::DefUseManager* def_use_mgr = context()->get_def_use_mgr();
628
0
  Instruction* ptr_type_inst = def_use_mgr->GetDef(ptr_type_id);
629
630
0
  assert(ptr_type_inst->opcode() == spv::Op::OpTypePointer &&
631
0
         "Variable must have a pointer type.");
632
0
  return ptr_type_inst->GetSingleWordInOperand(kOpTypePtrTypeInOperandIndex);
633
0
}
634
635
void InterfaceVariableScalarReplacement::StoreComponentOfValueToScalarVar(
636
    uint32_t value_id, const std::vector<uint32_t>& component_indices,
637
    Instruction* scalar_var, const uint32_t* extra_array_index,
638
0
    Instruction* insert_before) {
639
0
  uint32_t component_type_id = GetPointeeTypeIdOfVar(scalar_var);
640
0
  Instruction* ptr = scalar_var;
641
0
  if (extra_array_index) {
642
0
    auto* ty_mgr = context()->get_type_mgr();
643
0
    analysis::Array* array_type = ty_mgr->GetType(component_type_id)->AsArray();
644
0
    assert(array_type != nullptr);
645
0
    component_type_id = ty_mgr->GetTypeInstruction(array_type->element_type());
646
0
    ptr = CreateAccessChainWithIndex(component_type_id, scalar_var,
647
0
                                     *extra_array_index, insert_before);
648
0
    if (ptr == nullptr) {
649
0
      return;
650
0
    }
651
0
  }
652
653
0
  StoreComponentOfValueTo(component_type_id, value_id, component_indices, ptr,
654
0
                          extra_array_index, insert_before);
655
0
}
656
657
Instruction* InterfaceVariableScalarReplacement::LoadScalarVar(
658
    Instruction* scalar_var, const uint32_t* extra_array_index,
659
0
    Instruction* insert_before) {
660
0
  uint32_t component_type_id = GetPointeeTypeIdOfVar(scalar_var);
661
0
  Instruction* ptr = scalar_var;
662
0
  if (extra_array_index) {
663
0
    auto* ty_mgr = context()->get_type_mgr();
664
0
    analysis::Array* array_type = ty_mgr->GetType(component_type_id)->AsArray();
665
0
    assert(array_type != nullptr);
666
0
    component_type_id = ty_mgr->GetTypeInstruction(array_type->element_type());
667
0
    ptr = CreateAccessChainWithIndex(component_type_id, scalar_var,
668
0
                                     *extra_array_index, insert_before);
669
0
    if (ptr == nullptr) {
670
0
      return nullptr;
671
0
    }
672
0
  }
673
674
0
  return CreateLoad(component_type_id, ptr, insert_before);
675
0
}
676
677
Instruction* InterfaceVariableScalarReplacement::CreateLoad(
678
0
    uint32_t type_id, Instruction* ptr, Instruction* insert_before) {
679
0
  uint32_t new_id = TakeNextId();
680
0
  if (new_id == 0) {
681
0
    return nullptr;
682
0
  }
683
0
  std::unique_ptr<Instruction> load(
684
0
      new Instruction(context(), spv::Op::OpLoad, type_id, new_id,
685
0
                      std::initializer_list<Operand>{
686
0
                          {SPV_OPERAND_TYPE_ID, {ptr->result_id()}}}));
687
0
  Instruction* load_inst = load.get();
688
0
  context()->get_def_use_mgr()->AnalyzeInstDefUse(load_inst);
689
0
  insert_before->InsertBefore(std::move(load));
690
0
  return load_inst;
691
0
}
692
693
void InterfaceVariableScalarReplacement::StoreComponentOfValueTo(
694
    uint32_t component_type_id, uint32_t value_id,
695
    const std::vector<uint32_t>& component_indices, Instruction* ptr,
696
0
    const uint32_t* extra_array_index, Instruction* insert_before) {
697
0
  std::unique_ptr<Instruction> composite_extract(CreateCompositeExtract(
698
0
      component_type_id, value_id, component_indices, extra_array_index));
699
0
  if (composite_extract == nullptr) {
700
0
    return;
701
0
  }
702
703
0
  std::unique_ptr<Instruction> new_store(
704
0
      new Instruction(context(), spv::Op::OpStore));
705
0
  new_store->AddOperand({SPV_OPERAND_TYPE_ID, {ptr->result_id()}});
706
0
  new_store->AddOperand(
707
0
      {SPV_OPERAND_TYPE_ID, {composite_extract->result_id()}});
708
709
0
  analysis::DefUseManager* def_use_mgr = context()->get_def_use_mgr();
710
0
  def_use_mgr->AnalyzeInstDefUse(composite_extract.get());
711
0
  def_use_mgr->AnalyzeInstDefUse(new_store.get());
712
713
0
  insert_before->InsertBefore(std::move(composite_extract));
714
0
  insert_before->InsertBefore(std::move(new_store));
715
0
}
716
717
Instruction* InterfaceVariableScalarReplacement::CreateCompositeExtract(
718
    uint32_t type_id, uint32_t composite_id,
719
0
    const std::vector<uint32_t>& indexes, const uint32_t* extra_first_index) {
720
0
  uint32_t component_id = TakeNextId();
721
0
  if (component_id == 0) {
722
0
    return nullptr;
723
0
  }
724
0
  Instruction* composite_extract = new Instruction(
725
0
      context(), spv::Op::OpCompositeExtract, type_id, component_id,
726
0
      std::initializer_list<Operand>{{SPV_OPERAND_TYPE_ID, {composite_id}}});
727
0
  if (extra_first_index) {
728
0
    composite_extract->AddOperand(
729
0
        {SPV_OPERAND_TYPE_LITERAL_INTEGER, {*extra_first_index}});
730
0
  }
731
0
  for (uint32_t index : indexes) {
732
0
    composite_extract->AddOperand({SPV_OPERAND_TYPE_LITERAL_INTEGER, {index}});
733
0
  }
734
0
  return composite_extract;
735
0
}
736
737
void InterfaceVariableScalarReplacement::
738
    StoreComponentOfValueToAccessChainToScalarVar(
739
        uint32_t value_id, const std::vector<uint32_t>& component_indices,
740
        Instruction* scalar_var,
741
        const std::vector<uint32_t>& access_chain_indices,
742
0
        Instruction* insert_before) {
743
0
  uint32_t component_type_id = GetPointeeTypeIdOfVar(scalar_var);
744
0
  Instruction* ptr = scalar_var;
745
0
  if (!access_chain_indices.empty()) {
746
0
    ptr = CreateAccessChainToVar(component_type_id, scalar_var,
747
0
                                 access_chain_indices, insert_before,
748
0
                                 &component_type_id);
749
0
  }
750
751
0
  StoreComponentOfValueTo(component_type_id, value_id, component_indices, ptr,
752
0
                          nullptr, insert_before);
753
0
}
754
755
Instruction* InterfaceVariableScalarReplacement::LoadAccessChainToVar(
756
    Instruction* var, const std::vector<uint32_t>& indexes,
757
0
    Instruction* insert_before) {
758
0
  uint32_t component_type_id = GetPointeeTypeIdOfVar(var);
759
0
  Instruction* ptr = var;
760
0
  if (!indexes.empty()) {
761
0
    ptr = CreateAccessChainToVar(component_type_id, var, indexes, insert_before,
762
0
                                 &component_type_id);
763
0
    if (ptr == nullptr) {
764
0
      return nullptr;
765
0
    }
766
0
  }
767
768
0
  return CreateLoad(component_type_id, ptr, insert_before);
769
0
}
770
771
Instruction*
772
InterfaceVariableScalarReplacement::CreateCompositeConstructForComponentOfLoad(
773
0
    Instruction* load, uint32_t depth_to_component) {
774
0
  analysis::DefUseManager* def_use_mgr = context()->get_def_use_mgr();
775
0
  uint32_t type_id = load->type_id();
776
0
  if (depth_to_component != 0) {
777
0
    type_id = GetComponentTypeOfArrayMatrix(def_use_mgr, load->type_id(),
778
0
                                            depth_to_component);
779
0
  }
780
0
  uint32_t new_id = TakeNextId();
781
0
  if (new_id == 0) {
782
0
    return nullptr;
783
0
  }
784
0
  std::unique_ptr<Instruction> new_composite_construct(new Instruction(
785
0
      context(), spv::Op::OpCompositeConstruct, type_id, new_id, {}));
786
0
  Instruction* composite_construct = new_composite_construct.get();
787
0
  def_use_mgr->AnalyzeInstDefUse(composite_construct);
788
789
  // Insert |new_composite_construct| after |load|. When there are multiple
790
  // recursive composite construct instructions for a load, we have to place the
791
  // composite construct with a lower depth later because it constructs the
792
  // composite that contains other composites with lower depths.
793
0
  auto* insert_before = load->NextNode();
794
0
  while (true) {
795
0
    auto itr =
796
0
        composite_ids_to_component_depths.find(insert_before->result_id());
797
0
    if (itr == composite_ids_to_component_depths.end()) break;
798
0
    if (itr->second <= depth_to_component) break;
799
0
    insert_before = insert_before->NextNode();
800
0
  }
801
0
  insert_before->InsertBefore(std::move(new_composite_construct));
802
0
  composite_ids_to_component_depths.insert({new_id, depth_to_component});
803
0
  return composite_construct;
804
0
}
805
806
void InterfaceVariableScalarReplacement::AddComponentsToCompositesForLoads(
807
    const std::unordered_map<Instruction*, Instruction*>&
808
        loads_to_component_values,
809
    std::unordered_map<Instruction*, Instruction*>* loads_to_composites,
810
0
    uint32_t depth_to_component) {
811
0
  analysis::DefUseManager* def_use_mgr = context()->get_def_use_mgr();
812
0
  for (auto& load_and_component_vale : loads_to_component_values) {
813
0
    Instruction* load = load_and_component_vale.first;
814
0
    Instruction* component_value = load_and_component_vale.second;
815
0
    Instruction* composite_construct = nullptr;
816
0
    auto itr = loads_to_composites->find(load);
817
0
    if (itr == loads_to_composites->end()) {
818
0
      composite_construct =
819
0
          CreateCompositeConstructForComponentOfLoad(load, depth_to_component);
820
0
      if (composite_construct == nullptr) {
821
0
        assert(false && "Could not create composite construct");
822
0
        return;
823
0
      }
824
0
      loads_to_composites->insert({load, composite_construct});
825
0
    } else {
826
0
      composite_construct = itr->second;
827
0
    }
828
0
    composite_construct->AddOperand(
829
0
        {SPV_OPERAND_TYPE_ID, {component_value->result_id()}});
830
0
    def_use_mgr->AnalyzeInstDefUse(composite_construct);
831
0
  }
832
0
}
833
834
uint32_t InterfaceVariableScalarReplacement::GetArrayType(
835
0
    uint32_t elem_type_id, uint32_t array_length) {
836
0
  analysis::Type* elem_type = context()->get_type_mgr()->GetType(elem_type_id);
837
0
  uint32_t array_length_id =
838
0
      context()->get_constant_mgr()->GetUIntConstId(array_length);
839
0
  analysis::Array array_type(
840
0
      elem_type,
841
0
      analysis::Array::LengthInfo{array_length_id, {0, array_length}});
842
0
  return context()->get_type_mgr()->GetTypeInstruction(&array_type);
843
0
}
844
845
uint32_t InterfaceVariableScalarReplacement::GetPointerType(
846
0
    uint32_t type_id, spv::StorageClass storage_class) {
847
0
  analysis::Type* type = context()->get_type_mgr()->GetType(type_id);
848
0
  analysis::Pointer ptr_type(type, storage_class);
849
0
  return context()->get_type_mgr()->GetTypeInstruction(&ptr_type);
850
0
}
851
852
std::optional<InterfaceVariableScalarReplacement::NestedCompositeComponents>
853
InterfaceVariableScalarReplacement::CreateScalarInterfaceVarsForArray(
854
    Instruction* interface_var_type, spv::StorageClass storage_class,
855
0
    uint32_t extra_array_length) {
856
0
  assert(interface_var_type->opcode() == spv::Op::OpTypeArray);
857
858
0
  analysis::DefUseManager* def_use_mgr = context()->get_def_use_mgr();
859
0
  uint32_t array_length = GetArrayLength(def_use_mgr, interface_var_type);
860
0
  Instruction* elem_type = GetArrayElementType(def_use_mgr, interface_var_type);
861
862
0
  NestedCompositeComponents scalar_vars;
863
0
  while (array_length > 0) {
864
0
    std::optional<NestedCompositeComponents> scalar_vars_for_element =
865
0
        CreateScalarInterfaceVarsForReplacement(elem_type, storage_class,
866
0
                                                extra_array_length);
867
0
    if (!scalar_vars_for_element) {
868
0
      return std::nullopt;
869
0
    }
870
0
    scalar_vars.AddComponent(*scalar_vars_for_element);
871
0
    --array_length;
872
0
  }
873
0
  return scalar_vars;
874
0
}
875
876
std::optional<InterfaceVariableScalarReplacement::NestedCompositeComponents>
877
InterfaceVariableScalarReplacement::CreateScalarInterfaceVarsForMatrix(
878
    Instruction* interface_var_type, spv::StorageClass storage_class,
879
0
    uint32_t extra_array_length) {
880
0
  assert(interface_var_type->opcode() == spv::Op::OpTypeMatrix);
881
882
0
  analysis::DefUseManager* def_use_mgr = context()->get_def_use_mgr();
883
0
  uint32_t column_count = interface_var_type->GetSingleWordInOperand(
884
0
      kOpTypeMatrixColCountInOperandIndex);
885
0
  Instruction* column_type =
886
0
      GetMatrixColumnType(def_use_mgr, interface_var_type);
887
888
0
  NestedCompositeComponents scalar_vars;
889
0
  while (column_count > 0) {
890
0
    std::optional<NestedCompositeComponents> scalar_vars_for_column =
891
0
        CreateScalarInterfaceVarsForReplacement(column_type, storage_class,
892
0
                                                extra_array_length);
893
0
    if (!scalar_vars_for_column) {
894
0
      return std::nullopt;
895
0
    }
896
0
    scalar_vars.AddComponent(*scalar_vars_for_column);
897
0
    --column_count;
898
0
  }
899
0
  return scalar_vars;
900
0
}
901
902
std::optional<InterfaceVariableScalarReplacement::NestedCompositeComponents>
903
InterfaceVariableScalarReplacement::CreateScalarInterfaceVarsForReplacement(
904
    Instruction* interface_var_type, spv::StorageClass storage_class,
905
0
    uint32_t extra_array_length) {
906
  // Handle array case.
907
0
  if (interface_var_type->opcode() == spv::Op::OpTypeArray) {
908
0
    return CreateScalarInterfaceVarsForArray(interface_var_type, storage_class,
909
0
                                             extra_array_length);
910
0
  }
911
912
  // Handle matrix case.
913
0
  if (interface_var_type->opcode() == spv::Op::OpTypeMatrix) {
914
0
    return CreateScalarInterfaceVarsForMatrix(interface_var_type, storage_class,
915
0
                                              extra_array_length);
916
0
  }
917
918
  // Handle scalar or vector case.
919
0
  NestedCompositeComponents scalar_var;
920
0
  uint32_t type_id = interface_var_type->result_id();
921
0
  if (extra_array_length != 0) {
922
0
    type_id = GetArrayType(type_id, extra_array_length);
923
0
  }
924
0
  uint32_t ptr_type_id =
925
0
      context()->get_type_mgr()->FindPointerToType(type_id, storage_class);
926
0
  uint32_t id = TakeNextId();
927
0
  if (id == 0) {
928
0
    return std::nullopt;
929
0
  }
930
0
  std::unique_ptr<Instruction> variable(
931
0
      new Instruction(context(), spv::Op::OpVariable, ptr_type_id, id,
932
0
                      std::initializer_list<Operand>{
933
0
                          {SPV_OPERAND_TYPE_STORAGE_CLASS,
934
0
                           {static_cast<uint32_t>(storage_class)}}}));
935
0
  scalar_var.SetSingleComponentVariable(variable.get());
936
0
  context()->AddGlobalValue(std::move(variable));
937
0
  return scalar_var;
938
0
}
939
940
Instruction* InterfaceVariableScalarReplacement::GetTypeOfVariable(
941
0
    Instruction* var) {
942
0
  uint32_t pointee_type_id = GetPointeeTypeIdOfVar(var);
943
0
  analysis::DefUseManager* def_use_mgr = context()->get_def_use_mgr();
944
0
  return def_use_mgr->GetDef(pointee_type_id);
945
0
}
946
947
0
Pass::Status InterfaceVariableScalarReplacement::Process() {
948
0
  Pass::Status status = Status::SuccessWithoutChange;
949
0
  for (Instruction& entry_point : get_module()->entry_points()) {
950
0
    status =
951
0
        CombineStatus(status, ReplaceInterfaceVarsWithScalars(entry_point));
952
0
  }
953
0
  return status;
954
0
}
955
956
bool InterfaceVariableScalarReplacement::
957
0
    ReportErrorIfHasExtraArraynessForOtherEntry(Instruction* var) {
958
0
  if (vars_with_extra_arrayness.find(var) == vars_with_extra_arrayness.end())
959
0
    return false;
960
961
0
  std::string message(
962
0
      "A variable is arrayed for an entry point but it is not "
963
0
      "arrayed for another entry point");
964
0
  message +=
965
0
      "\n  " + var->PrettyPrint(SPV_BINARY_TO_TEXT_OPTION_FRIENDLY_NAMES);
966
0
  context()->consumer()(SPV_MSG_ERROR, "", {0, 0, 0}, message.c_str());
967
0
  return true;
968
0
}
969
970
bool InterfaceVariableScalarReplacement::
971
0
    ReportErrorIfHasNoExtraArraynessForOtherEntry(Instruction* var) {
972
0
  if (vars_without_extra_arrayness.find(var) ==
973
0
      vars_without_extra_arrayness.end())
974
0
    return false;
975
976
0
  std::string message(
977
0
      "A variable is not arrayed for an entry point but it is "
978
0
      "arrayed for another entry point");
979
0
  message +=
980
0
      "\n  " + var->PrettyPrint(SPV_BINARY_TO_TEXT_OPTION_FRIENDLY_NAMES);
981
0
  context()->consumer()(SPV_MSG_ERROR, "", {0, 0, 0}, message.c_str());
982
0
  return true;
983
0
}
984
985
Pass::Status
986
InterfaceVariableScalarReplacement::ReplaceInterfaceVarsWithScalars(
987
0
    Instruction& entry_point) {
988
0
  std::vector<Instruction*> interface_vars =
989
0
      CollectInterfaceVariables(entry_point);
990
991
0
  Pass::Status status = Status::SuccessWithoutChange;
992
0
  for (Instruction* interface_var : interface_vars) {
993
0
    uint32_t location, component;
994
0
    if (!GetVariableLocation(interface_var, &location)) continue;
995
0
    if (!GetVariableComponent(interface_var, &component)) component = 0;
996
997
0
    Instruction* interface_var_type = GetTypeOfVariable(interface_var);
998
0
    uint32_t extra_array_length = 0;
999
0
    if (HasExtraArrayness(entry_point, interface_var)) {
1000
0
      extra_array_length =
1001
0
          GetArrayLength(context()->get_def_use_mgr(), interface_var_type);
1002
0
      interface_var_type =
1003
0
          GetArrayElementType(context()->get_def_use_mgr(), interface_var_type);
1004
0
      vars_with_extra_arrayness.insert(interface_var);
1005
0
    } else {
1006
0
      vars_without_extra_arrayness.insert(interface_var);
1007
0
    }
1008
1009
0
    if (!CheckExtraArraynessConflictBetweenEntries(interface_var,
1010
0
                                                   extra_array_length != 0)) {
1011
0
      return Pass::Status::Failure;
1012
0
    }
1013
1014
0
    if (interface_var_type->opcode() != spv::Op::OpTypeArray &&
1015
0
        interface_var_type->opcode() != spv::Op::OpTypeMatrix) {
1016
0
      continue;
1017
0
    }
1018
1019
0
    if (ReplaceInterfaceVariableWithScalars(
1020
0
            interface_var, interface_var_type, location, component,
1021
0
            extra_array_length) == Pass::Status::Failure) {
1022
0
      return Pass::Status::Failure;
1023
0
    }
1024
0
    status = Pass::Status::SuccessWithChange;
1025
0
  }
1026
1027
0
  return status;
1028
0
}
1029
1030
}  // namespace opt
1031
}  // namespace spvtools