Coverage Report

Created: 2025-06-13 06:49

/src/spirv-tools/source/val/validate_composites.cpp
Line
Count
Source (jump to first uncovered line)
1
// Copyright (c) 2017 Google Inc.
2
// Modifications Copyright (C) 2024 Advanced Micro Devices, Inc. All rights
3
// reserved.
4
//
5
// Licensed under the Apache License, Version 2.0 (the "License");
6
// you may not use this file except in compliance with the License.
7
// You may obtain a copy of the License at
8
//
9
//     http://www.apache.org/licenses/LICENSE-2.0
10
//
11
// Unless required by applicable law or agreed to in writing, software
12
// distributed under the License is distributed on an "AS IS" BASIS,
13
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14
// See the License for the specific language governing permissions and
15
// limitations under the License.
16
17
// Validates correctness of composite SPIR-V instructions.
18
19
#include "source/opcode.h"
20
#include "source/spirv_target_env.h"
21
#include "source/val/instruction.h"
22
#include "source/val/validate.h"
23
#include "source/val/validation_state.h"
24
25
namespace spvtools {
26
namespace val {
27
namespace {
28
29
// Returns the type of the value accessed by OpCompositeExtract or
30
// OpCompositeInsert instruction. The function traverses the hierarchy of
31
// nested data structures (structs, arrays, vectors, matrices) as directed by
32
// the sequence of indices in the instruction. May return error if traversal
33
// fails (encountered non-composite, out of bounds, no indices, nesting too
34
// deep).
35
spv_result_t GetExtractInsertValueType(ValidationState_t& _,
36
                                       const Instruction* inst,
37
45.4k
                                       uint32_t* member_type) {
38
45.4k
  const spv::Op opcode = inst->opcode();
39
45.4k
  assert(opcode == spv::Op::OpCompositeExtract ||
40
45.4k
         opcode == spv::Op::OpCompositeInsert);
41
45.4k
  uint32_t word_index = opcode == spv::Op::OpCompositeExtract ? 4 : 5;
42
45.4k
  const uint32_t num_words = static_cast<uint32_t>(inst->words().size());
43
45.4k
  const uint32_t composite_id_index = word_index - 1;
44
45.4k
  const uint32_t num_indices = num_words - word_index;
45
45.4k
  const uint32_t kCompositeExtractInsertMaxNumIndices = 255;
46
47
45.4k
  if (num_indices == 0) {
48
9
    return _.diag(SPV_ERROR_INVALID_DATA, inst)
49
9
           << "Expected at least one index to Op"
50
9
           << spvOpcodeString(inst->opcode()) << ", zero found";
51
52
45.4k
  } else if (num_indices > kCompositeExtractInsertMaxNumIndices) {
53
1
    return _.diag(SPV_ERROR_INVALID_DATA, inst)
54
1
           << "The number of indexes in Op" << spvOpcodeString(opcode)
55
1
           << " may not exceed " << kCompositeExtractInsertMaxNumIndices
56
1
           << ". Found " << num_indices << " indexes.";
57
1
  }
58
59
45.4k
  *member_type = _.GetTypeId(inst->word(composite_id_index));
60
45.4k
  if (*member_type == 0) {
61
0
    return _.diag(SPV_ERROR_INVALID_DATA, inst)
62
0
           << "Expected Composite to be an object of composite type";
63
0
  }
64
65
90.7k
  for (; word_index < num_words; ++word_index) {
66
45.4k
    const uint32_t component_index = inst->word(word_index);
67
45.4k
    const Instruction* const type_inst = _.FindDef(*member_type);
68
45.4k
    assert(type_inst);
69
45.4k
    switch (type_inst->opcode()) {
70
41.0k
      case spv::Op::OpTypeVector: {
71
41.0k
        *member_type = type_inst->word(2);
72
41.0k
        const uint32_t vector_size = type_inst->word(3);
73
41.0k
        if (component_index >= vector_size) {
74
38
          return _.diag(SPV_ERROR_INVALID_DATA, inst)
75
38
                 << "Vector access is out of bounds, vector size is "
76
38
                 << vector_size << ", but access index is " << component_index;
77
38
        }
78
41.0k
        break;
79
41.0k
      }
80
41.0k
      case spv::Op::OpTypeMatrix: {
81
313
        *member_type = type_inst->word(2);
82
313
        const uint32_t num_cols = type_inst->word(3);
83
313
        if (component_index >= num_cols) {
84
45
          return _.diag(SPV_ERROR_INVALID_DATA, inst)
85
45
                 << "Matrix access is out of bounds, matrix has " << num_cols
86
45
                 << " columns, but access index is " << component_index;
87
45
        }
88
268
        break;
89
313
      }
90
268
      case spv::Op::OpTypeArray: {
91
179
        uint64_t array_size = 0;
92
179
        auto size = _.FindDef(type_inst->word(3));
93
179
        *member_type = type_inst->word(2);
94
179
        if (spvOpcodeIsSpecConstant(size->opcode())) {
95
          // Cannot verify against the size of this array.
96
76
          break;
97
76
        }
98
99
103
        if (!_.EvalConstantValUint64(type_inst->word(3), &array_size)) {
100
0
          assert(0 && "Array type definition is corrupt");
101
0
        }
102
103
        if (component_index >= array_size) {
103
35
          return _.diag(SPV_ERROR_INVALID_DATA, inst)
104
35
                 << "Array access is out of bounds, array size is "
105
35
                 << array_size << ", but access index is " << component_index;
106
35
        }
107
68
        break;
108
103
      }
109
68
      case spv::Op::OpTypeRuntimeArray:
110
4
      case spv::Op::OpTypeNodePayloadArrayAMDX: {
111
4
        *member_type = type_inst->word(2);
112
        // Array size is unknown.
113
4
        break;
114
4
      }
115
3.84k
      case spv::Op::OpTypeStruct: {
116
3.84k
        const size_t num_struct_members = type_inst->words().size() - 2;
117
3.84k
        if (component_index >= num_struct_members) {
118
15
          return _.diag(SPV_ERROR_INVALID_DATA, inst)
119
15
                 << "Index is out of bounds, can not find index "
120
15
                 << component_index << " in the structure <id> '"
121
15
                 << type_inst->id() << "'. This structure has "
122
15
                 << num_struct_members << " members. Largest valid index is "
123
15
                 << num_struct_members - 1 << ".";
124
15
        }
125
3.82k
        *member_type = type_inst->word(component_index + 2);
126
3.82k
        break;
127
3.84k
      }
128
0
      case spv::Op::OpTypeCooperativeVectorNV:
129
0
      case spv::Op::OpTypeCooperativeMatrixKHR:
130
0
      case spv::Op::OpTypeCooperativeMatrixNV: {
131
0
        *member_type = type_inst->word(2);
132
0
        break;
133
0
      }
134
25
      default:
135
25
        return _.diag(SPV_ERROR_INVALID_DATA, inst)
136
25
               << "Reached non-composite type while indexes still remain to "
137
25
                  "be traversed.";
138
45.4k
    }
139
45.4k
  }
140
141
45.2k
  return SPV_SUCCESS;
142
45.4k
}
143
144
spv_result_t ValidateVectorExtractDynamic(ValidationState_t& _,
145
106
                                          const Instruction* inst) {
146
106
  const uint32_t result_type = inst->type_id();
147
106
  const spv::Op result_opcode = _.GetIdOpcode(result_type);
148
106
  if (!spvOpcodeIsScalarType(result_opcode)) {
149
3
    return _.diag(SPV_ERROR_INVALID_DATA, inst)
150
3
           << "Expected Result Type to be a scalar type";
151
3
  }
152
153
103
  const uint32_t vector_type = _.GetOperandTypeId(inst, 2);
154
103
  const spv::Op vector_opcode = _.GetIdOpcode(vector_type);
155
103
  if (vector_opcode != spv::Op::OpTypeVector &&
156
103
      vector_opcode != spv::Op::OpTypeCooperativeVectorNV) {
157
4
    return _.diag(SPV_ERROR_INVALID_DATA, inst)
158
4
           << "Expected Vector type to be OpTypeVector";
159
4
  }
160
161
99
  if (_.GetComponentType(vector_type) != result_type) {
162
4
    return _.diag(SPV_ERROR_INVALID_DATA, inst)
163
4
           << "Expected Vector component type to be equal to Result Type";
164
4
  }
165
166
95
  const auto index = _.FindDef(inst->GetOperandAs<uint32_t>(3));
167
95
  if (!index || index->type_id() == 0 || !_.IsIntScalarType(index->type_id())) {
168
4
    return _.diag(SPV_ERROR_INVALID_DATA, inst)
169
4
           << "Expected Index to be int scalar";
170
4
  }
171
172
91
  if (_.HasCapability(spv::Capability::Shader) &&
173
91
      _.ContainsLimitedUseIntOrFloatType(inst->type_id())) {
174
0
    return _.diag(SPV_ERROR_INVALID_DATA, inst)
175
0
           << "Cannot extract from a vector of 8- or 16-bit types";
176
0
  }
177
91
  return SPV_SUCCESS;
178
91
}
179
180
spv_result_t ValidateVectorInsertDyanmic(ValidationState_t& _,
181
68
                                         const Instruction* inst) {
182
68
  const uint32_t result_type = inst->type_id();
183
68
  const spv::Op result_opcode = _.GetIdOpcode(result_type);
184
68
  if (result_opcode != spv::Op::OpTypeVector &&
185
68
      result_opcode != spv::Op::OpTypeCooperativeVectorNV) {
186
3
    return _.diag(SPV_ERROR_INVALID_DATA, inst)
187
3
           << "Expected Result Type to be OpTypeVector";
188
3
  }
189
190
65
  const uint32_t vector_type = _.GetOperandTypeId(inst, 2);
191
65
  if (vector_type != result_type) {
192
4
    return _.diag(SPV_ERROR_INVALID_DATA, inst)
193
4
           << "Expected Vector type to be equal to Result Type";
194
4
  }
195
196
61
  const uint32_t component_type = _.GetOperandTypeId(inst, 3);
197
61
  if (_.GetComponentType(result_type) != component_type) {
198
3
    return _.diag(SPV_ERROR_INVALID_DATA, inst)
199
3
           << "Expected Component type to be equal to Result Type "
200
3
           << "component type";
201
3
  }
202
203
58
  const uint32_t index_type = _.GetOperandTypeId(inst, 4);
204
58
  if (!_.IsIntScalarType(index_type)) {
205
3
    return _.diag(SPV_ERROR_INVALID_DATA, inst)
206
3
           << "Expected Index to be int scalar";
207
3
  }
208
209
55
  if (_.HasCapability(spv::Capability::Shader) &&
210
55
      _.ContainsLimitedUseIntOrFloatType(inst->type_id())) {
211
0
    return _.diag(SPV_ERROR_INVALID_DATA, inst)
212
0
           << "Cannot insert into a vector of 8- or 16-bit types";
213
0
  }
214
55
  return SPV_SUCCESS;
215
55
}
216
217
spv_result_t ValidateCompositeConstruct(ValidationState_t& _,
218
21.5k
                                        const Instruction* inst) {
219
21.5k
  const uint32_t num_operands = static_cast<uint32_t>(inst->operands().size());
220
21.5k
  const uint32_t result_type = inst->type_id();
221
21.5k
  const spv::Op result_opcode = _.GetIdOpcode(result_type);
222
21.5k
  switch (result_opcode) {
223
20.5k
    case spv::Op::OpTypeVector:
224
20.5k
    case spv::Op::OpTypeCooperativeVectorNV: {
225
20.5k
      uint32_t num_result_components = _.GetDimension(result_type);
226
20.5k
      const uint32_t result_component_type = _.GetComponentType(result_type);
227
20.5k
      uint32_t given_component_count = 0;
228
229
20.5k
      bool comp_is_int32 = true, comp_is_const_int32 = true;
230
231
20.5k
      if (result_opcode == spv::Op::OpTypeVector) {
232
20.5k
        if (num_operands <= 3) {
233
4
          return _.diag(SPV_ERROR_INVALID_DATA, inst)
234
4
                 << "Expected number of constituents to be at least 2";
235
4
        }
236
20.5k
      } else {
237
0
        uint32_t comp_count_id =
238
0
            _.FindDef(result_type)->GetOperandAs<uint32_t>(2);
239
0
        std::tie(comp_is_int32, comp_is_const_int32, num_result_components) =
240
0
            _.EvalInt32IfConst(comp_count_id);
241
0
      }
242
243
80.9k
      for (uint32_t operand_index = 2; operand_index < num_operands;
244
60.4k
           ++operand_index) {
245
60.4k
        const uint32_t operand_type = _.GetOperandTypeId(inst, operand_index);
246
60.4k
        if (operand_type == result_component_type) {
247
60.3k
          ++given_component_count;
248
60.3k
        } else {
249
89
          if (_.GetIdOpcode(operand_type) != spv::Op::OpTypeVector ||
250
89
              _.GetComponentType(operand_type) != result_component_type) {
251
20
            return _.diag(SPV_ERROR_INVALID_DATA, inst)
252
20
                   << "Expected Constituents to be scalars or vectors of"
253
20
                   << " the same type as Result Type components";
254
20
          }
255
256
69
          given_component_count += _.GetDimension(operand_type);
257
69
        }
258
60.4k
      }
259
260
20.5k
      if (comp_is_const_int32 &&
261
20.5k
          num_result_components != given_component_count) {
262
16
        return _.diag(SPV_ERROR_INVALID_DATA, inst)
263
16
               << "Expected total number of given components to be equal "
264
16
               << "to the size of Result Type vector";
265
16
      }
266
267
20.5k
      break;
268
20.5k
    }
269
20.5k
    case spv::Op::OpTypeMatrix: {
270
392
      uint32_t result_num_rows = 0;
271
392
      uint32_t result_num_cols = 0;
272
392
      uint32_t result_col_type = 0;
273
392
      uint32_t result_component_type = 0;
274
392
      if (!_.GetMatrixTypeInfo(result_type, &result_num_rows, &result_num_cols,
275
392
                               &result_col_type, &result_component_type)) {
276
0
        assert(0);
277
0
      }
278
279
392
      if (result_num_cols + 2 != num_operands) {
280
14
        return _.diag(SPV_ERROR_INVALID_DATA, inst)
281
14
               << "Expected total number of Constituents to be equal "
282
14
               << "to the number of columns of Result Type matrix";
283
14
      }
284
285
1.25k
      for (uint32_t operand_index = 2; operand_index < num_operands;
286
882
           ++operand_index) {
287
882
        const uint32_t operand_type = _.GetOperandTypeId(inst, operand_index);
288
882
        if (operand_type != result_col_type) {
289
8
          return _.diag(SPV_ERROR_INVALID_DATA, inst)
290
8
                 << "Expected Constituent type to be equal to the column "
291
8
                 << "type Result Type matrix";
292
8
        }
293
882
      }
294
295
370
      break;
296
378
    }
297
370
    case spv::Op::OpTypeArray: {
298
62
      const Instruction* const array_inst = _.FindDef(result_type);
299
62
      assert(array_inst);
300
62
      assert(array_inst->opcode() == spv::Op::OpTypeArray);
301
302
62
      auto size = _.FindDef(array_inst->word(3));
303
62
      if (spvOpcodeIsSpecConstant(size->opcode())) {
304
        // Cannot verify against the size of this array.
305
11
        break;
306
11
      }
307
308
51
      uint64_t array_size = 0;
309
51
      if (!_.EvalConstantValUint64(array_inst->word(3), &array_size)) {
310
0
        assert(0 && "Array type definition is corrupt");
311
0
      }
312
313
51
      if (array_size + 2 != num_operands) {
314
12
        return _.diag(SPV_ERROR_INVALID_DATA, inst)
315
12
               << "Expected total number of Constituents to be equal "
316
12
               << "to the number of elements of Result Type array";
317
12
      }
318
319
39
      const uint32_t result_component_type = array_inst->word(2);
320
605
      for (uint32_t operand_index = 2; operand_index < num_operands;
321
585
           ++operand_index) {
322
585
        const uint32_t operand_type = _.GetOperandTypeId(inst, operand_index);
323
585
        if (operand_type != result_component_type) {
324
19
          return _.diag(SPV_ERROR_INVALID_DATA, inst)
325
19
                 << "Expected Constituent type to be equal to the column "
326
19
                 << "type Result Type array";
327
19
        }
328
585
      }
329
330
20
      break;
331
39
    }
332
509
    case spv::Op::OpTypeStruct: {
333
509
      const Instruction* const struct_inst = _.FindDef(result_type);
334
509
      assert(struct_inst);
335
509
      assert(struct_inst->opcode() == spv::Op::OpTypeStruct);
336
337
509
      if (struct_inst->operands().size() + 1 != num_operands) {
338
3
        return _.diag(SPV_ERROR_INVALID_DATA, inst)
339
3
               << "Expected total number of Constituents to be equal "
340
3
               << "to the number of members of Result Type struct";
341
3
      }
342
343
1.04k
      for (uint32_t operand_index = 2; operand_index < num_operands;
344
547
           ++operand_index) {
345
547
        const uint32_t operand_type = _.GetOperandTypeId(inst, operand_index);
346
547
        const uint32_t member_type = struct_inst->word(operand_index);
347
547
        if (operand_type != member_type) {
348
6
          return _.diag(SPV_ERROR_INVALID_DATA, inst)
349
6
                 << "Expected Constituent type to be equal to the "
350
6
                 << "corresponding member type of Result Type struct";
351
6
        }
352
547
      }
353
354
500
      break;
355
506
    }
356
500
    case spv::Op::OpTypeCooperativeMatrixKHR: {
357
0
      const auto result_type_inst = _.FindDef(result_type);
358
0
      assert(result_type_inst);
359
0
      const auto component_type_id =
360
0
          result_type_inst->GetOperandAs<uint32_t>(1);
361
362
0
      if (3 != num_operands) {
363
0
        return _.diag(SPV_ERROR_INVALID_DATA, inst)
364
0
               << "Must be only one constituent";
365
0
      }
366
367
0
      const uint32_t operand_type_id = _.GetOperandTypeId(inst, 2);
368
369
0
      if (operand_type_id != component_type_id) {
370
0
        return _.diag(SPV_ERROR_INVALID_DATA, inst)
371
0
               << "Expected Constituent type to be equal to the component type";
372
0
      }
373
0
      break;
374
0
    }
375
0
    case spv::Op::OpTypeCooperativeMatrixNV: {
376
0
      const auto result_type_inst = _.FindDef(result_type);
377
0
      assert(result_type_inst);
378
0
      const auto component_type_id =
379
0
          result_type_inst->GetOperandAs<uint32_t>(1);
380
381
0
      if (3 != num_operands) {
382
0
        return _.diag(SPV_ERROR_INVALID_DATA, inst)
383
0
               << "Expected single constituent";
384
0
      }
385
386
0
      const uint32_t operand_type_id = _.GetOperandTypeId(inst, 2);
387
388
0
      if (operand_type_id != component_type_id) {
389
0
        return _.diag(SPV_ERROR_INVALID_DATA, inst)
390
0
               << "Expected Constituent type to be equal to the component type";
391
0
      }
392
393
0
      break;
394
0
    }
395
13
    default: {
396
13
      return _.diag(SPV_ERROR_INVALID_DATA, inst)
397
13
             << "Expected Result Type to be a composite type";
398
0
    }
399
21.5k
  }
400
401
21.4k
  if (_.HasCapability(spv::Capability::Shader) &&
402
21.4k
      _.ContainsLimitedUseIntOrFloatType(inst->type_id())) {
403
0
    return _.diag(SPV_ERROR_INVALID_DATA, inst)
404
0
           << "Cannot create a composite containing 8- or 16-bit types";
405
0
  }
406
21.4k
  return SPV_SUCCESS;
407
21.4k
}
408
409
spv_result_t ValidateCompositeExtract(ValidationState_t& _,
410
37.4k
                                      const Instruction* inst) {
411
37.4k
  uint32_t member_type = 0;
412
37.4k
  if (spv_result_t error = GetExtractInsertValueType(_, inst, &member_type)) {
413
159
    return error;
414
159
  }
415
416
37.3k
  const uint32_t result_type = inst->type_id();
417
37.3k
  if (result_type != member_type) {
418
9
    return _.diag(SPV_ERROR_INVALID_DATA, inst)
419
9
           << "Result type (Op" << spvOpcodeString(_.GetIdOpcode(result_type))
420
9
           << ") does not match the type that results from indexing into "
421
9
              "the composite (Op"
422
9
           << spvOpcodeString(_.GetIdOpcode(member_type)) << ").";
423
9
  }
424
425
37.3k
  if (_.HasCapability(spv::Capability::Shader) &&
426
37.3k
      _.ContainsLimitedUseIntOrFloatType(inst->type_id())) {
427
0
    return _.diag(SPV_ERROR_INVALID_DATA, inst)
428
0
           << "Cannot extract from a composite of 8- or 16-bit types";
429
0
  }
430
431
37.3k
  return SPV_SUCCESS;
432
37.3k
}
433
434
spv_result_t ValidateCompositeInsert(ValidationState_t& _,
435
7.93k
                                     const Instruction* inst) {
436
7.93k
  const uint32_t object_type = _.GetOperandTypeId(inst, 2);
437
7.93k
  const uint32_t composite_type = _.GetOperandTypeId(inst, 3);
438
7.93k
  const uint32_t result_type = inst->type_id();
439
7.93k
  if (result_type != composite_type) {
440
8
    return _.diag(SPV_ERROR_INVALID_DATA, inst)
441
8
           << "The Result Type must be the same as Composite type in Op"
442
8
           << spvOpcodeString(inst->opcode()) << " yielding Result Id "
443
8
           << result_type << ".";
444
8
  }
445
446
7.92k
  uint32_t member_type = 0;
447
7.92k
  if (spv_result_t error = GetExtractInsertValueType(_, inst, &member_type)) {
448
9
    return error;
449
9
  }
450
451
7.91k
  if (object_type != member_type) {
452
5
    return _.diag(SPV_ERROR_INVALID_DATA, inst)
453
5
           << "The Object type (Op"
454
5
           << spvOpcodeString(_.GetIdOpcode(object_type))
455
5
           << ") does not match the type that results from indexing into the "
456
5
              "Composite (Op"
457
5
           << spvOpcodeString(_.GetIdOpcode(member_type)) << ").";
458
5
  }
459
460
7.91k
  if (_.HasCapability(spv::Capability::Shader) &&
461
7.91k
      _.ContainsLimitedUseIntOrFloatType(inst->type_id())) {
462
0
    return _.diag(SPV_ERROR_INVALID_DATA, inst)
463
0
           << "Cannot insert into a composite of 8- or 16-bit types";
464
0
  }
465
466
7.91k
  return SPV_SUCCESS;
467
7.91k
}
468
469
278
spv_result_t ValidateCopyObject(ValidationState_t& _, const Instruction* inst) {
470
278
  const uint32_t result_type = inst->type_id();
471
278
  const uint32_t operand_type = _.GetOperandTypeId(inst, 2);
472
278
  if (operand_type != result_type) {
473
3
    return _.diag(SPV_ERROR_INVALID_DATA, inst)
474
3
           << "Expected Result Type and Operand type to be the same";
475
3
  }
476
275
  if (_.IsVoidType(result_type)) {
477
0
    return _.diag(SPV_ERROR_INVALID_DATA, inst)
478
0
           << "OpCopyObject cannot have void result type";
479
0
  }
480
275
  return SPV_SUCCESS;
481
275
}
482
483
64
spv_result_t ValidateTranspose(ValidationState_t& _, const Instruction* inst) {
484
64
  uint32_t result_num_rows = 0;
485
64
  uint32_t result_num_cols = 0;
486
64
  uint32_t result_col_type = 0;
487
64
  uint32_t result_component_type = 0;
488
64
  const uint32_t result_type = inst->type_id();
489
64
  if (!_.GetMatrixTypeInfo(result_type, &result_num_rows, &result_num_cols,
490
64
                           &result_col_type, &result_component_type)) {
491
8
    return _.diag(SPV_ERROR_INVALID_DATA, inst)
492
8
           << "Expected Result Type to be a matrix type";
493
8
  }
494
495
56
  const uint32_t matrix_type = _.GetOperandTypeId(inst, 2);
496
56
  uint32_t matrix_num_rows = 0;
497
56
  uint32_t matrix_num_cols = 0;
498
56
  uint32_t matrix_col_type = 0;
499
56
  uint32_t matrix_component_type = 0;
500
56
  if (!_.GetMatrixTypeInfo(matrix_type, &matrix_num_rows, &matrix_num_cols,
501
56
                           &matrix_col_type, &matrix_component_type)) {
502
3
    return _.diag(SPV_ERROR_INVALID_DATA, inst)
503
3
           << "Expected Matrix to be of type OpTypeMatrix";
504
3
  }
505
506
53
  if (result_component_type != matrix_component_type) {
507
0
    return _.diag(SPV_ERROR_INVALID_DATA, inst)
508
0
           << "Expected component types of Matrix and Result Type to be "
509
0
           << "identical";
510
0
  }
511
512
53
  if (result_num_rows != matrix_num_cols ||
513
53
      result_num_cols != matrix_num_rows) {
514
7
    return _.diag(SPV_ERROR_INVALID_DATA, inst)
515
7
           << "Expected number of columns and the column size of Matrix "
516
7
           << "to be the reverse of those of Result Type";
517
7
  }
518
519
46
  if (_.HasCapability(spv::Capability::Shader) &&
520
46
      _.ContainsLimitedUseIntOrFloatType(inst->type_id())) {
521
0
    return _.diag(SPV_ERROR_INVALID_DATA, inst)
522
0
           << "Cannot transpose matrices of 16-bit floats";
523
0
  }
524
46
  return SPV_SUCCESS;
525
46
}
526
527
spv_result_t ValidateVectorShuffle(ValidationState_t& _,
528
10.2k
                                   const Instruction* inst) {
529
10.2k
  auto resultType = _.FindDef(inst->type_id());
530
10.2k
  if (!resultType || resultType->opcode() != spv::Op::OpTypeVector) {
531
5
    return _.diag(SPV_ERROR_INVALID_ID, inst)
532
5
           << "The Result Type of OpVectorShuffle must be"
533
5
           << " OpTypeVector. Found Op"
534
5
           << spvOpcodeString(static_cast<spv::Op>(resultType->opcode()))
535
5
           << ".";
536
5
  }
537
538
  // The number of components in Result Type must be the same as the number of
539
  // Component operands.
540
10.2k
  auto componentCount = inst->operands().size() - 4;
541
10.2k
  auto resultVectorDimension = resultType->GetOperandAs<uint32_t>(2);
542
10.2k
  if (componentCount != resultVectorDimension) {
543
3
    return _.diag(SPV_ERROR_INVALID_ID, inst)
544
3
           << "OpVectorShuffle component literals count does not match "
545
3
              "Result Type <id> "
546
3
           << _.getIdName(resultType->id()) << "s vector component count.";
547
3
  }
548
549
  // Vector 1 and Vector 2 must both have vector types, with the same Component
550
  // Type as Result Type.
551
10.2k
  auto vector1Object = _.FindDef(inst->GetOperandAs<uint32_t>(2));
552
10.2k
  auto vector1Type = _.FindDef(vector1Object->type_id());
553
10.2k
  auto vector2Object = _.FindDef(inst->GetOperandAs<uint32_t>(3));
554
10.2k
  auto vector2Type = _.FindDef(vector2Object->type_id());
555
10.2k
  if (!vector1Type || vector1Type->opcode() != spv::Op::OpTypeVector) {
556
4
    return _.diag(SPV_ERROR_INVALID_ID, inst)
557
4
           << "The type of Vector 1 must be OpTypeVector.";
558
4
  }
559
10.2k
  if (!vector2Type || vector2Type->opcode() != spv::Op::OpTypeVector) {
560
5
    return _.diag(SPV_ERROR_INVALID_ID, inst)
561
5
           << "The type of Vector 2 must be OpTypeVector.";
562
5
  }
563
564
10.2k
  auto resultComponentType = resultType->GetOperandAs<uint32_t>(1);
565
10.2k
  if (vector1Type->GetOperandAs<uint32_t>(1) != resultComponentType) {
566
5
    return _.diag(SPV_ERROR_INVALID_ID, inst)
567
5
           << "The Component Type of Vector 1 must be the same as ResultType.";
568
5
  }
569
10.2k
  if (vector2Type->GetOperandAs<uint32_t>(1) != resultComponentType) {
570
4
    return _.diag(SPV_ERROR_INVALID_ID, inst)
571
4
           << "The Component Type of Vector 2 must be the same as ResultType.";
572
4
  }
573
574
  // All Component literals must either be FFFFFFFF or in [0, N - 1].
575
10.2k
  auto vector1ComponentCount = vector1Type->GetOperandAs<uint32_t>(2);
576
10.2k
  auto vector2ComponentCount = vector2Type->GetOperandAs<uint32_t>(2);
577
10.2k
  auto N = vector1ComponentCount + vector2ComponentCount;
578
10.2k
  auto firstLiteralIndex = 4;
579
33.0k
  for (size_t i = firstLiteralIndex; i < inst->operands().size(); ++i) {
580
22.9k
    auto literal = inst->GetOperandAs<uint32_t>(i);
581
22.9k
    if (literal != 0xFFFFFFFF && literal >= N) {
582
106
      return _.diag(SPV_ERROR_INVALID_ID, inst)
583
106
             << "Component index " << literal << " is out of bounds for "
584
106
             << "combined (Vector1 + Vector2) size of " << N << ".";
585
106
    }
586
22.9k
  }
587
588
10.1k
  if (_.HasCapability(spv::Capability::Shader) &&
589
10.1k
      _.ContainsLimitedUseIntOrFloatType(inst->type_id())) {
590
0
    return _.diag(SPV_ERROR_INVALID_DATA, inst)
591
0
           << "Cannot shuffle a vector of 8- or 16-bit types";
592
0
  }
593
594
10.1k
  return SPV_SUCCESS;
595
10.1k
}
596
597
spv_result_t ValidateCopyLogical(ValidationState_t& _,
598
0
                                 const Instruction* inst) {
599
0
  const auto result_type = _.FindDef(inst->type_id());
600
0
  const auto source = _.FindDef(inst->GetOperandAs<uint32_t>(2u));
601
0
  const auto source_type = _.FindDef(source->type_id());
602
0
  if (!source_type || !result_type || source_type == result_type) {
603
0
    return _.diag(SPV_ERROR_INVALID_ID, inst)
604
0
           << "Result Type must not equal the Operand type";
605
0
  }
606
607
0
  if (!_.LogicallyMatch(source_type, result_type, false)) {
608
0
    return _.diag(SPV_ERROR_INVALID_ID, inst)
609
0
           << "Result Type does not logically match the Operand type";
610
0
  }
611
612
0
  if (_.HasCapability(spv::Capability::Shader) &&
613
0
      _.ContainsLimitedUseIntOrFloatType(inst->type_id())) {
614
0
    return _.diag(SPV_ERROR_INVALID_DATA, inst)
615
0
           << "Cannot copy composites of 8- or 16-bit types";
616
0
  }
617
618
0
  return SPV_SUCCESS;
619
0
}
620
621
}  // anonymous namespace
622
623
// Validates correctness of composite instructions.
624
11.3M
spv_result_t CompositesPass(ValidationState_t& _, const Instruction* inst) {
625
11.3M
  switch (inst->opcode()) {
626
106
    case spv::Op::OpVectorExtractDynamic:
627
106
      return ValidateVectorExtractDynamic(_, inst);
628
68
    case spv::Op::OpVectorInsertDynamic:
629
68
      return ValidateVectorInsertDyanmic(_, inst);
630
10.2k
    case spv::Op::OpVectorShuffle:
631
10.2k
      return ValidateVectorShuffle(_, inst);
632
21.5k
    case spv::Op::OpCompositeConstruct:
633
21.5k
      return ValidateCompositeConstruct(_, inst);
634
37.4k
    case spv::Op::OpCompositeExtract:
635
37.4k
      return ValidateCompositeExtract(_, inst);
636
7.93k
    case spv::Op::OpCompositeInsert:
637
7.93k
      return ValidateCompositeInsert(_, inst);
638
278
    case spv::Op::OpCopyObject:
639
278
      return ValidateCopyObject(_, inst);
640
64
    case spv::Op::OpTranspose:
641
64
      return ValidateTranspose(_, inst);
642
0
    case spv::Op::OpCopyLogical:
643
0
      return ValidateCopyLogical(_, inst);
644
11.2M
    default:
645
11.2M
      break;
646
11.3M
  }
647
648
11.2M
  return SPV_SUCCESS;
649
11.3M
}
650
651
}  // namespace val
652
}  // namespace spvtools