Coverage Report

Created: 2025-06-13 06:37

/src/spirv-tools/source/name_mapper.cpp
Line
Count
Source (jump to first uncovered line)
1
// Copyright (c) 2016 Google Inc.
2
// Modifications Copyright (C) 2024 Advanced Micro Devices, Inc. All rights
3
// reserved.
4
//
5
// Licensed under the Apache License, Version 2.0 (the "License");
6
// you may not use this file except in compliance with the License.
7
// You may obtain a copy of the License at
8
//
9
//     http://www.apache.org/licenses/LICENSE-2.0
10
//
11
// Unless required by applicable law or agreed to in writing, software
12
// distributed under the License is distributed on an "AS IS" BASIS,
13
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14
// See the License for the specific language governing permissions and
15
// limitations under the License.
16
17
#include "source/name_mapper.h"
18
19
#include <algorithm>
20
#include <cassert>
21
#include <iterator>
22
#include <sstream>
23
#include <string>
24
#include <unordered_map>
25
#include <unordered_set>
26
27
#include "source/binary.h"
28
#include "source/latest_version_spirv_header.h"
29
#include "source/parsed_operand.h"
30
#include "source/table2.h"
31
#include "source/to_string.h"
32
#include "spirv-tools/libspirv.h"
33
34
namespace spvtools {
35
36
29.8k
NameMapper GetTrivialNameMapper() {
37
29.8k
  return [](uint32_t i) { return spvtools::to_string(i); };
38
29.8k
}
39
40
FriendlyNameMapper::FriendlyNameMapper(const spv_const_context context,
41
                                       const uint32_t* code,
42
                                       const size_t wordCount)
43
29.8k
    : grammar_(AssemblyGrammar(context)) {
44
29.8k
  spv_diagnostic diag = nullptr;
45
  // We don't care if the parse fails.
46
29.8k
  spvBinaryParse(context, this, code, wordCount, nullptr,
47
29.8k
                 ParseInstructionForwarder, &diag);
48
29.8k
  spvDiagnosticDestroy(diag);
49
29.8k
}
50
51
1.15M
std::string FriendlyNameMapper::NameForId(uint32_t id) {
52
1.15M
  auto iter = name_for_id_.find(id);
53
1.15M
  if (iter == name_for_id_.end()) {
54
    // It must have been an invalid module, so just return a trivial mapping.
55
    // We don't care about uniqueness.
56
975k
    return to_string(id);
57
975k
  } else {
58
180k
    return iter->second;
59
180k
  }
60
1.15M
}
61
62
695k
std::string FriendlyNameMapper::Sanitize(const std::string& suggested_name) {
63
695k
  if (suggested_name.empty()) return "_";
64
  // Otherwise, replace invalid characters by '_'.
65
694k
  std::string result;
66
694k
  std::string valid =
67
694k
      "abcdefghijklmnopqrstuvwxyz"
68
694k
      "ABCDEFGHIJKLMNOPQRSTUVWXYZ"
69
694k
      "_0123456789";
70
694k
  std::transform(suggested_name.begin(), suggested_name.end(),
71
3.35M
                 std::back_inserter(result), [&valid](const char c) {
72
3.35M
                   return (std::string::npos == valid.find(c)) ? '_' : c;
73
3.35M
                 });
74
694k
  return result;
75
695k
}
76
77
void FriendlyNameMapper::SaveName(uint32_t id,
78
749k
                                  const std::string& suggested_name) {
79
749k
  if (name_for_id_.find(id) != name_for_id_.end()) return;
80
81
695k
  const std::string sanitized_suggested_name = Sanitize(suggested_name);
82
695k
  std::string name = sanitized_suggested_name;
83
695k
  auto inserted = used_names_.insert(name);
84
695k
  if (!inserted.second) {
85
7.42k
    const std::string base_name = sanitized_suggested_name + "_";
86
43.7k
    for (uint32_t index = 0; !inserted.second; ++index) {
87
36.3k
      name = base_name + to_string(index);
88
36.3k
      inserted = used_names_.insert(name);
89
36.3k
    }
90
7.42k
  }
91
695k
  name_for_id_[id] = name;
92
695k
}
93
94
void FriendlyNameMapper::SaveBuiltInName(uint32_t target_id,
95
55.4k
                                         uint32_t built_in) {
96
55.4k
#define GLCASE(name)                  \
97
55.4k
  case spv::BuiltIn::name:            \
98
15.5k
    SaveName(target_id, "gl_" #name); \
99
15.5k
    return;
100
55.4k
#define GLCASE2(name, suggested)           \
101
55.4k
  case spv::BuiltIn::name:                 \
102
22.8k
    SaveName(target_id, "gl_" #suggested); \
103
22.8k
    return;
104
55.4k
#define CASE(name)              \
105
55.4k
  case spv::BuiltIn::name:      \
106
16.8k
    SaveName(target_id, #name); \
107
16.8k
    return;
108
55.4k
  switch (spv::BuiltIn(built_in)) {
109
351
    GLCASE(Position)
110
415
    GLCASE(PointSize)
111
267
    GLCASE(ClipDistance)
112
162
    GLCASE(CullDistance)
113
71
    GLCASE2(VertexId, VertexID)
114
76
    GLCASE2(InstanceId, InstanceID)
115
966
    GLCASE2(PrimitiveId, PrimitiveID)
116
250
    GLCASE2(InvocationId, InvocationID)
117
70
    GLCASE(Layer)
118
154
    GLCASE(ViewportIndex)
119
276
    GLCASE(TessLevelOuter)
120
772
    GLCASE(TessLevelInner)
121
70
    GLCASE(TessCoord)
122
5.86k
    GLCASE(PatchVertices)
123
474
    GLCASE(FragCoord)
124
436
    GLCASE(PointCoord)
125
84
    GLCASE(FrontFacing)
126
1.75k
    GLCASE2(SampleId, SampleID)
127
73
    GLCASE(SamplePosition)
128
77
    GLCASE(SampleMask)
129
76
    GLCASE(FragDepth)
130
91
    GLCASE(HelperInvocation)
131
1.85k
    GLCASE2(NumWorkgroups, NumWorkGroups)
132
783
    GLCASE2(WorkgroupSize, WorkGroupSize)
133
4.72k
    GLCASE2(WorkgroupId, WorkGroupID)
134
589
    GLCASE2(LocalInvocationId, LocalInvocationID)
135
11.7k
    GLCASE2(GlobalInvocationId, GlobalInvocationID)
136
1.28k
    GLCASE(LocalInvocationIndex)
137
466
    CASE(WorkDim)
138
4.75k
    CASE(GlobalSize)
139
82
    CASE(EnqueuedWorkgroupSize)
140
136
    CASE(GlobalOffset)
141
6.57k
    CASE(GlobalLinearId)
142
1.40k
    CASE(SubgroupSize)
143
80
    CASE(SubgroupMaxSize)
144
80
    CASE(NumSubgroups)
145
1.09k
    CASE(NumEnqueuedSubgroups)
146
70
    CASE(SubgroupId)
147
1.41k
    CASE(SubgroupLocalInvocationId)
148
1.29k
    GLCASE(VertexIndex)
149
3.24k
    GLCASE(InstanceIndex)
150
66
    GLCASE(BaseInstance)
151
199
    CASE(SubgroupEqMaskKHR)
152
70
    CASE(SubgroupGeMaskKHR)
153
70
    CASE(SubgroupGtMaskKHR)
154
331
    CASE(SubgroupLeMaskKHR)
155
70
    CASE(SubgroupLtMaskKHR)
156
98
    default:
157
98
      break;
158
55.4k
  }
159
55.4k
#undef GLCASE
160
55.4k
#undef GLCASE2
161
55.4k
#undef CASE
162
55.4k
}
163
164
spv_result_t FriendlyNameMapper::ParseInstruction(
165
4.40M
    const spv_parsed_instruction_t& inst) {
166
4.40M
  const auto result_id = inst.result_id;
167
4.40M
  switch (spv::Op(inst.opcode)) {
168
10.8k
    case spv::Op::OpName:
169
10.8k
      SaveName(inst.words[1], spvDecodeLiteralStringOperand(inst, 1));
170
10.8k
      break;
171
3.22M
    case spv::Op::OpDecorate:
172
      // Decorations come after OpName.  So OpName will take precedence over
173
      // decorations.
174
      //
175
      // In theory, we should also handle OpGroupDecorate.  But that's unlikely
176
      // to occur.
177
3.22M
      if (spv::Decoration(inst.words[2]) == spv::Decoration::BuiltIn) {
178
55.4k
        assert(inst.num_words > 3);
179
55.4k
        SaveBuiltInName(inst.words[1], inst.words[3]);
180
55.4k
      }
181
3.22M
      break;
182
3.22M
    case spv::Op::OpTypeVoid:
183
13.6k
      SaveName(result_id, "void");
184
13.6k
      break;
185
3.54k
    case spv::Op::OpTypeBool:
186
3.54k
      SaveName(result_id, "bool");
187
3.54k
      break;
188
11.7k
    case spv::Op::OpTypeInt: {
189
11.7k
      std::string signedness;
190
11.7k
      std::string root;
191
11.7k
      const auto bit_width = inst.words[2];
192
11.7k
      switch (bit_width) {
193
88
        case 8:
194
88
          root = "char";
195
88
          break;
196
30
        case 16:
197
30
          root = "short";
198
30
          break;
199
10.1k
        case 32:
200
10.1k
          root = "int";
201
10.1k
          break;
202
20
        case 64:
203
20
          root = "long";
204
20
          break;
205
1.40k
        default:
206
1.40k
          root = to_string(bit_width);
207
1.40k
          signedness = "i";
208
1.40k
          break;
209
11.7k
      }
210
11.7k
      if (0 == inst.words[3]) signedness = "u";
211
11.7k
      SaveName(result_id, signedness + root);
212
11.7k
    } break;
213
8.00k
    case spv::Op::OpTypeFloat: {
214
8.00k
      const auto bit_width = inst.words[2];
215
8.00k
      if (inst.num_words > 3) {
216
82
        if (spv::FPEncoding(inst.words[3]) == spv::FPEncoding::BFloat16KHR) {
217
64
          SaveName(result_id, "bfloat16");
218
64
          break;
219
64
        }
220
18
        if (spv::FPEncoding(inst.words[3]) == spv::FPEncoding::Float8E4M3EXT) {
221
7
          SaveName(result_id, "fp8e4m3");
222
7
          break;
223
7
        }
224
11
        if (spv::FPEncoding(inst.words[3]) == spv::FPEncoding::Float8E5M2EXT) {
225
11
          SaveName(result_id, "fp8e5m2");
226
11
          break;
227
11
        }
228
11
      }
229
7.92k
      switch (bit_width) {
230
248
        case 16:
231
248
          SaveName(result_id, "half");
232
248
          break;
233
6.54k
        case 32:
234
6.54k
          SaveName(result_id, "float");
235
6.54k
          break;
236
24
        case 64:
237
24
          SaveName(result_id, "double");
238
24
          break;
239
1.10k
        default:
240
1.10k
          SaveName(result_id, std::string("fp") + to_string(bit_width));
241
1.10k
          break;
242
7.92k
      }
243
7.92k
    } break;
244
10.1k
    case spv::Op::OpTypeVector:
245
10.1k
      SaveName(result_id, std::string("v") + to_string(inst.words[3]) +
246
10.1k
                              NameForId(inst.words[2]));
247
10.1k
      break;
248
3.03k
    case spv::Op::OpTypeMatrix:
249
3.03k
      SaveName(result_id, std::string("mat") + to_string(inst.words[3]) +
250
3.03k
                              NameForId(inst.words[2]));
251
3.03k
      break;
252
3.76k
    case spv::Op::OpTypeArray:
253
3.76k
      SaveName(result_id, std::string("_arr_") + NameForId(inst.words[2]) +
254
3.76k
                              "_" + NameForId(inst.words[3]));
255
3.76k
      break;
256
1.27k
    case spv::Op::OpTypeRuntimeArray:
257
1.27k
      SaveName(result_id,
258
1.27k
               std::string("_runtimearr_") + NameForId(inst.words[2]));
259
1.27k
      break;
260
108
    case spv::Op::OpTypeNodePayloadArrayAMDX:
261
108
      SaveName(result_id,
262
108
               std::string("_payloadarr_") + NameForId(inst.words[2]));
263
108
      break;
264
24.2k
    case spv::Op::OpTypePointer:
265
24.2k
      SaveName(result_id, std::string("_ptr_") +
266
24.2k
                              NameForEnumOperand(SPV_OPERAND_TYPE_STORAGE_CLASS,
267
24.2k
                                                 inst.words[2]) +
268
24.2k
                              "_" + NameForId(inst.words[3]));
269
24.2k
      break;
270
43
    case spv::Op::OpTypeUntypedPointerKHR:
271
43
      SaveName(result_id, std::string("_ptr_") +
272
43
                              NameForEnumOperand(SPV_OPERAND_TYPE_STORAGE_CLASS,
273
43
                                                 inst.words[2]));
274
43
      break;
275
44
    case spv::Op::OpTypePipe:
276
44
      SaveName(result_id,
277
44
               std::string("Pipe") +
278
44
                   NameForEnumOperand(SPV_OPERAND_TYPE_ACCESS_QUALIFIER,
279
44
                                      inst.words[2]));
280
44
      break;
281
305
    case spv::Op::OpTypeEvent:
282
305
      SaveName(result_id, "Event");
283
305
      break;
284
250
    case spv::Op::OpTypeDeviceEvent:
285
250
      SaveName(result_id, "DeviceEvent");
286
250
      break;
287
149
    case spv::Op::OpTypeReserveId:
288
149
      SaveName(result_id, "ReserveId");
289
149
      break;
290
139
    case spv::Op::OpTypeQueue:
291
139
      SaveName(result_id, "Queue");
292
139
      break;
293
110
    case spv::Op::OpTypeOpaque:
294
110
      SaveName(result_id, std::string("Opaque_") +
295
110
                              Sanitize(spvDecodeLiteralStringOperand(inst, 1)));
296
110
      break;
297
39
    case spv::Op::OpTypePipeStorage:
298
39
      SaveName(result_id, "PipeStorage");
299
39
      break;
300
57
    case spv::Op::OpTypeNamedBarrier:
301
57
      SaveName(result_id, "NamedBarrier");
302
57
      break;
303
13.7k
    case spv::Op::OpTypeStruct:
304
      // Structs are mapped rather simplisitically. Just indicate that they
305
      // are a struct and then give the raw Id number.
306
13.7k
      SaveName(result_id, std::string("_struct_") + to_string(result_id));
307
13.7k
      break;
308
797
    case spv::Op::OpConstantTrue:
309
797
      SaveName(result_id, "true");
310
797
      break;
311
651
    case spv::Op::OpConstantFalse:
312
651
      SaveName(result_id, "false");
313
651
      break;
314
54.7k
    case spv::Op::OpConstant: {
315
54.7k
      std::ostringstream value;
316
54.7k
      EmitNumericLiteral(&value, inst, inst.operands[2]);
317
54.7k
      auto value_str = value.str();
318
      // Use 'n' to signify negative. Other invalid characters will be mapped
319
      // to underscore.
320
54.7k
      for (auto& c : value_str)
321
283k
        if (c == '-') c = 'n';
322
54.7k
      SaveName(result_id, NameForId(inst.type_id) + "_" + value_str);
323
54.7k
    } break;
324
1.02M
    default:
325
      // If this instruction otherwise defines an Id, then save a mapping for
326
      // it.  This is needed to ensure uniqueness in there is an OpName with
327
      // string something like "1" that might collide with this result_id.
328
      // We should only do this if a name hasn't already been registered by some
329
      // previous forward reference.
330
1.02M
      if (result_id && name_for_id_.find(result_id) == name_for_id_.end())
331
532k
        SaveName(result_id, to_string(result_id));
332
1.02M
      break;
333
4.40M
  }
334
4.40M
  return SPV_SUCCESS;
335
4.40M
}
336
337
std::string FriendlyNameMapper::NameForEnumOperand(spv_operand_type_t type,
338
24.3k
                                                   uint32_t word) {
339
24.3k
  const spvtools::OperandDesc* desc = nullptr;
340
24.3k
  if (SPV_SUCCESS == spvtools::LookupOperand(type, word, &desc)) {
341
24.3k
    return desc->name().data();
342
24.3k
  } else {
343
    // Invalid input.  Just give something.
344
0
    return std::string("StorageClass") + to_string(word);
345
0
  }
346
24.3k
}
347
348
}  // namespace spvtools