/src/spirv-tools/source/val/validate_graph.cpp
Line | Count | Source |
1 | | // Copyright (c) 2023-2025 Arm Ltd. |
2 | | // |
3 | | // Licensed under the Apache License, Version 2.0 (the "License"); |
4 | | // you may not use this file except in compliance with the License. |
5 | | // You may obtain a copy of the License at |
6 | | // |
7 | | // http://www.apache.org/licenses/LICENSE-2.0 |
8 | | // |
9 | | // Unless required by applicable law or agreed to in writing, software |
10 | | // distributed under the License is distributed on an "AS IS" BASIS, |
11 | | // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. |
12 | | // See the License for the specific language governing permissions and |
13 | | // limitations under the License. |
14 | | |
15 | | // Validates correctness of graph instructions. |
16 | | |
17 | | #include <deque> |
18 | | |
19 | | #include "source/opcode.h" |
20 | | #include "source/spirv_target_env.h" |
21 | | #include "source/val/validate.h" |
22 | | #include "source/val/validation_state.h" |
23 | | |
24 | | namespace spvtools { |
25 | | namespace val { |
26 | | namespace { |
27 | | |
28 | 0 | bool IsTensorArray(ValidationState_t& _, uint32_t id) { |
29 | 0 | auto def = _.FindDef(id); |
30 | 0 | if (!def || (def->opcode() != spv::Op::OpTypeArray && |
31 | 0 | def->opcode() != spv::Op::OpTypeRuntimeArray)) { |
32 | 0 | return false; |
33 | 0 | } |
34 | 0 | auto tdef = _.FindDef(def->word(2)); |
35 | 0 | if (!tdef || tdef->opcode() != spv::Op::OpTypeTensorARM) { |
36 | 0 | return false; |
37 | 0 | } |
38 | 0 | return true; |
39 | 0 | } |
40 | | |
41 | 0 | bool IsGraphInterfaceType(ValidationState_t& _, uint32_t id) { |
42 | 0 | return _.IsTensorType(id) || IsTensorArray(_, id); |
43 | 0 | } |
44 | | |
45 | 0 | bool IsGraph(ValidationState_t& _, uint32_t id) { |
46 | 0 | auto def = _.FindDef(id); |
47 | 0 | if (!def || def->opcode() != spv::Op::OpGraphARM) { |
48 | 0 | return false; |
49 | 0 | } |
50 | 0 | return true; |
51 | 0 | } |
52 | | |
53 | 0 | bool IsGraphType(ValidationState_t& _, uint32_t id) { |
54 | 0 | auto def = _.FindDef(id); |
55 | 0 | if (!def || def->opcode() != spv::Op::OpTypeGraphARM) { |
56 | 0 | return false; |
57 | 0 | } |
58 | 0 | return true; |
59 | 0 | } |
60 | | |
61 | 0 | bool IsConstantInstruction(ValidationState_t& _, uint32_t id) { |
62 | 0 | auto def = _.FindDef(id); |
63 | 0 | return def && spvOpcodeIsConstant(def->opcode()); |
64 | 0 | } |
65 | | |
66 | | const uint32_t kGraphTypeIOStartWord = 3; |
67 | | |
68 | 0 | uint32_t GraphTypeInstNumIO(const Instruction* inst) { |
69 | 0 | return static_cast<uint32_t>(inst->words().size()) - kGraphTypeIOStartWord; |
70 | 0 | } |
71 | | |
72 | 0 | uint32_t GraphTypeInstNumInputs(const Instruction* inst) { |
73 | 0 | return inst->word(2); |
74 | 0 | } |
75 | | |
76 | 0 | uint32_t GraphTypeInstNumOutputs(const Instruction* inst) { |
77 | 0 | return GraphTypeInstNumIO(inst) - GraphTypeInstNumInputs(inst); |
78 | 0 | } |
79 | | |
80 | | uint32_t GraphTypeInstGetOutputAtIndex(const Instruction* inst, |
81 | 0 | uint64_t index) { |
82 | 0 | return inst->word(kGraphTypeIOStartWord + GraphTypeInstNumInputs(inst) + |
83 | 0 | static_cast<uint32_t>(index)); |
84 | 0 | } |
85 | | |
86 | 0 | uint32_t GraphTypeInstGetInputAtIndex(const Instruction* inst, uint64_t index) { |
87 | 0 | return inst->word(kGraphTypeIOStartWord + static_cast<uint32_t>(index)); |
88 | 0 | } |
89 | | |
90 | 0 | spv_result_t ValidateGraphType(ValidationState_t& _, const Instruction* inst) { |
91 | | // Check there are at least NumInputs types |
92 | 0 | uint32_t NumInputs = GraphTypeInstNumInputs(inst); |
93 | 0 | size_t NumIOTypes = GraphTypeInstNumIO(inst); |
94 | 0 | if (NumIOTypes < NumInputs) { |
95 | 0 | return _.diag(SPV_ERROR_INVALID_DATA, inst) |
96 | 0 | << NumIOTypes << " I/O types were provided but the graph has " |
97 | 0 | << NumInputs << " inputs."; |
98 | 0 | } |
99 | | |
100 | | // Check there is at least one output |
101 | 0 | if (NumIOTypes == NumInputs) { |
102 | 0 | return _.diag(SPV_ERROR_INVALID_DATA, inst) |
103 | 0 | << "A graph type must have at least one output."; |
104 | 0 | } |
105 | | |
106 | | // Check all I/O types are graph interface type |
107 | 0 | for (unsigned i = kGraphTypeIOStartWord; i < inst->words().size(); i++) { |
108 | 0 | auto tid = inst->word(i); |
109 | 0 | if (!IsGraphInterfaceType(_, tid)) { |
110 | 0 | return _.diag(SPV_ERROR_INVALID_DATA, inst) |
111 | 0 | << "I/O type " << _.getIdName(tid) |
112 | 0 | << " is not a Graph Interface Type."; |
113 | 0 | } |
114 | 0 | } |
115 | | |
116 | 0 | return SPV_SUCCESS; |
117 | 0 | } |
118 | | |
119 | | spv_result_t ValidateGraphConstant(ValidationState_t& _, |
120 | 0 | const Instruction* inst) { |
121 | | // Check Result Type |
122 | 0 | if (!_.IsTensorType(inst->type_id())) { |
123 | 0 | return _.diag(SPV_ERROR_INVALID_DATA, inst) |
124 | 0 | << spvOpcodeString(inst->opcode()) |
125 | 0 | << " must have a Result Type that is a tensor type."; |
126 | 0 | } |
127 | | |
128 | | // Check the instruction is not preceded by another OpGraphConstantARM with |
129 | | // the same ID |
130 | 0 | const uint32_t cst_id = inst->word(3); |
131 | 0 | size_t inst_num = inst->LineNum() - 1; |
132 | 0 | while (--inst_num) { |
133 | 0 | auto prev_inst = &_.ordered_instructions()[inst_num]; |
134 | 0 | if (prev_inst->opcode() == spv::Op::OpGraphConstantARM) { |
135 | 0 | const uint32_t prev_cst_id = prev_inst->word(3); |
136 | 0 | if (prev_cst_id == cst_id) { |
137 | 0 | return _.diag(SPV_ERROR_INVALID_DATA, inst) |
138 | 0 | << "No two OpGraphConstantARM instructions may have the same " |
139 | 0 | "GraphConstantID"; |
140 | 0 | } |
141 | 0 | } |
142 | 0 | } |
143 | 0 | return SPV_SUCCESS; |
144 | 0 | } |
145 | | |
146 | | spv_result_t ValidateGraphEntryPoint(ValidationState_t& _, |
147 | 0 | const Instruction* inst) { |
148 | | // Graph must be an OpGraphARM |
149 | 0 | uint32_t graph = inst->GetOperandAs<uint32_t>(0); |
150 | 0 | auto graph_inst = _.FindDef(graph); |
151 | 0 | if (!IsGraph(_, graph)) { |
152 | 0 | return _.diag(SPV_ERROR_INVALID_DATA, inst) |
153 | 0 | << spvOpcodeString(inst->opcode()) |
154 | 0 | << " Graph must be a OpGraphARM but found " |
155 | 0 | << spvOpcodeString(graph_inst->opcode()) << "."; |
156 | 0 | } |
157 | | |
158 | | // Check number of Interface IDs matches number of I/Os of graph |
159 | 0 | auto graph_type_inst = _.FindDef(graph_inst->type_id()); |
160 | 0 | size_t graph_type_num_io = GraphTypeInstNumIO(graph_type_inst); |
161 | 0 | size_t graph_entry_point_num_interface_id = inst->operands().size() - 2; |
162 | 0 | if (graph_type_inst->opcode() != spv::Op::OpTypeGraphARM) { |
163 | | // This is invalid but we want ValidateGraph to report a clear error |
164 | | // so stop validating the graph entry point instruction |
165 | 0 | return SPV_SUCCESS; |
166 | 0 | } |
167 | 0 | if (graph_type_num_io != graph_entry_point_num_interface_id) { |
168 | 0 | return _.diag(SPV_ERROR_INVALID_DATA, inst) |
169 | 0 | << spvOpcodeString(inst->opcode()) << " Interface list contains " |
170 | 0 | << graph_entry_point_num_interface_id << " IDs but Graph's type " |
171 | 0 | << _.getIdName(graph_inst->type_id()) << " has " << graph_type_num_io |
172 | 0 | << " inputs and outputs."; |
173 | 0 | } |
174 | | |
175 | | // Check Interface IDs |
176 | 0 | for (uint32_t i = 2; i < inst->operands().size(); i++) { |
177 | 0 | uint32_t interface_id = inst->GetOperandAs<uint32_t>(i); |
178 | 0 | auto interface_inst = _.FindDef(interface_id); |
179 | | |
180 | | // Check interface IDs come from OpVariable |
181 | 0 | if ((interface_inst->opcode() != spv::Op::OpVariable) || |
182 | 0 | (interface_inst->GetOperandAs<spv::StorageClass>(2) != |
183 | 0 | spv::StorageClass::UniformConstant)) { |
184 | 0 | return _.diag(SPV_ERROR_INVALID_DATA, interface_inst) |
185 | 0 | << spvOpcodeString(inst->opcode()) << " Interface ID " |
186 | 0 | << _.getIdName(interface_id) |
187 | 0 | << " must come from OpVariable with UniformConstant Storage " |
188 | 0 | "Class."; |
189 | 0 | } |
190 | | |
191 | | // Check type of interface variable matches type of the corresponding graph |
192 | | // I/O |
193 | 0 | uint32_t corresponding_graph_io_type = |
194 | 0 | graph_type_inst->GetOperandAs<uint32_t>(i); |
195 | |
|
196 | 0 | uint32_t interface_ptr_type = interface_inst->type_id(); |
197 | 0 | auto interface_ptr_inst = _.FindDef(interface_ptr_type); |
198 | 0 | auto interface_pointee_type = interface_ptr_inst->GetOperandAs<uint32_t>(2); |
199 | 0 | if (interface_pointee_type != corresponding_graph_io_type) { |
200 | 0 | return _.diag(SPV_ERROR_INVALID_DATA, inst) |
201 | 0 | << spvOpcodeString(inst->opcode()) << " Interface ID type " |
202 | 0 | << _.getIdName(interface_pointee_type) |
203 | 0 | << " must match the type of the corresponding graph I/O " |
204 | 0 | << _.getIdName(corresponding_graph_io_type); |
205 | 0 | } |
206 | 0 | } |
207 | | |
208 | 0 | return SPV_SUCCESS; |
209 | 0 | } |
210 | | |
211 | 0 | spv_result_t ValidateGraph(ValidationState_t& _, const Instruction* inst) { |
212 | | // Result Type must be an OpTypeGraphARM |
213 | 0 | if (!IsGraphType(_, inst->type_id())) { |
214 | 0 | return _.diag(SPV_ERROR_INVALID_DATA, inst) |
215 | 0 | << spvOpcodeString(inst->opcode()) |
216 | 0 | << " Result Type must be an OpTypeGraphARM."; |
217 | 0 | } |
218 | | |
219 | 0 | return SPV_SUCCESS; |
220 | 0 | } |
221 | | |
222 | 0 | spv_result_t ValidateGraphInput(ValidationState_t& _, const Instruction* inst) { |
223 | 0 | const uint32_t input_index_id = inst->GetOperandAs<uint32_t>(2); |
224 | | // Check type of InputIndex |
225 | 0 | auto input_index_inst = _.FindDef(input_index_id); |
226 | 0 | if (!input_index_inst || |
227 | 0 | !_.IsIntScalarType(input_index_inst->type_id(), 32)) { |
228 | 0 | return _.diag(SPV_ERROR_INVALID_DATA, inst) |
229 | 0 | << spvOpcodeString(inst->opcode()) |
230 | 0 | << " InputIndex must be a 32-bit integer."; |
231 | 0 | } |
232 | | |
233 | 0 | bool has_element_index = inst->operands().size() > 3; |
234 | | |
235 | | // Check type of ElementIndex |
236 | 0 | uint32_t element_index_id = 0; |
237 | 0 | if (has_element_index) { |
238 | 0 | element_index_id = inst->GetOperandAs<uint32_t>(3); |
239 | 0 | auto element_index_inst = _.FindDef(element_index_id); |
240 | 0 | if (!element_index_inst || |
241 | 0 | !_.IsIntScalarType(element_index_inst->type_id(), 32)) { |
242 | 0 | return _.diag(SPV_ERROR_INVALID_DATA, inst) |
243 | 0 | << spvOpcodeString(inst->opcode()) |
244 | 0 | << " ElementIndex must be a 32-bit integer."; |
245 | 0 | } |
246 | 0 | } |
247 | | |
248 | 0 | if (spvIsVulkanEnv(_.context()->target_env)) { |
249 | 0 | if (!IsConstantInstruction(_, input_index_id)) { |
250 | 0 | return _.diag(SPV_ERROR_INVALID_DATA, inst) |
251 | 0 | << _.VkErrorID(9931) << "OpGraphInputARM InputIndex must be the " |
252 | 0 | << "<id> of a constant instruction."; |
253 | 0 | } |
254 | 0 | if (has_element_index && !IsConstantInstruction(_, element_index_id)) { |
255 | 0 | return _.diag(SPV_ERROR_INVALID_DATA, inst) |
256 | 0 | << _.VkErrorID(9931) << "OpGraphInputARM ElementIndex must be " |
257 | 0 | << "the <id> of a constant instruction."; |
258 | 0 | } |
259 | 0 | } |
260 | | |
261 | | // Find graph definition |
262 | 0 | size_t inst_num = inst->LineNum() - 1; |
263 | 0 | auto graph_inst = &_.ordered_instructions()[inst_num]; |
264 | 0 | while (--inst_num) { |
265 | 0 | graph_inst = &_.ordered_instructions()[inst_num]; |
266 | 0 | if (graph_inst->opcode() == spv::Op::OpGraphARM) { |
267 | 0 | break; |
268 | 0 | } |
269 | 0 | } |
270 | | |
271 | | // Can the InputIndex be evaluated? |
272 | | // If not, there's nothing more we can validate here. |
273 | 0 | uint64_t input_index; |
274 | 0 | if (!_.EvalConstantValUint64(inst->GetOperandAs<uint32_t>(2), &input_index)) { |
275 | 0 | return SPV_SUCCESS; |
276 | 0 | } |
277 | | |
278 | 0 | auto const graph_type_inst = _.FindDef(graph_inst->type_id()); |
279 | 0 | size_t graph_type_num_inputs = graph_type_inst->GetOperandAs<uint32_t>(1); |
280 | | |
281 | | // Check InputIndex is in range |
282 | 0 | if (input_index >= graph_type_num_inputs) { |
283 | 0 | std::string disassembly = _.Disassemble(*inst); |
284 | 0 | return _.diag(SPV_ERROR_INVALID_DATA, nullptr) |
285 | 0 | << "Type " << _.getIdName(graph_type_inst->id()) << " for graph " |
286 | 0 | << _.getIdName(graph_inst->id()) << " has " << graph_type_num_inputs |
287 | 0 | << " inputs but found an OpGraphInputARM instruction with an " |
288 | 0 | "InputIndex that is " |
289 | 0 | << input_index << ": " << disassembly; |
290 | 0 | } |
291 | | |
292 | 0 | uint32_t graph_type_input_type = |
293 | 0 | GraphTypeInstGetInputAtIndex(graph_type_inst, input_index); |
294 | |
|
295 | 0 | if (has_element_index) { |
296 | | // Check ElementIndex is allowed |
297 | 0 | if (!IsTensorArray(_, graph_type_input_type)) { |
298 | 0 | return _.diag(SPV_ERROR_INVALID_DATA, inst) |
299 | 0 | << "OpGraphInputARM ElementIndex not allowed when the graph input " |
300 | 0 | "selected by " |
301 | 0 | << "InputIndex is not an OpTypeArray or OpTypeRuntimeArray"; |
302 | 0 | } |
303 | | |
304 | | // Check ElementIndex is in range if it can be evaluated and the input is a |
305 | | // fixed-sized array whose Length can be evaluated |
306 | 0 | uint64_t element_index; |
307 | 0 | if (_.IsArrayType(graph_type_input_type) && |
308 | 0 | _.EvalConstantValUint64(inst->GetOperandAs<uint32_t>(3), |
309 | 0 | &element_index)) { |
310 | 0 | uint64_t array_length; |
311 | 0 | auto graph_type_input_type_inst = _.FindDef(graph_type_input_type); |
312 | 0 | if (_.EvalConstantValUint64( |
313 | 0 | graph_type_input_type_inst->GetOperandAs<uint32_t>(2), |
314 | 0 | &array_length)) { |
315 | 0 | if (element_index >= array_length) { |
316 | 0 | return _.diag(SPV_ERROR_INVALID_DATA, inst) |
317 | 0 | << "OpGraphInputARM ElementIndex out of range. The type of " |
318 | 0 | "the graph input being accessed " |
319 | 0 | << _.getIdName(graph_type_input_type) << " is an array of " |
320 | 0 | << array_length << " elements but " << "ElementIndex is " |
321 | 0 | << element_index; |
322 | 0 | } |
323 | 0 | } |
324 | 0 | } |
325 | 0 | } |
326 | | |
327 | | // Check result type matches with graph type |
328 | 0 | if (has_element_index) { |
329 | 0 | uint32_t expected_type = _.GetComponentType(graph_type_input_type); |
330 | 0 | if (inst->type_id() != expected_type) { |
331 | 0 | return _.diag(SPV_ERROR_INVALID_DATA, inst) |
332 | 0 | << "Result Type " << _.getIdName(inst->type_id()) |
333 | 0 | << " of graph input instruction " << _.getIdName(inst->id()) |
334 | 0 | << " does not match the component type " |
335 | 0 | << _.getIdName(expected_type) << " of input " << input_index |
336 | 0 | << " in the graph type."; |
337 | 0 | } |
338 | 0 | } else { |
339 | 0 | if (inst->type_id() != graph_type_input_type) { |
340 | 0 | return _.diag(SPV_ERROR_INVALID_DATA, inst) |
341 | 0 | << "Result Type " << _.getIdName(inst->type_id()) |
342 | 0 | << " of graph input instruction " << _.getIdName(inst->id()) |
343 | 0 | << " does not match the type " |
344 | 0 | << _.getIdName(graph_type_input_type) << " of input " |
345 | 0 | << input_index << " in the graph type."; |
346 | 0 | } |
347 | 0 | } |
348 | 0 | return SPV_SUCCESS; |
349 | 0 | } |
350 | | |
351 | | spv_result_t ValidateGraphSetOutput(ValidationState_t& _, |
352 | 0 | const Instruction* inst) { |
353 | 0 | const uint32_t output_index_id = inst->GetOperandAs<uint32_t>(1); |
354 | | // Check type of OutputIndex |
355 | 0 | auto output_index_inst = _.FindDef(output_index_id); |
356 | 0 | if (!output_index_inst || |
357 | 0 | !_.IsIntScalarType(output_index_inst->type_id(), 32)) { |
358 | 0 | return _.diag(SPV_ERROR_INVALID_DATA, inst) |
359 | 0 | << spvOpcodeString(inst->opcode()) |
360 | 0 | << " OutputIndex must be a 32-bit integer."; |
361 | 0 | } |
362 | | |
363 | 0 | bool has_element_index = inst->operands().size() > 2; |
364 | | |
365 | | // Check type of ElementIndex |
366 | 0 | uint32_t element_index_id = 0; |
367 | 0 | if (has_element_index) { |
368 | 0 | element_index_id = inst->GetOperandAs<uint32_t>(2); |
369 | 0 | auto element_index_inst = _.FindDef(element_index_id); |
370 | 0 | if (!element_index_inst || |
371 | 0 | !_.IsIntScalarType(element_index_inst->type_id(), 32)) { |
372 | 0 | return _.diag(SPV_ERROR_INVALID_DATA, inst) |
373 | 0 | << spvOpcodeString(inst->opcode()) |
374 | 0 | << " ElementIndex must be a 32-bit integer."; |
375 | 0 | } |
376 | 0 | } |
377 | | |
378 | 0 | if (spvIsVulkanEnv(_.context()->target_env)) { |
379 | 0 | if (!IsConstantInstruction(_, output_index_id)) { |
380 | 0 | return _.diag(SPV_ERROR_INVALID_DATA, inst) |
381 | 0 | << _.VkErrorID(9932) << "OpGraphSetOutputARM OutputIndex must " |
382 | 0 | << "be the <id> of a constant instruction."; |
383 | 0 | } |
384 | 0 | if (has_element_index && !IsConstantInstruction(_, element_index_id)) { |
385 | 0 | return _.diag(SPV_ERROR_INVALID_DATA, inst) |
386 | 0 | << _.VkErrorID(9932) << "OpGraphSetOutputARM ElementIndex must " |
387 | 0 | << "be the <id> of a constant instruction."; |
388 | 0 | } |
389 | 0 | } |
390 | | |
391 | | // Find graph definition |
392 | 0 | size_t inst_num = inst->LineNum() - 1; |
393 | 0 | auto graph_inst = &_.ordered_instructions()[inst_num]; |
394 | 0 | while (--inst_num) { |
395 | 0 | graph_inst = &_.ordered_instructions()[inst_num]; |
396 | 0 | if (graph_inst->opcode() == spv::Op::OpGraphARM) { |
397 | 0 | break; |
398 | 0 | } |
399 | 0 | } |
400 | | |
401 | | // Can the OutputIndex be evaluated? |
402 | | // If not, there's nothing more we can validate here. |
403 | 0 | uint64_t output_index; |
404 | 0 | if (!_.EvalConstantValUint64(inst->GetOperandAs<uint32_t>(1), |
405 | 0 | &output_index)) { |
406 | 0 | return SPV_SUCCESS; |
407 | 0 | } |
408 | | |
409 | | // Check that the OutputIndex is valid with respect to the graph type |
410 | 0 | auto graph_type_inst = _.FindDef(graph_inst->type_id()); |
411 | 0 | size_t graph_type_num_outputs = GraphTypeInstNumOutputs(graph_type_inst); |
412 | |
|
413 | 0 | if (output_index >= graph_type_num_outputs) { |
414 | 0 | return _.diag(SPV_ERROR_INVALID_DATA, inst) |
415 | 0 | << spvOpcodeString(inst->opcode()) << " setting OutputIndex " |
416 | 0 | << output_index << " but graph only has " << graph_type_num_outputs |
417 | 0 | << " outputs."; |
418 | 0 | } |
419 | | |
420 | 0 | uint32_t graph_type_output_type = |
421 | 0 | GraphTypeInstGetOutputAtIndex(graph_type_inst, output_index); |
422 | |
|
423 | 0 | if (has_element_index) { |
424 | | // Check ElementIndex is allowed |
425 | 0 | if (!IsTensorArray(_, graph_type_output_type)) { |
426 | 0 | return _.diag(SPV_ERROR_INVALID_DATA, inst) |
427 | 0 | << "OpGraphSetOutputARM ElementIndex not allowed when the graph " |
428 | 0 | "output selected by " |
429 | 0 | << "OutputIndex is not an OpTypeArray or OpTypeRuntimeArray"; |
430 | 0 | } |
431 | | |
432 | | // Check ElementIndex is in range if it can be evaluated and the output is a |
433 | | // fixed-sized array whose Length can be evaluated |
434 | 0 | uint64_t element_index; |
435 | 0 | if (_.IsArrayType(graph_type_output_type) && |
436 | 0 | _.EvalConstantValUint64(inst->GetOperandAs<uint32_t>(2), |
437 | 0 | &element_index)) { |
438 | 0 | uint64_t array_length; |
439 | 0 | auto graph_type_output_type_inst = _.FindDef(graph_type_output_type); |
440 | 0 | if (_.EvalConstantValUint64( |
441 | 0 | graph_type_output_type_inst->GetOperandAs<uint32_t>(2), |
442 | 0 | &array_length)) { |
443 | 0 | if (element_index >= array_length) { |
444 | 0 | return _.diag(SPV_ERROR_INVALID_DATA, inst) |
445 | 0 | << "OpGraphSetOutputARM ElementIndex out of range. The type " |
446 | 0 | "of the graph output being accessed " |
447 | 0 | << _.getIdName(graph_type_output_type) << " is an array of " |
448 | 0 | << array_length << " elements but " << "ElementIndex is " |
449 | 0 | << element_index; |
450 | 0 | } |
451 | 0 | } |
452 | 0 | } |
453 | 0 | } |
454 | | |
455 | | // Check Value's type matches with graph type |
456 | 0 | uint32_t value = inst->GetOperandAs<uint32_t>(0); |
457 | 0 | uint32_t value_type = _.FindDef(value)->type_id(); |
458 | 0 | if (has_element_index) { |
459 | 0 | uint32_t expected_type = _.GetComponentType(graph_type_output_type); |
460 | 0 | if (value_type != expected_type) { |
461 | 0 | return _.diag(SPV_ERROR_INVALID_DATA, inst) |
462 | 0 | << "The type " << _.getIdName(value_type) |
463 | 0 | << " of Value provided to the graph output instruction " |
464 | 0 | << _.getIdName(value) << " does not match the component type " |
465 | 0 | << _.getIdName(expected_type) << " of output " << output_index |
466 | 0 | << " in the graph type."; |
467 | 0 | } |
468 | 0 | } else { |
469 | 0 | if (value_type != graph_type_output_type) { |
470 | 0 | return _.diag(SPV_ERROR_INVALID_DATA, inst) |
471 | 0 | << "The type " << _.getIdName(value_type) |
472 | 0 | << " of Value provided to the graph output instruction " |
473 | 0 | << _.getIdName(value) << " does not match the type " |
474 | 0 | << _.getIdName(graph_type_output_type) << " of output " |
475 | 0 | << output_index << " in the graph type."; |
476 | 0 | } |
477 | 0 | } |
478 | 0 | return SPV_SUCCESS; |
479 | 0 | } |
480 | | |
481 | | bool InputOutputInstructionsHaveDuplicateIndices( |
482 | | ValidationState_t& _, std::deque<const Instruction*>& inout_insts, |
483 | 0 | const Instruction** first_dup) { |
484 | 0 | std::set<std::pair<uint64_t, uint64_t>> inout_element_indices; |
485 | 0 | for (auto const inst : inout_insts) { |
486 | 0 | const bool is_input = inst->opcode() == spv::Op::OpGraphInputARM; |
487 | 0 | bool has_element_index = inst->operands().size() > (is_input ? 3 : 2); |
488 | 0 | uint64_t inout_index; |
489 | 0 | if (!_.EvalConstantValUint64(inst->GetOperandAs<uint32_t>(is_input ? 2 : 1), |
490 | 0 | &inout_index)) { |
491 | 0 | continue; |
492 | 0 | } |
493 | 0 | uint64_t element_index = -1; // -1 means no ElementIndex |
494 | 0 | if (has_element_index) { |
495 | 0 | if (!_.EvalConstantValUint64( |
496 | 0 | inst->GetOperandAs<uint32_t>(is_input ? 3 : 2), &element_index)) { |
497 | 0 | continue; |
498 | 0 | } |
499 | 0 | } |
500 | 0 | auto inout_element_pair = std::make_pair(inout_index, element_index); |
501 | 0 | auto inout_noelement_pair = std::make_pair(inout_index, -1); |
502 | 0 | if (inout_element_indices.count(inout_element_pair) || |
503 | 0 | inout_element_indices.count(inout_noelement_pair)) { |
504 | 0 | *first_dup = inst; |
505 | 0 | return true; |
506 | 0 | } |
507 | 0 | inout_element_indices.insert(inout_element_pair); |
508 | 0 | } |
509 | 0 | return false; |
510 | 0 | } |
511 | | |
512 | 0 | spv_result_t ValidateGraphEnd(ValidationState_t& _, const Instruction* inst) { |
513 | 0 | size_t end_inst_num = inst->LineNum() - 1; |
514 | | |
515 | | // Gather OpGraphInputARM and OpGraphSetOutputARM instructions |
516 | 0 | std::deque<const Instruction*> graph_inputs, graph_outputs; |
517 | 0 | size_t in_inst_num = end_inst_num; |
518 | 0 | auto graph_inst = &_.ordered_instructions()[in_inst_num]; |
519 | 0 | while (--in_inst_num) { |
520 | 0 | graph_inst = &_.ordered_instructions()[in_inst_num]; |
521 | 0 | if (graph_inst->opcode() == spv::Op::OpGraphInputARM) { |
522 | 0 | graph_inputs.push_front(graph_inst); |
523 | 0 | continue; |
524 | 0 | } |
525 | 0 | if (graph_inst->opcode() == spv::Op::OpGraphSetOutputARM) { |
526 | 0 | graph_outputs.push_front(graph_inst); |
527 | 0 | continue; |
528 | 0 | } |
529 | 0 | if (graph_inst->opcode() == spv::Op::OpGraphARM) { |
530 | 0 | break; |
531 | 0 | } |
532 | 0 | } |
533 | |
|
534 | 0 | const Instruction* first_dup; |
535 | | |
536 | | // Check that there are no duplicate InputIndex and ElementIndex values |
537 | 0 | if (InputOutputInstructionsHaveDuplicateIndices(_, graph_inputs, |
538 | 0 | &first_dup)) { |
539 | 0 | return _.diag(SPV_ERROR_INVALID_DATA, first_dup) |
540 | 0 | << "Two OpGraphInputARM instructions with the same InputIndex " |
541 | 0 | "must not be part of the same " |
542 | 0 | << "graph definition unless ElementIndex is present in both with " |
543 | 0 | "different values."; |
544 | 0 | } |
545 | | |
546 | | // Check that there are no duplicate OutputIndex and ElementIndex values |
547 | 0 | if (InputOutputInstructionsHaveDuplicateIndices(_, graph_outputs, |
548 | 0 | &first_dup)) { |
549 | 0 | return _.diag(SPV_ERROR_INVALID_DATA, first_dup) |
550 | 0 | << "Two OpGraphSetOutputARM instructions with the same " |
551 | 0 | "OutputIndex must not be part of the same " |
552 | 0 | << "graph definition unless ElementIndex is present in both with " |
553 | 0 | "different values."; |
554 | 0 | } |
555 | | |
556 | 0 | return SPV_SUCCESS; |
557 | 0 | } |
558 | | |
559 | | } // namespace |
560 | | |
561 | | // Validates correctness of graph instructions. |
562 | 14.6M | spv_result_t GraphPass(ValidationState_t& _, const Instruction* inst) { |
563 | 14.6M | switch (inst->opcode()) { |
564 | 0 | case spv::Op::OpTypeGraphARM: |
565 | 0 | return ValidateGraphType(_, inst); |
566 | 0 | case spv::Op::OpGraphConstantARM: |
567 | 0 | return ValidateGraphConstant(_, inst); |
568 | 0 | case spv::Op::OpGraphEntryPointARM: |
569 | 0 | return ValidateGraphEntryPoint(_, inst); |
570 | 0 | case spv::Op::OpGraphARM: |
571 | 0 | return ValidateGraph(_, inst); |
572 | 0 | case spv::Op::OpGraphInputARM: |
573 | 0 | return ValidateGraphInput(_, inst); |
574 | 0 | case spv::Op::OpGraphSetOutputARM: |
575 | 0 | return ValidateGraphSetOutput(_, inst); |
576 | 0 | case spv::Op::OpGraphEndARM: |
577 | 0 | return ValidateGraphEnd(_, inst); |
578 | 14.6M | default: |
579 | 14.6M | break; |
580 | 14.6M | } |
581 | 14.6M | return SPV_SUCCESS; |
582 | 14.6M | } |
583 | | |
584 | | } // namespace val |
585 | | } // namespace spvtools |