/src/spirv-tools/source/val/validate_constants.cpp
Line | Count | Source |
1 | | // Copyright (c) 2018 Google LLC. |
2 | | // |
3 | | // Licensed under the Apache License, Version 2.0 (the "License"); |
4 | | // you may not use this file except in compliance with the License. |
5 | | // You may obtain a copy of the License at |
6 | | // |
7 | | // http://www.apache.org/licenses/LICENSE-2.0 |
8 | | // |
9 | | // Unless required by applicable law or agreed to in writing, software |
10 | | // distributed under the License is distributed on an "AS IS" BASIS, |
11 | | // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. |
12 | | // See the License for the specific language governing permissions and |
13 | | // limitations under the License. |
14 | | |
15 | | #include "source/opcode.h" |
16 | | #include "source/val/instruction.h" |
17 | | #include "source/val/validate.h" |
18 | | #include "source/val/validation_state.h" |
19 | | |
20 | | namespace spvtools { |
21 | | namespace val { |
22 | | namespace { |
23 | | |
24 | | spv_result_t ValidateConstantBool(ValidationState_t& _, |
25 | 6.73k | const Instruction* inst) { |
26 | 6.73k | auto type = _.FindDef(inst->type_id()); |
27 | 6.73k | if (!type || type->opcode() != spv::Op::OpTypeBool) { |
28 | 21 | return _.diag(SPV_ERROR_INVALID_ID, inst) |
29 | 21 | << "Op" << spvOpcodeString(inst->opcode()) << " Result Type <id> " |
30 | 21 | << _.getIdName(inst->type_id()) << " is not a boolean type."; |
31 | 21 | } |
32 | | |
33 | 6.71k | return SPV_SUCCESS; |
34 | 6.73k | } |
35 | | |
36 | 31.1k | bool isCompositeType(const Instruction* inst) { |
37 | 31.1k | bool is_tensor = inst->opcode() == spv::Op::OpTypeTensorARM; |
38 | 31.1k | bool tensor_is_shaped = inst->words().size() == 5; |
39 | 31.1k | return spvOpcodeIsComposite(inst->opcode()) || |
40 | 14 | (is_tensor && tensor_is_shaped); |
41 | 31.1k | } |
42 | | |
43 | | spv_result_t ValidateConstantComposite(ValidationState_t& _, |
44 | 31.1k | const Instruction* inst) { |
45 | 31.1k | std::string opcode_name = std::string("Op") + spvOpcodeString(inst->opcode()); |
46 | | |
47 | 31.1k | const auto result_type = _.FindDef(inst->type_id()); |
48 | 31.1k | if (!result_type || !isCompositeType(result_type)) { |
49 | 14 | return _.diag(SPV_ERROR_INVALID_ID, inst) |
50 | 14 | << opcode_name << " Result Type <id> " |
51 | 14 | << _.getIdName(inst->type_id()) << " is not a composite type."; |
52 | 14 | } |
53 | | |
54 | 31.1k | const auto constituent_count = inst->words().size() - 3; |
55 | 31.1k | switch (result_type->opcode()) { |
56 | 26.8k | case spv::Op::OpTypeVector: |
57 | 26.8k | case spv::Op::OpTypeCooperativeVectorNV: { |
58 | 26.8k | uint32_t num_result_components = _.GetDimension(result_type->id()); |
59 | 26.8k | bool comp_is_int32 = true, comp_is_const_int32 = true; |
60 | | |
61 | 26.8k | if (result_type->opcode() == spv::Op::OpTypeCooperativeVectorNV) { |
62 | 0 | uint32_t comp_count_id = result_type->GetOperandAs<uint32_t>(2); |
63 | 0 | std::tie(comp_is_int32, comp_is_const_int32, num_result_components) = |
64 | 0 | _.EvalInt32IfConst(comp_count_id); |
65 | 0 | } |
66 | | |
67 | 26.8k | if (comp_is_const_int32 && num_result_components != constituent_count) { |
68 | | // TODO: Output ID's on diagnostic |
69 | 11 | return _.diag(SPV_ERROR_INVALID_ID, inst) |
70 | 11 | << opcode_name |
71 | 11 | << " Constituent <id> count does not match " |
72 | 11 | "Result Type <id> " |
73 | 11 | << _.getIdName(result_type->id()) << "s vector component count."; |
74 | 11 | } |
75 | 26.7k | const auto component_type = |
76 | 26.7k | _.FindDef(result_type->GetOperandAs<uint32_t>(1)); |
77 | 26.7k | if (!component_type) { |
78 | 0 | return _.diag(SPV_ERROR_INVALID_ID, result_type) |
79 | 0 | << "Component type is not defined."; |
80 | 0 | } |
81 | 26.7k | for (size_t constituent_index = 2; |
82 | 117k | constituent_index < inst->operands().size(); constituent_index++) { |
83 | 90.7k | const auto constituent_id = |
84 | 90.7k | inst->GetOperandAs<uint32_t>(constituent_index); |
85 | 90.7k | const auto constituent = _.FindDef(constituent_id); |
86 | 90.7k | if (!constituent || |
87 | 90.7k | !spvOpcodeIsConstantOrUndef(constituent->opcode())) { |
88 | 9 | return _.diag(SPV_ERROR_INVALID_ID, inst) |
89 | 9 | << opcode_name << " Constituent <id> " |
90 | 9 | << _.getIdName(constituent_id) |
91 | 9 | << " is not a constant or undef."; |
92 | 9 | } |
93 | 90.7k | const auto constituent_result_type = _.FindDef(constituent->type_id()); |
94 | 90.7k | if (!constituent_result_type || |
95 | 90.7k | component_type->id() != constituent_result_type->id()) { |
96 | 14 | return _.diag(SPV_ERROR_INVALID_ID, inst) |
97 | 14 | << opcode_name << " Constituent <id> " |
98 | 14 | << _.getIdName(constituent_id) |
99 | 14 | << "s type does not match Result Type <id> " |
100 | 14 | << _.getIdName(result_type->id()) << "s vector element type."; |
101 | 14 | } |
102 | 90.7k | } |
103 | 26.7k | } break; |
104 | 26.7k | case spv::Op::OpTypeMatrix: { |
105 | 122 | const auto column_count = result_type->GetOperandAs<uint32_t>(2); |
106 | 122 | if (column_count != constituent_count) { |
107 | | // TODO: Output ID's on diagnostic |
108 | 8 | return _.diag(SPV_ERROR_INVALID_ID, inst) |
109 | 8 | << opcode_name |
110 | 8 | << " Constituent <id> count does not match " |
111 | 8 | "Result Type <id> " |
112 | 8 | << _.getIdName(result_type->id()) << "s matrix column count."; |
113 | 8 | } |
114 | | |
115 | 114 | const auto column_type = _.FindDef(result_type->words()[2]); |
116 | 114 | if (!column_type) { |
117 | 0 | return _.diag(SPV_ERROR_INVALID_ID, result_type) |
118 | 0 | << "Column type is not defined."; |
119 | 0 | } |
120 | 114 | const auto component_count = column_type->GetOperandAs<uint32_t>(2); |
121 | 114 | const auto component_type = |
122 | 114 | _.FindDef(column_type->GetOperandAs<uint32_t>(1)); |
123 | 114 | if (!component_type) { |
124 | 0 | return _.diag(SPV_ERROR_INVALID_ID, column_type) |
125 | 0 | << "Component type is not defined."; |
126 | 0 | } |
127 | | |
128 | 114 | for (size_t constituent_index = 2; |
129 | 342 | constituent_index < inst->operands().size(); constituent_index++) { |
130 | 250 | const auto constituent_id = |
131 | 250 | inst->GetOperandAs<uint32_t>(constituent_index); |
132 | 250 | const auto constituent = _.FindDef(constituent_id); |
133 | 250 | if (!constituent || |
134 | 250 | !spvOpcodeIsConstantOrUndef(constituent->opcode())) { |
135 | | // The message says "... or undef" because the spec does not say |
136 | | // undef is a constant. |
137 | 6 | return _.diag(SPV_ERROR_INVALID_ID, inst) |
138 | 6 | << opcode_name << " Constituent <id> " |
139 | 6 | << _.getIdName(constituent_id) |
140 | 6 | << " is not a constant or undef."; |
141 | 6 | } |
142 | 244 | const auto vector = _.FindDef(constituent->type_id()); |
143 | 244 | if (!vector) { |
144 | 0 | return _.diag(SPV_ERROR_INVALID_ID, constituent) |
145 | 0 | << "Result type is not defined."; |
146 | 0 | } |
147 | 244 | if (column_type->opcode() != vector->opcode()) { |
148 | 9 | return _.diag(SPV_ERROR_INVALID_ID, inst) |
149 | 9 | << opcode_name << " Constituent <id> " |
150 | 9 | << _.getIdName(constituent_id) |
151 | 9 | << " type does not match Result Type <id> " |
152 | 9 | << _.getIdName(result_type->id()) << "s matrix column type."; |
153 | 9 | } |
154 | 235 | const auto vector_component_type = |
155 | 235 | _.FindDef(vector->GetOperandAs<uint32_t>(1)); |
156 | 235 | if (component_type->id() != vector_component_type->id()) { |
157 | 2 | return _.diag(SPV_ERROR_INVALID_ID, inst) |
158 | 2 | << opcode_name << " Constituent <id> " |
159 | 2 | << _.getIdName(constituent_id) |
160 | 2 | << " component type does not match Result Type <id> " |
161 | 2 | << _.getIdName(result_type->id()) |
162 | 2 | << "s matrix column component type."; |
163 | 2 | } |
164 | 233 | if (component_count != vector->words()[3]) { |
165 | 5 | return _.diag(SPV_ERROR_INVALID_ID, inst) |
166 | 5 | << opcode_name << " Constituent <id> " |
167 | 5 | << _.getIdName(constituent_id) |
168 | 5 | << " vector component count does not match Result Type <id> " |
169 | 5 | << _.getIdName(result_type->id()) |
170 | 5 | << "s vector component count."; |
171 | 5 | } |
172 | 233 | } |
173 | 114 | } break; |
174 | 1.56k | case spv::Op::OpTypeArray: { |
175 | 1.56k | auto element_type = _.FindDef(result_type->GetOperandAs<uint32_t>(1)); |
176 | 1.56k | if (!element_type) { |
177 | 0 | return _.diag(SPV_ERROR_INVALID_ID, result_type) |
178 | 0 | << "Element type is not defined."; |
179 | 0 | } |
180 | 1.56k | const auto length = _.FindDef(result_type->GetOperandAs<uint32_t>(2)); |
181 | 1.56k | if (!length) { |
182 | 0 | return _.diag(SPV_ERROR_INVALID_ID, result_type) |
183 | 0 | << "Length is not defined."; |
184 | 0 | } |
185 | 1.56k | bool is_int32; |
186 | 1.56k | bool is_const; |
187 | 1.56k | uint32_t value; |
188 | 1.56k | std::tie(is_int32, is_const, value) = _.EvalInt32IfConst(length->id()); |
189 | 1.56k | if (is_int32 && is_const && value != constituent_count) { |
190 | 50 | return _.diag(SPV_ERROR_INVALID_ID, inst) |
191 | 50 | << opcode_name |
192 | 50 | << " Constituent count does not match " |
193 | 50 | "Result Type <id> " |
194 | 50 | << _.getIdName(result_type->id()) << "s array length."; |
195 | 50 | } |
196 | 1.51k | for (size_t constituent_index = 2; |
197 | 21.8k | constituent_index < inst->operands().size(); constituent_index++) { |
198 | 20.3k | const auto constituent_id = |
199 | 20.3k | inst->GetOperandAs<uint32_t>(constituent_index); |
200 | 20.3k | const auto constituent = _.FindDef(constituent_id); |
201 | 20.3k | if (!constituent || |
202 | 20.3k | !spvOpcodeIsConstantOrUndef(constituent->opcode())) { |
203 | 5 | return _.diag(SPV_ERROR_INVALID_ID, inst) |
204 | 5 | << opcode_name << " Constituent <id> " |
205 | 5 | << _.getIdName(constituent_id) |
206 | 5 | << " is not a constant or undef."; |
207 | 5 | } |
208 | 20.3k | const auto constituent_type = _.FindDef(constituent->type_id()); |
209 | 20.3k | if (!constituent_type) { |
210 | 0 | return _.diag(SPV_ERROR_INVALID_ID, constituent) |
211 | 0 | << "Result type is not defined."; |
212 | 0 | } |
213 | 20.3k | if (element_type->id() != constituent_type->id()) { |
214 | 25 | return _.diag(SPV_ERROR_INVALID_ID, inst) |
215 | 25 | << opcode_name << " Constituent <id> " |
216 | 25 | << _.getIdName(constituent_id) |
217 | 25 | << "s type does not match Result Type <id> " |
218 | 25 | << _.getIdName(result_type->id()) << "s array element type."; |
219 | 25 | } |
220 | 20.3k | } |
221 | 1.51k | } break; |
222 | 2.66k | case spv::Op::OpTypeStruct: { |
223 | 2.66k | const auto member_count = result_type->words().size() - 2; |
224 | 2.66k | if (member_count != constituent_count) { |
225 | 10 | return _.diag(SPV_ERROR_INVALID_ID, inst) |
226 | 10 | << opcode_name << " Constituent <id> " |
227 | 10 | << _.getIdName(inst->type_id()) |
228 | 10 | << " count does not match Result Type <id> " |
229 | 10 | << _.getIdName(result_type->id()) << "s struct member count."; |
230 | 10 | } |
231 | 2.65k | for (uint32_t constituent_index = 2, member_index = 1; |
232 | 10.2k | constituent_index < inst->operands().size(); |
233 | 7.61k | constituent_index++, member_index++) { |
234 | 7.61k | const auto constituent_id = |
235 | 7.61k | inst->GetOperandAs<uint32_t>(constituent_index); |
236 | 7.61k | const auto constituent = _.FindDef(constituent_id); |
237 | 7.61k | if (!constituent || |
238 | 7.61k | !spvOpcodeIsConstantOrUndef(constituent->opcode())) { |
239 | 7 | return _.diag(SPV_ERROR_INVALID_ID, inst) |
240 | 7 | << opcode_name << " Constituent <id> " |
241 | 7 | << _.getIdName(constituent_id) |
242 | 7 | << " is not a constant or undef."; |
243 | 7 | } |
244 | 7.60k | const auto constituent_type = _.FindDef(constituent->type_id()); |
245 | 7.60k | if (!constituent_type) { |
246 | 0 | return _.diag(SPV_ERROR_INVALID_ID, constituent) |
247 | 0 | << "Result type is not defined."; |
248 | 0 | } |
249 | | |
250 | 7.60k | const auto member_type_id = |
251 | 7.60k | result_type->GetOperandAs<uint32_t>(member_index); |
252 | 7.60k | const auto member_type = _.FindDef(member_type_id); |
253 | 7.60k | if (!member_type || member_type->id() != constituent_type->id()) { |
254 | 18 | return _.diag(SPV_ERROR_INVALID_ID, inst) |
255 | 18 | << opcode_name << " Constituent <id> " |
256 | 18 | << _.getIdName(constituent_id) |
257 | 18 | << " type does not match the Result Type <id> " |
258 | 18 | << _.getIdName(result_type->id()) << "s member type."; |
259 | 18 | } |
260 | 7.60k | } |
261 | 2.65k | } break; |
262 | 2.62k | case spv::Op::OpTypeCooperativeMatrixKHR: |
263 | 0 | case spv::Op::OpTypeCooperativeMatrixNV: { |
264 | 0 | if (1 != constituent_count) { |
265 | 0 | return _.diag(SPV_ERROR_INVALID_ID, inst) |
266 | 0 | << opcode_name << " Constituent <id> " |
267 | 0 | << _.getIdName(inst->type_id()) << " count must be one."; |
268 | 0 | } |
269 | 0 | const auto constituent_id = inst->GetOperandAs<uint32_t>(2); |
270 | 0 | const auto constituent = _.FindDef(constituent_id); |
271 | 0 | if (!constituent || !spvOpcodeIsConstantOrUndef(constituent->opcode())) { |
272 | 0 | return _.diag(SPV_ERROR_INVALID_ID, inst) |
273 | 0 | << opcode_name << " Constituent <id> " |
274 | 0 | << _.getIdName(constituent_id) << " is not a constant or undef."; |
275 | 0 | } |
276 | 0 | const auto constituent_type = _.FindDef(constituent->type_id()); |
277 | 0 | if (!constituent_type) { |
278 | 0 | return _.diag(SPV_ERROR_INVALID_ID, constituent) |
279 | 0 | << "Result type is not defined."; |
280 | 0 | } |
281 | | |
282 | 0 | const auto component_type_id = result_type->GetOperandAs<uint32_t>(1); |
283 | 0 | const auto component_type = _.FindDef(component_type_id); |
284 | 0 | if (!component_type || component_type->id() != constituent_type->id()) { |
285 | 0 | return _.diag(SPV_ERROR_INVALID_ID, inst) |
286 | 0 | << opcode_name << " Constituent <id> " |
287 | 0 | << _.getIdName(constituent_id) |
288 | 0 | << " type does not match the Result Type <id> " |
289 | 0 | << _.getIdName(result_type->id()) << "s component type."; |
290 | 0 | } |
291 | 0 | } break; |
292 | 0 | case spv::Op::OpTypeTensorARM: { |
293 | 0 | auto inst_element_type = |
294 | 0 | _.FindDef(result_type->GetOperandAs<uint32_t>(1)); |
295 | 0 | if (!inst_element_type) { |
296 | 0 | return _.diag(SPV_ERROR_INVALID_ID, result_type) |
297 | 0 | << "Element type is not defined."; |
298 | 0 | } |
299 | 0 | const auto inst_rank = _.FindDef(result_type->GetOperandAs<uint32_t>(2)); |
300 | 0 | if (!inst_rank) { |
301 | 0 | return _.diag(SPV_ERROR_INVALID_ID, result_type) |
302 | 0 | << "Rank is not defined."; |
303 | 0 | } |
304 | 0 | const auto inst_shape = _.FindDef(result_type->GetOperandAs<uint32_t>(3)); |
305 | 0 | if (!inst_shape) { |
306 | 0 | return _.diag(SPV_ERROR_INVALID_ID, result_type) |
307 | 0 | << "Shape is not defined."; |
308 | 0 | } |
309 | | |
310 | 0 | uint64_t rank = 0; |
311 | 0 | _.EvalConstantValUint64(inst_rank->id(), &rank); |
312 | |
|
313 | 0 | uint64_t outermost_shape = 0; |
314 | 0 | if (_.EvalConstantValUint64(inst_shape->GetOperandAs<uint32_t>(2), |
315 | 0 | &outermost_shape) && |
316 | 0 | (outermost_shape != constituent_count)) { |
317 | 0 | return _.diag(SPV_ERROR_INVALID_ID, inst) |
318 | 0 | << opcode_name |
319 | 0 | << " Constituent count does not match " |
320 | 0 | "the shape of Result Type <id> " |
321 | 0 | << _.getIdName(result_type->id()) |
322 | 0 | << " along its outermost dimension, " << "expected " |
323 | 0 | << outermost_shape << " but got " << constituent_count << "."; |
324 | 0 | } |
325 | | |
326 | 0 | for (size_t constituent_index = 2; |
327 | 0 | constituent_index < inst->operands().size(); constituent_index++) { |
328 | 0 | const auto constituent_id = |
329 | 0 | inst->GetOperandAs<uint32_t>(constituent_index); |
330 | 0 | const auto constituent = _.FindDef(constituent_id); |
331 | 0 | if (!constituent || |
332 | 0 | !spvOpcodeIsConstantOrUndef(constituent->opcode())) { |
333 | 0 | return _.diag(SPV_ERROR_INVALID_ID, inst) |
334 | 0 | << opcode_name << " Constituent <id> " |
335 | 0 | << _.getIdName(constituent_id) |
336 | 0 | << " is not a constant or undef."; |
337 | 0 | } |
338 | 0 | const auto constituent_type = _.FindDef(constituent->type_id()); |
339 | 0 | if (!constituent_type) { |
340 | 0 | return _.diag(SPV_ERROR_INVALID_ID, constituent) |
341 | 0 | << "Type of Constituent " << constituent_index - 2 |
342 | 0 | << " is not defined."; |
343 | 0 | } |
344 | | |
345 | 0 | if (rank == 0) { |
346 | | // The rank of the returned tensor constant is not known. |
347 | | // Skip rank-dependent validation. |
348 | 0 | continue; |
349 | 0 | } |
350 | | |
351 | 0 | if (rank == 1) { |
352 | 0 | if (inst_element_type->id() != constituent_type->id()) { |
353 | 0 | return _.diag(SPV_ERROR_INVALID_ID, inst) |
354 | 0 | << opcode_name << " Constituent <id> " |
355 | 0 | << _.getIdName(constituent_id) |
356 | 0 | << " type does not match the element type of the tensor (" |
357 | 0 | << _.getIdName(result_type->id()) << ")."; |
358 | 0 | } |
359 | 0 | } else { |
360 | 0 | if (constituent_type->opcode() != spv::Op::OpTypeTensorARM) { |
361 | 0 | return _.diag(SPV_ERROR_INVALID_ID, inst) |
362 | 0 | << opcode_name << " Constituent <id> " |
363 | 0 | << _.getIdName(constituent_id) |
364 | 0 | << " must be an OpTypeTensorARM."; |
365 | 0 | } |
366 | 0 | auto inst_constituent_element_type = |
367 | 0 | _.FindDef(constituent_type->GetOperandAs<uint32_t>(1)); |
368 | 0 | if (!inst_constituent_element_type || |
369 | 0 | inst_constituent_element_type->id() != inst_element_type->id()) { |
370 | 0 | return _.diag(SPV_ERROR_INVALID_ID, inst) |
371 | 0 | << opcode_name << " Constituent <id> " |
372 | 0 | << _.getIdName(constituent_id) |
373 | 0 | << " must have the same Element Type as Result Type <id> " |
374 | 0 | << _.getIdName(result_type->id()) << "."; |
375 | 0 | } |
376 | 0 | auto inst_constituent_rank = |
377 | 0 | _.FindDef(constituent_type->GetOperandAs<uint32_t>(2)); |
378 | 0 | uint64_t constituent_rank; |
379 | 0 | if (inst_constituent_rank && |
380 | 0 | _.EvalConstantValUint64(inst_constituent_rank->id(), |
381 | 0 | &constituent_rank) && |
382 | 0 | (constituent_rank != rank - 1)) { |
383 | 0 | return _.diag(SPV_ERROR_INVALID_ID, inst) |
384 | 0 | << opcode_name << " Constituent <id> " |
385 | 0 | << _.getIdName(constituent_id) |
386 | 0 | << " must have a Rank that is 1 less than the Rank of " |
387 | 0 | "Result Type <id> " |
388 | 0 | << _.getIdName(result_type->id()) << ", expected " |
389 | 0 | << rank - 1 << " but got " << constituent_rank << "."; |
390 | 0 | } |
391 | | |
392 | 0 | auto inst_constituent_shape = |
393 | 0 | _.FindDef(constituent_type->GetOperandAs<uint32_t>(3)); |
394 | 0 | if (!inst_constituent_shape) { |
395 | 0 | return _.diag(SPV_ERROR_INVALID_ID, result_type) |
396 | 0 | << "Shape of Constituent " << constituent_index - 2 |
397 | 0 | << " is not defined."; |
398 | 0 | } |
399 | 0 | for (size_t constituent_shape_index = 2; |
400 | 0 | constituent_shape_index < |
401 | 0 | inst_constituent_shape->operands().size(); |
402 | 0 | constituent_shape_index++) { |
403 | 0 | size_t shape_index = constituent_shape_index + 1; |
404 | 0 | uint64_t constituent_shape = 0, shape = 1; |
405 | 0 | if (_.EvalConstantValUint64( |
406 | 0 | inst_constituent_shape->GetOperandAs<uint32_t>( |
407 | 0 | constituent_shape_index), |
408 | 0 | &constituent_shape) && |
409 | 0 | _.EvalConstantValUint64( |
410 | 0 | inst_shape->GetOperandAs<uint32_t>(shape_index), &shape) && |
411 | 0 | (constituent_shape != shape)) { |
412 | 0 | return _.diag(SPV_ERROR_INVALID_ID, inst) |
413 | 0 | << opcode_name << " Constituent <id> " |
414 | 0 | << _.getIdName(constituent_id) |
415 | 0 | << " must have a Shape that matches that of Result Type " |
416 | 0 | "<id> " |
417 | 0 | << _.getIdName(result_type->id()) |
418 | 0 | << " along all inner dimensions of Result Type, expected " |
419 | 0 | << shape << " for dimension " |
420 | 0 | << constituent_shape_index - 2 |
421 | 0 | << " of Constituent but got " << constituent_shape << "."; |
422 | 0 | } |
423 | 0 | } |
424 | 0 | } |
425 | 0 | } |
426 | 0 | } break; |
427 | 25 | default: |
428 | 25 | break; |
429 | 31.1k | } |
430 | 30.9k | return SPV_SUCCESS; |
431 | 31.1k | } |
432 | | |
433 | | spv_result_t ValidateConstantSampler(ValidationState_t& _, |
434 | 0 | const Instruction* inst) { |
435 | 0 | const auto result_type = _.FindDef(inst->type_id()); |
436 | 0 | if (!result_type || result_type->opcode() != spv::Op::OpTypeSampler) { |
437 | 0 | return _.diag(SPV_ERROR_INVALID_ID, result_type) |
438 | 0 | << "OpConstantSampler Result Type <id> " |
439 | 0 | << _.getIdName(inst->type_id()) << " is not a sampler type."; |
440 | 0 | } |
441 | | |
442 | 0 | return SPV_SUCCESS; |
443 | 0 | } |
444 | | |
445 | | // True if instruction defines a type that can have a null value, as defined by |
446 | | // the SPIR-V spec. Tracks composite-type components through module to check |
447 | | // nullability transitively. |
448 | | bool IsTypeNullable(const std::vector<uint32_t>& instruction, |
449 | 5.27k | const ValidationState_t& _) { |
450 | 5.27k | uint16_t opcode; |
451 | 5.27k | uint16_t word_count; |
452 | 5.27k | spvOpcodeSplit(instruction[0], &word_count, &opcode); |
453 | 5.27k | switch (static_cast<spv::Op>(opcode)) { |
454 | 1.24k | case spv::Op::OpTypeBool: |
455 | 3.04k | case spv::Op::OpTypeInt: |
456 | 3.79k | case spv::Op::OpTypeFloat: |
457 | 3.83k | case spv::Op::OpTypeEvent: |
458 | 3.85k | case spv::Op::OpTypeDeviceEvent: |
459 | 3.92k | case spv::Op::OpTypeReserveId: |
460 | 3.96k | case spv::Op::OpTypeQueue: |
461 | 3.96k | return true; |
462 | 53 | case spv::Op::OpTypeArray: |
463 | 93 | case spv::Op::OpTypeMatrix: |
464 | 93 | case spv::Op::OpTypeCooperativeMatrixNV: |
465 | 93 | case spv::Op::OpTypeCooperativeMatrixKHR: |
466 | 93 | case spv::Op::OpTypeCooperativeVectorNV: |
467 | 555 | case spv::Op::OpTypeVector: { |
468 | 555 | auto base_type = _.FindDef(instruction[2]); |
469 | 555 | return base_type && IsTypeNullable(base_type->words(), _); |
470 | 93 | } |
471 | 490 | case spv::Op::OpTypeStruct: { |
472 | 1.15k | for (size_t elementIndex = 2; elementIndex < instruction.size(); |
473 | 679 | ++elementIndex) { |
474 | 679 | auto element = _.FindDef(instruction[elementIndex]); |
475 | 679 | if (!element || !IsTypeNullable(element->words(), _)) return false; |
476 | 679 | } |
477 | 479 | return true; |
478 | 490 | } |
479 | 0 | case spv::Op::OpTypeUntypedPointerKHR: |
480 | 257 | case spv::Op::OpTypePointer: |
481 | 257 | if (spv::StorageClass(instruction[2]) == |
482 | 257 | spv::StorageClass::PhysicalStorageBuffer) { |
483 | 0 | return false; |
484 | 0 | } |
485 | 257 | return true; |
486 | 0 | case spv::Op::OpTypeTensorARM: { |
487 | 0 | auto elem_type = _.FindDef(instruction[2]); |
488 | 0 | return (instruction.size() > 4) && elem_type && |
489 | 0 | IsTypeNullable(elem_type->words(), _); |
490 | 257 | } |
491 | 15 | default: |
492 | 15 | return false; |
493 | 5.27k | } |
494 | 5.27k | } |
495 | | |
496 | | spv_result_t ValidateConstantNull(ValidationState_t& _, |
497 | 4.04k | const Instruction* inst) { |
498 | 4.04k | const auto result_type = _.FindDef(inst->type_id()); |
499 | 4.04k | if (!result_type || !IsTypeNullable(result_type->words(), _)) { |
500 | 15 | return _.diag(SPV_ERROR_INVALID_ID, inst) |
501 | 15 | << "OpConstantNull Result Type <id> " << _.getIdName(inst->type_id()) |
502 | 15 | << " cannot have a null value."; |
503 | 15 | } |
504 | | |
505 | 4.02k | return SPV_SUCCESS; |
506 | 4.04k | } |
507 | | |
508 | | // Validates that OpSpecConstant specializes to either int or float type. |
509 | | spv_result_t ValidateSpecConstant(ValidationState_t& _, |
510 | 2.87k | const Instruction* inst) { |
511 | | // Operand 0 is the <id> of the type that we're specializing to. |
512 | 2.87k | auto type_id = inst->GetOperandAs<const uint32_t>(0); |
513 | 2.87k | auto type_instruction = _.FindDef(type_id); |
514 | 2.87k | auto type_opcode = type_instruction->opcode(); |
515 | 2.87k | if (type_opcode != spv::Op::OpTypeInt && |
516 | 1.12k | type_opcode != spv::Op::OpTypeFloat) { |
517 | 0 | return _.diag(SPV_ERROR_INVALID_DATA, inst) << "Specialization constant " |
518 | 0 | "must be an integer or " |
519 | 0 | "floating-point number."; |
520 | 0 | } |
521 | 2.87k | return SPV_SUCCESS; |
522 | 2.87k | } |
523 | | |
524 | | spv_result_t ValidateSpecConstantOp(ValidationState_t& _, |
525 | 124 | const Instruction* inst) { |
526 | 124 | const auto op = inst->GetOperandAs<spv::Op>(2); |
527 | | |
528 | | // The binary parser already ensures that the op is valid for *some* |
529 | | // environment. Here we check restrictions. |
530 | 124 | switch (op) { |
531 | 10 | case spv::Op::OpQuantizeToF16: |
532 | 10 | if (!_.HasCapability(spv::Capability::Shader)) { |
533 | 3 | return _.diag(SPV_ERROR_INVALID_ID, inst) |
534 | 3 | << "Specialization constant operation " << spvOpcodeString(op) |
535 | 3 | << " requires Shader capability"; |
536 | 3 | } |
537 | 7 | break; |
538 | | |
539 | 7 | case spv::Op::OpUConvert: |
540 | 6 | if (!_.features().uconvert_spec_constant_op && |
541 | 6 | !_.HasCapability(spv::Capability::Kernel)) { |
542 | 3 | return _.diag(SPV_ERROR_INVALID_ID, inst) |
543 | 3 | << "Prior to SPIR-V 1.4, specialization constant operation " |
544 | 3 | "UConvert requires Kernel capability or extension " |
545 | 3 | "SPV_AMD_gpu_shader_int16"; |
546 | 3 | } |
547 | 3 | break; |
548 | | |
549 | 3 | case spv::Op::OpConvertFToS: |
550 | 6 | case spv::Op::OpConvertSToF: |
551 | 9 | case spv::Op::OpConvertFToU: |
552 | 12 | case spv::Op::OpConvertUToF: |
553 | 15 | case spv::Op::OpConvertPtrToU: |
554 | 18 | case spv::Op::OpConvertUToPtr: |
555 | 21 | case spv::Op::OpGenericCastToPtr: |
556 | 24 | case spv::Op::OpPtrCastToGeneric: |
557 | 27 | case spv::Op::OpBitcast: |
558 | 30 | case spv::Op::OpFNegate: |
559 | 33 | case spv::Op::OpFAdd: |
560 | 36 | case spv::Op::OpFSub: |
561 | 39 | case spv::Op::OpFMul: |
562 | 41 | case spv::Op::OpFDiv: |
563 | 44 | case spv::Op::OpFRem: |
564 | 47 | case spv::Op::OpFMod: |
565 | 50 | case spv::Op::OpAccessChain: |
566 | 53 | case spv::Op::OpInBoundsAccessChain: |
567 | 56 | case spv::Op::OpPtrAccessChain: |
568 | 59 | case spv::Op::OpInBoundsPtrAccessChain: |
569 | 59 | if (!_.HasCapability(spv::Capability::Kernel)) { |
570 | 44 | return _.diag(SPV_ERROR_INVALID_ID, inst) |
571 | 44 | << "Specialization constant operation " << spvOpcodeString(op) |
572 | 44 | << " requires Kernel capability"; |
573 | 44 | } |
574 | 15 | break; |
575 | | |
576 | 49 | default: |
577 | 49 | break; |
578 | 124 | } |
579 | | |
580 | | // TODO(dneto): Validate result type and arguments to the various operations. |
581 | 74 | return SPV_SUCCESS; |
582 | 124 | } |
583 | | |
584 | | } // namespace |
585 | | |
586 | 13.1M | spv_result_t ConstantPass(ValidationState_t& _, const Instruction* inst) { |
587 | 13.1M | switch (inst->opcode()) { |
588 | 3.31k | case spv::Op::OpConstantTrue: |
589 | 5.54k | case spv::Op::OpConstantFalse: |
590 | 6.20k | case spv::Op::OpSpecConstantTrue: |
591 | 6.73k | case spv::Op::OpSpecConstantFalse: |
592 | 6.73k | if (auto error = ValidateConstantBool(_, inst)) return error; |
593 | 6.71k | break; |
594 | 30.7k | case spv::Op::OpConstantComposite: |
595 | 31.1k | case spv::Op::OpSpecConstantComposite: |
596 | 31.1k | if (auto error = ValidateConstantComposite(_, inst)) return error; |
597 | 30.9k | break; |
598 | 30.9k | case spv::Op::OpConstantSampler: |
599 | 0 | if (auto error = ValidateConstantSampler(_, inst)) return error; |
600 | 0 | break; |
601 | 4.04k | case spv::Op::OpConstantNull: |
602 | 4.04k | if (auto error = ValidateConstantNull(_, inst)) return error; |
603 | 4.02k | break; |
604 | 4.02k | case spv::Op::OpSpecConstant: |
605 | 2.87k | if (auto error = ValidateSpecConstant(_, inst)) return error; |
606 | 2.87k | break; |
607 | 2.87k | case spv::Op::OpSpecConstantOp: |
608 | 124 | if (auto error = ValidateSpecConstantOp(_, inst)) return error; |
609 | 74 | break; |
610 | 13.0M | default: |
611 | 13.0M | break; |
612 | 13.1M | } |
613 | | |
614 | | // Generally disallow creating 8- or 16-bit constants unless the full |
615 | | // capabilities are present. |
616 | 13.1M | if (spvOpcodeIsConstant(inst->opcode()) && |
617 | 228k | _.HasCapability(spv::Capability::Shader) && |
618 | 227k | !_.IsPointerType(inst->type_id()) && |
619 | 227k | _.ContainsLimitedUseIntOrFloatType(inst->type_id())) { |
620 | 3 | return _.diag(SPV_ERROR_INVALID_ID, inst) |
621 | 3 | << "Cannot form constants of 8- or 16-bit types"; |
622 | 3 | } |
623 | | |
624 | 13.1M | return SPV_SUCCESS; |
625 | 13.1M | } |
626 | | |
627 | | } // namespace val |
628 | | } // namespace spvtools |