Coverage Report

Created: 2025-12-31 06:15

next uncovered line (L), next uncovered region (R), next uncovered branch (B)
/src/spirv-tools/source/val/validate_constants.cpp
Line
Count
Source
1
// Copyright (c) 2018 Google LLC.
2
//
3
// Licensed under the Apache License, Version 2.0 (the "License");
4
// you may not use this file except in compliance with the License.
5
// You may obtain a copy of the License at
6
//
7
//     http://www.apache.org/licenses/LICENSE-2.0
8
//
9
// Unless required by applicable law or agreed to in writing, software
10
// distributed under the License is distributed on an "AS IS" BASIS,
11
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
// See the License for the specific language governing permissions and
13
// limitations under the License.
14
15
#include "source/opcode.h"
16
#include "source/val/instruction.h"
17
#include "source/val/validate.h"
18
#include "source/val/validation_state.h"
19
20
namespace spvtools {
21
namespace val {
22
namespace {
23
24
spv_result_t ValidateConstantBool(ValidationState_t& _,
25
6.73k
                                  const Instruction* inst) {
26
6.73k
  auto type = _.FindDef(inst->type_id());
27
6.73k
  if (!type || type->opcode() != spv::Op::OpTypeBool) {
28
21
    return _.diag(SPV_ERROR_INVALID_ID, inst)
29
21
           << "Op" << spvOpcodeString(inst->opcode()) << " Result Type <id> "
30
21
           << _.getIdName(inst->type_id()) << " is not a boolean type.";
31
21
  }
32
33
6.71k
  return SPV_SUCCESS;
34
6.73k
}
35
36
31.1k
bool isCompositeType(const Instruction* inst) {
37
31.1k
  bool is_tensor = inst->opcode() == spv::Op::OpTypeTensorARM;
38
31.1k
  bool tensor_is_shaped = inst->words().size() == 5;
39
31.1k
  return spvOpcodeIsComposite(inst->opcode()) ||
40
14
         (is_tensor && tensor_is_shaped);
41
31.1k
}
42
43
spv_result_t ValidateConstantComposite(ValidationState_t& _,
44
31.1k
                                       const Instruction* inst) {
45
31.1k
  std::string opcode_name = std::string("Op") + spvOpcodeString(inst->opcode());
46
47
31.1k
  const auto result_type = _.FindDef(inst->type_id());
48
31.1k
  if (!result_type || !isCompositeType(result_type)) {
49
14
    return _.diag(SPV_ERROR_INVALID_ID, inst)
50
14
           << opcode_name << " Result Type <id> "
51
14
           << _.getIdName(inst->type_id()) << " is not a composite type.";
52
14
  }
53
54
31.1k
  const auto constituent_count = inst->words().size() - 3;
55
31.1k
  switch (result_type->opcode()) {
56
26.8k
    case spv::Op::OpTypeVector:
57
26.8k
    case spv::Op::OpTypeCooperativeVectorNV: {
58
26.8k
      uint32_t num_result_components = _.GetDimension(result_type->id());
59
26.8k
      bool comp_is_int32 = true, comp_is_const_int32 = true;
60
61
26.8k
      if (result_type->opcode() == spv::Op::OpTypeCooperativeVectorNV) {
62
0
        uint32_t comp_count_id = result_type->GetOperandAs<uint32_t>(2);
63
0
        std::tie(comp_is_int32, comp_is_const_int32, num_result_components) =
64
0
            _.EvalInt32IfConst(comp_count_id);
65
0
      }
66
67
26.8k
      if (comp_is_const_int32 && num_result_components != constituent_count) {
68
        // TODO: Output ID's on diagnostic
69
11
        return _.diag(SPV_ERROR_INVALID_ID, inst)
70
11
               << opcode_name
71
11
               << " Constituent <id> count does not match "
72
11
                  "Result Type <id> "
73
11
               << _.getIdName(result_type->id()) << "s vector component count.";
74
11
      }
75
26.7k
      const auto component_type =
76
26.7k
          _.FindDef(result_type->GetOperandAs<uint32_t>(1));
77
26.7k
      if (!component_type) {
78
0
        return _.diag(SPV_ERROR_INVALID_ID, result_type)
79
0
               << "Component type is not defined.";
80
0
      }
81
26.7k
      for (size_t constituent_index = 2;
82
117k
           constituent_index < inst->operands().size(); constituent_index++) {
83
90.7k
        const auto constituent_id =
84
90.7k
            inst->GetOperandAs<uint32_t>(constituent_index);
85
90.7k
        const auto constituent = _.FindDef(constituent_id);
86
90.7k
        if (!constituent ||
87
90.7k
            !spvOpcodeIsConstantOrUndef(constituent->opcode())) {
88
9
          return _.diag(SPV_ERROR_INVALID_ID, inst)
89
9
                 << opcode_name << " Constituent <id> "
90
9
                 << _.getIdName(constituent_id)
91
9
                 << " is not a constant or undef.";
92
9
        }
93
90.7k
        const auto constituent_result_type = _.FindDef(constituent->type_id());
94
90.7k
        if (!constituent_result_type ||
95
90.7k
            component_type->id() != constituent_result_type->id()) {
96
14
          return _.diag(SPV_ERROR_INVALID_ID, inst)
97
14
                 << opcode_name << " Constituent <id> "
98
14
                 << _.getIdName(constituent_id)
99
14
                 << "s type does not match Result Type <id> "
100
14
                 << _.getIdName(result_type->id()) << "s vector element type.";
101
14
        }
102
90.7k
      }
103
26.7k
    } break;
104
26.7k
    case spv::Op::OpTypeMatrix: {
105
122
      const auto column_count = result_type->GetOperandAs<uint32_t>(2);
106
122
      if (column_count != constituent_count) {
107
        // TODO: Output ID's on diagnostic
108
8
        return _.diag(SPV_ERROR_INVALID_ID, inst)
109
8
               << opcode_name
110
8
               << " Constituent <id> count does not match "
111
8
                  "Result Type <id> "
112
8
               << _.getIdName(result_type->id()) << "s matrix column count.";
113
8
      }
114
115
114
      const auto column_type = _.FindDef(result_type->words()[2]);
116
114
      if (!column_type) {
117
0
        return _.diag(SPV_ERROR_INVALID_ID, result_type)
118
0
               << "Column type is not defined.";
119
0
      }
120
114
      const auto component_count = column_type->GetOperandAs<uint32_t>(2);
121
114
      const auto component_type =
122
114
          _.FindDef(column_type->GetOperandAs<uint32_t>(1));
123
114
      if (!component_type) {
124
0
        return _.diag(SPV_ERROR_INVALID_ID, column_type)
125
0
               << "Component type is not defined.";
126
0
      }
127
128
114
      for (size_t constituent_index = 2;
129
342
           constituent_index < inst->operands().size(); constituent_index++) {
130
250
        const auto constituent_id =
131
250
            inst->GetOperandAs<uint32_t>(constituent_index);
132
250
        const auto constituent = _.FindDef(constituent_id);
133
250
        if (!constituent ||
134
250
            !spvOpcodeIsConstantOrUndef(constituent->opcode())) {
135
          // The message says "... or undef" because the spec does not say
136
          // undef is a constant.
137
6
          return _.diag(SPV_ERROR_INVALID_ID, inst)
138
6
                 << opcode_name << " Constituent <id> "
139
6
                 << _.getIdName(constituent_id)
140
6
                 << " is not a constant or undef.";
141
6
        }
142
244
        const auto vector = _.FindDef(constituent->type_id());
143
244
        if (!vector) {
144
0
          return _.diag(SPV_ERROR_INVALID_ID, constituent)
145
0
                 << "Result type is not defined.";
146
0
        }
147
244
        if (column_type->opcode() != vector->opcode()) {
148
9
          return _.diag(SPV_ERROR_INVALID_ID, inst)
149
9
                 << opcode_name << " Constituent <id> "
150
9
                 << _.getIdName(constituent_id)
151
9
                 << " type does not match Result Type <id> "
152
9
                 << _.getIdName(result_type->id()) << "s matrix column type.";
153
9
        }
154
235
        const auto vector_component_type =
155
235
            _.FindDef(vector->GetOperandAs<uint32_t>(1));
156
235
        if (component_type->id() != vector_component_type->id()) {
157
2
          return _.diag(SPV_ERROR_INVALID_ID, inst)
158
2
                 << opcode_name << " Constituent <id> "
159
2
                 << _.getIdName(constituent_id)
160
2
                 << " component type does not match Result Type <id> "
161
2
                 << _.getIdName(result_type->id())
162
2
                 << "s matrix column component type.";
163
2
        }
164
233
        if (component_count != vector->words()[3]) {
165
5
          return _.diag(SPV_ERROR_INVALID_ID, inst)
166
5
                 << opcode_name << " Constituent <id> "
167
5
                 << _.getIdName(constituent_id)
168
5
                 << " vector component count does not match Result Type <id> "
169
5
                 << _.getIdName(result_type->id())
170
5
                 << "s vector component count.";
171
5
        }
172
233
      }
173
114
    } break;
174
1.56k
    case spv::Op::OpTypeArray: {
175
1.56k
      auto element_type = _.FindDef(result_type->GetOperandAs<uint32_t>(1));
176
1.56k
      if (!element_type) {
177
0
        return _.diag(SPV_ERROR_INVALID_ID, result_type)
178
0
               << "Element type is not defined.";
179
0
      }
180
1.56k
      const auto length = _.FindDef(result_type->GetOperandAs<uint32_t>(2));
181
1.56k
      if (!length) {
182
0
        return _.diag(SPV_ERROR_INVALID_ID, result_type)
183
0
               << "Length is not defined.";
184
0
      }
185
1.56k
      bool is_int32;
186
1.56k
      bool is_const;
187
1.56k
      uint32_t value;
188
1.56k
      std::tie(is_int32, is_const, value) = _.EvalInt32IfConst(length->id());
189
1.56k
      if (is_int32 && is_const && value != constituent_count) {
190
50
        return _.diag(SPV_ERROR_INVALID_ID, inst)
191
50
               << opcode_name
192
50
               << " Constituent count does not match "
193
50
                  "Result Type <id> "
194
50
               << _.getIdName(result_type->id()) << "s array length.";
195
50
      }
196
1.51k
      for (size_t constituent_index = 2;
197
21.8k
           constituent_index < inst->operands().size(); constituent_index++) {
198
20.3k
        const auto constituent_id =
199
20.3k
            inst->GetOperandAs<uint32_t>(constituent_index);
200
20.3k
        const auto constituent = _.FindDef(constituent_id);
201
20.3k
        if (!constituent ||
202
20.3k
            !spvOpcodeIsConstantOrUndef(constituent->opcode())) {
203
5
          return _.diag(SPV_ERROR_INVALID_ID, inst)
204
5
                 << opcode_name << " Constituent <id> "
205
5
                 << _.getIdName(constituent_id)
206
5
                 << " is not a constant or undef.";
207
5
        }
208
20.3k
        const auto constituent_type = _.FindDef(constituent->type_id());
209
20.3k
        if (!constituent_type) {
210
0
          return _.diag(SPV_ERROR_INVALID_ID, constituent)
211
0
                 << "Result type is not defined.";
212
0
        }
213
20.3k
        if (element_type->id() != constituent_type->id()) {
214
25
          return _.diag(SPV_ERROR_INVALID_ID, inst)
215
25
                 << opcode_name << " Constituent <id> "
216
25
                 << _.getIdName(constituent_id)
217
25
                 << "s type does not match Result Type <id> "
218
25
                 << _.getIdName(result_type->id()) << "s array element type.";
219
25
        }
220
20.3k
      }
221
1.51k
    } break;
222
2.66k
    case spv::Op::OpTypeStruct: {
223
2.66k
      const auto member_count = result_type->words().size() - 2;
224
2.66k
      if (member_count != constituent_count) {
225
10
        return _.diag(SPV_ERROR_INVALID_ID, inst)
226
10
               << opcode_name << " Constituent <id> "
227
10
               << _.getIdName(inst->type_id())
228
10
               << " count does not match Result Type <id> "
229
10
               << _.getIdName(result_type->id()) << "s struct member count.";
230
10
      }
231
2.65k
      for (uint32_t constituent_index = 2, member_index = 1;
232
10.2k
           constituent_index < inst->operands().size();
233
7.61k
           constituent_index++, member_index++) {
234
7.61k
        const auto constituent_id =
235
7.61k
            inst->GetOperandAs<uint32_t>(constituent_index);
236
7.61k
        const auto constituent = _.FindDef(constituent_id);
237
7.61k
        if (!constituent ||
238
7.61k
            !spvOpcodeIsConstantOrUndef(constituent->opcode())) {
239
7
          return _.diag(SPV_ERROR_INVALID_ID, inst)
240
7
                 << opcode_name << " Constituent <id> "
241
7
                 << _.getIdName(constituent_id)
242
7
                 << " is not a constant or undef.";
243
7
        }
244
7.60k
        const auto constituent_type = _.FindDef(constituent->type_id());
245
7.60k
        if (!constituent_type) {
246
0
          return _.diag(SPV_ERROR_INVALID_ID, constituent)
247
0
                 << "Result type is not defined.";
248
0
        }
249
250
7.60k
        const auto member_type_id =
251
7.60k
            result_type->GetOperandAs<uint32_t>(member_index);
252
7.60k
        const auto member_type = _.FindDef(member_type_id);
253
7.60k
        if (!member_type || member_type->id() != constituent_type->id()) {
254
18
          return _.diag(SPV_ERROR_INVALID_ID, inst)
255
18
                 << opcode_name << " Constituent <id> "
256
18
                 << _.getIdName(constituent_id)
257
18
                 << " type does not match the Result Type <id> "
258
18
                 << _.getIdName(result_type->id()) << "s member type.";
259
18
        }
260
7.60k
      }
261
2.65k
    } break;
262
2.62k
    case spv::Op::OpTypeCooperativeMatrixKHR:
263
0
    case spv::Op::OpTypeCooperativeMatrixNV: {
264
0
      if (1 != constituent_count) {
265
0
        return _.diag(SPV_ERROR_INVALID_ID, inst)
266
0
               << opcode_name << " Constituent <id> "
267
0
               << _.getIdName(inst->type_id()) << " count must be one.";
268
0
      }
269
0
      const auto constituent_id = inst->GetOperandAs<uint32_t>(2);
270
0
      const auto constituent = _.FindDef(constituent_id);
271
0
      if (!constituent || !spvOpcodeIsConstantOrUndef(constituent->opcode())) {
272
0
        return _.diag(SPV_ERROR_INVALID_ID, inst)
273
0
               << opcode_name << " Constituent <id> "
274
0
               << _.getIdName(constituent_id) << " is not a constant or undef.";
275
0
      }
276
0
      const auto constituent_type = _.FindDef(constituent->type_id());
277
0
      if (!constituent_type) {
278
0
        return _.diag(SPV_ERROR_INVALID_ID, constituent)
279
0
               << "Result type is not defined.";
280
0
      }
281
282
0
      const auto component_type_id = result_type->GetOperandAs<uint32_t>(1);
283
0
      const auto component_type = _.FindDef(component_type_id);
284
0
      if (!component_type || component_type->id() != constituent_type->id()) {
285
0
        return _.diag(SPV_ERROR_INVALID_ID, inst)
286
0
               << opcode_name << " Constituent <id> "
287
0
               << _.getIdName(constituent_id)
288
0
               << " type does not match the Result Type <id> "
289
0
               << _.getIdName(result_type->id()) << "s component type.";
290
0
      }
291
0
    } break;
292
0
    case spv::Op::OpTypeTensorARM: {
293
0
      auto inst_element_type =
294
0
          _.FindDef(result_type->GetOperandAs<uint32_t>(1));
295
0
      if (!inst_element_type) {
296
0
        return _.diag(SPV_ERROR_INVALID_ID, result_type)
297
0
               << "Element type is not defined.";
298
0
      }
299
0
      const auto inst_rank = _.FindDef(result_type->GetOperandAs<uint32_t>(2));
300
0
      if (!inst_rank) {
301
0
        return _.diag(SPV_ERROR_INVALID_ID, result_type)
302
0
               << "Rank is not defined.";
303
0
      }
304
0
      const auto inst_shape = _.FindDef(result_type->GetOperandAs<uint32_t>(3));
305
0
      if (!inst_shape) {
306
0
        return _.diag(SPV_ERROR_INVALID_ID, result_type)
307
0
               << "Shape is not defined.";
308
0
      }
309
310
0
      uint64_t rank = 0;
311
0
      _.EvalConstantValUint64(inst_rank->id(), &rank);
312
313
0
      uint64_t outermost_shape = 0;
314
0
      if (_.EvalConstantValUint64(inst_shape->GetOperandAs<uint32_t>(2),
315
0
                                  &outermost_shape) &&
316
0
          (outermost_shape != constituent_count)) {
317
0
        return _.diag(SPV_ERROR_INVALID_ID, inst)
318
0
               << opcode_name
319
0
               << " Constituent count does not match "
320
0
                  "the shape of Result Type <id> "
321
0
               << _.getIdName(result_type->id())
322
0
               << " along its outermost dimension, " << "expected "
323
0
               << outermost_shape << " but got " << constituent_count << ".";
324
0
      }
325
326
0
      for (size_t constituent_index = 2;
327
0
           constituent_index < inst->operands().size(); constituent_index++) {
328
0
        const auto constituent_id =
329
0
            inst->GetOperandAs<uint32_t>(constituent_index);
330
0
        const auto constituent = _.FindDef(constituent_id);
331
0
        if (!constituent ||
332
0
            !spvOpcodeIsConstantOrUndef(constituent->opcode())) {
333
0
          return _.diag(SPV_ERROR_INVALID_ID, inst)
334
0
                 << opcode_name << " Constituent <id> "
335
0
                 << _.getIdName(constituent_id)
336
0
                 << " is not a constant or undef.";
337
0
        }
338
0
        const auto constituent_type = _.FindDef(constituent->type_id());
339
0
        if (!constituent_type) {
340
0
          return _.diag(SPV_ERROR_INVALID_ID, constituent)
341
0
                 << "Type of Constituent " << constituent_index - 2
342
0
                 << " is not defined.";
343
0
        }
344
345
0
        if (rank == 0) {
346
          // The rank of the returned tensor constant is not known.
347
          // Skip rank-dependent validation.
348
0
          continue;
349
0
        }
350
351
0
        if (rank == 1) {
352
0
          if (inst_element_type->id() != constituent_type->id()) {
353
0
            return _.diag(SPV_ERROR_INVALID_ID, inst)
354
0
                   << opcode_name << " Constituent <id> "
355
0
                   << _.getIdName(constituent_id)
356
0
                   << " type does not match the element type of the tensor ("
357
0
                   << _.getIdName(result_type->id()) << ").";
358
0
          }
359
0
        } else {
360
0
          if (constituent_type->opcode() != spv::Op::OpTypeTensorARM) {
361
0
            return _.diag(SPV_ERROR_INVALID_ID, inst)
362
0
                   << opcode_name << " Constituent <id> "
363
0
                   << _.getIdName(constituent_id)
364
0
                   << " must be an OpTypeTensorARM.";
365
0
          }
366
0
          auto inst_constituent_element_type =
367
0
              _.FindDef(constituent_type->GetOperandAs<uint32_t>(1));
368
0
          if (!inst_constituent_element_type ||
369
0
              inst_constituent_element_type->id() != inst_element_type->id()) {
370
0
            return _.diag(SPV_ERROR_INVALID_ID, inst)
371
0
                   << opcode_name << " Constituent <id> "
372
0
                   << _.getIdName(constituent_id)
373
0
                   << " must have the same Element Type as Result Type <id> "
374
0
                   << _.getIdName(result_type->id()) << ".";
375
0
          }
376
0
          auto inst_constituent_rank =
377
0
              _.FindDef(constituent_type->GetOperandAs<uint32_t>(2));
378
0
          uint64_t constituent_rank;
379
0
          if (inst_constituent_rank &&
380
0
              _.EvalConstantValUint64(inst_constituent_rank->id(),
381
0
                                      &constituent_rank) &&
382
0
              (constituent_rank != rank - 1)) {
383
0
            return _.diag(SPV_ERROR_INVALID_ID, inst)
384
0
                   << opcode_name << " Constituent <id> "
385
0
                   << _.getIdName(constituent_id)
386
0
                   << " must have a Rank that is 1 less than the Rank of "
387
0
                      "Result Type <id> "
388
0
                   << _.getIdName(result_type->id()) << ", expected "
389
0
                   << rank - 1 << " but got " << constituent_rank << ".";
390
0
          }
391
392
0
          auto inst_constituent_shape =
393
0
              _.FindDef(constituent_type->GetOperandAs<uint32_t>(3));
394
0
          if (!inst_constituent_shape) {
395
0
            return _.diag(SPV_ERROR_INVALID_ID, result_type)
396
0
                   << "Shape of Constituent " << constituent_index - 2
397
0
                   << " is not defined.";
398
0
          }
399
0
          for (size_t constituent_shape_index = 2;
400
0
               constituent_shape_index <
401
0
               inst_constituent_shape->operands().size();
402
0
               constituent_shape_index++) {
403
0
            size_t shape_index = constituent_shape_index + 1;
404
0
            uint64_t constituent_shape = 0, shape = 1;
405
0
            if (_.EvalConstantValUint64(
406
0
                    inst_constituent_shape->GetOperandAs<uint32_t>(
407
0
                        constituent_shape_index),
408
0
                    &constituent_shape) &&
409
0
                _.EvalConstantValUint64(
410
0
                    inst_shape->GetOperandAs<uint32_t>(shape_index), &shape) &&
411
0
                (constituent_shape != shape)) {
412
0
              return _.diag(SPV_ERROR_INVALID_ID, inst)
413
0
                     << opcode_name << " Constituent <id> "
414
0
                     << _.getIdName(constituent_id)
415
0
                     << " must have a Shape that matches that of Result Type "
416
0
                        "<id> "
417
0
                     << _.getIdName(result_type->id())
418
0
                     << " along all inner dimensions of Result Type, expected "
419
0
                     << shape << " for dimension "
420
0
                     << constituent_shape_index - 2
421
0
                     << " of Constituent but got " << constituent_shape << ".";
422
0
            }
423
0
          }
424
0
        }
425
0
      }
426
0
    } break;
427
25
    default:
428
25
      break;
429
31.1k
  }
430
30.9k
  return SPV_SUCCESS;
431
31.1k
}
432
433
spv_result_t ValidateConstantSampler(ValidationState_t& _,
434
0
                                     const Instruction* inst) {
435
0
  const auto result_type = _.FindDef(inst->type_id());
436
0
  if (!result_type || result_type->opcode() != spv::Op::OpTypeSampler) {
437
0
    return _.diag(SPV_ERROR_INVALID_ID, result_type)
438
0
           << "OpConstantSampler Result Type <id> "
439
0
           << _.getIdName(inst->type_id()) << " is not a sampler type.";
440
0
  }
441
442
0
  return SPV_SUCCESS;
443
0
}
444
445
// True if instruction defines a type that can have a null value, as defined by
446
// the SPIR-V spec.  Tracks composite-type components through module to check
447
// nullability transitively.
448
bool IsTypeNullable(const std::vector<uint32_t>& instruction,
449
5.27k
                    const ValidationState_t& _) {
450
5.27k
  uint16_t opcode;
451
5.27k
  uint16_t word_count;
452
5.27k
  spvOpcodeSplit(instruction[0], &word_count, &opcode);
453
5.27k
  switch (static_cast<spv::Op>(opcode)) {
454
1.24k
    case spv::Op::OpTypeBool:
455
3.04k
    case spv::Op::OpTypeInt:
456
3.79k
    case spv::Op::OpTypeFloat:
457
3.83k
    case spv::Op::OpTypeEvent:
458
3.85k
    case spv::Op::OpTypeDeviceEvent:
459
3.92k
    case spv::Op::OpTypeReserveId:
460
3.96k
    case spv::Op::OpTypeQueue:
461
3.96k
      return true;
462
53
    case spv::Op::OpTypeArray:
463
93
    case spv::Op::OpTypeMatrix:
464
93
    case spv::Op::OpTypeCooperativeMatrixNV:
465
93
    case spv::Op::OpTypeCooperativeMatrixKHR:
466
93
    case spv::Op::OpTypeCooperativeVectorNV:
467
555
    case spv::Op::OpTypeVector: {
468
555
      auto base_type = _.FindDef(instruction[2]);
469
555
      return base_type && IsTypeNullable(base_type->words(), _);
470
93
    }
471
490
    case spv::Op::OpTypeStruct: {
472
1.15k
      for (size_t elementIndex = 2; elementIndex < instruction.size();
473
679
           ++elementIndex) {
474
679
        auto element = _.FindDef(instruction[elementIndex]);
475
679
        if (!element || !IsTypeNullable(element->words(), _)) return false;
476
679
      }
477
479
      return true;
478
490
    }
479
0
    case spv::Op::OpTypeUntypedPointerKHR:
480
257
    case spv::Op::OpTypePointer:
481
257
      if (spv::StorageClass(instruction[2]) ==
482
257
          spv::StorageClass::PhysicalStorageBuffer) {
483
0
        return false;
484
0
      }
485
257
      return true;
486
0
    case spv::Op::OpTypeTensorARM: {
487
0
      auto elem_type = _.FindDef(instruction[2]);
488
0
      return (instruction.size() > 4) && elem_type &&
489
0
             IsTypeNullable(elem_type->words(), _);
490
257
    }
491
15
    default:
492
15
      return false;
493
5.27k
  }
494
5.27k
}
495
496
spv_result_t ValidateConstantNull(ValidationState_t& _,
497
4.04k
                                  const Instruction* inst) {
498
4.04k
  const auto result_type = _.FindDef(inst->type_id());
499
4.04k
  if (!result_type || !IsTypeNullable(result_type->words(), _)) {
500
15
    return _.diag(SPV_ERROR_INVALID_ID, inst)
501
15
           << "OpConstantNull Result Type <id> " << _.getIdName(inst->type_id())
502
15
           << " cannot have a null value.";
503
15
  }
504
505
4.02k
  return SPV_SUCCESS;
506
4.04k
}
507
508
// Validates that OpSpecConstant specializes to either int or float type.
509
spv_result_t ValidateSpecConstant(ValidationState_t& _,
510
2.87k
                                  const Instruction* inst) {
511
  // Operand 0 is the <id> of the type that we're specializing to.
512
2.87k
  auto type_id = inst->GetOperandAs<const uint32_t>(0);
513
2.87k
  auto type_instruction = _.FindDef(type_id);
514
2.87k
  auto type_opcode = type_instruction->opcode();
515
2.87k
  if (type_opcode != spv::Op::OpTypeInt &&
516
1.12k
      type_opcode != spv::Op::OpTypeFloat) {
517
0
    return _.diag(SPV_ERROR_INVALID_DATA, inst) << "Specialization constant "
518
0
                                                   "must be an integer or "
519
0
                                                   "floating-point number.";
520
0
  }
521
2.87k
  return SPV_SUCCESS;
522
2.87k
}
523
524
spv_result_t ValidateSpecConstantOp(ValidationState_t& _,
525
124
                                    const Instruction* inst) {
526
124
  const auto op = inst->GetOperandAs<spv::Op>(2);
527
528
  // The binary parser already ensures that the op is valid for *some*
529
  // environment.  Here we check restrictions.
530
124
  switch (op) {
531
10
    case spv::Op::OpQuantizeToF16:
532
10
      if (!_.HasCapability(spv::Capability::Shader)) {
533
3
        return _.diag(SPV_ERROR_INVALID_ID, inst)
534
3
               << "Specialization constant operation " << spvOpcodeString(op)
535
3
               << " requires Shader capability";
536
3
      }
537
7
      break;
538
539
7
    case spv::Op::OpUConvert:
540
6
      if (!_.features().uconvert_spec_constant_op &&
541
6
          !_.HasCapability(spv::Capability::Kernel)) {
542
3
        return _.diag(SPV_ERROR_INVALID_ID, inst)
543
3
               << "Prior to SPIR-V 1.4, specialization constant operation "
544
3
                  "UConvert requires Kernel capability or extension "
545
3
                  "SPV_AMD_gpu_shader_int16";
546
3
      }
547
3
      break;
548
549
3
    case spv::Op::OpConvertFToS:
550
6
    case spv::Op::OpConvertSToF:
551
9
    case spv::Op::OpConvertFToU:
552
12
    case spv::Op::OpConvertUToF:
553
15
    case spv::Op::OpConvertPtrToU:
554
18
    case spv::Op::OpConvertUToPtr:
555
21
    case spv::Op::OpGenericCastToPtr:
556
24
    case spv::Op::OpPtrCastToGeneric:
557
27
    case spv::Op::OpBitcast:
558
30
    case spv::Op::OpFNegate:
559
33
    case spv::Op::OpFAdd:
560
36
    case spv::Op::OpFSub:
561
39
    case spv::Op::OpFMul:
562
41
    case spv::Op::OpFDiv:
563
44
    case spv::Op::OpFRem:
564
47
    case spv::Op::OpFMod:
565
50
    case spv::Op::OpAccessChain:
566
53
    case spv::Op::OpInBoundsAccessChain:
567
56
    case spv::Op::OpPtrAccessChain:
568
59
    case spv::Op::OpInBoundsPtrAccessChain:
569
59
      if (!_.HasCapability(spv::Capability::Kernel)) {
570
44
        return _.diag(SPV_ERROR_INVALID_ID, inst)
571
44
               << "Specialization constant operation " << spvOpcodeString(op)
572
44
               << " requires Kernel capability";
573
44
      }
574
15
      break;
575
576
49
    default:
577
49
      break;
578
124
  }
579
580
  // TODO(dneto): Validate result type and arguments to the various operations.
581
74
  return SPV_SUCCESS;
582
124
}
583
584
}  // namespace
585
586
13.1M
spv_result_t ConstantPass(ValidationState_t& _, const Instruction* inst) {
587
13.1M
  switch (inst->opcode()) {
588
3.31k
    case spv::Op::OpConstantTrue:
589
5.54k
    case spv::Op::OpConstantFalse:
590
6.20k
    case spv::Op::OpSpecConstantTrue:
591
6.73k
    case spv::Op::OpSpecConstantFalse:
592
6.73k
      if (auto error = ValidateConstantBool(_, inst)) return error;
593
6.71k
      break;
594
30.7k
    case spv::Op::OpConstantComposite:
595
31.1k
    case spv::Op::OpSpecConstantComposite:
596
31.1k
      if (auto error = ValidateConstantComposite(_, inst)) return error;
597
30.9k
      break;
598
30.9k
    case spv::Op::OpConstantSampler:
599
0
      if (auto error = ValidateConstantSampler(_, inst)) return error;
600
0
      break;
601
4.04k
    case spv::Op::OpConstantNull:
602
4.04k
      if (auto error = ValidateConstantNull(_, inst)) return error;
603
4.02k
      break;
604
4.02k
    case spv::Op::OpSpecConstant:
605
2.87k
      if (auto error = ValidateSpecConstant(_, inst)) return error;
606
2.87k
      break;
607
2.87k
    case spv::Op::OpSpecConstantOp:
608
124
      if (auto error = ValidateSpecConstantOp(_, inst)) return error;
609
74
      break;
610
13.0M
    default:
611
13.0M
      break;
612
13.1M
  }
613
614
  // Generally disallow creating 8- or 16-bit constants unless the full
615
  // capabilities are present.
616
13.1M
  if (spvOpcodeIsConstant(inst->opcode()) &&
617
228k
      _.HasCapability(spv::Capability::Shader) &&
618
227k
      !_.IsPointerType(inst->type_id()) &&
619
227k
      _.ContainsLimitedUseIntOrFloatType(inst->type_id())) {
620
3
    return _.diag(SPV_ERROR_INVALID_ID, inst)
621
3
           << "Cannot form constants of 8- or 16-bit types";
622
3
  }
623
624
13.1M
  return SPV_SUCCESS;
625
13.1M
}
626
627
}  // namespace val
628
}  // namespace spvtools