/src/spirv-tools/source/val/validate_arithmetics.cpp
Line | Count | Source |
1 | | // Copyright (c) 2017 Google Inc. |
2 | | // |
3 | | // Licensed under the Apache License, Version 2.0 (the "License"); |
4 | | // you may not use this file except in compliance with the License. |
5 | | // You may obtain a copy of the License at |
6 | | // |
7 | | // http://www.apache.org/licenses/LICENSE-2.0 |
8 | | // |
9 | | // Unless required by applicable law or agreed to in writing, software |
10 | | // distributed under the License is distributed on an "AS IS" BASIS, |
11 | | // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. |
12 | | // See the License for the specific language governing permissions and |
13 | | // limitations under the License. |
14 | | |
15 | | // Performs validation of arithmetic instructions. |
16 | | |
17 | | #include <vector> |
18 | | |
19 | | #include "source/opcode.h" |
20 | | #include "source/val/instruction.h" |
21 | | #include "source/val/validate.h" |
22 | | #include "source/val/validation_state.h" |
23 | | |
24 | | namespace spvtools { |
25 | | namespace val { |
26 | | |
27 | | spv_result_t ValidateFloat(ValidationState_t& _, const Instruction* inst, |
28 | 111k | uint32_t starting_index = 2) { |
29 | 111k | const spv::Op opcode = inst->opcode(); |
30 | 111k | const uint32_t result_type = inst->type_id(); |
31 | 111k | bool supportsCoopMat = |
32 | 111k | (opcode != spv::Op::OpFMul && opcode != spv::Op::OpFRem && |
33 | 80.2k | opcode != spv::Op::OpFMod); |
34 | 111k | bool supportsCoopVec = |
35 | 111k | (opcode != spv::Op::OpFRem && opcode != spv::Op::OpFMod); |
36 | 111k | if (!_.IsFloatScalarType(result_type) && !_.IsFloatVectorType(result_type) && |
37 | 48 | !(supportsCoopMat && _.IsFloatCooperativeMatrixType(result_type)) && |
38 | 48 | !(opcode == spv::Op::OpFMul && |
39 | 6 | _.IsCooperativeMatrixKHRType(result_type) && |
40 | 0 | _.IsFloatCooperativeMatrixType(result_type)) && |
41 | 48 | !(supportsCoopVec && _.IsFloatCooperativeVectorNVType(result_type))) |
42 | 48 | return _.diag(SPV_ERROR_INVALID_DATA, inst) |
43 | 48 | << "Expected floating scalar or vector type as Result Type: " |
44 | 48 | << spvOpcodeString(opcode); |
45 | | |
46 | 111k | for (size_t operand_index = starting_index; |
47 | 332k | operand_index < inst->operands().size(); ++operand_index) { |
48 | 221k | if (supportsCoopVec && _.IsCooperativeVectorNVType(result_type)) { |
49 | 0 | const uint32_t type_id = _.GetOperandTypeId(inst, operand_index); |
50 | 0 | if (!_.IsCooperativeVectorNVType(type_id)) { |
51 | 0 | return _.diag(SPV_ERROR_INVALID_DATA, inst) |
52 | 0 | << "Expected arithmetic operands to be of Result Type: " |
53 | 0 | << spvOpcodeString(opcode) << " operand index " << operand_index; |
54 | 0 | } |
55 | 0 | spv_result_t ret = |
56 | 0 | _.CooperativeVectorDimensionsMatch(inst, type_id, result_type); |
57 | 0 | if (ret != SPV_SUCCESS) return ret; |
58 | 221k | } else if (supportsCoopMat && _.IsCooperativeMatrixKHRType(result_type)) { |
59 | 0 | const uint32_t type_id = _.GetOperandTypeId(inst, operand_index); |
60 | 0 | if (!_.IsCooperativeMatrixKHRType(type_id) || |
61 | 0 | !_.IsFloatCooperativeMatrixType(type_id)) { |
62 | 0 | return _.diag(SPV_ERROR_INVALID_DATA, inst) |
63 | 0 | << "Expected arithmetic operands to be of Result Type: " |
64 | 0 | << spvOpcodeString(opcode) << " operand index " << operand_index; |
65 | 0 | } |
66 | 0 | spv_result_t ret = |
67 | 0 | _.CooperativeMatrixShapesMatch(inst, result_type, type_id, false); |
68 | 0 | if (ret != SPV_SUCCESS) return ret; |
69 | 221k | } else if (_.GetOperandTypeId(inst, operand_index) != result_type) |
70 | 53 | return _.diag(SPV_ERROR_INVALID_DATA, inst) |
71 | 53 | << "Expected arithmetic operands to be of Result Type: " |
72 | 53 | << spvOpcodeString(opcode) << " operand index " << operand_index; |
73 | 221k | } |
74 | 111k | return SPV_SUCCESS; |
75 | 111k | } |
76 | | |
77 | | spv_result_t ValidateUnsignedInt(ValidationState_t& _, const Instruction* inst, |
78 | 323 | uint32_t starting_index = 2) { |
79 | 323 | const spv::Op opcode = inst->opcode(); |
80 | 323 | const uint32_t result_type = inst->type_id(); |
81 | 323 | bool supportsCoopMat = (opcode == spv::Op::OpUDiv); |
82 | 323 | bool supportsCoopVec = (opcode == spv::Op::OpUDiv); |
83 | 323 | if (!_.IsUnsignedIntScalarType(result_type) && |
84 | 77 | !_.IsUnsignedIntVectorType(result_type) && |
85 | 39 | !(supportsCoopMat && _.IsUnsignedIntCooperativeMatrixType(result_type)) && |
86 | 39 | !(supportsCoopVec && _.IsUnsignedIntCooperativeVectorNVType(result_type))) |
87 | 39 | return _.diag(SPV_ERROR_INVALID_DATA, inst) |
88 | 39 | << "Expected unsigned int scalar or vector type as Result Type: " |
89 | 39 | << spvOpcodeString(opcode); |
90 | | |
91 | 284 | for (size_t operand_index = starting_index; |
92 | 817 | operand_index < inst->operands().size(); ++operand_index) { |
93 | 554 | if (supportsCoopVec && _.IsCooperativeVectorNVType(result_type)) { |
94 | 0 | const uint32_t type_id = _.GetOperandTypeId(inst, operand_index); |
95 | 0 | if (!_.IsCooperativeVectorNVType(type_id)) { |
96 | 0 | return _.diag(SPV_ERROR_INVALID_DATA, inst) |
97 | 0 | << "Expected arithmetic operands to be of Result Type: " |
98 | 0 | << spvOpcodeString(opcode) << " operand index " << operand_index; |
99 | 0 | } |
100 | 0 | spv_result_t ret = |
101 | 0 | _.CooperativeVectorDimensionsMatch(inst, type_id, result_type); |
102 | 0 | if (ret != SPV_SUCCESS) return ret; |
103 | 554 | } else if (supportsCoopMat && _.IsCooperativeMatrixKHRType(result_type)) { |
104 | 0 | const uint32_t type_id = _.GetOperandTypeId(inst, operand_index); |
105 | 0 | if (!_.IsCooperativeMatrixKHRType(type_id) || |
106 | 0 | !_.IsUnsignedIntCooperativeMatrixType(type_id)) { |
107 | 0 | return _.diag(SPV_ERROR_INVALID_DATA, inst) |
108 | 0 | << "Expected arithmetic operands to be of Result Type: " |
109 | 0 | << spvOpcodeString(opcode) << " operand index " << operand_index; |
110 | 0 | } |
111 | 0 | spv_result_t ret = |
112 | 0 | _.CooperativeMatrixShapesMatch(inst, result_type, type_id, false); |
113 | 0 | if (ret != SPV_SUCCESS) return ret; |
114 | 554 | } else if (_.GetOperandTypeId(inst, operand_index) != result_type) |
115 | 21 | return _.diag(SPV_ERROR_INVALID_DATA, inst) |
116 | 21 | << "Expected arithmetic operands to be of Result Type: " |
117 | 21 | << spvOpcodeString(opcode) << " operand index " << operand_index; |
118 | 554 | } |
119 | | |
120 | 263 | return SPV_SUCCESS; |
121 | 284 | } |
122 | | |
123 | | spv_result_t ValidateSignedInt(ValidationState_t& _, const Instruction* inst, |
124 | 80.2k | uint32_t starting_index = 2) { |
125 | 80.2k | const spv::Op opcode = inst->opcode(); |
126 | 80.2k | const uint32_t result_type = inst->type_id(); |
127 | 80.2k | bool supportsCoopMat = |
128 | 80.2k | (opcode != spv::Op::OpIMul && opcode != spv::Op::OpSRem && |
129 | 72.9k | opcode != spv::Op::OpSMod); |
130 | 80.2k | bool supportsCoopVec = |
131 | 80.2k | (opcode != spv::Op::OpSRem && opcode != spv::Op::OpSMod); |
132 | 80.2k | if (!_.IsIntScalarType(result_type) && !_.IsIntVectorType(result_type) && |
133 | 51 | !(supportsCoopMat && _.IsIntCooperativeMatrixType(result_type)) && |
134 | 51 | !(opcode == spv::Op::OpIMul && |
135 | 5 | _.IsCooperativeMatrixKHRType(result_type) && |
136 | 0 | _.IsIntCooperativeMatrixType(result_type)) && |
137 | 51 | !(supportsCoopVec && _.IsIntCooperativeVectorNVType(result_type))) |
138 | 51 | return _.diag(SPV_ERROR_INVALID_DATA, inst) |
139 | 51 | << "Expected int scalar or vector type as Result Type: " |
140 | 51 | << spvOpcodeString(opcode); |
141 | | |
142 | 80.1k | const uint32_t dimension = _.GetDimension(result_type); |
143 | 80.1k | const uint32_t bit_width = _.GetBitWidth(result_type); |
144 | | |
145 | 80.1k | for (size_t operand_index = starting_index; |
146 | 237k | operand_index < inst->operands().size(); ++operand_index) { |
147 | 157k | const uint32_t type_id = _.GetOperandTypeId(inst, operand_index); |
148 | | |
149 | 157k | if (supportsCoopVec && _.IsCooperativeVectorNVType(result_type)) { |
150 | 0 | if (!_.IsCooperativeVectorNVType(type_id)) { |
151 | 0 | return _.diag(SPV_ERROR_INVALID_DATA, inst) |
152 | 0 | << "Expected arithmetic operands to be of Result Type: " |
153 | 0 | << spvOpcodeString(opcode) << " operand index " << operand_index; |
154 | 0 | } |
155 | 0 | spv_result_t ret = |
156 | 0 | _.CooperativeVectorDimensionsMatch(inst, type_id, result_type); |
157 | 0 | if (ret != SPV_SUCCESS) return ret; |
158 | 0 | } |
159 | | |
160 | 157k | if (supportsCoopMat && _.IsCooperativeMatrixKHRType(result_type)) { |
161 | 0 | if (!_.IsCooperativeMatrixKHRType(type_id) || |
162 | 0 | !_.IsIntCooperativeMatrixType(type_id)) { |
163 | 0 | return _.diag(SPV_ERROR_INVALID_DATA, inst) |
164 | 0 | << "Expected arithmetic operands to be of Result Type: " |
165 | 0 | << spvOpcodeString(opcode) << " operand index " << operand_index; |
166 | 0 | } |
167 | 0 | spv_result_t ret = |
168 | 0 | _.CooperativeMatrixShapesMatch(inst, result_type, type_id, false); |
169 | 0 | if (ret != SPV_SUCCESS) return ret; |
170 | 0 | } |
171 | | |
172 | 157k | if (!type_id || |
173 | 157k | (!_.IsIntScalarType(type_id) && !_.IsIntVectorType(type_id) && |
174 | 32 | !(supportsCoopMat && _.IsIntCooperativeMatrixType(result_type)) && |
175 | 32 | !(opcode == spv::Op::OpIMul && |
176 | 3 | _.IsCooperativeMatrixKHRType(result_type) && |
177 | 0 | _.IsIntCooperativeMatrixType(result_type)) && |
178 | 32 | !(supportsCoopVec && _.IsIntCooperativeVectorNVType(result_type)))) |
179 | 36 | return _.diag(SPV_ERROR_INVALID_DATA, inst) |
180 | 36 | << "Expected int scalar or vector type as operand: " |
181 | 36 | << spvOpcodeString(opcode) << " operand index " << operand_index; |
182 | | |
183 | 157k | if (_.GetDimension(type_id) != dimension) |
184 | 26 | return _.diag(SPV_ERROR_INVALID_DATA, inst) |
185 | 26 | << "Expected arithmetic operands to have the same dimension " |
186 | 26 | << "as Result Type: " << spvOpcodeString(opcode) |
187 | 26 | << " operand index " << operand_index; |
188 | | |
189 | 157k | if (_.GetBitWidth(type_id) != bit_width) |
190 | 0 | return _.diag(SPV_ERROR_INVALID_DATA, inst) |
191 | 0 | << "Expected arithmetic operands to have the same bit width " |
192 | 0 | << "as Result Type: " << spvOpcodeString(opcode) |
193 | 0 | << " operand index " << operand_index; |
194 | 157k | } |
195 | 80.1k | return SPV_SUCCESS; |
196 | 80.1k | } |
197 | | |
198 | 85 | spv_result_t ValidateDot(ValidationState_t& _, const Instruction* inst) { |
199 | 85 | const spv::Op opcode = inst->opcode(); |
200 | 85 | const uint32_t result_type = inst->type_id(); |
201 | 85 | if (!_.IsFloatScalarType(result_type)) |
202 | 8 | return _.diag(SPV_ERROR_INVALID_DATA, inst) |
203 | 8 | << "Expected float scalar type as Result Type: " |
204 | 8 | << spvOpcodeString(opcode); |
205 | | |
206 | 77 | if (_.IsBfloat16ScalarType(result_type)) { |
207 | 0 | if (!_.HasCapability(spv::Capability::BFloat16DotProductKHR)) { |
208 | 0 | return _.diag(SPV_ERROR_INVALID_DATA, inst) |
209 | 0 | << "OpDot Result Type <id> " << _.getIdName(result_type) |
210 | 0 | << "requires BFloat16DotProductKHR be declared."; |
211 | 0 | } |
212 | 0 | } |
213 | | |
214 | 77 | uint32_t first_vector_num_components = 0; |
215 | | |
216 | 214 | for (size_t operand_index = 2; operand_index < inst->operands().size(); |
217 | 150 | ++operand_index) { |
218 | 150 | const uint32_t type_id = _.GetOperandTypeId(inst, operand_index); |
219 | | |
220 | 150 | if (!type_id || !_.IsFloatVectorType(type_id)) |
221 | 9 | return _.diag(SPV_ERROR_INVALID_DATA, inst) |
222 | 9 | << "Expected float vector as operand: " << spvOpcodeString(opcode) |
223 | 9 | << " operand index " << operand_index; |
224 | | |
225 | 141 | const uint32_t component_type = _.GetComponentType(type_id); |
226 | 141 | if (component_type != result_type) |
227 | 0 | return _.diag(SPV_ERROR_INVALID_DATA, inst) |
228 | 0 | << "Expected component type to be equal to Result Type: " |
229 | 0 | << spvOpcodeString(opcode) << " operand index " << operand_index; |
230 | | |
231 | 141 | const uint32_t num_components = _.GetDimension(type_id); |
232 | 141 | if (operand_index == 2) { |
233 | 73 | first_vector_num_components = num_components; |
234 | 73 | } else if (num_components != first_vector_num_components) { |
235 | 4 | return _.diag(SPV_ERROR_INVALID_DATA, inst) |
236 | 4 | << "Expected operands to have the same number of components: " |
237 | 4 | << spvOpcodeString(opcode); |
238 | 4 | } |
239 | 141 | } |
240 | 64 | return SPV_SUCCESS; |
241 | 77 | } |
242 | | |
243 | | spv_result_t ValidateVectorTimesScalar(ValidationState_t& _, |
244 | 8.59k | const Instruction* inst) { |
245 | 8.59k | const spv::Op opcode = inst->opcode(); |
246 | 8.59k | const uint32_t result_type = inst->type_id(); |
247 | 8.59k | if (!_.IsFloatVectorType(result_type) && |
248 | 3 | !_.IsFloatCooperativeVectorNVType(result_type)) |
249 | 3 | return _.diag(SPV_ERROR_INVALID_DATA, inst) |
250 | 3 | << "Expected float vector type as Result Type: " |
251 | 3 | << spvOpcodeString(opcode); |
252 | | |
253 | 8.59k | const uint32_t vector_type_id = _.GetOperandTypeId(inst, 2); |
254 | 8.59k | if (result_type != vector_type_id) |
255 | 7 | return _.diag(SPV_ERROR_INVALID_DATA, inst) |
256 | 7 | << "Expected vector operand type to be equal to Result Type: " |
257 | 7 | << spvOpcodeString(opcode); |
258 | | |
259 | 8.58k | const uint32_t component_type = _.GetComponentType(vector_type_id); |
260 | | |
261 | 8.58k | const uint32_t scalar_type_id = _.GetOperandTypeId(inst, 3); |
262 | 8.58k | if (component_type != scalar_type_id) |
263 | 3 | return _.diag(SPV_ERROR_INVALID_DATA, inst) |
264 | 3 | << "Expected scalar operand type to be equal to the component " |
265 | 3 | << "type of the vector operand: " << spvOpcodeString(opcode); |
266 | | |
267 | 8.58k | return SPV_SUCCESS; |
268 | 8.58k | } |
269 | | |
270 | | spv_result_t ValidateMatrixTimesScalar(ValidationState_t& _, |
271 | 44 | const Instruction* inst) { |
272 | 44 | const spv::Op opcode = inst->opcode(); |
273 | 44 | const uint32_t result_type = inst->type_id(); |
274 | 44 | if (!_.IsFloatMatrixType(result_type) && |
275 | 26 | !(_.IsCooperativeMatrixType(result_type))) |
276 | 26 | return _.diag(SPV_ERROR_INVALID_DATA, inst) |
277 | 26 | << "Expected float matrix type as Result Type: " |
278 | 26 | << spvOpcodeString(opcode); |
279 | | |
280 | 18 | const uint32_t matrix_type_id = _.GetOperandTypeId(inst, 2); |
281 | 18 | if (result_type != matrix_type_id) |
282 | 5 | return _.diag(SPV_ERROR_INVALID_DATA, inst) |
283 | 5 | << "Expected matrix operand type to be equal to Result Type: " |
284 | 5 | << spvOpcodeString(opcode); |
285 | | |
286 | 13 | const uint32_t component_type = _.GetComponentType(matrix_type_id); |
287 | | |
288 | 13 | const uint32_t scalar_type_id = _.GetOperandTypeId(inst, 3); |
289 | 13 | if (component_type != scalar_type_id) |
290 | 11 | return _.diag(SPV_ERROR_INVALID_DATA, inst) |
291 | 11 | << "Expected scalar operand type to be equal to the component " |
292 | 11 | << "type of the matrix operand: " << spvOpcodeString(opcode); |
293 | | |
294 | 2 | return SPV_SUCCESS; |
295 | 13 | } |
296 | | |
297 | | spv_result_t ValidateVectorTimesMatrix(ValidationState_t& _, |
298 | 1.26k | const Instruction* inst) { |
299 | 1.26k | const spv::Op opcode = inst->opcode(); |
300 | 1.26k | const uint32_t result_type = inst->type_id(); |
301 | 1.26k | const uint32_t vector_type_id = _.GetOperandTypeId(inst, 2); |
302 | 1.26k | const uint32_t matrix_type_id = _.GetOperandTypeId(inst, 3); |
303 | | |
304 | 1.26k | if (!_.IsFloatVectorType(result_type)) |
305 | 6 | return _.diag(SPV_ERROR_INVALID_DATA, inst) |
306 | 6 | << "Expected float vector type as Result Type: " |
307 | 6 | << spvOpcodeString(opcode); |
308 | | |
309 | 1.26k | const uint32_t res_component_type = _.GetComponentType(result_type); |
310 | | |
311 | 1.26k | if (!vector_type_id || !_.IsFloatVectorType(vector_type_id)) |
312 | 5 | return _.diag(SPV_ERROR_INVALID_DATA, inst) |
313 | 5 | << "Expected float vector type as left operand: " |
314 | 5 | << spvOpcodeString(opcode); |
315 | | |
316 | 1.25k | if (res_component_type != _.GetComponentType(vector_type_id)) |
317 | 0 | return _.diag(SPV_ERROR_INVALID_DATA, inst) |
318 | 0 | << "Expected component types of Result Type and vector to be " |
319 | 0 | << "equal: " << spvOpcodeString(opcode); |
320 | | |
321 | 1.25k | uint32_t matrix_num_rows = 0; |
322 | 1.25k | uint32_t matrix_num_cols = 0; |
323 | 1.25k | uint32_t matrix_col_type = 0; |
324 | 1.25k | uint32_t matrix_component_type = 0; |
325 | 1.25k | if (!_.GetMatrixTypeInfo(matrix_type_id, &matrix_num_rows, &matrix_num_cols, |
326 | 1.25k | &matrix_col_type, &matrix_component_type)) |
327 | 4 | return _.diag(SPV_ERROR_INVALID_DATA, inst) |
328 | 4 | << "Expected float matrix type as right operand: " |
329 | 4 | << spvOpcodeString(opcode); |
330 | | |
331 | 1.25k | if (res_component_type != matrix_component_type) |
332 | 0 | return _.diag(SPV_ERROR_INVALID_DATA, inst) |
333 | 0 | << "Expected component types of Result Type and matrix to be " |
334 | 0 | << "equal: " << spvOpcodeString(opcode); |
335 | | |
336 | 1.25k | if (matrix_num_cols != _.GetDimension(result_type)) |
337 | 4 | return _.diag(SPV_ERROR_INVALID_DATA, inst) |
338 | 4 | << "Expected number of columns of the matrix to be equal to " |
339 | 4 | << "Result Type vector size: " << spvOpcodeString(opcode); |
340 | | |
341 | 1.24k | if (matrix_num_rows != _.GetDimension(vector_type_id)) |
342 | 4 | return _.diag(SPV_ERROR_INVALID_DATA, inst) |
343 | 4 | << "Expected number of rows of the matrix to be equal to the " |
344 | 4 | << "vector operand size: " << spvOpcodeString(opcode); |
345 | 1.24k | return SPV_SUCCESS; |
346 | 1.24k | } |
347 | | |
348 | | spv_result_t ValidateMatrixTimesVector(ValidationState_t& _, |
349 | 174 | const Instruction* inst) { |
350 | 174 | const spv::Op opcode = inst->opcode(); |
351 | 174 | const uint32_t result_type = inst->type_id(); |
352 | 174 | const uint32_t matrix_type_id = _.GetOperandTypeId(inst, 2); |
353 | 174 | const uint32_t vector_type_id = _.GetOperandTypeId(inst, 3); |
354 | | |
355 | 174 | if (!_.IsFloatVectorType(result_type)) |
356 | 12 | return _.diag(SPV_ERROR_INVALID_DATA, inst) |
357 | 12 | << "Expected float vector type as Result Type: " |
358 | 12 | << spvOpcodeString(opcode); |
359 | | |
360 | 162 | uint32_t matrix_num_rows = 0; |
361 | 162 | uint32_t matrix_num_cols = 0; |
362 | 162 | uint32_t matrix_col_type = 0; |
363 | 162 | uint32_t matrix_component_type = 0; |
364 | 162 | if (!_.GetMatrixTypeInfo(matrix_type_id, &matrix_num_rows, &matrix_num_cols, |
365 | 162 | &matrix_col_type, &matrix_component_type)) |
366 | 5 | return _.diag(SPV_ERROR_INVALID_DATA, inst) |
367 | 5 | << "Expected float matrix type as left operand: " |
368 | 5 | << spvOpcodeString(opcode); |
369 | | |
370 | 157 | if (result_type != matrix_col_type) |
371 | 3 | return _.diag(SPV_ERROR_INVALID_DATA, inst) |
372 | 3 | << "Expected column type of the matrix to be equal to Result " |
373 | 3 | "Type: " |
374 | 3 | << spvOpcodeString(opcode); |
375 | | |
376 | 154 | if (!vector_type_id || !_.IsFloatVectorType(vector_type_id)) |
377 | 6 | return _.diag(SPV_ERROR_INVALID_DATA, inst) |
378 | 6 | << "Expected float vector type as right operand: " |
379 | 6 | << spvOpcodeString(opcode); |
380 | | |
381 | 148 | if (matrix_component_type != _.GetComponentType(vector_type_id)) |
382 | 0 | return _.diag(SPV_ERROR_INVALID_DATA, inst) |
383 | 0 | << "Expected component types of the operands to be equal: " |
384 | 0 | << spvOpcodeString(opcode); |
385 | | |
386 | 148 | if (matrix_num_cols != _.GetDimension(vector_type_id)) |
387 | 4 | return _.diag(SPV_ERROR_INVALID_DATA, inst) |
388 | 4 | << "Expected number of columns of the matrix to be equal to the " |
389 | 4 | << "vector size: " << spvOpcodeString(opcode); |
390 | | |
391 | 144 | return SPV_SUCCESS; |
392 | 148 | } |
393 | | |
394 | | spv_result_t ValidateMatrixTimesMatrix(ValidationState_t& _, |
395 | 1.62k | const Instruction* inst) { |
396 | 1.62k | const spv::Op opcode = inst->opcode(); |
397 | 1.62k | const uint32_t result_type = inst->type_id(); |
398 | 1.62k | const uint32_t left_type_id = _.GetOperandTypeId(inst, 2); |
399 | 1.62k | const uint32_t right_type_id = _.GetOperandTypeId(inst, 3); |
400 | | |
401 | 1.62k | uint32_t res_num_rows = 0; |
402 | 1.62k | uint32_t res_num_cols = 0; |
403 | 1.62k | uint32_t res_col_type = 0; |
404 | 1.62k | uint32_t res_component_type = 0; |
405 | 1.62k | if (!_.GetMatrixTypeInfo(result_type, &res_num_rows, &res_num_cols, |
406 | 1.62k | &res_col_type, &res_component_type)) |
407 | 9 | return _.diag(SPV_ERROR_INVALID_DATA, inst) |
408 | 9 | << "Expected float matrix type as Result Type: " |
409 | 9 | << spvOpcodeString(opcode); |
410 | | |
411 | 1.61k | uint32_t left_num_rows = 0; |
412 | 1.61k | uint32_t left_num_cols = 0; |
413 | 1.61k | uint32_t left_col_type = 0; |
414 | 1.61k | uint32_t left_component_type = 0; |
415 | 1.61k | if (!_.GetMatrixTypeInfo(left_type_id, &left_num_rows, &left_num_cols, |
416 | 1.61k | &left_col_type, &left_component_type)) |
417 | 4 | return _.diag(SPV_ERROR_INVALID_DATA, inst) |
418 | 4 | << "Expected float matrix type as left operand: " |
419 | 4 | << spvOpcodeString(opcode); |
420 | | |
421 | 1.61k | uint32_t right_num_rows = 0; |
422 | 1.61k | uint32_t right_num_cols = 0; |
423 | 1.61k | uint32_t right_col_type = 0; |
424 | 1.61k | uint32_t right_component_type = 0; |
425 | 1.61k | if (!_.GetMatrixTypeInfo(right_type_id, &right_num_rows, &right_num_cols, |
426 | 1.61k | &right_col_type, &right_component_type)) |
427 | 6 | return _.diag(SPV_ERROR_INVALID_DATA, inst) |
428 | 6 | << "Expected float matrix type as right operand: " |
429 | 6 | << spvOpcodeString(opcode); |
430 | | |
431 | 1.60k | if (!_.IsFloatScalarType(res_component_type)) |
432 | 0 | return _.diag(SPV_ERROR_INVALID_DATA, inst) |
433 | 0 | << "Expected float matrix type as Result Type: " |
434 | 0 | << spvOpcodeString(opcode); |
435 | | |
436 | 1.60k | if (res_col_type != left_col_type) |
437 | 0 | return _.diag(SPV_ERROR_INVALID_DATA, inst) |
438 | 0 | << "Expected column types of Result Type and left matrix to be " |
439 | 0 | << "equal: " << spvOpcodeString(opcode); |
440 | | |
441 | 1.60k | if (res_component_type != right_component_type) |
442 | 0 | return _.diag(SPV_ERROR_INVALID_DATA, inst) |
443 | 0 | << "Expected component types of Result Type and right matrix to " |
444 | 0 | "be " |
445 | 0 | << "equal: " << spvOpcodeString(opcode); |
446 | | |
447 | 1.60k | if (res_num_cols != right_num_cols) |
448 | 0 | return _.diag(SPV_ERROR_INVALID_DATA, inst) |
449 | 0 | << "Expected number of columns of Result Type and right matrix " |
450 | 0 | "to " |
451 | 0 | << "be equal: " << spvOpcodeString(opcode); |
452 | | |
453 | 1.60k | if (left_num_cols != right_num_rows) |
454 | 2 | return _.diag(SPV_ERROR_INVALID_DATA, inst) |
455 | 2 | << "Expected number of columns of left matrix and number of " |
456 | 2 | "rows " |
457 | 2 | << "of right matrix to be equal: " << spvOpcodeString(opcode); |
458 | | |
459 | 1.60k | assert(left_num_rows == res_num_rows); |
460 | 1.60k | return SPV_SUCCESS; |
461 | 1.60k | } |
462 | | |
463 | | spv_result_t ValidateOuterProduct(ValidationState_t& _, |
464 | 37 | const Instruction* inst) { |
465 | 37 | const spv::Op opcode = inst->opcode(); |
466 | 37 | const uint32_t result_type = inst->type_id(); |
467 | 37 | const uint32_t left_type_id = _.GetOperandTypeId(inst, 2); |
468 | 37 | const uint32_t right_type_id = _.GetOperandTypeId(inst, 3); |
469 | | |
470 | 37 | uint32_t res_num_rows = 0; |
471 | 37 | uint32_t res_num_cols = 0; |
472 | 37 | uint32_t res_col_type = 0; |
473 | 37 | uint32_t res_component_type = 0; |
474 | 37 | if (!_.GetMatrixTypeInfo(result_type, &res_num_rows, &res_num_cols, |
475 | 37 | &res_col_type, &res_component_type)) |
476 | 9 | return _.diag(SPV_ERROR_INVALID_DATA, inst) |
477 | 9 | << "Expected float matrix type as Result Type: " |
478 | 9 | << spvOpcodeString(opcode); |
479 | | |
480 | 28 | if (left_type_id != res_col_type) |
481 | 4 | return _.diag(SPV_ERROR_INVALID_DATA, inst) |
482 | 4 | << "Expected column type of Result Type to be equal to the type " |
483 | 4 | << "of the left operand: " << spvOpcodeString(opcode); |
484 | | |
485 | 24 | if (!right_type_id || !_.IsFloatVectorType(right_type_id)) |
486 | 5 | return _.diag(SPV_ERROR_INVALID_DATA, inst) |
487 | 5 | << "Expected float vector type as right operand: " |
488 | 5 | << spvOpcodeString(opcode); |
489 | | |
490 | 19 | if (res_component_type != _.GetComponentType(right_type_id)) |
491 | 0 | return _.diag(SPV_ERROR_INVALID_DATA, inst) |
492 | 0 | << "Expected component types of the operands to be equal: " |
493 | 0 | << spvOpcodeString(opcode); |
494 | | |
495 | 19 | if (res_num_cols != _.GetDimension(right_type_id)) |
496 | 3 | return _.diag(SPV_ERROR_INVALID_DATA, inst) |
497 | 3 | << "Expected number of columns of the matrix to be equal to the " |
498 | 3 | << "vector size of the right operand: " << spvOpcodeString(opcode); |
499 | | |
500 | 16 | return SPV_SUCCESS; |
501 | 19 | } |
502 | | |
503 | | spv_result_t ValidateExtendedCarry(ValidationState_t& _, |
504 | 237 | const Instruction* inst) { |
505 | 237 | const spv::Op opcode = inst->opcode(); |
506 | 237 | const uint32_t result_type = inst->type_id(); |
507 | | |
508 | 237 | std::vector<uint32_t> result_types; |
509 | 237 | if (!_.GetStructMemberTypes(result_type, &result_types)) |
510 | 41 | return _.diag(SPV_ERROR_INVALID_DATA, inst) |
511 | 41 | << "Expected a struct as Result Type: " << spvOpcodeString(opcode); |
512 | | |
513 | 196 | if (result_types.size() != 2) |
514 | 3 | return _.diag(SPV_ERROR_INVALID_DATA, inst) |
515 | 3 | << "Expected Result Type struct to have two members: " |
516 | 3 | << spvOpcodeString(opcode); |
517 | | |
518 | 193 | if (opcode == spv::Op::OpSMulExtended) { |
519 | 32 | if (!_.IsIntScalarType(result_types[0]) && |
520 | 4 | !_.IsIntVectorType(result_types[0])) |
521 | 3 | return _.diag(SPV_ERROR_INVALID_DATA, inst) |
522 | 3 | << "Expected Result Type struct member types to be integer " |
523 | 3 | "scalar " |
524 | 3 | << "or vector: " << spvOpcodeString(opcode); |
525 | 161 | } else { |
526 | 161 | if (!_.IsUnsignedIntScalarType(result_types[0]) && |
527 | 5 | !_.IsUnsignedIntVectorType(result_types[0])) |
528 | 5 | return _.diag(SPV_ERROR_INVALID_DATA, inst) |
529 | 5 | << "Expected Result Type struct member types to be unsigned " |
530 | 5 | << "integer scalar or vector: " << spvOpcodeString(opcode); |
531 | 161 | } |
532 | | |
533 | 185 | if (result_types[0] != result_types[1]) |
534 | 4 | return _.diag(SPV_ERROR_INVALID_DATA, inst) |
535 | 4 | << "Expected Result Type struct member types to be identical: " |
536 | 4 | << spvOpcodeString(opcode); |
537 | | |
538 | 181 | const uint32_t left_type_id = _.GetOperandTypeId(inst, 2); |
539 | 181 | const uint32_t right_type_id = _.GetOperandTypeId(inst, 3); |
540 | | |
541 | 181 | if (left_type_id != result_types[0] || right_type_id != result_types[0]) |
542 | 10 | return _.diag(SPV_ERROR_INVALID_DATA, inst) |
543 | 10 | << "Expected both operands to be of Result Type member type: " |
544 | 10 | << spvOpcodeString(opcode); |
545 | 171 | return SPV_SUCCESS; |
546 | 181 | } |
547 | | |
548 | | spv_result_t ValidateCooperativeMatrixMulAddNV(ValidationState_t& _, |
549 | 0 | const Instruction* inst) { |
550 | 0 | const spv::Op opcode = inst->opcode(); |
551 | 0 | const uint32_t D_type_id = _.GetOperandTypeId(inst, 1); |
552 | 0 | const uint32_t A_type_id = _.GetOperandTypeId(inst, 2); |
553 | 0 | const uint32_t B_type_id = _.GetOperandTypeId(inst, 3); |
554 | 0 | const uint32_t C_type_id = _.GetOperandTypeId(inst, 4); |
555 | |
|
556 | 0 | if (!_.IsCooperativeMatrixNVType(A_type_id)) { |
557 | 0 | return _.diag(SPV_ERROR_INVALID_DATA, inst) |
558 | 0 | << "Expected cooperative matrix type as A Type: " |
559 | 0 | << spvOpcodeString(opcode); |
560 | 0 | } |
561 | 0 | if (!_.IsCooperativeMatrixNVType(B_type_id)) { |
562 | 0 | return _.diag(SPV_ERROR_INVALID_DATA, inst) |
563 | 0 | << "Expected cooperative matrix type as B Type: " |
564 | 0 | << spvOpcodeString(opcode); |
565 | 0 | } |
566 | 0 | if (!_.IsCooperativeMatrixNVType(C_type_id)) { |
567 | 0 | return _.diag(SPV_ERROR_INVALID_DATA, inst) |
568 | 0 | << "Expected cooperative matrix type as C Type: " |
569 | 0 | << spvOpcodeString(opcode); |
570 | 0 | } |
571 | 0 | if (!_.IsCooperativeMatrixNVType(D_type_id)) { |
572 | 0 | return _.diag(SPV_ERROR_INVALID_DATA, inst) |
573 | 0 | << "Expected cooperative matrix type as Result Type: " |
574 | 0 | << spvOpcodeString(opcode); |
575 | 0 | } |
576 | | |
577 | 0 | const auto A = _.FindDef(A_type_id); |
578 | 0 | const auto B = _.FindDef(B_type_id); |
579 | 0 | const auto C = _.FindDef(C_type_id); |
580 | 0 | const auto D = _.FindDef(D_type_id); |
581 | |
|
582 | 0 | std::tuple<bool, bool, uint32_t> A_scope, B_scope, C_scope, D_scope, A_rows, |
583 | 0 | B_rows, C_rows, D_rows, A_cols, B_cols, C_cols, D_cols; |
584 | |
|
585 | 0 | A_scope = _.EvalInt32IfConst(A->GetOperandAs<uint32_t>(2)); |
586 | 0 | B_scope = _.EvalInt32IfConst(B->GetOperandAs<uint32_t>(2)); |
587 | 0 | C_scope = _.EvalInt32IfConst(C->GetOperandAs<uint32_t>(2)); |
588 | 0 | D_scope = _.EvalInt32IfConst(D->GetOperandAs<uint32_t>(2)); |
589 | |
|
590 | 0 | A_rows = _.EvalInt32IfConst(A->GetOperandAs<uint32_t>(3)); |
591 | 0 | B_rows = _.EvalInt32IfConst(B->GetOperandAs<uint32_t>(3)); |
592 | 0 | C_rows = _.EvalInt32IfConst(C->GetOperandAs<uint32_t>(3)); |
593 | 0 | D_rows = _.EvalInt32IfConst(D->GetOperandAs<uint32_t>(3)); |
594 | |
|
595 | 0 | A_cols = _.EvalInt32IfConst(A->GetOperandAs<uint32_t>(4)); |
596 | 0 | B_cols = _.EvalInt32IfConst(B->GetOperandAs<uint32_t>(4)); |
597 | 0 | C_cols = _.EvalInt32IfConst(C->GetOperandAs<uint32_t>(4)); |
598 | 0 | D_cols = _.EvalInt32IfConst(D->GetOperandAs<uint32_t>(4)); |
599 | |
|
600 | 0 | const auto notEqual = [](std::tuple<bool, bool, uint32_t> X, |
601 | 0 | std::tuple<bool, bool, uint32_t> Y) { |
602 | 0 | return (std::get<1>(X) && std::get<1>(Y) && |
603 | 0 | std::get<2>(X) != std::get<2>(Y)); |
604 | 0 | }; |
605 | |
|
606 | 0 | if (notEqual(A_scope, B_scope) || notEqual(A_scope, C_scope) || |
607 | 0 | notEqual(A_scope, D_scope) || notEqual(B_scope, C_scope) || |
608 | 0 | notEqual(B_scope, D_scope) || notEqual(C_scope, D_scope)) { |
609 | 0 | return _.diag(SPV_ERROR_INVALID_DATA, inst) |
610 | 0 | << "Cooperative matrix scopes must match: " |
611 | 0 | << spvOpcodeString(opcode); |
612 | 0 | } |
613 | | |
614 | 0 | if (notEqual(A_rows, C_rows) || notEqual(A_rows, D_rows) || |
615 | 0 | notEqual(C_rows, D_rows)) { |
616 | 0 | return _.diag(SPV_ERROR_INVALID_DATA, inst) |
617 | 0 | << "Cooperative matrix 'M' mismatch: " << spvOpcodeString(opcode); |
618 | 0 | } |
619 | | |
620 | 0 | if (notEqual(B_cols, C_cols) || notEqual(B_cols, D_cols) || |
621 | 0 | notEqual(C_cols, D_cols)) { |
622 | 0 | return _.diag(SPV_ERROR_INVALID_DATA, inst) |
623 | 0 | << "Cooperative matrix 'N' mismatch: " << spvOpcodeString(opcode); |
624 | 0 | } |
625 | | |
626 | 0 | if (notEqual(A_cols, B_rows)) { |
627 | 0 | return _.diag(SPV_ERROR_INVALID_DATA, inst) |
628 | 0 | << "Cooperative matrix 'K' mismatch: " << spvOpcodeString(opcode); |
629 | 0 | } |
630 | 0 | return SPV_SUCCESS; |
631 | 0 | } |
632 | | |
633 | | spv_result_t ValidateCooperativeMatrixMulAddKHR(ValidationState_t& _, |
634 | 0 | const Instruction* inst) { |
635 | 0 | const spv::Op opcode = inst->opcode(); |
636 | 0 | const uint32_t D_type_id = _.GetOperandTypeId(inst, 1); |
637 | 0 | const uint32_t A_type_id = _.GetOperandTypeId(inst, 2); |
638 | 0 | const uint32_t B_type_id = _.GetOperandTypeId(inst, 3); |
639 | 0 | const uint32_t C_type_id = _.GetOperandTypeId(inst, 4); |
640 | |
|
641 | 0 | if (!_.IsCooperativeMatrixAType(A_type_id)) { |
642 | 0 | return _.diag(SPV_ERROR_INVALID_DATA, inst) |
643 | 0 | << "Cooperative matrix type must be A Type: " |
644 | 0 | << spvOpcodeString(opcode); |
645 | 0 | } |
646 | 0 | if (!_.IsCooperativeMatrixBType(B_type_id)) { |
647 | 0 | return _.diag(SPV_ERROR_INVALID_DATA, inst) |
648 | 0 | << "Cooperative matrix type must be B Type: " |
649 | 0 | << spvOpcodeString(opcode); |
650 | 0 | } |
651 | 0 | if (!_.IsCooperativeMatrixAccType(C_type_id)) { |
652 | 0 | return _.diag(SPV_ERROR_INVALID_DATA, inst) |
653 | 0 | << "Cooperative matrix type must be Accumulator Type: " |
654 | 0 | << spvOpcodeString(opcode); |
655 | 0 | } |
656 | 0 | if (!_.IsCooperativeMatrixKHRType(D_type_id)) { |
657 | 0 | return _.diag(SPV_ERROR_INVALID_DATA, inst) |
658 | 0 | << "Expected cooperative matrix type as Result Type: " |
659 | 0 | << spvOpcodeString(opcode); |
660 | 0 | } |
661 | | |
662 | 0 | const auto A = _.FindDef(A_type_id); |
663 | 0 | const auto B = _.FindDef(B_type_id); |
664 | 0 | const auto C = _.FindDef(C_type_id); |
665 | 0 | const auto D = _.FindDef(D_type_id); |
666 | |
|
667 | 0 | std::tuple<bool, bool, uint32_t> A_scope, B_scope, C_scope, D_scope, A_rows, |
668 | 0 | B_rows, C_rows, D_rows, A_cols, B_cols, C_cols, D_cols; |
669 | |
|
670 | 0 | A_scope = _.EvalInt32IfConst(A->GetOperandAs<uint32_t>(2)); |
671 | 0 | B_scope = _.EvalInt32IfConst(B->GetOperandAs<uint32_t>(2)); |
672 | 0 | C_scope = _.EvalInt32IfConst(C->GetOperandAs<uint32_t>(2)); |
673 | 0 | D_scope = _.EvalInt32IfConst(D->GetOperandAs<uint32_t>(2)); |
674 | |
|
675 | 0 | A_rows = _.EvalInt32IfConst(A->GetOperandAs<uint32_t>(3)); |
676 | 0 | B_rows = _.EvalInt32IfConst(B->GetOperandAs<uint32_t>(3)); |
677 | 0 | C_rows = _.EvalInt32IfConst(C->GetOperandAs<uint32_t>(3)); |
678 | 0 | D_rows = _.EvalInt32IfConst(D->GetOperandAs<uint32_t>(3)); |
679 | |
|
680 | 0 | A_cols = _.EvalInt32IfConst(A->GetOperandAs<uint32_t>(4)); |
681 | 0 | B_cols = _.EvalInt32IfConst(B->GetOperandAs<uint32_t>(4)); |
682 | 0 | C_cols = _.EvalInt32IfConst(C->GetOperandAs<uint32_t>(4)); |
683 | 0 | D_cols = _.EvalInt32IfConst(D->GetOperandAs<uint32_t>(4)); |
684 | |
|
685 | 0 | const auto notEqual = [](std::tuple<bool, bool, uint32_t> X, |
686 | 0 | std::tuple<bool, bool, uint32_t> Y) { |
687 | 0 | return (std::get<1>(X) && std::get<1>(Y) && |
688 | 0 | std::get<2>(X) != std::get<2>(Y)); |
689 | 0 | }; |
690 | |
|
691 | 0 | if (notEqual(A_scope, B_scope) || notEqual(A_scope, C_scope) || |
692 | 0 | notEqual(A_scope, D_scope) || notEqual(B_scope, C_scope) || |
693 | 0 | notEqual(B_scope, D_scope) || notEqual(C_scope, D_scope)) { |
694 | 0 | return _.diag(SPV_ERROR_INVALID_DATA, inst) |
695 | 0 | << "Cooperative matrix scopes must match: " |
696 | 0 | << spvOpcodeString(opcode); |
697 | 0 | } |
698 | | |
699 | 0 | if (notEqual(A_rows, C_rows) || notEqual(A_rows, D_rows) || |
700 | 0 | notEqual(C_rows, D_rows)) { |
701 | 0 | return _.diag(SPV_ERROR_INVALID_DATA, inst) |
702 | 0 | << "Cooperative matrix 'M' mismatch: " << spvOpcodeString(opcode); |
703 | 0 | } |
704 | | |
705 | 0 | if (notEqual(B_cols, C_cols) || notEqual(B_cols, D_cols) || |
706 | 0 | notEqual(C_cols, D_cols)) { |
707 | 0 | return _.diag(SPV_ERROR_INVALID_DATA, inst) |
708 | 0 | << "Cooperative matrix 'N' mismatch: " << spvOpcodeString(opcode); |
709 | 0 | } |
710 | | |
711 | 0 | if (notEqual(A_cols, B_rows)) { |
712 | 0 | return _.diag(SPV_ERROR_INVALID_DATA, inst) |
713 | 0 | << "Cooperative matrix 'K' mismatch: " << spvOpcodeString(opcode); |
714 | 0 | } |
715 | 0 | return SPV_SUCCESS; |
716 | 0 | } |
717 | | |
718 | | spv_result_t ValidateCooperativeMatrixReduceNV(ValidationState_t& _, |
719 | 0 | const Instruction* inst) { |
720 | 0 | const spv::Op opcode = inst->opcode(); |
721 | 0 | const uint32_t result_type = inst->type_id(); |
722 | 0 | if (!_.IsCooperativeMatrixKHRType(result_type)) { |
723 | 0 | return _.diag(SPV_ERROR_INVALID_DATA, inst) |
724 | 0 | << "Result Type must be a cooperative matrix type: " |
725 | 0 | << spvOpcodeString(opcode); |
726 | 0 | } |
727 | | |
728 | 0 | const auto result_comp_type_id = |
729 | 0 | _.FindDef(result_type)->GetOperandAs<uint32_t>(1); |
730 | |
|
731 | 0 | const auto matrix_id = inst->GetOperandAs<uint32_t>(2); |
732 | 0 | const auto matrix = _.FindDef(matrix_id); |
733 | 0 | const auto matrix_type_id = matrix->type_id(); |
734 | 0 | if (!_.IsCooperativeMatrixKHRType(matrix_type_id)) { |
735 | 0 | return _.diag(SPV_ERROR_INVALID_DATA, inst) |
736 | 0 | << "Matrix must have a cooperative matrix type: " |
737 | 0 | << spvOpcodeString(opcode); |
738 | 0 | } |
739 | 0 | const auto matrix_type = _.FindDef(matrix_type_id); |
740 | 0 | const auto matrix_comp_type_id = matrix_type->GetOperandAs<uint32_t>(1); |
741 | 0 | if (matrix_comp_type_id != result_comp_type_id) { |
742 | 0 | return _.diag(SPV_ERROR_INVALID_DATA, inst) |
743 | 0 | << "Result Type and Matrix type must have the same component " |
744 | 0 | "type: " |
745 | 0 | << spvOpcodeString(opcode); |
746 | 0 | } |
747 | 0 | if (_.FindDef(result_type)->GetOperandAs<uint32_t>(2) != |
748 | 0 | matrix_type->GetOperandAs<uint32_t>(2)) { |
749 | 0 | return _.diag(SPV_ERROR_INVALID_DATA, inst) |
750 | 0 | << "Result Type and Matrix type must have the same scope: " |
751 | 0 | << spvOpcodeString(opcode); |
752 | 0 | } |
753 | | |
754 | 0 | if (!_.IsCooperativeMatrixAccType(result_type)) { |
755 | 0 | return _.diag(SPV_ERROR_INVALID_DATA, inst) |
756 | 0 | << "Result Type must have UseAccumulator: " |
757 | 0 | << spvOpcodeString(opcode); |
758 | 0 | } |
759 | 0 | if (!_.IsCooperativeMatrixAccType(matrix_type_id)) { |
760 | 0 | return _.diag(SPV_ERROR_INVALID_DATA, inst) |
761 | 0 | << "Matrix type must have UseAccumulator: " |
762 | 0 | << spvOpcodeString(opcode); |
763 | 0 | } |
764 | | |
765 | 0 | const auto reduce_value = inst->GetOperandAs<uint32_t>(3); |
766 | |
|
767 | 0 | if ((reduce_value & |
768 | 0 | uint32_t( |
769 | 0 | spv::CooperativeMatrixReduceMask::CooperativeMatrixReduce2x2)) && |
770 | 0 | (reduce_value & uint32_t(spv::CooperativeMatrixReduceMask::Row | |
771 | 0 | spv::CooperativeMatrixReduceMask::Column))) { |
772 | 0 | return _.diag(SPV_ERROR_INVALID_DATA, inst) |
773 | 0 | << "Reduce 2x2 must not be used with Row/Column: " |
774 | 0 | << spvOpcodeString(opcode); |
775 | 0 | } |
776 | | |
777 | 0 | std::tuple<bool, bool, uint32_t> result_rows, result_cols, matrix_rows, |
778 | 0 | matrix_cols; |
779 | 0 | result_rows = |
780 | 0 | _.EvalInt32IfConst(_.FindDef(result_type)->GetOperandAs<uint32_t>(3)); |
781 | 0 | result_cols = |
782 | 0 | _.EvalInt32IfConst(_.FindDef(result_type)->GetOperandAs<uint32_t>(4)); |
783 | 0 | matrix_rows = _.EvalInt32IfConst(matrix_type->GetOperandAs<uint32_t>(3)); |
784 | 0 | matrix_cols = _.EvalInt32IfConst(matrix_type->GetOperandAs<uint32_t>(4)); |
785 | |
|
786 | 0 | if (reduce_value & |
787 | 0 | uint32_t(spv::CooperativeMatrixReduceMask::CooperativeMatrixReduce2x2)) { |
788 | 0 | if (std::get<1>(result_rows) && std::get<1>(result_cols) && |
789 | 0 | std::get<1>(matrix_rows) && std::get<1>(matrix_cols) && |
790 | 0 | (std::get<2>(result_rows) != std::get<2>(matrix_rows) / 2 || |
791 | 0 | std::get<2>(result_cols) != std::get<2>(matrix_cols) / 2)) { |
792 | 0 | return _.diag(SPV_ERROR_INVALID_DATA, inst) |
793 | 0 | << "For Reduce2x2, result rows/cols must be half of matrix " |
794 | 0 | "rows/cols: " |
795 | 0 | << spvOpcodeString(opcode); |
796 | 0 | } |
797 | 0 | } |
798 | 0 | if (reduce_value == uint32_t(spv::CooperativeMatrixReduceMask::Row)) { |
799 | 0 | if (std::get<1>(result_rows) && std::get<1>(matrix_rows) && |
800 | 0 | std::get<2>(result_rows) != std::get<2>(matrix_rows)) { |
801 | 0 | return _.diag(SPV_ERROR_INVALID_DATA, inst) |
802 | 0 | << "For ReduceRow, result rows must match matrix rows: " |
803 | 0 | << spvOpcodeString(opcode); |
804 | 0 | } |
805 | 0 | } |
806 | 0 | if (reduce_value == uint32_t(spv::CooperativeMatrixReduceMask::Column)) { |
807 | 0 | if (std::get<1>(result_cols) && std::get<1>(matrix_cols) && |
808 | 0 | std::get<2>(result_cols) != std::get<2>(matrix_cols)) { |
809 | 0 | return _.diag(SPV_ERROR_INVALID_DATA, inst) |
810 | 0 | << "For ReduceColumn, result cols must match matrix cols: " |
811 | 0 | << spvOpcodeString(opcode); |
812 | 0 | } |
813 | 0 | } |
814 | | |
815 | 0 | const auto combine_func_id = inst->GetOperandAs<uint32_t>(4); |
816 | 0 | const auto combine_func = _.FindDef(combine_func_id); |
817 | 0 | if (!combine_func || combine_func->opcode() != spv::Op::OpFunction) { |
818 | 0 | return _.diag(SPV_ERROR_INVALID_DATA, inst) |
819 | 0 | << "CombineFunc must be a function: " << spvOpcodeString(opcode); |
820 | 0 | } |
821 | 0 | const auto function_type_id = combine_func->GetOperandAs<uint32_t>(3); |
822 | 0 | const auto function_type = _.FindDef(function_type_id); |
823 | 0 | if (function_type->operands().size() != 4) { |
824 | 0 | return _.diag(SPV_ERROR_INVALID_DATA, inst) |
825 | 0 | << "CombineFunc must have two parameters: " |
826 | 0 | << spvOpcodeString(opcode); |
827 | 0 | } |
828 | 0 | for (uint32_t i = 0; i < 3; ++i) { |
829 | | // checks return type and two params |
830 | 0 | const auto param_type_id = function_type->GetOperandAs<uint32_t>(i + 1); |
831 | 0 | if (param_type_id != matrix_comp_type_id) { |
832 | 0 | return _.diag(SPV_ERROR_INVALID_DATA, inst) |
833 | 0 | << "CombineFunc return type and parameters must match matrix " |
834 | 0 | "component type: " |
835 | 0 | << spvOpcodeString(opcode); |
836 | 0 | } |
837 | 0 | } |
838 | 0 | return SPV_SUCCESS; |
839 | 0 | } |
840 | | |
841 | | // Validates correctness of arithmetic instructions. |
842 | 14.6M | spv_result_t ArithmeticsPass(ValidationState_t& _, const Instruction* inst) { |
843 | 14.6M | switch (inst->opcode()) { |
844 | 29.8k | case spv::Op::OpFAdd: |
845 | 56.8k | case spv::Op::OpFSub: |
846 | 86.7k | case spv::Op::OpFMul: |
847 | 107k | case spv::Op::OpFDiv: |
848 | 108k | case spv::Op::OpFRem: |
849 | 110k | case spv::Op::OpFMod: |
850 | 111k | case spv::Op::OpFNegate: |
851 | 111k | case spv::Op::OpFmaKHR: |
852 | 111k | return ValidateFloat(_, inst); |
853 | 153 | case spv::Op::OpUDiv: |
854 | 316 | case spv::Op::OpUMod: |
855 | 316 | return ValidateUnsignedInt(_, inst); |
856 | 18.0k | case spv::Op::OpISub: |
857 | 63.8k | case spv::Op::OpIAdd: |
858 | 69.8k | case spv::Op::OpIMul: |
859 | 75.2k | case spv::Op::OpSDiv: |
860 | 76.3k | case spv::Op::OpSMod: |
861 | 77.6k | case spv::Op::OpSRem: |
862 | 80.1k | case spv::Op::OpSNegate: |
863 | 80.1k | return ValidateSignedInt(_, inst); |
864 | 85 | case spv::Op::OpDot: |
865 | 85 | return ValidateDot(_, inst); |
866 | 8.59k | case spv::Op::OpVectorTimesScalar: |
867 | 8.59k | return ValidateVectorTimesScalar(_, inst); |
868 | 44 | case spv::Op::OpMatrixTimesScalar: |
869 | 44 | return ValidateMatrixTimesScalar(_, inst); |
870 | 1.26k | case spv::Op::OpVectorTimesMatrix: |
871 | 1.26k | return ValidateVectorTimesMatrix(_, inst); |
872 | 174 | case spv::Op::OpMatrixTimesVector: |
873 | 174 | return ValidateMatrixTimesVector(_, inst); |
874 | 1.62k | case spv::Op::OpMatrixTimesMatrix: |
875 | 1.62k | return ValidateMatrixTimesMatrix(_, inst); |
876 | 37 | case spv::Op::OpOuterProduct: |
877 | 37 | return ValidateOuterProduct(_, inst); |
878 | 112 | case spv::Op::OpIAddCarry: |
879 | 174 | case spv::Op::OpISubBorrow: |
880 | 192 | case spv::Op::OpUMulExtended: |
881 | 237 | case spv::Op::OpSMulExtended: |
882 | 237 | return ValidateExtendedCarry(_, inst); |
883 | 0 | case spv::Op::OpCooperativeMatrixMulAddNV: |
884 | 0 | return ValidateCooperativeMatrixMulAddNV(_, inst); |
885 | 0 | case spv::Op::OpCooperativeMatrixMulAddKHR: |
886 | 0 | return ValidateCooperativeMatrixMulAddKHR(_, inst); |
887 | 0 | case spv::Op::OpCooperativeMatrixReduceNV: |
888 | 0 | return ValidateCooperativeMatrixReduceNV(_, inst); |
889 | | |
890 | 126 | case spv::Op::OpSpecConstantOp: { |
891 | 126 | switch (inst->GetOperandAs<spv::Op>(2u)) { |
892 | 2 | case spv::Op::OpFAdd: |
893 | 3 | case spv::Op::OpFSub: |
894 | 4 | case spv::Op::OpFMul: |
895 | 5 | case spv::Op::OpFDiv: |
896 | 6 | case spv::Op::OpFRem: |
897 | 7 | case spv::Op::OpFMod: |
898 | 11 | case spv::Op::OpFNegate: |
899 | 11 | return ValidateFloat(_, inst, 3); |
900 | 4 | case spv::Op::OpUDiv: |
901 | 7 | case spv::Op::OpUMod: |
902 | 7 | return ValidateUnsignedInt(_, inst, 3); |
903 | 2 | case spv::Op::OpISub: |
904 | 6 | case spv::Op::OpIAdd: |
905 | 9 | case spv::Op::OpIMul: |
906 | 12 | case spv::Op::OpSDiv: |
907 | 14 | case spv::Op::OpSMod: |
908 | 16 | case spv::Op::OpSRem: |
909 | 27 | case spv::Op::OpSNegate: |
910 | 27 | return ValidateSignedInt(_, inst, 3); |
911 | 81 | default: |
912 | 81 | break; |
913 | 126 | } |
914 | 81 | break; |
915 | 126 | } |
916 | 14.4M | default: |
917 | 14.4M | break; |
918 | 14.6M | } |
919 | | |
920 | 14.4M | return SPV_SUCCESS; |
921 | 14.6M | } |
922 | | |
923 | | } // namespace val |
924 | | } // namespace spvtools |