Coverage Report

Created: 2026-02-14 06:30

next uncovered line (L), next uncovered region (R), next uncovered branch (B)
/src/spirv-tools/source/opt/legalize_multidim_array_pass.cpp
Line
Count
Source
1
// Copyright (c) 2026 Google LLC
2
//
3
// Licensed under the Apache License, Version 2.0 (the "License");
4
// you may not use this file except in compliance with the License.
5
// You may obtain a copy of the License at
6
//
7
//     http://www.apache.org/licenses/LICENSE-2.0
8
//
9
// Unless required by applicable law or agreed to in writing, software
10
// distributed under the License is distributed on an "AS IS" BASIS,
11
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
// See the License for the specific language governing permissions and
13
// limitations under the License.
14
15
#include "source/opt/legalize_multidim_array_pass.h"
16
17
#include "source/opt/constants.h"
18
#include "source/opt/desc_sroa_util.h"
19
#include "source/opt/ir_builder.h"
20
#include "source/opt/ir_context.h"
21
#include "source/opt/type_manager.h"
22
23
namespace spvtools {
24
namespace opt {
25
26
7.29k
Pass::Status LegalizeMultidimArrayPass::Process() {
27
7.29k
  std::vector<Instruction*> vars_to_legalize;
28
29
139k
  for (auto& var : context()->types_values()) {
30
139k
    if (var.opcode() != spv::Op::OpVariable) continue;
31
7.80k
    if (!IsMultidimArrayOfResources(&var)) continue;
32
0
    if (!CanLegalize(&var)) {
33
0
      context()->EmitErrorMessage("Unable to legalize multidimensional array: ",
34
0
                                  &var);
35
0
      return Status::Failure;
36
0
    }
37
0
    vars_to_legalize.push_back(&var);
38
0
  }
39
40
7.29k
  if (vars_to_legalize.empty()) return Status::SuccessWithoutChange;
41
42
0
  for (auto* var : vars_to_legalize) {
43
0
    uint32_t old_ptr_type_id = var->type_id();
44
0
    uint32_t new_ptr_type_id = FlattenArrayType(var);
45
0
    if (new_ptr_type_id == 0) return Status::Failure;
46
0
    if (!RewriteAccessChains(var, old_ptr_type_id)) return Status::Failure;
47
0
  }
48
49
0
  return Status::SuccessWithChange;
50
0
}
51
52
7.80k
bool LegalizeMultidimArrayPass::IsMultidimArrayOfResources(Instruction* var) {
53
7.80k
  if (!descsroautil::IsDescriptorArray(context(), var)) return false;
54
55
0
  uint32_t type_id = var->type_id();
56
0
  Instruction* type_inst = context()->get_def_use_mgr()->GetDef(type_id);
57
0
  uint32_t pointee_type_id = type_inst->GetSingleWordInOperand(1);
58
0
  std::vector<uint32_t> dims;
59
0
  uint32_t element_type_id = 0;
60
0
  GetArrayDimensions(pointee_type_id, &dims, &element_type_id);
61
62
0
  return dims.size() > 1;
63
7.80k
}
64
65
void LegalizeMultidimArrayPass::GetArrayDimensions(uint32_t type_id,
66
                                                   std::vector<uint32_t>* dims,
67
0
                                                   uint32_t* element_type_id) {
68
0
  assert(dims != nullptr && "dims cannot be null.");
69
0
  dims->clear();
70
71
0
  Instruction* type_inst = context()->get_def_use_mgr()->GetDef(type_id);
72
0
  while (type_inst->opcode() == spv::Op::OpTypeArray) {
73
0
    uint32_t length_id = type_inst->GetSingleWordInOperand(1);
74
0
    Instruction* length_inst = context()->get_def_use_mgr()->GetDef(length_id);
75
    // Assume OpConstant. According to the spec the length could also be an
76
    // OpSpecConstantOp. However, DXC will not generate that type of code. The
77
    // code to handle spec constants will be much more complicated.
78
0
    assert(length_inst->opcode() == spv::Op::OpConstant);
79
0
    uint32_t length = length_inst->GetSingleWordInOperand(0);
80
0
    dims->push_back(length);
81
0
    type_id = type_inst->GetSingleWordInOperand(0);
82
0
    type_inst = context()->get_def_use_mgr()->GetDef(type_id);
83
0
  }
84
0
  *element_type_id = type_id;
85
0
}
86
87
0
uint32_t LegalizeMultidimArrayPass::FlattenArrayType(Instruction* var) {
88
0
  analysis::TypeManager* type_mgr = context()->get_type_mgr();
89
0
  analysis::ConstantManager* constant_mgr = context()->get_constant_mgr();
90
91
0
  uint32_t ptr_type_id = var->type_id();
92
0
  Instruction* ptr_type_inst =
93
0
      context()->get_def_use_mgr()->GetDef(ptr_type_id);
94
0
  uint32_t pointee_type_id = ptr_type_inst->GetSingleWordInOperand(1);
95
96
0
  std::vector<uint32_t> dims;
97
0
  uint32_t element_type_id = 0;
98
0
  GetArrayDimensions(pointee_type_id, &dims, &element_type_id);
99
100
0
  uint32_t total_elements = 1;
101
0
  for (uint32_t dim : dims) {
102
0
    total_elements *= dim;
103
0
  }
104
105
0
  const analysis::Constant* total_elements_const =
106
0
      constant_mgr->GetIntConst(total_elements, 32, false);
107
108
0
  Instruction* total_elements_inst =
109
0
      constant_mgr->GetDefiningInstruction(total_elements_const);
110
0
  uint32_t total_elements_id = total_elements_inst->result_id();
111
112
  // Create new OpTypeArray.
113
0
  analysis::Type* element_type = type_mgr->GetType(element_type_id);
114
0
  analysis::Array::LengthInfo length_info = {
115
0
      total_elements_id,
116
0
      {analysis::Array::LengthInfo::kConstant, total_elements}};
117
0
  analysis::Array new_array_type(element_type, length_info);
118
0
  uint32_t new_array_type_id = type_mgr->GetTypeInstruction(&new_array_type);
119
120
  // Create new OpTypePointer.
121
0
  spv::StorageClass sc =
122
0
      static_cast<spv::StorageClass>(ptr_type_inst->GetSingleWordInOperand(0));
123
0
  analysis::Pointer new_ptr_type(type_mgr->GetType(new_array_type_id), sc);
124
0
  uint32_t new_ptr_type_id = type_mgr->GetTypeInstruction(&new_ptr_type);
125
126
0
  var->SetResultType(new_ptr_type_id);
127
0
  context()->UpdateDefUse(var);
128
129
  // Move the var after the new pointer type to avoid a def-before-use.
130
0
  var->InsertAfter(get_def_use_mgr()->GetDef(new_ptr_type_id));
131
132
0
  return new_ptr_type_id;
133
0
}
134
135
bool LegalizeMultidimArrayPass::RewriteAccessChains(Instruction* var,
136
0
                                                    uint32_t old_ptr_type_id) {
137
0
  uint32_t var_id = var->result_id();
138
0
  std::vector<Instruction*> users;
139
  // Use a worklist to handle transitive uses (e.g. through OpCopyObject)
140
0
  std::vector<Instruction*> worklist;
141
142
0
  context()->get_def_use_mgr()->ForEachUser(
143
0
      var_id, [&worklist](Instruction* user) { worklist.push_back(user); });
144
145
0
  Instruction* old_ptr_type_inst =
146
0
      context()->get_def_use_mgr()->GetDef(old_ptr_type_id);
147
0
  uint32_t old_pointee_type_id = old_ptr_type_inst->GetSingleWordInOperand(1);
148
0
  std::vector<uint32_t> dims;
149
0
  uint32_t element_type_id = 0;
150
0
  GetArrayDimensions(old_pointee_type_id, &dims, &element_type_id);
151
0
  assert(dims.size() != 0 &&
152
0
         "This variable should have been rejected earlier.");
153
154
  // Calculate strides once
155
0
  std::vector<uint32_t> strides(dims.size());
156
0
  strides[dims.size() - 1] = 1;
157
0
  for (int i = static_cast<int>(dims.size()) - 2; i >= 0; --i) {
158
0
    strides[i] = strides[i + 1] * dims[i + 1];
159
0
  }
160
161
  // Pre-calculate uint type id
162
0
  uint32_t uint_type_id = context()->get_type_mgr()->GetUIntTypeId();
163
0
  if (uint_type_id == 0) return false;
164
165
0
  while (!worklist.empty()) {
166
0
    Instruction* user = worklist.back();
167
0
    worklist.pop_back();
168
169
0
    if (user->opcode() == spv::Op::OpAccessChain ||
170
0
        user->opcode() == spv::Op::OpInBoundsAccessChain) {
171
0
      uint32_t num_indices = user->NumInOperands() - 1;
172
0
      assert(num_indices >= dims.size());
173
174
0
      InstructionBuilder builder(context(), user, IRContext::kAnalysisDefUse);
175
176
0
      uint32_t linearized_idx_id = 0;
177
0
      for (uint32_t i = 0; i < dims.size(); ++i) {
178
0
        uint32_t idx_id = user->GetSingleWordInOperand(i + 1);
179
180
0
        uint32_t term_id = idx_id;
181
0
        if (strides[i] != 1) {
182
0
          const analysis::Constant* stride_const =
183
0
              context()->get_constant_mgr()->GetConstant(
184
0
                  context()->get_type_mgr()->GetType(uint_type_id),
185
0
                  {strides[i]});
186
0
          Instruction* stride_inst =
187
0
              context()->get_constant_mgr()->GetDefiningInstruction(
188
0
                  stride_const);
189
190
0
          Instruction* mul_inst = builder.AddBinaryOp(
191
0
              uint_type_id, spv::Op::OpIMul, idx_id, stride_inst->result_id());
192
0
          if (mul_inst == nullptr) return false;
193
0
          term_id = mul_inst->result_id();
194
0
        }
195
196
0
        if (linearized_idx_id == 0) {
197
0
          linearized_idx_id = term_id;
198
0
        } else {
199
0
          Instruction* add_inst = builder.AddBinaryOp(
200
0
              uint_type_id, spv::Op::OpIAdd, linearized_idx_id, term_id);
201
0
          if (add_inst == nullptr) return false;
202
0
          linearized_idx_id = add_inst->result_id();
203
0
        }
204
0
      }
205
206
      // Create new AccessChain.
207
0
      Instruction::OperandList new_operands;
208
0
      new_operands.push_back(user->GetInOperand(0));
209
0
      new_operands.push_back({SPV_OPERAND_TYPE_ID, {linearized_idx_id}});
210
0
      for (uint32_t i = static_cast<uint32_t>(dims.size()); i < num_indices;
211
0
           ++i) {
212
0
        new_operands.push_back(user->GetInOperand(i + 1));
213
0
      }
214
0
      user->SetInOperands(std::move(new_operands));
215
0
      context()->UpdateDefUse(user);
216
0
    } else if (user->opcode() == spv::Op::OpCopyObject) {
217
      // The type of the variable has changed so the result type of the
218
      // OpCopyObject will change as well.
219
220
0
      uint32_t operand_id = user->GetSingleWordInOperand(0);
221
0
      Instruction* operand_inst =
222
0
          context()->get_def_use_mgr()->GetDef(operand_id);
223
0
      user->SetResultType(operand_inst->type_id());
224
0
      context()->UpdateDefUse(user);
225
226
      // Add users of this copy to worklist
227
0
      context()->get_def_use_mgr()->ForEachUser(
228
0
          user->result_id(),
229
0
          [&worklist](Instruction* u) { worklist.push_back(u); });
230
0
    }
231
0
  }
232
0
  return true;
233
0
}
234
235
bool LegalizeMultidimArrayPass::CheckUse(Instruction* inst,
236
0
                                         uint32_t max_depth) {
237
0
  if (inst->opcode() == spv::Op::OpAccessChain ||
238
0
      inst->opcode() == spv::Op::OpInBoundsAccessChain) {
239
0
    uint32_t num_indices = inst->NumInOperands() - 1;
240
0
    return num_indices >= max_depth;
241
0
  } else if (inst->opcode() == spv::Op::OpCopyObject) {
242
0
    bool ok = true;
243
0
    return !context()->get_def_use_mgr()->WhileEachUser(
244
0
        inst->result_id(),
245
0
        [&](Instruction* u) { return !CheckUse(u, max_depth); });
246
0
    return ok;
247
0
  } else if (inst->IsDecoration() || inst->opcode() == spv::Op::OpName ||
248
0
             inst->opcode() == spv::Op::OpMemberName) {
249
    // Metadata is fine.
250
0
    return true;
251
0
  }
252
253
  // Direct use of array or partial array without AccessChain is not allowed.
254
0
  return false;
255
0
}
256
257
0
bool LegalizeMultidimArrayPass::CanLegalize(Instruction* var) {
258
0
  bool ok = true;
259
0
  uint32_t ptr_type_id = var->type_id();
260
0
  Instruction* ptr_type_inst =
261
0
      context()->get_def_use_mgr()->GetDef(ptr_type_id);
262
0
  uint32_t pointee_type_id = ptr_type_inst->GetSingleWordInOperand(1);
263
0
  std::vector<uint32_t> dims;
264
0
  uint32_t element_type_id = 0;
265
0
  GetArrayDimensions(pointee_type_id, &dims, &element_type_id);
266
267
0
  context()->get_def_use_mgr()->ForEachUser(
268
0
      var->result_id(), [&](Instruction* u) {
269
0
        if (!CheckUse(u, static_cast<uint32_t>(dims.size()))) ok = false;
270
0
      });
271
0
  return ok;
272
0
}
273
274
}  // namespace opt
275
}  // namespace spvtools