/src/spirv-tools/source/val/validate_function.cpp
Line | Count | Source (jump to first uncovered line) |
1 | | // Copyright (c) 2018 Google LLC. |
2 | | // |
3 | | // Licensed under the Apache License, Version 2.0 (the "License"); |
4 | | // you may not use this file except in compliance with the License. |
5 | | // You may obtain a copy of the License at |
6 | | // |
7 | | // http://www.apache.org/licenses/LICENSE-2.0 |
8 | | // |
9 | | // Unless required by applicable law or agreed to in writing, software |
10 | | // distributed under the License is distributed on an "AS IS" BASIS, |
11 | | // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. |
12 | | // See the License for the specific language governing permissions and |
13 | | // limitations under the License. |
14 | | |
15 | | #include <algorithm> |
16 | | |
17 | | #include "source/opcode.h" |
18 | | #include "source/table2.h" |
19 | | #include "source/val/instruction.h" |
20 | | #include "source/val/validate.h" |
21 | | #include "source/val/validation_state.h" |
22 | | |
23 | | namespace spvtools { |
24 | | namespace val { |
25 | | namespace { |
26 | | |
27 | | // Returns true if |a| and |b| are instructions defining pointers that point to |
28 | | // types logically match and the decorations that apply to |b| are a subset |
29 | | // of the decorations that apply to |a|. |
30 | | bool DoPointeesLogicallyMatch(val::Instruction* a, val::Instruction* b, |
31 | 0 | ValidationState_t& _) { |
32 | 0 | if (a->opcode() != spv::Op::OpTypePointer || |
33 | 0 | b->opcode() != spv::Op::OpTypePointer) { |
34 | 0 | return false; |
35 | 0 | } |
36 | | |
37 | 0 | const auto& dec_a = _.id_decorations(a->id()); |
38 | 0 | const auto& dec_b = _.id_decorations(b->id()); |
39 | 0 | for (const auto& dec : dec_b) { |
40 | 0 | if (std::find(dec_a.begin(), dec_a.end(), dec) == dec_a.end()) { |
41 | 0 | return false; |
42 | 0 | } |
43 | 0 | } |
44 | | |
45 | 0 | uint32_t a_type = a->GetOperandAs<uint32_t>(2); |
46 | 0 | uint32_t b_type = b->GetOperandAs<uint32_t>(2); |
47 | |
|
48 | 0 | if (a_type == b_type) { |
49 | 0 | return true; |
50 | 0 | } |
51 | | |
52 | 0 | Instruction* a_type_inst = _.FindDef(a_type); |
53 | 0 | Instruction* b_type_inst = _.FindDef(b_type); |
54 | |
|
55 | 0 | return _.LogicallyMatch(a_type_inst, b_type_inst, true); |
56 | 0 | } |
57 | | |
58 | 33.1k | spv_result_t ValidateFunction(ValidationState_t& _, const Instruction* inst) { |
59 | 33.1k | const auto function_type_id = inst->GetOperandAs<uint32_t>(3); |
60 | 33.1k | const auto function_type = _.FindDef(function_type_id); |
61 | 33.1k | if (!function_type || spv::Op::OpTypeFunction != function_type->opcode()) { |
62 | 64 | return _.diag(SPV_ERROR_INVALID_ID, inst) |
63 | 64 | << "OpFunction Function Type <id> " << _.getIdName(function_type_id) |
64 | 64 | << " is not a function type."; |
65 | 64 | } |
66 | | |
67 | 33.1k | const auto return_id = function_type->GetOperandAs<uint32_t>(1); |
68 | 33.1k | if (return_id != inst->type_id()) { |
69 | 16 | return _.diag(SPV_ERROR_INVALID_ID, inst) |
70 | 16 | << "OpFunction Result Type <id> " << _.getIdName(inst->type_id()) |
71 | 16 | << " does not match the Function Type's return type <id> " |
72 | 16 | << _.getIdName(return_id) << "."; |
73 | 16 | } |
74 | | |
75 | 33.1k | const std::vector<spv::Op> acceptable = { |
76 | 33.1k | spv::Op::OpGroupDecorate, |
77 | 33.1k | spv::Op::OpDecorate, |
78 | 33.1k | spv::Op::OpEnqueueKernel, |
79 | 33.1k | spv::Op::OpEntryPoint, |
80 | 33.1k | spv::Op::OpExecutionMode, |
81 | 33.1k | spv::Op::OpExecutionModeId, |
82 | 33.1k | spv::Op::OpFunctionCall, |
83 | 33.1k | spv::Op::OpGetKernelNDrangeSubGroupCount, |
84 | 33.1k | spv::Op::OpGetKernelNDrangeMaxSubGroupSize, |
85 | 33.1k | spv::Op::OpGetKernelWorkGroupSize, |
86 | 33.1k | spv::Op::OpGetKernelPreferredWorkGroupSizeMultiple, |
87 | 33.1k | spv::Op::OpGetKernelLocalSizeForSubgroupCount, |
88 | 33.1k | spv::Op::OpGetKernelMaxNumSubgroups, |
89 | 33.1k | spv::Op::OpName, |
90 | 33.1k | spv::Op::OpCooperativeMatrixPerElementOpNV, |
91 | 33.1k | spv::Op::OpCooperativeMatrixReduceNV, |
92 | 33.1k | spv::Op::OpCooperativeMatrixLoadTensorNV}; |
93 | 72.2k | for (auto& pair : inst->uses()) { |
94 | 72.2k | const auto* use = pair.first; |
95 | 72.2k | if (std::find(acceptable.begin(), acceptable.end(), use->opcode()) == |
96 | 72.2k | acceptable.end() && |
97 | 72.2k | !use->IsNonSemantic() && !use->IsDebugInfo()) { |
98 | 16 | return _.diag(SPV_ERROR_INVALID_ID, use) |
99 | 16 | << "Invalid use of function result id " << _.getIdName(inst->id()) |
100 | 16 | << "."; |
101 | 16 | } |
102 | 72.2k | } |
103 | | |
104 | 33.0k | return SPV_SUCCESS; |
105 | 33.1k | } |
106 | | |
107 | | spv_result_t ValidateFunctionParameter(ValidationState_t& _, |
108 | 13.5k | const Instruction* inst) { |
109 | | // NOTE: Find OpFunction & ensure OpFunctionParameter is not out of place. |
110 | 13.5k | size_t param_index = 0; |
111 | 13.5k | size_t inst_num = inst->LineNum() - 1; |
112 | 13.5k | if (inst_num == 0) { |
113 | 0 | return _.diag(SPV_ERROR_INVALID_LAYOUT, inst) |
114 | 0 | << "Function parameter cannot be the first instruction."; |
115 | 0 | } |
116 | | |
117 | 13.5k | auto func_inst = &_.ordered_instructions()[inst_num]; |
118 | 23.6k | while (--inst_num) { |
119 | 23.6k | func_inst = &_.ordered_instructions()[inst_num]; |
120 | 23.6k | if (func_inst->opcode() == spv::Op::OpFunction) { |
121 | 13.5k | break; |
122 | 13.5k | } else if (func_inst->opcode() == spv::Op::OpFunctionParameter) { |
123 | 10.0k | ++param_index; |
124 | 10.0k | } |
125 | 23.6k | } |
126 | | |
127 | 13.5k | if (func_inst->opcode() != spv::Op::OpFunction) { |
128 | 0 | return _.diag(SPV_ERROR_INVALID_LAYOUT, inst) |
129 | 0 | << "Function parameter must be preceded by a function."; |
130 | 0 | } |
131 | | |
132 | 13.5k | const auto function_type_id = func_inst->GetOperandAs<uint32_t>(3); |
133 | 13.5k | const auto function_type = _.FindDef(function_type_id); |
134 | 13.5k | if (!function_type) { |
135 | 0 | return _.diag(SPV_ERROR_INVALID_ID, func_inst) |
136 | 0 | << "Missing function type definition."; |
137 | 0 | } |
138 | 13.5k | if (param_index >= function_type->words().size() - 3) { |
139 | 3 | return _.diag(SPV_ERROR_INVALID_ID, inst) |
140 | 3 | << "Too many OpFunctionParameters for " << func_inst->id() |
141 | 3 | << ": expected " << function_type->words().size() - 3 |
142 | 3 | << " based on the function's type"; |
143 | 3 | } |
144 | | |
145 | 13.4k | const auto param_type = |
146 | 13.4k | _.FindDef(function_type->GetOperandAs<uint32_t>(param_index + 2)); |
147 | 13.4k | if (!param_type || inst->type_id() != param_type->id()) { |
148 | 17 | return _.diag(SPV_ERROR_INVALID_ID, inst) |
149 | 17 | << "OpFunctionParameter Result Type <id> " |
150 | 17 | << _.getIdName(inst->type_id()) |
151 | 17 | << " does not match the OpTypeFunction parameter " |
152 | 17 | "type of the same index."; |
153 | 17 | } |
154 | | |
155 | 13.4k | return SPV_SUCCESS; |
156 | 13.4k | } |
157 | | |
158 | | spv_result_t ValidateFunctionCall(ValidationState_t& _, |
159 | 23.7k | const Instruction* inst) { |
160 | 23.7k | const auto function_id = inst->GetOperandAs<uint32_t>(2); |
161 | 23.7k | const auto function = _.FindDef(function_id); |
162 | 23.7k | if (!function || spv::Op::OpFunction != function->opcode()) { |
163 | 13 | return _.diag(SPV_ERROR_INVALID_ID, inst) |
164 | 13 | << "OpFunctionCall Function <id> " << _.getIdName(function_id) |
165 | 13 | << " is not a function."; |
166 | 13 | } |
167 | | |
168 | 23.7k | auto return_type = _.FindDef(function->type_id()); |
169 | 23.7k | if (!return_type || return_type->id() != inst->type_id()) { |
170 | 12 | return _.diag(SPV_ERROR_INVALID_ID, inst) |
171 | 12 | << "OpFunctionCall Result Type <id> " << _.getIdName(inst->type_id()) |
172 | 12 | << "s type does not match Function <id> " |
173 | 12 | << _.getIdName(return_type->id()) << "s return type."; |
174 | 12 | } |
175 | | |
176 | 23.6k | const auto function_type_id = function->GetOperandAs<uint32_t>(3); |
177 | 23.6k | const auto function_type = _.FindDef(function_type_id); |
178 | 23.6k | if (!function_type || function_type->opcode() != spv::Op::OpTypeFunction) { |
179 | 3 | return _.diag(SPV_ERROR_INVALID_ID, inst) |
180 | 3 | << "Missing function type definition."; |
181 | 3 | } |
182 | | |
183 | 23.6k | const auto function_call_arg_count = inst->words().size() - 4; |
184 | 23.6k | const auto function_param_count = function_type->words().size() - 3; |
185 | 23.6k | if (function_param_count != function_call_arg_count) { |
186 | 3 | return _.diag(SPV_ERROR_INVALID_ID, inst) |
187 | 3 | << "OpFunctionCall Function <id>'s parameter count does not match " |
188 | 3 | "the argument count."; |
189 | 3 | } |
190 | | |
191 | 23.6k | for (size_t argument_index = 3, param_index = 2; |
192 | 59.9k | argument_index < inst->operands().size(); |
193 | 36.2k | argument_index++, param_index++) { |
194 | 36.2k | const auto argument_id = inst->GetOperandAs<uint32_t>(argument_index); |
195 | 36.2k | const auto argument = _.FindDef(argument_id); |
196 | 36.2k | if (!argument) { |
197 | 0 | return _.diag(SPV_ERROR_INVALID_ID, inst) |
198 | 0 | << "Missing argument " << argument_index - 3 << " definition."; |
199 | 0 | } |
200 | | |
201 | 36.2k | const auto argument_type = _.FindDef(argument->type_id()); |
202 | 36.2k | if (!argument_type) { |
203 | 0 | return _.diag(SPV_ERROR_INVALID_ID, inst) |
204 | 0 | << "Missing argument " << argument_index - 3 |
205 | 0 | << " type definition."; |
206 | 0 | } |
207 | | |
208 | 36.2k | const auto parameter_type_id = |
209 | 36.2k | function_type->GetOperandAs<uint32_t>(param_index); |
210 | 36.2k | const auto parameter_type = _.FindDef(parameter_type_id); |
211 | 36.2k | if (!parameter_type || argument_type->id() != parameter_type->id()) { |
212 | 25 | if (!parameter_type || !_.options()->before_hlsl_legalization || |
213 | 25 | !DoPointeesLogicallyMatch(argument_type, parameter_type, _)) { |
214 | 25 | return _.diag(SPV_ERROR_INVALID_ID, inst) |
215 | 25 | << "OpFunctionCall Argument <id> " << _.getIdName(argument_id) |
216 | 25 | << "s type does not match Function <id> " |
217 | 25 | << _.getIdName(parameter_type_id) << "s parameter type."; |
218 | 25 | } |
219 | 25 | } |
220 | | |
221 | 36.2k | if (_.addressing_model() == spv::AddressingModel::Logical) { |
222 | 36.2k | if ((parameter_type->opcode() == spv::Op::OpTypePointer || |
223 | 36.2k | parameter_type->opcode() == spv::Op::OpTypeUntypedPointerKHR) && |
224 | 36.2k | !_.options()->relax_logical_pointer) { |
225 | 36.0k | spv::StorageClass sc = |
226 | 36.0k | parameter_type->GetOperandAs<spv::StorageClass>(1u); |
227 | | // Validate which storage classes can be pointer operands. |
228 | 36.0k | switch (sc) { |
229 | 1 | case spv::StorageClass::UniformConstant: |
230 | 35.9k | case spv::StorageClass::Function: |
231 | 35.9k | case spv::StorageClass::Private: |
232 | 35.9k | case spv::StorageClass::Workgroup: |
233 | 35.9k | case spv::StorageClass::AtomicCounter: |
234 | | // These are always allowed. |
235 | 35.9k | break; |
236 | 0 | case spv::StorageClass::StorageBuffer: |
237 | 0 | if (!_.features().variable_pointers) { |
238 | 0 | return _.diag(SPV_ERROR_INVALID_ID, inst) |
239 | 0 | << "StorageBuffer pointer operand " |
240 | 0 | << _.getIdName(argument_id) |
241 | 0 | << " requires a variable pointers capability"; |
242 | 0 | } |
243 | 0 | break; |
244 | 4 | default: |
245 | 4 | return _.diag(SPV_ERROR_INVALID_ID, inst) |
246 | 4 | << "Invalid storage class for pointer operand " |
247 | 4 | << _.getIdName(argument_id); |
248 | 36.0k | } |
249 | | |
250 | | // Validate memory object declaration requirements. |
251 | 35.9k | if (argument->opcode() != spv::Op::OpVariable && |
252 | 35.9k | argument->opcode() != spv::Op::OpUntypedVariableKHR && |
253 | 35.9k | argument->opcode() != spv::Op::OpFunctionParameter) { |
254 | 6 | const bool ssbo_vptr = |
255 | 6 | _.HasCapability(spv::Capability::VariablePointersStorageBuffer) && |
256 | 6 | sc == spv::StorageClass::StorageBuffer; |
257 | 6 | const bool wg_vptr = |
258 | 6 | _.HasCapability(spv::Capability::VariablePointers) && |
259 | 6 | sc == spv::StorageClass::Workgroup; |
260 | 6 | const bool uc_ptr = sc == spv::StorageClass::UniformConstant; |
261 | 6 | if (!_.options()->before_hlsl_legalization && !ssbo_vptr && |
262 | 6 | !wg_vptr && !uc_ptr) { |
263 | 5 | return _.diag(SPV_ERROR_INVALID_ID, inst) |
264 | 5 | << "Pointer operand " << _.getIdName(argument_id) |
265 | 5 | << " must be a memory object declaration"; |
266 | 5 | } |
267 | 6 | } |
268 | 35.9k | } |
269 | 36.2k | } |
270 | 36.2k | } |
271 | 23.6k | return SPV_SUCCESS; |
272 | 23.6k | } |
273 | | |
274 | | spv_result_t ValidateCooperativeMatrixPerElementOp(ValidationState_t& _, |
275 | 0 | const Instruction* inst) { |
276 | 0 | const auto function_id = inst->GetOperandAs<uint32_t>(3); |
277 | 0 | const auto function = _.FindDef(function_id); |
278 | 0 | if (!function || spv::Op::OpFunction != function->opcode()) { |
279 | 0 | return _.diag(SPV_ERROR_INVALID_ID, inst) |
280 | 0 | << "OpCooperativeMatrixPerElementOpNV Function <id> " |
281 | 0 | << _.getIdName(function_id) << " is not a function."; |
282 | 0 | } |
283 | | |
284 | 0 | const auto matrix_id = inst->GetOperandAs<uint32_t>(2); |
285 | 0 | const auto matrix = _.FindDef(matrix_id); |
286 | 0 | const auto matrix_type_id = matrix->type_id(); |
287 | 0 | if (!_.IsCooperativeMatrixKHRType(matrix_type_id)) { |
288 | 0 | return _.diag(SPV_ERROR_INVALID_ID, inst) |
289 | 0 | << "OpCooperativeMatrixPerElementOpNV Matrix <id> " |
290 | 0 | << _.getIdName(matrix_id) << " is not a cooperative matrix."; |
291 | 0 | } |
292 | | |
293 | 0 | const auto result_type_id = inst->GetOperandAs<uint32_t>(0); |
294 | 0 | if (matrix_type_id != result_type_id) { |
295 | 0 | return _.diag(SPV_ERROR_INVALID_ID, inst) |
296 | 0 | << "OpCooperativeMatrixPerElementOpNV Result Type <id> " |
297 | 0 | << _.getIdName(result_type_id) << " must match matrix type <id> " |
298 | 0 | << _.getIdName(matrix_type_id) << "."; |
299 | 0 | } |
300 | | |
301 | 0 | const auto matrix_comp_type_id = |
302 | 0 | _.FindDef(matrix_type_id)->GetOperandAs<uint32_t>(1); |
303 | 0 | const auto function_type_id = function->GetOperandAs<uint32_t>(3); |
304 | 0 | const auto function_type = _.FindDef(function_type_id); |
305 | 0 | auto return_type_id = function_type->GetOperandAs<uint32_t>(1); |
306 | 0 | if (return_type_id != matrix_comp_type_id) { |
307 | 0 | return _.diag(SPV_ERROR_INVALID_ID, inst) |
308 | 0 | << "OpCooperativeMatrixPerElementOpNV function return type <id> " |
309 | 0 | << _.getIdName(return_type_id) |
310 | 0 | << " must match matrix component type <id> " |
311 | 0 | << _.getIdName(matrix_comp_type_id) << "."; |
312 | 0 | } |
313 | | |
314 | 0 | if (function_type->operands().size() < 5) { |
315 | 0 | return _.diag(SPV_ERROR_INVALID_ID, inst) |
316 | 0 | << "OpCooperativeMatrixPerElementOpNV function type <id> " |
317 | 0 | << _.getIdName(function_type_id) |
318 | 0 | << " must have a least three parameters."; |
319 | 0 | } |
320 | | |
321 | 0 | const auto param0_id = function_type->GetOperandAs<uint32_t>(2); |
322 | 0 | const auto param1_id = function_type->GetOperandAs<uint32_t>(3); |
323 | 0 | const auto param2_id = function_type->GetOperandAs<uint32_t>(4); |
324 | 0 | if (!_.IsIntScalarType(param0_id) || _.GetBitWidth(param0_id) != 32) { |
325 | 0 | return _.diag(SPV_ERROR_INVALID_ID, inst) |
326 | 0 | << "OpCooperativeMatrixPerElementOpNV function type first parameter " |
327 | 0 | "type <id> " |
328 | 0 | << _.getIdName(param0_id) << " must be a 32-bit integer."; |
329 | 0 | } |
330 | | |
331 | 0 | if (!_.IsIntScalarType(param1_id) || _.GetBitWidth(param1_id) != 32) { |
332 | 0 | return _.diag(SPV_ERROR_INVALID_ID, inst) |
333 | 0 | << "OpCooperativeMatrixPerElementOpNV function type second " |
334 | 0 | "parameter type <id> " |
335 | 0 | << _.getIdName(param1_id) << " must be a 32-bit integer."; |
336 | 0 | } |
337 | | |
338 | 0 | if (param2_id != matrix_comp_type_id) { |
339 | 0 | return _.diag(SPV_ERROR_INVALID_ID, inst) |
340 | 0 | << "OpCooperativeMatrixPerElementOpNV function type third parameter " |
341 | 0 | "type <id> " |
342 | 0 | << _.getIdName(param2_id) << " must match matrix component type."; |
343 | 0 | } |
344 | | |
345 | 0 | return SPV_SUCCESS; |
346 | 0 | } |
347 | | |
348 | | } // namespace |
349 | | |
350 | 10.9M | spv_result_t FunctionPass(ValidationState_t& _, const Instruction* inst) { |
351 | 10.9M | switch (inst->opcode()) { |
352 | 33.1k | case spv::Op::OpFunction: |
353 | 33.1k | if (auto error = ValidateFunction(_, inst)) return error; |
354 | 33.0k | break; |
355 | 33.0k | case spv::Op::OpFunctionParameter: |
356 | 13.5k | if (auto error = ValidateFunctionParameter(_, inst)) return error; |
357 | 13.4k | break; |
358 | 23.7k | case spv::Op::OpFunctionCall: |
359 | 23.7k | if (auto error = ValidateFunctionCall(_, inst)) return error; |
360 | 23.6k | break; |
361 | 23.6k | case spv::Op::OpCooperativeMatrixPerElementOpNV: |
362 | 0 | if (auto error = ValidateCooperativeMatrixPerElementOp(_, inst)) |
363 | 0 | return error; |
364 | 0 | break; |
365 | 10.9M | default: |
366 | 10.9M | break; |
367 | 10.9M | } |
368 | | |
369 | 10.9M | return SPV_SUCCESS; |
370 | 10.9M | } |
371 | | |
372 | | } // namespace val |
373 | | } // namespace spvtools |