Coverage Report

Created: 2026-03-31 06:42

next uncovered line (L), next uncovered region (R), next uncovered branch (B)
/src/spirv-tools/source/val/validate_function.cpp
Line
Count
Source
1
// Copyright (c) 2018 Google LLC.
2
//
3
// Licensed under the Apache License, Version 2.0 (the "License");
4
// you may not use this file except in compliance with the License.
5
// You may obtain a copy of the License at
6
//
7
//     http://www.apache.org/licenses/LICENSE-2.0
8
//
9
// Unless required by applicable law or agreed to in writing, software
10
// distributed under the License is distributed on an "AS IS" BASIS,
11
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
// See the License for the specific language governing permissions and
13
// limitations under the License.
14
15
#include <algorithm>
16
17
#include "source/opcode.h"
18
#include "source/table2.h"
19
#include "source/val/instruction.h"
20
#include "source/val/validate.h"
21
#include "source/val/validation_state.h"
22
23
namespace spvtools {
24
namespace val {
25
namespace {
26
27
// Returns true if |a| and |b| are instructions defining pointers that point to
28
// types logically match and the decorations that apply to |b| are a subset
29
// of the decorations that apply to |a|.
30
bool DoPointeesLogicallyMatch(val::Instruction* a, val::Instruction* b,
31
0
                              ValidationState_t& _) {
32
0
  if (a->opcode() != spv::Op::OpTypePointer ||
33
0
      b->opcode() != spv::Op::OpTypePointer) {
34
0
    return false;
35
0
  }
36
37
0
  const auto& dec_a = _.id_decorations(a->id());
38
0
  const auto& dec_b = _.id_decorations(b->id());
39
0
  for (const auto& dec : dec_b) {
40
0
    if (std::find(dec_a.begin(), dec_a.end(), dec) == dec_a.end()) {
41
0
      return false;
42
0
    }
43
0
  }
44
45
0
  uint32_t a_type = a->GetOperandAs<uint32_t>(2);
46
0
  uint32_t b_type = b->GetOperandAs<uint32_t>(2);
47
48
0
  if (a_type == b_type) {
49
0
    return true;
50
0
  }
51
52
0
  Instruction* a_type_inst = _.FindDef(a_type);
53
0
  Instruction* b_type_inst = _.FindDef(b_type);
54
55
0
  return _.LogicallyMatch(a_type_inst, b_type_inst, true);
56
0
}
57
58
39.9k
spv_result_t ValidateFunction(ValidationState_t& _, const Instruction* inst) {
59
39.9k
  const auto function_type_id = inst->GetOperandAs<uint32_t>(3);
60
39.9k
  const auto function_type = _.FindDef(function_type_id);
61
39.9k
  if (!function_type || spv::Op::OpTypeFunction != function_type->opcode()) {
62
50
    return _.diag(SPV_ERROR_INVALID_ID, inst)
63
50
           << "OpFunction Function Type <id> " << _.getIdName(function_type_id)
64
50
           << " is not a function type.";
65
50
  }
66
67
39.9k
  const auto return_id = function_type->GetOperandAs<uint32_t>(1);
68
39.9k
  if (return_id != inst->type_id()) {
69
18
    return _.diag(SPV_ERROR_INVALID_ID, inst)
70
18
           << "OpFunction Result Type <id> " << _.getIdName(inst->type_id())
71
18
           << " does not match the Function Type's return type <id> "
72
18
           << _.getIdName(return_id) << ".";
73
18
  }
74
75
39.9k
  const std::vector<spv::Op> acceptable = {
76
39.9k
      spv::Op::OpGroupDecorate,
77
39.9k
      spv::Op::OpDecorate,
78
39.9k
      spv::Op::OpEnqueueKernel,
79
39.9k
      spv::Op::OpEntryPoint,
80
39.9k
      spv::Op::OpExecutionMode,
81
39.9k
      spv::Op::OpExecutionModeId,
82
39.9k
      spv::Op::OpFunctionCall,
83
39.9k
      spv::Op::OpGetKernelNDrangeSubGroupCount,
84
39.9k
      spv::Op::OpGetKernelNDrangeMaxSubGroupSize,
85
39.9k
      spv::Op::OpGetKernelWorkGroupSize,
86
39.9k
      spv::Op::OpGetKernelPreferredWorkGroupSizeMultiple,
87
39.9k
      spv::Op::OpGetKernelLocalSizeForSubgroupCount,
88
39.9k
      spv::Op::OpGetKernelMaxNumSubgroups,
89
39.9k
      spv::Op::OpName,
90
39.9k
      spv::Op::OpCooperativeMatrixPerElementOpNV,
91
39.9k
      spv::Op::OpCooperativeMatrixReduceNV,
92
39.9k
      spv::Op::OpCooperativeMatrixLoadTensorNV,
93
39.9k
      spv::Op::OpConditionalEntryPointINTEL,
94
39.9k
      spv::Op::OpConstantFunctionPointerINTEL};
95
95.5k
  for (auto& pair : inst->uses()) {
96
95.5k
    const auto* use = pair.first;
97
95.5k
    if (std::find(acceptable.begin(), acceptable.end(), use->opcode()) ==
98
95.5k
            acceptable.end() &&
99
19
        !use->IsNonSemantic() && !use->IsDebugInfo() &&
100
19
        !spvOpcodeIsDecoration(use->opcode())) {
101
19
      return _.diag(SPV_ERROR_INVALID_ID, use)
102
19
             << "Invalid use of function result id " << _.getIdName(inst->id())
103
19
             << ".";
104
19
    }
105
95.5k
  }
106
107
39.9k
  return SPV_SUCCESS;
108
39.9k
}
109
110
spv_result_t ValidateFunctionParameter(ValidationState_t& _,
111
18.0k
                                       const Instruction* inst) {
112
  // NOTE: Find OpFunction & ensure OpFunctionParameter is not out of place.
113
18.0k
  size_t param_index = 0;
114
18.0k
  size_t inst_num = inst->LineNum() - 1;
115
18.0k
  auto func_inst = &_.ordered_instructions()[inst_num];
116
31.2k
  while (--inst_num) {
117
31.2k
    func_inst = &_.ordered_instructions()[inst_num];
118
31.2k
    if (func_inst->opcode() == spv::Op::OpFunction) {
119
18.0k
      break;
120
18.0k
    } else if (func_inst->opcode() == spv::Op::OpFunctionParameter) {
121
12.9k
      ++param_index;
122
12.9k
    }
123
31.2k
  }
124
125
18.0k
  if (func_inst->opcode() != spv::Op::OpFunction) {
126
0
    return _.diag(SPV_ERROR_INVALID_LAYOUT, inst)
127
0
           << "Function parameter must be preceded by a function.";
128
0
  }
129
130
18.0k
  const auto function_type_id = func_inst->GetOperandAs<uint32_t>(3);
131
18.0k
  const auto function_type = _.FindDef(function_type_id);
132
18.0k
  if (!function_type) {
133
0
    return _.diag(SPV_ERROR_INVALID_ID, func_inst)
134
0
           << "Missing function type definition.";
135
0
  }
136
18.0k
  if (param_index >= function_type->words().size() - 3) {
137
4
    return _.diag(SPV_ERROR_INVALID_ID, inst)
138
4
           << "Too many OpFunctionParameters for " << func_inst->id()
139
4
           << ": expected " << function_type->words().size() - 3
140
4
           << " based on the function's type";
141
4
  }
142
143
18.0k
  const auto param_type =
144
18.0k
      _.FindDef(function_type->GetOperandAs<uint32_t>(param_index + 2));
145
18.0k
  if (!param_type || inst->type_id() != param_type->id()) {
146
18
    return _.diag(SPV_ERROR_INVALID_ID, inst)
147
18
           << "OpFunctionParameter Result Type <id> "
148
18
           << _.getIdName(inst->type_id())
149
18
           << " does not match the OpTypeFunction parameter "
150
18
              "type of the same index.";
151
18
  }
152
153
18.0k
  return SPV_SUCCESS;
154
18.0k
}
155
156
spv_result_t ValidateFunctionCall(ValidationState_t& _,
157
31.5k
                                  const Instruction* inst) {
158
31.5k
  const auto function_id = inst->GetOperandAs<uint32_t>(2);
159
31.5k
  const auto function = _.FindDef(function_id);
160
31.5k
  if (!function || spv::Op::OpFunction != function->opcode()) {
161
13
    return _.diag(SPV_ERROR_INVALID_ID, inst)
162
13
           << "OpFunctionCall Function <id> " << _.getIdName(function_id)
163
13
           << " is not a function.";
164
13
  }
165
166
31.5k
  auto return_type = _.FindDef(function->type_id());
167
31.5k
  if (!return_type || return_type->id() != inst->type_id()) {
168
11
    return _.diag(SPV_ERROR_INVALID_ID, inst)
169
11
           << "OpFunctionCall Result Type <id> " << _.getIdName(inst->type_id())
170
11
           << "s type does not match Function <id> "
171
11
           << _.getIdName(return_type->id()) << "s return type.";
172
11
  }
173
31.5k
  if (!_.options()->relax_logical_pointer &&
174
31.5k
      (_.addressing_model() == spv::AddressingModel::Logical ||
175
31.5k
       _.addressing_model() == spv::AddressingModel::PhysicalStorageBuffer64)) {
176
31.5k
    if (return_type->opcode() == spv::Op::OpTypePointer ||
177
31.5k
        return_type->opcode() == spv::Op::OpTypeUntypedPointerKHR) {
178
3
      const auto sc = return_type->GetOperandAs<spv::StorageClass>(1);
179
3
      if (sc != spv::StorageClass::PhysicalStorageBuffer) {
180
3
        if (!_.HasCapability(spv::Capability::VariablePointersStorageBuffer) &&
181
3
            sc == spv::StorageClass::StorageBuffer) {
182
0
          return _.diag(SPV_ERROR_INVALID_ID, inst)
183
0
                 << "In Logical addressing, functions may only return a "
184
0
                    "storage buffer pointer if the "
185
0
                    "VariablePointersStorageBuffer capability is declared";
186
3
        } else if (!_.HasCapability(spv::Capability::VariablePointers) &&
187
3
                   sc == spv::StorageClass::Workgroup) {
188
1
          return _.diag(SPV_ERROR_INVALID_ID, inst)
189
1
                 << "In Logical addressing, functions may only return a "
190
1
                    "workgroup pointer if the VariablePointers capability is "
191
1
                    "declared";
192
2
        } else if (sc != spv::StorageClass::StorageBuffer &&
193
2
                   sc != spv::StorageClass::Workgroup) {
194
2
          return _.diag(SPV_ERROR_INVALID_ID, inst)
195
2
                 << "In Logical addressing, functions may not return a pointer "
196
2
                    "in this storage class";
197
2
        }
198
3
      }
199
3
    }
200
31.5k
  }
201
202
31.5k
  const auto function_type_id = function->GetOperandAs<uint32_t>(3);
203
31.5k
  const auto function_type = _.FindDef(function_type_id);
204
31.5k
  if (!function_type || function_type->opcode() != spv::Op::OpTypeFunction) {
205
4
    return _.diag(SPV_ERROR_INVALID_ID, inst)
206
4
           << "Missing function type definition.";
207
4
  }
208
209
31.5k
  const auto function_call_arg_count = inst->words().size() - 4;
210
31.5k
  const auto function_param_count = function_type->words().size() - 3;
211
31.5k
  if (function_param_count != function_call_arg_count) {
212
3
    return _.diag(SPV_ERROR_INVALID_ID, inst)
213
3
           << "OpFunctionCall Function <id>'s parameter count does not match "
214
3
              "the argument count.";
215
3
  }
216
217
31.5k
  for (size_t argument_index = 3, param_index = 2;
218
78.4k
       argument_index < inst->operands().size();
219
46.8k
       argument_index++, param_index++) {
220
46.8k
    const auto argument_id = inst->GetOperandAs<uint32_t>(argument_index);
221
46.8k
    const auto argument = _.FindDef(argument_id);
222
46.8k
    if (!argument) {
223
0
      return _.diag(SPV_ERROR_INVALID_ID, inst)
224
0
             << "Missing argument " << argument_index - 3 << " definition.";
225
0
    }
226
227
46.8k
    const auto argument_type = _.FindDef(argument->type_id());
228
46.8k
    if (!argument_type) {
229
0
      return _.diag(SPV_ERROR_INVALID_ID, inst)
230
0
             << "Missing argument " << argument_index - 3
231
0
             << " type definition.";
232
0
    }
233
234
46.8k
    const auto parameter_type_id =
235
46.8k
        function_type->GetOperandAs<uint32_t>(param_index);
236
46.8k
    const auto parameter_type = _.FindDef(parameter_type_id);
237
46.8k
    if (!parameter_type || argument_type->id() != parameter_type->id()) {
238
26
      if (!parameter_type || !_.options()->before_hlsl_legalization ||
239
26
          !DoPointeesLogicallyMatch(argument_type, parameter_type, _)) {
240
26
        return _.diag(SPV_ERROR_INVALID_ID, inst)
241
26
               << "OpFunctionCall Argument <id> " << _.getIdName(argument_id)
242
26
               << "s type does not match Function <id> "
243
26
               << _.getIdName(parameter_type_id) << "s parameter type.";
244
26
      }
245
26
    }
246
247
46.8k
    if (_.addressing_model() == spv::AddressingModel::Logical ||
248
46.8k
        _.addressing_model() == spv::AddressingModel::PhysicalStorageBuffer64) {
249
46.8k
      if ((parameter_type->opcode() == spv::Op::OpTypePointer ||
250
265
           parameter_type->opcode() == spv::Op::OpTypeUntypedPointerKHR) &&
251
46.5k
          !_.options()->relax_logical_pointer) {
252
46.5k
        spv::StorageClass sc =
253
46.5k
            parameter_type->GetOperandAs<spv::StorageClass>(1u);
254
46.5k
        if (sc != spv::StorageClass::PhysicalStorageBuffer) {
255
          // Validate which storage classes can be pointer operands.
256
46.5k
          switch (sc) {
257
1
            case spv::StorageClass::UniformConstant:
258
46.5k
            case spv::StorageClass::Function:
259
46.5k
            case spv::StorageClass::Private:
260
46.5k
            case spv::StorageClass::Workgroup:
261
46.5k
            case spv::StorageClass::AtomicCounter:
262
            // SPV_EXT_tile_image
263
46.5k
            case spv::StorageClass::TileImageEXT:
264
            // SPV_KHR_ray_tracing
265
46.5k
            case spv::StorageClass::ShaderRecordBufferKHR:
266
              // These are always allowed.
267
46.5k
              break;
268
0
            case spv::StorageClass::StorageBuffer:
269
0
              if (!_.features().variable_pointers) {
270
0
                return _.diag(SPV_ERROR_INVALID_ID, inst)
271
0
                       << "StorageBuffer pointer operand "
272
0
                       << _.getIdName(argument_id)
273
0
                       << " requires a variable pointers capability";
274
0
              }
275
0
              break;
276
4
            default:
277
4
              return _.diag(SPV_ERROR_INVALID_ID, inst)
278
4
                     << "Invalid storage class for pointer operand "
279
4
                     << _.getIdName(argument_id);
280
46.5k
          }
281
282
          // Validate memory object declaration requirements.
283
46.5k
          if (argument->opcode() != spv::Op::OpVariable &&
284
97
              argument->opcode() != spv::Op::OpUntypedVariableKHR &&
285
97
              argument->opcode() != spv::Op::OpFunctionParameter) {
286
6
            const bool ssbo_vptr =
287
6
                _.HasCapability(
288
6
                    spv::Capability::VariablePointersStorageBuffer) &&
289
0
                sc == spv::StorageClass::StorageBuffer;
290
6
            const bool wg_vptr =
291
6
                _.HasCapability(spv::Capability::VariablePointers) &&
292
0
                sc == spv::StorageClass::Workgroup;
293
6
            const bool uc_ptr = sc == spv::StorageClass::UniformConstant;
294
6
            if (!_.options()->before_hlsl_legalization && !ssbo_vptr &&
295
6
                !wg_vptr && !uc_ptr) {
296
5
              return _.diag(SPV_ERROR_INVALID_ID, inst)
297
5
                     << "Pointer operand " << _.getIdName(argument_id)
298
5
                     << " must be a memory object declaration";
299
5
            }
300
6
          }
301
46.5k
        }
302
46.5k
      }
303
46.8k
    }
304
46.8k
  }
305
31.5k
  return SPV_SUCCESS;
306
31.5k
}
307
308
spv_result_t ValidateCooperativeMatrixPerElementOp(ValidationState_t& _,
309
0
                                                   const Instruction* inst) {
310
0
  const auto function_id = inst->GetOperandAs<uint32_t>(3);
311
0
  const auto function = _.FindDef(function_id);
312
0
  if (!function || spv::Op::OpFunction != function->opcode()) {
313
0
    return _.diag(SPV_ERROR_INVALID_ID, inst)
314
0
           << "OpCooperativeMatrixPerElementOpNV Function <id> "
315
0
           << _.getIdName(function_id) << " is not a function.";
316
0
  }
317
318
0
  const auto matrix_id = inst->GetOperandAs<uint32_t>(2);
319
0
  const auto matrix = _.FindDef(matrix_id);
320
0
  const auto matrix_type_id = matrix->type_id();
321
0
  if (!_.IsCooperativeMatrixKHRType(matrix_type_id)) {
322
0
    return _.diag(SPV_ERROR_INVALID_ID, inst)
323
0
           << "OpCooperativeMatrixPerElementOpNV Matrix <id> "
324
0
           << _.getIdName(matrix_id) << " is not a cooperative matrix.";
325
0
  }
326
327
0
  const auto result_type_id = inst->GetOperandAs<uint32_t>(0);
328
0
  if (matrix_type_id != result_type_id) {
329
0
    return _.diag(SPV_ERROR_INVALID_ID, inst)
330
0
           << "OpCooperativeMatrixPerElementOpNV Result Type <id> "
331
0
           << _.getIdName(result_type_id) << " must match matrix type <id> "
332
0
           << _.getIdName(matrix_type_id) << ".";
333
0
  }
334
335
0
  const auto matrix_comp_type_id =
336
0
      _.FindDef(matrix_type_id)->GetOperandAs<uint32_t>(1);
337
0
  const auto function_type_id = function->GetOperandAs<uint32_t>(3);
338
0
  const auto function_type = _.FindDef(function_type_id);
339
0
  auto return_type_id = function_type->GetOperandAs<uint32_t>(1);
340
0
  if (return_type_id != matrix_comp_type_id) {
341
0
    return _.diag(SPV_ERROR_INVALID_ID, inst)
342
0
           << "OpCooperativeMatrixPerElementOpNV function return type <id> "
343
0
           << _.getIdName(return_type_id)
344
0
           << " must match matrix component type <id> "
345
0
           << _.getIdName(matrix_comp_type_id) << ".";
346
0
  }
347
348
0
  if (function_type->operands().size() < 5) {
349
0
    return _.diag(SPV_ERROR_INVALID_ID, inst)
350
0
           << "OpCooperativeMatrixPerElementOpNV function type <id> "
351
0
           << _.getIdName(function_type_id)
352
0
           << " must have a least three parameters.";
353
0
  }
354
355
0
  const auto param0_id = function_type->GetOperandAs<uint32_t>(2);
356
0
  const auto param1_id = function_type->GetOperandAs<uint32_t>(3);
357
0
  const auto param2_id = function_type->GetOperandAs<uint32_t>(4);
358
0
  if (!_.IsIntScalarType(param0_id, 32)) {
359
0
    return _.diag(SPV_ERROR_INVALID_ID, inst)
360
0
           << "OpCooperativeMatrixPerElementOpNV function type first parameter "
361
0
              "type <id> "
362
0
           << _.getIdName(param0_id) << " must be a 32-bit integer.";
363
0
  }
364
365
0
  if (!_.IsIntScalarType(param1_id, 32)) {
366
0
    return _.diag(SPV_ERROR_INVALID_ID, inst)
367
0
           << "OpCooperativeMatrixPerElementOpNV function type second "
368
0
              "parameter type <id> "
369
0
           << _.getIdName(param1_id) << " must be a 32-bit integer.";
370
0
  }
371
372
0
  if (param2_id != matrix_comp_type_id) {
373
0
    return _.diag(SPV_ERROR_INVALID_ID, inst)
374
0
           << "OpCooperativeMatrixPerElementOpNV function type third parameter "
375
0
              "type <id> "
376
0
           << _.getIdName(param2_id) << " must match matrix component type.";
377
0
  }
378
379
0
  return SPV_SUCCESS;
380
0
}
381
382
}  // namespace
383
384
14.5M
spv_result_t FunctionPass(ValidationState_t& _, const Instruction* inst) {
385
14.5M
  switch (inst->opcode()) {
386
39.9k
    case spv::Op::OpFunction:
387
39.9k
      if (auto error = ValidateFunction(_, inst)) return error;
388
39.9k
      break;
389
39.9k
    case spv::Op::OpFunctionParameter:
390
18.0k
      if (auto error = ValidateFunctionParameter(_, inst)) return error;
391
18.0k
      break;
392
31.5k
    case spv::Op::OpFunctionCall:
393
31.5k
      if (auto error = ValidateFunctionCall(_, inst)) return error;
394
31.5k
      break;
395
31.5k
    case spv::Op::OpCooperativeMatrixPerElementOpNV:
396
0
      if (auto error = ValidateCooperativeMatrixPerElementOp(_, inst))
397
0
        return error;
398
0
      break;
399
14.4M
    default:
400
14.4M
      break;
401
14.5M
  }
402
403
14.5M
  return SPV_SUCCESS;
404
14.5M
}
405
406
}  // namespace val
407
}  // namespace spvtools