/src/spirv-tools/source/val/validate_function.cpp
Line | Count | Source |
1 | | // Copyright (c) 2018 Google LLC. |
2 | | // |
3 | | // Licensed under the Apache License, Version 2.0 (the "License"); |
4 | | // you may not use this file except in compliance with the License. |
5 | | // You may obtain a copy of the License at |
6 | | // |
7 | | // http://www.apache.org/licenses/LICENSE-2.0 |
8 | | // |
9 | | // Unless required by applicable law or agreed to in writing, software |
10 | | // distributed under the License is distributed on an "AS IS" BASIS, |
11 | | // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. |
12 | | // See the License for the specific language governing permissions and |
13 | | // limitations under the License. |
14 | | |
15 | | #include <algorithm> |
16 | | |
17 | | #include "source/opcode.h" |
18 | | #include "source/table2.h" |
19 | | #include "source/val/instruction.h" |
20 | | #include "source/val/validate.h" |
21 | | #include "source/val/validation_state.h" |
22 | | |
23 | | namespace spvtools { |
24 | | namespace val { |
25 | | namespace { |
26 | | |
27 | | // Returns true if |a| and |b| are instructions defining pointers that point to |
28 | | // types logically match and the decorations that apply to |b| are a subset |
29 | | // of the decorations that apply to |a|. |
30 | | bool DoPointeesLogicallyMatch(val::Instruction* a, val::Instruction* b, |
31 | 0 | ValidationState_t& _) { |
32 | 0 | if (a->opcode() != spv::Op::OpTypePointer || |
33 | 0 | b->opcode() != spv::Op::OpTypePointer) { |
34 | 0 | return false; |
35 | 0 | } |
36 | | |
37 | 0 | const auto& dec_a = _.id_decorations(a->id()); |
38 | 0 | const auto& dec_b = _.id_decorations(b->id()); |
39 | 0 | for (const auto& dec : dec_b) { |
40 | 0 | if (std::find(dec_a.begin(), dec_a.end(), dec) == dec_a.end()) { |
41 | 0 | return false; |
42 | 0 | } |
43 | 0 | } |
44 | | |
45 | 0 | uint32_t a_type = a->GetOperandAs<uint32_t>(2); |
46 | 0 | uint32_t b_type = b->GetOperandAs<uint32_t>(2); |
47 | |
|
48 | 0 | if (a_type == b_type) { |
49 | 0 | return true; |
50 | 0 | } |
51 | | |
52 | 0 | Instruction* a_type_inst = _.FindDef(a_type); |
53 | 0 | Instruction* b_type_inst = _.FindDef(b_type); |
54 | |
|
55 | 0 | return _.LogicallyMatch(a_type_inst, b_type_inst, true); |
56 | 0 | } |
57 | | |
58 | 39.9k | spv_result_t ValidateFunction(ValidationState_t& _, const Instruction* inst) { |
59 | 39.9k | const auto function_type_id = inst->GetOperandAs<uint32_t>(3); |
60 | 39.9k | const auto function_type = _.FindDef(function_type_id); |
61 | 39.9k | if (!function_type || spv::Op::OpTypeFunction != function_type->opcode()) { |
62 | 50 | return _.diag(SPV_ERROR_INVALID_ID, inst) |
63 | 50 | << "OpFunction Function Type <id> " << _.getIdName(function_type_id) |
64 | 50 | << " is not a function type."; |
65 | 50 | } |
66 | | |
67 | 39.9k | const auto return_id = function_type->GetOperandAs<uint32_t>(1); |
68 | 39.9k | if (return_id != inst->type_id()) { |
69 | 18 | return _.diag(SPV_ERROR_INVALID_ID, inst) |
70 | 18 | << "OpFunction Result Type <id> " << _.getIdName(inst->type_id()) |
71 | 18 | << " does not match the Function Type's return type <id> " |
72 | 18 | << _.getIdName(return_id) << "."; |
73 | 18 | } |
74 | | |
75 | 39.9k | const std::vector<spv::Op> acceptable = { |
76 | 39.9k | spv::Op::OpGroupDecorate, |
77 | 39.9k | spv::Op::OpDecorate, |
78 | 39.9k | spv::Op::OpEnqueueKernel, |
79 | 39.9k | spv::Op::OpEntryPoint, |
80 | 39.9k | spv::Op::OpExecutionMode, |
81 | 39.9k | spv::Op::OpExecutionModeId, |
82 | 39.9k | spv::Op::OpFunctionCall, |
83 | 39.9k | spv::Op::OpGetKernelNDrangeSubGroupCount, |
84 | 39.9k | spv::Op::OpGetKernelNDrangeMaxSubGroupSize, |
85 | 39.9k | spv::Op::OpGetKernelWorkGroupSize, |
86 | 39.9k | spv::Op::OpGetKernelPreferredWorkGroupSizeMultiple, |
87 | 39.9k | spv::Op::OpGetKernelLocalSizeForSubgroupCount, |
88 | 39.9k | spv::Op::OpGetKernelMaxNumSubgroups, |
89 | 39.9k | spv::Op::OpName, |
90 | 39.9k | spv::Op::OpCooperativeMatrixPerElementOpNV, |
91 | 39.9k | spv::Op::OpCooperativeMatrixReduceNV, |
92 | 39.9k | spv::Op::OpCooperativeMatrixLoadTensorNV, |
93 | 39.9k | spv::Op::OpConditionalEntryPointINTEL, |
94 | 39.9k | spv::Op::OpConstantFunctionPointerINTEL}; |
95 | 95.5k | for (auto& pair : inst->uses()) { |
96 | 95.5k | const auto* use = pair.first; |
97 | 95.5k | if (std::find(acceptable.begin(), acceptable.end(), use->opcode()) == |
98 | 95.5k | acceptable.end() && |
99 | 19 | !use->IsNonSemantic() && !use->IsDebugInfo() && |
100 | 19 | !spvOpcodeIsDecoration(use->opcode())) { |
101 | 19 | return _.diag(SPV_ERROR_INVALID_ID, use) |
102 | 19 | << "Invalid use of function result id " << _.getIdName(inst->id()) |
103 | 19 | << "."; |
104 | 19 | } |
105 | 95.5k | } |
106 | | |
107 | 39.9k | return SPV_SUCCESS; |
108 | 39.9k | } |
109 | | |
110 | | spv_result_t ValidateFunctionParameter(ValidationState_t& _, |
111 | 18.0k | const Instruction* inst) { |
112 | | // NOTE: Find OpFunction & ensure OpFunctionParameter is not out of place. |
113 | 18.0k | size_t param_index = 0; |
114 | 18.0k | size_t inst_num = inst->LineNum() - 1; |
115 | 18.0k | auto func_inst = &_.ordered_instructions()[inst_num]; |
116 | 31.2k | while (--inst_num) { |
117 | 31.2k | func_inst = &_.ordered_instructions()[inst_num]; |
118 | 31.2k | if (func_inst->opcode() == spv::Op::OpFunction) { |
119 | 18.0k | break; |
120 | 18.0k | } else if (func_inst->opcode() == spv::Op::OpFunctionParameter) { |
121 | 12.9k | ++param_index; |
122 | 12.9k | } |
123 | 31.2k | } |
124 | | |
125 | 18.0k | if (func_inst->opcode() != spv::Op::OpFunction) { |
126 | 0 | return _.diag(SPV_ERROR_INVALID_LAYOUT, inst) |
127 | 0 | << "Function parameter must be preceded by a function."; |
128 | 0 | } |
129 | | |
130 | 18.0k | const auto function_type_id = func_inst->GetOperandAs<uint32_t>(3); |
131 | 18.0k | const auto function_type = _.FindDef(function_type_id); |
132 | 18.0k | if (!function_type) { |
133 | 0 | return _.diag(SPV_ERROR_INVALID_ID, func_inst) |
134 | 0 | << "Missing function type definition."; |
135 | 0 | } |
136 | 18.0k | if (param_index >= function_type->words().size() - 3) { |
137 | 4 | return _.diag(SPV_ERROR_INVALID_ID, inst) |
138 | 4 | << "Too many OpFunctionParameters for " << func_inst->id() |
139 | 4 | << ": expected " << function_type->words().size() - 3 |
140 | 4 | << " based on the function's type"; |
141 | 4 | } |
142 | | |
143 | 18.0k | const auto param_type = |
144 | 18.0k | _.FindDef(function_type->GetOperandAs<uint32_t>(param_index + 2)); |
145 | 18.0k | if (!param_type || inst->type_id() != param_type->id()) { |
146 | 18 | return _.diag(SPV_ERROR_INVALID_ID, inst) |
147 | 18 | << "OpFunctionParameter Result Type <id> " |
148 | 18 | << _.getIdName(inst->type_id()) |
149 | 18 | << " does not match the OpTypeFunction parameter " |
150 | 18 | "type of the same index."; |
151 | 18 | } |
152 | | |
153 | 18.0k | return SPV_SUCCESS; |
154 | 18.0k | } |
155 | | |
156 | | spv_result_t ValidateFunctionCall(ValidationState_t& _, |
157 | 31.5k | const Instruction* inst) { |
158 | 31.5k | const auto function_id = inst->GetOperandAs<uint32_t>(2); |
159 | 31.5k | const auto function = _.FindDef(function_id); |
160 | 31.5k | if (!function || spv::Op::OpFunction != function->opcode()) { |
161 | 13 | return _.diag(SPV_ERROR_INVALID_ID, inst) |
162 | 13 | << "OpFunctionCall Function <id> " << _.getIdName(function_id) |
163 | 13 | << " is not a function."; |
164 | 13 | } |
165 | | |
166 | 31.5k | auto return_type = _.FindDef(function->type_id()); |
167 | 31.5k | if (!return_type || return_type->id() != inst->type_id()) { |
168 | 11 | return _.diag(SPV_ERROR_INVALID_ID, inst) |
169 | 11 | << "OpFunctionCall Result Type <id> " << _.getIdName(inst->type_id()) |
170 | 11 | << "s type does not match Function <id> " |
171 | 11 | << _.getIdName(return_type->id()) << "s return type."; |
172 | 11 | } |
173 | 31.5k | if (!_.options()->relax_logical_pointer && |
174 | 31.5k | (_.addressing_model() == spv::AddressingModel::Logical || |
175 | 31.5k | _.addressing_model() == spv::AddressingModel::PhysicalStorageBuffer64)) { |
176 | 31.5k | if (return_type->opcode() == spv::Op::OpTypePointer || |
177 | 31.5k | return_type->opcode() == spv::Op::OpTypeUntypedPointerKHR) { |
178 | 3 | const auto sc = return_type->GetOperandAs<spv::StorageClass>(1); |
179 | 3 | if (sc != spv::StorageClass::PhysicalStorageBuffer) { |
180 | 3 | if (!_.HasCapability(spv::Capability::VariablePointersStorageBuffer) && |
181 | 3 | sc == spv::StorageClass::StorageBuffer) { |
182 | 0 | return _.diag(SPV_ERROR_INVALID_ID, inst) |
183 | 0 | << "In Logical addressing, functions may only return a " |
184 | 0 | "storage buffer pointer if the " |
185 | 0 | "VariablePointersStorageBuffer capability is declared"; |
186 | 3 | } else if (!_.HasCapability(spv::Capability::VariablePointers) && |
187 | 3 | sc == spv::StorageClass::Workgroup) { |
188 | 1 | return _.diag(SPV_ERROR_INVALID_ID, inst) |
189 | 1 | << "In Logical addressing, functions may only return a " |
190 | 1 | "workgroup pointer if the VariablePointers capability is " |
191 | 1 | "declared"; |
192 | 2 | } else if (sc != spv::StorageClass::StorageBuffer && |
193 | 2 | sc != spv::StorageClass::Workgroup) { |
194 | 2 | return _.diag(SPV_ERROR_INVALID_ID, inst) |
195 | 2 | << "In Logical addressing, functions may not return a pointer " |
196 | 2 | "in this storage class"; |
197 | 2 | } |
198 | 3 | } |
199 | 3 | } |
200 | 31.5k | } |
201 | | |
202 | 31.5k | const auto function_type_id = function->GetOperandAs<uint32_t>(3); |
203 | 31.5k | const auto function_type = _.FindDef(function_type_id); |
204 | 31.5k | if (!function_type || function_type->opcode() != spv::Op::OpTypeFunction) { |
205 | 4 | return _.diag(SPV_ERROR_INVALID_ID, inst) |
206 | 4 | << "Missing function type definition."; |
207 | 4 | } |
208 | | |
209 | 31.5k | const auto function_call_arg_count = inst->words().size() - 4; |
210 | 31.5k | const auto function_param_count = function_type->words().size() - 3; |
211 | 31.5k | if (function_param_count != function_call_arg_count) { |
212 | 3 | return _.diag(SPV_ERROR_INVALID_ID, inst) |
213 | 3 | << "OpFunctionCall Function <id>'s parameter count does not match " |
214 | 3 | "the argument count."; |
215 | 3 | } |
216 | | |
217 | 31.5k | for (size_t argument_index = 3, param_index = 2; |
218 | 78.4k | argument_index < inst->operands().size(); |
219 | 46.8k | argument_index++, param_index++) { |
220 | 46.8k | const auto argument_id = inst->GetOperandAs<uint32_t>(argument_index); |
221 | 46.8k | const auto argument = _.FindDef(argument_id); |
222 | 46.8k | if (!argument) { |
223 | 0 | return _.diag(SPV_ERROR_INVALID_ID, inst) |
224 | 0 | << "Missing argument " << argument_index - 3 << " definition."; |
225 | 0 | } |
226 | | |
227 | 46.8k | const auto argument_type = _.FindDef(argument->type_id()); |
228 | 46.8k | if (!argument_type) { |
229 | 0 | return _.diag(SPV_ERROR_INVALID_ID, inst) |
230 | 0 | << "Missing argument " << argument_index - 3 |
231 | 0 | << " type definition."; |
232 | 0 | } |
233 | | |
234 | 46.8k | const auto parameter_type_id = |
235 | 46.8k | function_type->GetOperandAs<uint32_t>(param_index); |
236 | 46.8k | const auto parameter_type = _.FindDef(parameter_type_id); |
237 | 46.8k | if (!parameter_type || argument_type->id() != parameter_type->id()) { |
238 | 26 | if (!parameter_type || !_.options()->before_hlsl_legalization || |
239 | 26 | !DoPointeesLogicallyMatch(argument_type, parameter_type, _)) { |
240 | 26 | return _.diag(SPV_ERROR_INVALID_ID, inst) |
241 | 26 | << "OpFunctionCall Argument <id> " << _.getIdName(argument_id) |
242 | 26 | << "s type does not match Function <id> " |
243 | 26 | << _.getIdName(parameter_type_id) << "s parameter type."; |
244 | 26 | } |
245 | 26 | } |
246 | | |
247 | 46.8k | if (_.addressing_model() == spv::AddressingModel::Logical || |
248 | 46.8k | _.addressing_model() == spv::AddressingModel::PhysicalStorageBuffer64) { |
249 | 46.8k | if ((parameter_type->opcode() == spv::Op::OpTypePointer || |
250 | 265 | parameter_type->opcode() == spv::Op::OpTypeUntypedPointerKHR) && |
251 | 46.5k | !_.options()->relax_logical_pointer) { |
252 | 46.5k | spv::StorageClass sc = |
253 | 46.5k | parameter_type->GetOperandAs<spv::StorageClass>(1u); |
254 | 46.5k | if (sc != spv::StorageClass::PhysicalStorageBuffer) { |
255 | | // Validate which storage classes can be pointer operands. |
256 | 46.5k | switch (sc) { |
257 | 1 | case spv::StorageClass::UniformConstant: |
258 | 46.5k | case spv::StorageClass::Function: |
259 | 46.5k | case spv::StorageClass::Private: |
260 | 46.5k | case spv::StorageClass::Workgroup: |
261 | 46.5k | case spv::StorageClass::AtomicCounter: |
262 | | // SPV_EXT_tile_image |
263 | 46.5k | case spv::StorageClass::TileImageEXT: |
264 | | // SPV_KHR_ray_tracing |
265 | 46.5k | case spv::StorageClass::ShaderRecordBufferKHR: |
266 | | // These are always allowed. |
267 | 46.5k | break; |
268 | 0 | case spv::StorageClass::StorageBuffer: |
269 | 0 | if (!_.features().variable_pointers) { |
270 | 0 | return _.diag(SPV_ERROR_INVALID_ID, inst) |
271 | 0 | << "StorageBuffer pointer operand " |
272 | 0 | << _.getIdName(argument_id) |
273 | 0 | << " requires a variable pointers capability"; |
274 | 0 | } |
275 | 0 | break; |
276 | 4 | default: |
277 | 4 | return _.diag(SPV_ERROR_INVALID_ID, inst) |
278 | 4 | << "Invalid storage class for pointer operand " |
279 | 4 | << _.getIdName(argument_id); |
280 | 46.5k | } |
281 | | |
282 | | // Validate memory object declaration requirements. |
283 | 46.5k | if (argument->opcode() != spv::Op::OpVariable && |
284 | 97 | argument->opcode() != spv::Op::OpUntypedVariableKHR && |
285 | 97 | argument->opcode() != spv::Op::OpFunctionParameter) { |
286 | 6 | const bool ssbo_vptr = |
287 | 6 | _.HasCapability( |
288 | 6 | spv::Capability::VariablePointersStorageBuffer) && |
289 | 0 | sc == spv::StorageClass::StorageBuffer; |
290 | 6 | const bool wg_vptr = |
291 | 6 | _.HasCapability(spv::Capability::VariablePointers) && |
292 | 0 | sc == spv::StorageClass::Workgroup; |
293 | 6 | const bool uc_ptr = sc == spv::StorageClass::UniformConstant; |
294 | 6 | if (!_.options()->before_hlsl_legalization && !ssbo_vptr && |
295 | 6 | !wg_vptr && !uc_ptr) { |
296 | 5 | return _.diag(SPV_ERROR_INVALID_ID, inst) |
297 | 5 | << "Pointer operand " << _.getIdName(argument_id) |
298 | 5 | << " must be a memory object declaration"; |
299 | 5 | } |
300 | 6 | } |
301 | 46.5k | } |
302 | 46.5k | } |
303 | 46.8k | } |
304 | 46.8k | } |
305 | 31.5k | return SPV_SUCCESS; |
306 | 31.5k | } |
307 | | |
308 | | spv_result_t ValidateCooperativeMatrixPerElementOp(ValidationState_t& _, |
309 | 0 | const Instruction* inst) { |
310 | 0 | const auto function_id = inst->GetOperandAs<uint32_t>(3); |
311 | 0 | const auto function = _.FindDef(function_id); |
312 | 0 | if (!function || spv::Op::OpFunction != function->opcode()) { |
313 | 0 | return _.diag(SPV_ERROR_INVALID_ID, inst) |
314 | 0 | << "OpCooperativeMatrixPerElementOpNV Function <id> " |
315 | 0 | << _.getIdName(function_id) << " is not a function."; |
316 | 0 | } |
317 | | |
318 | 0 | const auto matrix_id = inst->GetOperandAs<uint32_t>(2); |
319 | 0 | const auto matrix = _.FindDef(matrix_id); |
320 | 0 | const auto matrix_type_id = matrix->type_id(); |
321 | 0 | if (!_.IsCooperativeMatrixKHRType(matrix_type_id)) { |
322 | 0 | return _.diag(SPV_ERROR_INVALID_ID, inst) |
323 | 0 | << "OpCooperativeMatrixPerElementOpNV Matrix <id> " |
324 | 0 | << _.getIdName(matrix_id) << " is not a cooperative matrix."; |
325 | 0 | } |
326 | | |
327 | 0 | const auto result_type_id = inst->GetOperandAs<uint32_t>(0); |
328 | 0 | if (matrix_type_id != result_type_id) { |
329 | 0 | return _.diag(SPV_ERROR_INVALID_ID, inst) |
330 | 0 | << "OpCooperativeMatrixPerElementOpNV Result Type <id> " |
331 | 0 | << _.getIdName(result_type_id) << " must match matrix type <id> " |
332 | 0 | << _.getIdName(matrix_type_id) << "."; |
333 | 0 | } |
334 | | |
335 | 0 | const auto matrix_comp_type_id = |
336 | 0 | _.FindDef(matrix_type_id)->GetOperandAs<uint32_t>(1); |
337 | 0 | const auto function_type_id = function->GetOperandAs<uint32_t>(3); |
338 | 0 | const auto function_type = _.FindDef(function_type_id); |
339 | 0 | auto return_type_id = function_type->GetOperandAs<uint32_t>(1); |
340 | 0 | if (return_type_id != matrix_comp_type_id) { |
341 | 0 | return _.diag(SPV_ERROR_INVALID_ID, inst) |
342 | 0 | << "OpCooperativeMatrixPerElementOpNV function return type <id> " |
343 | 0 | << _.getIdName(return_type_id) |
344 | 0 | << " must match matrix component type <id> " |
345 | 0 | << _.getIdName(matrix_comp_type_id) << "."; |
346 | 0 | } |
347 | | |
348 | 0 | if (function_type->operands().size() < 5) { |
349 | 0 | return _.diag(SPV_ERROR_INVALID_ID, inst) |
350 | 0 | << "OpCooperativeMatrixPerElementOpNV function type <id> " |
351 | 0 | << _.getIdName(function_type_id) |
352 | 0 | << " must have a least three parameters."; |
353 | 0 | } |
354 | | |
355 | 0 | const auto param0_id = function_type->GetOperandAs<uint32_t>(2); |
356 | 0 | const auto param1_id = function_type->GetOperandAs<uint32_t>(3); |
357 | 0 | const auto param2_id = function_type->GetOperandAs<uint32_t>(4); |
358 | 0 | if (!_.IsIntScalarType(param0_id, 32)) { |
359 | 0 | return _.diag(SPV_ERROR_INVALID_ID, inst) |
360 | 0 | << "OpCooperativeMatrixPerElementOpNV function type first parameter " |
361 | 0 | "type <id> " |
362 | 0 | << _.getIdName(param0_id) << " must be a 32-bit integer."; |
363 | 0 | } |
364 | | |
365 | 0 | if (!_.IsIntScalarType(param1_id, 32)) { |
366 | 0 | return _.diag(SPV_ERROR_INVALID_ID, inst) |
367 | 0 | << "OpCooperativeMatrixPerElementOpNV function type second " |
368 | 0 | "parameter type <id> " |
369 | 0 | << _.getIdName(param1_id) << " must be a 32-bit integer."; |
370 | 0 | } |
371 | | |
372 | 0 | if (param2_id != matrix_comp_type_id) { |
373 | 0 | return _.diag(SPV_ERROR_INVALID_ID, inst) |
374 | 0 | << "OpCooperativeMatrixPerElementOpNV function type third parameter " |
375 | 0 | "type <id> " |
376 | 0 | << _.getIdName(param2_id) << " must match matrix component type."; |
377 | 0 | } |
378 | | |
379 | 0 | return SPV_SUCCESS; |
380 | 0 | } |
381 | | |
382 | | } // namespace |
383 | | |
384 | 14.5M | spv_result_t FunctionPass(ValidationState_t& _, const Instruction* inst) { |
385 | 14.5M | switch (inst->opcode()) { |
386 | 39.9k | case spv::Op::OpFunction: |
387 | 39.9k | if (auto error = ValidateFunction(_, inst)) return error; |
388 | 39.9k | break; |
389 | 39.9k | case spv::Op::OpFunctionParameter: |
390 | 18.0k | if (auto error = ValidateFunctionParameter(_, inst)) return error; |
391 | 18.0k | break; |
392 | 31.5k | case spv::Op::OpFunctionCall: |
393 | 31.5k | if (auto error = ValidateFunctionCall(_, inst)) return error; |
394 | 31.5k | break; |
395 | 31.5k | case spv::Op::OpCooperativeMatrixPerElementOpNV: |
396 | 0 | if (auto error = ValidateCooperativeMatrixPerElementOp(_, inst)) |
397 | 0 | return error; |
398 | 0 | break; |
399 | 14.4M | default: |
400 | 14.4M | break; |
401 | 14.5M | } |
402 | | |
403 | 14.5M | return SPV_SUCCESS; |
404 | 14.5M | } |
405 | | |
406 | | } // namespace val |
407 | | } // namespace spvtools |