Coverage Report

Created: 2025-07-23 06:18

/src/spirv-tools/source/val/validate_function.cpp
Line
Count
Source (jump to first uncovered line)
1
// Copyright (c) 2018 Google LLC.
2
//
3
// Licensed under the Apache License, Version 2.0 (the "License");
4
// you may not use this file except in compliance with the License.
5
// You may obtain a copy of the License at
6
//
7
//     http://www.apache.org/licenses/LICENSE-2.0
8
//
9
// Unless required by applicable law or agreed to in writing, software
10
// distributed under the License is distributed on an "AS IS" BASIS,
11
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
// See the License for the specific language governing permissions and
13
// limitations under the License.
14
15
#include <algorithm>
16
17
#include "source/opcode.h"
18
#include "source/table2.h"
19
#include "source/val/instruction.h"
20
#include "source/val/validate.h"
21
#include "source/val/validation_state.h"
22
23
namespace spvtools {
24
namespace val {
25
namespace {
26
27
// Returns true if |a| and |b| are instructions defining pointers that point to
28
// types logically match and the decorations that apply to |b| are a subset
29
// of the decorations that apply to |a|.
30
bool DoPointeesLogicallyMatch(val::Instruction* a, val::Instruction* b,
31
0
                              ValidationState_t& _) {
32
0
  if (a->opcode() != spv::Op::OpTypePointer ||
33
0
      b->opcode() != spv::Op::OpTypePointer) {
34
0
    return false;
35
0
  }
36
37
0
  const auto& dec_a = _.id_decorations(a->id());
38
0
  const auto& dec_b = _.id_decorations(b->id());
39
0
  for (const auto& dec : dec_b) {
40
0
    if (std::find(dec_a.begin(), dec_a.end(), dec) == dec_a.end()) {
41
0
      return false;
42
0
    }
43
0
  }
44
45
0
  uint32_t a_type = a->GetOperandAs<uint32_t>(2);
46
0
  uint32_t b_type = b->GetOperandAs<uint32_t>(2);
47
48
0
  if (a_type == b_type) {
49
0
    return true;
50
0
  }
51
52
0
  Instruction* a_type_inst = _.FindDef(a_type);
53
0
  Instruction* b_type_inst = _.FindDef(b_type);
54
55
0
  return _.LogicallyMatch(a_type_inst, b_type_inst, true);
56
0
}
57
58
33.1k
spv_result_t ValidateFunction(ValidationState_t& _, const Instruction* inst) {
59
33.1k
  const auto function_type_id = inst->GetOperandAs<uint32_t>(3);
60
33.1k
  const auto function_type = _.FindDef(function_type_id);
61
33.1k
  if (!function_type || spv::Op::OpTypeFunction != function_type->opcode()) {
62
64
    return _.diag(SPV_ERROR_INVALID_ID, inst)
63
64
           << "OpFunction Function Type <id> " << _.getIdName(function_type_id)
64
64
           << " is not a function type.";
65
64
  }
66
67
33.1k
  const auto return_id = function_type->GetOperandAs<uint32_t>(1);
68
33.1k
  if (return_id != inst->type_id()) {
69
16
    return _.diag(SPV_ERROR_INVALID_ID, inst)
70
16
           << "OpFunction Result Type <id> " << _.getIdName(inst->type_id())
71
16
           << " does not match the Function Type's return type <id> "
72
16
           << _.getIdName(return_id) << ".";
73
16
  }
74
75
33.1k
  const std::vector<spv::Op> acceptable = {
76
33.1k
      spv::Op::OpGroupDecorate,
77
33.1k
      spv::Op::OpDecorate,
78
33.1k
      spv::Op::OpEnqueueKernel,
79
33.1k
      spv::Op::OpEntryPoint,
80
33.1k
      spv::Op::OpExecutionMode,
81
33.1k
      spv::Op::OpExecutionModeId,
82
33.1k
      spv::Op::OpFunctionCall,
83
33.1k
      spv::Op::OpGetKernelNDrangeSubGroupCount,
84
33.1k
      spv::Op::OpGetKernelNDrangeMaxSubGroupSize,
85
33.1k
      spv::Op::OpGetKernelWorkGroupSize,
86
33.1k
      spv::Op::OpGetKernelPreferredWorkGroupSizeMultiple,
87
33.1k
      spv::Op::OpGetKernelLocalSizeForSubgroupCount,
88
33.1k
      spv::Op::OpGetKernelMaxNumSubgroups,
89
33.1k
      spv::Op::OpName,
90
33.1k
      spv::Op::OpCooperativeMatrixPerElementOpNV,
91
33.1k
      spv::Op::OpCooperativeMatrixReduceNV,
92
33.1k
      spv::Op::OpCooperativeMatrixLoadTensorNV};
93
72.2k
  for (auto& pair : inst->uses()) {
94
72.2k
    const auto* use = pair.first;
95
72.2k
    if (std::find(acceptable.begin(), acceptable.end(), use->opcode()) ==
96
72.2k
            acceptable.end() &&
97
72.2k
        !use->IsNonSemantic() && !use->IsDebugInfo()) {
98
16
      return _.diag(SPV_ERROR_INVALID_ID, use)
99
16
             << "Invalid use of function result id " << _.getIdName(inst->id())
100
16
             << ".";
101
16
    }
102
72.2k
  }
103
104
33.0k
  return SPV_SUCCESS;
105
33.1k
}
106
107
spv_result_t ValidateFunctionParameter(ValidationState_t& _,
108
13.5k
                                       const Instruction* inst) {
109
  // NOTE: Find OpFunction & ensure OpFunctionParameter is not out of place.
110
13.5k
  size_t param_index = 0;
111
13.5k
  size_t inst_num = inst->LineNum() - 1;
112
13.5k
  if (inst_num == 0) {
113
0
    return _.diag(SPV_ERROR_INVALID_LAYOUT, inst)
114
0
           << "Function parameter cannot be the first instruction.";
115
0
  }
116
117
13.5k
  auto func_inst = &_.ordered_instructions()[inst_num];
118
23.6k
  while (--inst_num) {
119
23.6k
    func_inst = &_.ordered_instructions()[inst_num];
120
23.6k
    if (func_inst->opcode() == spv::Op::OpFunction) {
121
13.5k
      break;
122
13.5k
    } else if (func_inst->opcode() == spv::Op::OpFunctionParameter) {
123
10.0k
      ++param_index;
124
10.0k
    }
125
23.6k
  }
126
127
13.5k
  if (func_inst->opcode() != spv::Op::OpFunction) {
128
0
    return _.diag(SPV_ERROR_INVALID_LAYOUT, inst)
129
0
           << "Function parameter must be preceded by a function.";
130
0
  }
131
132
13.5k
  const auto function_type_id = func_inst->GetOperandAs<uint32_t>(3);
133
13.5k
  const auto function_type = _.FindDef(function_type_id);
134
13.5k
  if (!function_type) {
135
0
    return _.diag(SPV_ERROR_INVALID_ID, func_inst)
136
0
           << "Missing function type definition.";
137
0
  }
138
13.5k
  if (param_index >= function_type->words().size() - 3) {
139
3
    return _.diag(SPV_ERROR_INVALID_ID, inst)
140
3
           << "Too many OpFunctionParameters for " << func_inst->id()
141
3
           << ": expected " << function_type->words().size() - 3
142
3
           << " based on the function's type";
143
3
  }
144
145
13.4k
  const auto param_type =
146
13.4k
      _.FindDef(function_type->GetOperandAs<uint32_t>(param_index + 2));
147
13.4k
  if (!param_type || inst->type_id() != param_type->id()) {
148
17
    return _.diag(SPV_ERROR_INVALID_ID, inst)
149
17
           << "OpFunctionParameter Result Type <id> "
150
17
           << _.getIdName(inst->type_id())
151
17
           << " does not match the OpTypeFunction parameter "
152
17
              "type of the same index.";
153
17
  }
154
155
13.4k
  return SPV_SUCCESS;
156
13.4k
}
157
158
spv_result_t ValidateFunctionCall(ValidationState_t& _,
159
23.7k
                                  const Instruction* inst) {
160
23.7k
  const auto function_id = inst->GetOperandAs<uint32_t>(2);
161
23.7k
  const auto function = _.FindDef(function_id);
162
23.7k
  if (!function || spv::Op::OpFunction != function->opcode()) {
163
13
    return _.diag(SPV_ERROR_INVALID_ID, inst)
164
13
           << "OpFunctionCall Function <id> " << _.getIdName(function_id)
165
13
           << " is not a function.";
166
13
  }
167
168
23.7k
  auto return_type = _.FindDef(function->type_id());
169
23.7k
  if (!return_type || return_type->id() != inst->type_id()) {
170
12
    return _.diag(SPV_ERROR_INVALID_ID, inst)
171
12
           << "OpFunctionCall Result Type <id> " << _.getIdName(inst->type_id())
172
12
           << "s type does not match Function <id> "
173
12
           << _.getIdName(return_type->id()) << "s return type.";
174
12
  }
175
176
23.6k
  const auto function_type_id = function->GetOperandAs<uint32_t>(3);
177
23.6k
  const auto function_type = _.FindDef(function_type_id);
178
23.6k
  if (!function_type || function_type->opcode() != spv::Op::OpTypeFunction) {
179
3
    return _.diag(SPV_ERROR_INVALID_ID, inst)
180
3
           << "Missing function type definition.";
181
3
  }
182
183
23.6k
  const auto function_call_arg_count = inst->words().size() - 4;
184
23.6k
  const auto function_param_count = function_type->words().size() - 3;
185
23.6k
  if (function_param_count != function_call_arg_count) {
186
3
    return _.diag(SPV_ERROR_INVALID_ID, inst)
187
3
           << "OpFunctionCall Function <id>'s parameter count does not match "
188
3
              "the argument count.";
189
3
  }
190
191
23.6k
  for (size_t argument_index = 3, param_index = 2;
192
59.9k
       argument_index < inst->operands().size();
193
36.2k
       argument_index++, param_index++) {
194
36.2k
    const auto argument_id = inst->GetOperandAs<uint32_t>(argument_index);
195
36.2k
    const auto argument = _.FindDef(argument_id);
196
36.2k
    if (!argument) {
197
0
      return _.diag(SPV_ERROR_INVALID_ID, inst)
198
0
             << "Missing argument " << argument_index - 3 << " definition.";
199
0
    }
200
201
36.2k
    const auto argument_type = _.FindDef(argument->type_id());
202
36.2k
    if (!argument_type) {
203
0
      return _.diag(SPV_ERROR_INVALID_ID, inst)
204
0
             << "Missing argument " << argument_index - 3
205
0
             << " type definition.";
206
0
    }
207
208
36.2k
    const auto parameter_type_id =
209
36.2k
        function_type->GetOperandAs<uint32_t>(param_index);
210
36.2k
    const auto parameter_type = _.FindDef(parameter_type_id);
211
36.2k
    if (!parameter_type || argument_type->id() != parameter_type->id()) {
212
25
      if (!parameter_type || !_.options()->before_hlsl_legalization ||
213
25
          !DoPointeesLogicallyMatch(argument_type, parameter_type, _)) {
214
25
        return _.diag(SPV_ERROR_INVALID_ID, inst)
215
25
               << "OpFunctionCall Argument <id> " << _.getIdName(argument_id)
216
25
               << "s type does not match Function <id> "
217
25
               << _.getIdName(parameter_type_id) << "s parameter type.";
218
25
      }
219
25
    }
220
221
36.2k
    if (_.addressing_model() == spv::AddressingModel::Logical) {
222
36.2k
      if ((parameter_type->opcode() == spv::Op::OpTypePointer ||
223
36.2k
           parameter_type->opcode() == spv::Op::OpTypeUntypedPointerKHR) &&
224
36.2k
          !_.options()->relax_logical_pointer) {
225
36.0k
        spv::StorageClass sc =
226
36.0k
            parameter_type->GetOperandAs<spv::StorageClass>(1u);
227
        // Validate which storage classes can be pointer operands.
228
36.0k
        switch (sc) {
229
1
          case spv::StorageClass::UniformConstant:
230
35.9k
          case spv::StorageClass::Function:
231
35.9k
          case spv::StorageClass::Private:
232
35.9k
          case spv::StorageClass::Workgroup:
233
35.9k
          case spv::StorageClass::AtomicCounter:
234
            // These are always allowed.
235
35.9k
            break;
236
0
          case spv::StorageClass::StorageBuffer:
237
0
            if (!_.features().variable_pointers) {
238
0
              return _.diag(SPV_ERROR_INVALID_ID, inst)
239
0
                     << "StorageBuffer pointer operand "
240
0
                     << _.getIdName(argument_id)
241
0
                     << " requires a variable pointers capability";
242
0
            }
243
0
            break;
244
4
          default:
245
4
            return _.diag(SPV_ERROR_INVALID_ID, inst)
246
4
                   << "Invalid storage class for pointer operand "
247
4
                   << _.getIdName(argument_id);
248
36.0k
        }
249
250
        // Validate memory object declaration requirements.
251
35.9k
        if (argument->opcode() != spv::Op::OpVariable &&
252
35.9k
            argument->opcode() != spv::Op::OpUntypedVariableKHR &&
253
35.9k
            argument->opcode() != spv::Op::OpFunctionParameter) {
254
6
          const bool ssbo_vptr =
255
6
              _.HasCapability(spv::Capability::VariablePointersStorageBuffer) &&
256
6
              sc == spv::StorageClass::StorageBuffer;
257
6
          const bool wg_vptr =
258
6
              _.HasCapability(spv::Capability::VariablePointers) &&
259
6
              sc == spv::StorageClass::Workgroup;
260
6
          const bool uc_ptr = sc == spv::StorageClass::UniformConstant;
261
6
          if (!_.options()->before_hlsl_legalization && !ssbo_vptr &&
262
6
              !wg_vptr && !uc_ptr) {
263
5
            return _.diag(SPV_ERROR_INVALID_ID, inst)
264
5
                   << "Pointer operand " << _.getIdName(argument_id)
265
5
                   << " must be a memory object declaration";
266
5
          }
267
6
        }
268
35.9k
      }
269
36.2k
    }
270
36.2k
  }
271
23.6k
  return SPV_SUCCESS;
272
23.6k
}
273
274
spv_result_t ValidateCooperativeMatrixPerElementOp(ValidationState_t& _,
275
0
                                                   const Instruction* inst) {
276
0
  const auto function_id = inst->GetOperandAs<uint32_t>(3);
277
0
  const auto function = _.FindDef(function_id);
278
0
  if (!function || spv::Op::OpFunction != function->opcode()) {
279
0
    return _.diag(SPV_ERROR_INVALID_ID, inst)
280
0
           << "OpCooperativeMatrixPerElementOpNV Function <id> "
281
0
           << _.getIdName(function_id) << " is not a function.";
282
0
  }
283
284
0
  const auto matrix_id = inst->GetOperandAs<uint32_t>(2);
285
0
  const auto matrix = _.FindDef(matrix_id);
286
0
  const auto matrix_type_id = matrix->type_id();
287
0
  if (!_.IsCooperativeMatrixKHRType(matrix_type_id)) {
288
0
    return _.diag(SPV_ERROR_INVALID_ID, inst)
289
0
           << "OpCooperativeMatrixPerElementOpNV Matrix <id> "
290
0
           << _.getIdName(matrix_id) << " is not a cooperative matrix.";
291
0
  }
292
293
0
  const auto result_type_id = inst->GetOperandAs<uint32_t>(0);
294
0
  if (matrix_type_id != result_type_id) {
295
0
    return _.diag(SPV_ERROR_INVALID_ID, inst)
296
0
           << "OpCooperativeMatrixPerElementOpNV Result Type <id> "
297
0
           << _.getIdName(result_type_id) << " must match matrix type <id> "
298
0
           << _.getIdName(matrix_type_id) << ".";
299
0
  }
300
301
0
  const auto matrix_comp_type_id =
302
0
      _.FindDef(matrix_type_id)->GetOperandAs<uint32_t>(1);
303
0
  const auto function_type_id = function->GetOperandAs<uint32_t>(3);
304
0
  const auto function_type = _.FindDef(function_type_id);
305
0
  auto return_type_id = function_type->GetOperandAs<uint32_t>(1);
306
0
  if (return_type_id != matrix_comp_type_id) {
307
0
    return _.diag(SPV_ERROR_INVALID_ID, inst)
308
0
           << "OpCooperativeMatrixPerElementOpNV function return type <id> "
309
0
           << _.getIdName(return_type_id)
310
0
           << " must match matrix component type <id> "
311
0
           << _.getIdName(matrix_comp_type_id) << ".";
312
0
  }
313
314
0
  if (function_type->operands().size() < 5) {
315
0
    return _.diag(SPV_ERROR_INVALID_ID, inst)
316
0
           << "OpCooperativeMatrixPerElementOpNV function type <id> "
317
0
           << _.getIdName(function_type_id)
318
0
           << " must have a least three parameters.";
319
0
  }
320
321
0
  const auto param0_id = function_type->GetOperandAs<uint32_t>(2);
322
0
  const auto param1_id = function_type->GetOperandAs<uint32_t>(3);
323
0
  const auto param2_id = function_type->GetOperandAs<uint32_t>(4);
324
0
  if (!_.IsIntScalarType(param0_id) || _.GetBitWidth(param0_id) != 32) {
325
0
    return _.diag(SPV_ERROR_INVALID_ID, inst)
326
0
           << "OpCooperativeMatrixPerElementOpNV function type first parameter "
327
0
              "type <id> "
328
0
           << _.getIdName(param0_id) << " must be a 32-bit integer.";
329
0
  }
330
331
0
  if (!_.IsIntScalarType(param1_id) || _.GetBitWidth(param1_id) != 32) {
332
0
    return _.diag(SPV_ERROR_INVALID_ID, inst)
333
0
           << "OpCooperativeMatrixPerElementOpNV function type second "
334
0
              "parameter type <id> "
335
0
           << _.getIdName(param1_id) << " must be a 32-bit integer.";
336
0
  }
337
338
0
  if (param2_id != matrix_comp_type_id) {
339
0
    return _.diag(SPV_ERROR_INVALID_ID, inst)
340
0
           << "OpCooperativeMatrixPerElementOpNV function type third parameter "
341
0
              "type <id> "
342
0
           << _.getIdName(param2_id) << " must match matrix component type.";
343
0
  }
344
345
0
  return SPV_SUCCESS;
346
0
}
347
348
}  // namespace
349
350
10.9M
spv_result_t FunctionPass(ValidationState_t& _, const Instruction* inst) {
351
10.9M
  switch (inst->opcode()) {
352
33.1k
    case spv::Op::OpFunction:
353
33.1k
      if (auto error = ValidateFunction(_, inst)) return error;
354
33.0k
      break;
355
33.0k
    case spv::Op::OpFunctionParameter:
356
13.5k
      if (auto error = ValidateFunctionParameter(_, inst)) return error;
357
13.4k
      break;
358
23.7k
    case spv::Op::OpFunctionCall:
359
23.7k
      if (auto error = ValidateFunctionCall(_, inst)) return error;
360
23.6k
      break;
361
23.6k
    case spv::Op::OpCooperativeMatrixPerElementOpNV:
362
0
      if (auto error = ValidateCooperativeMatrixPerElementOp(_, inst))
363
0
        return error;
364
0
      break;
365
10.9M
    default:
366
10.9M
      break;
367
10.9M
  }
368
369
10.9M
  return SPV_SUCCESS;
370
10.9M
}
371
372
}  // namespace val
373
}  // namespace spvtools