Coverage Report

Created: 2024-09-11 07:09

/src/spirv-tools/source/val/validate_bitwise.cpp
Line
Count
Source (jump to first uncovered line)
1
// Copyright (c) 2017 Google Inc.
2
//
3
// Licensed under the Apache License, Version 2.0 (the "License");
4
// you may not use this file except in compliance with the License.
5
// You may obtain a copy of the License at
6
//
7
//     http://www.apache.org/licenses/LICENSE-2.0
8
//
9
// Unless required by applicable law or agreed to in writing, software
10
// distributed under the License is distributed on an "AS IS" BASIS,
11
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
// See the License for the specific language governing permissions and
13
// limitations under the License.
14
15
// Validates correctness of bitwise instructions.
16
17
#include "source/opcode.h"
18
#include "source/spirv_target_env.h"
19
#include "source/val/instruction.h"
20
#include "source/val/validate.h"
21
#include "source/val/validation_state.h"
22
23
namespace spvtools {
24
namespace val {
25
26
// Validates when base and result need to be the same type
27
spv_result_t ValidateBaseType(ValidationState_t& _, const Instruction* inst,
28
2.05k
                              const uint32_t base_type) {
29
2.05k
  const spv::Op opcode = inst->opcode();
30
31
2.05k
  if (!_.IsIntScalarType(base_type) && !_.IsIntVectorType(base_type)) {
32
15
    return _.diag(SPV_ERROR_INVALID_DATA, inst)
33
15
           << _.VkErrorID(4781)
34
15
           << "Expected int scalar or vector type for Base operand: "
35
15
           << spvOpcodeString(opcode);
36
15
  }
37
38
  // Vulkan has a restriction to 32 bit for base
39
2.04k
  if (spvIsVulkanEnv(_.context()->target_env)) {
40
0
    if (_.GetBitWidth(base_type) != 32) {
41
0
      return _.diag(SPV_ERROR_INVALID_DATA, inst)
42
0
             << _.VkErrorID(4781)
43
0
             << "Expected 32-bit int type for Base operand: "
44
0
             << spvOpcodeString(opcode);
45
0
    }
46
0
  }
47
48
  // OpBitCount just needs same number of components
49
2.04k
  if (base_type != inst->type_id() && opcode != spv::Op::OpBitCount) {
50
7
    return _.diag(SPV_ERROR_INVALID_DATA, inst)
51
7
           << "Expected Base Type to be equal to Result Type: "
52
7
           << spvOpcodeString(opcode);
53
7
  }
54
55
2.03k
  return SPV_SUCCESS;
56
2.04k
}
57
58
// Validates correctness of bitwise instructions.
59
10.1M
spv_result_t BitwisePass(ValidationState_t& _, const Instruction* inst) {
60
10.1M
  const spv::Op opcode = inst->opcode();
61
10.1M
  const uint32_t result_type = inst->type_id();
62
63
10.1M
  switch (opcode) {
64
297
    case spv::Op::OpShiftRightLogical:
65
568
    case spv::Op::OpShiftRightArithmetic:
66
851
    case spv::Op::OpShiftLeftLogical: {
67
851
      if (!_.IsIntScalarType(result_type) && !_.IsIntVectorType(result_type))
68
2
        return _.diag(SPV_ERROR_INVALID_DATA, inst)
69
2
               << "Expected int scalar or vector type as Result Type: "
70
2
               << spvOpcodeString(opcode);
71
72
849
      const uint32_t result_dimension = _.GetDimension(result_type);
73
849
      const uint32_t base_type = _.GetOperandTypeId(inst, 2);
74
849
      const uint32_t shift_type = _.GetOperandTypeId(inst, 3);
75
76
849
      if (!base_type ||
77
849
          (!_.IsIntScalarType(base_type) && !_.IsIntVectorType(base_type)))
78
3
        return _.diag(SPV_ERROR_INVALID_DATA, inst)
79
3
               << "Expected Base to be int scalar or vector: "
80
3
               << spvOpcodeString(opcode);
81
82
846
      if (_.GetDimension(base_type) != result_dimension)
83
4
        return _.diag(SPV_ERROR_INVALID_DATA, inst)
84
4
               << "Expected Base to have the same dimension "
85
4
               << "as Result Type: " << spvOpcodeString(opcode);
86
87
842
      if (_.GetBitWidth(base_type) != _.GetBitWidth(result_type))
88
0
        return _.diag(SPV_ERROR_INVALID_DATA, inst)
89
0
               << "Expected Base to have the same bit width "
90
0
               << "as Result Type: " << spvOpcodeString(opcode);
91
92
842
      if (!shift_type ||
93
842
          (!_.IsIntScalarType(shift_type) && !_.IsIntVectorType(shift_type)))
94
8
        return _.diag(SPV_ERROR_INVALID_DATA, inst)
95
8
               << "Expected Shift to be int scalar or vector: "
96
8
               << spvOpcodeString(opcode);
97
98
834
      if (_.GetDimension(shift_type) != result_dimension)
99
3
        return _.diag(SPV_ERROR_INVALID_DATA, inst)
100
3
               << "Expected Shift to have the same dimension "
101
3
               << "as Result Type: " << spvOpcodeString(opcode);
102
831
      break;
103
834
    }
104
105
10.4k
    case spv::Op::OpBitwiseOr:
106
10.7k
    case spv::Op::OpBitwiseXor:
107
14.0k
    case spv::Op::OpBitwiseAnd:
108
16.2k
    case spv::Op::OpNot: {
109
16.2k
      if (!_.IsIntScalarType(result_type) && !_.IsIntVectorType(result_type))
110
10
        return _.diag(SPV_ERROR_INVALID_DATA, inst)
111
10
               << "Expected int scalar or vector type as Result Type: "
112
10
               << spvOpcodeString(opcode);
113
114
16.2k
      const uint32_t result_dimension = _.GetDimension(result_type);
115
16.2k
      const uint32_t result_bit_width = _.GetBitWidth(result_type);
116
117
46.5k
      for (size_t operand_index = 2; operand_index < inst->operands().size();
118
30.2k
           ++operand_index) {
119
30.2k
        const uint32_t type_id = _.GetOperandTypeId(inst, operand_index);
120
30.2k
        if (!type_id ||
121
30.2k
            (!_.IsIntScalarType(type_id) && !_.IsIntVectorType(type_id)))
122
19
          return _.diag(SPV_ERROR_INVALID_DATA, inst)
123
19
                 << "Expected int scalar or vector as operand: "
124
19
                 << spvOpcodeString(opcode) << " operand index "
125
19
                 << operand_index;
126
127
30.2k
        if (_.GetDimension(type_id) != result_dimension)
128
8
          return _.diag(SPV_ERROR_INVALID_DATA, inst)
129
8
                 << "Expected operands to have the same dimension "
130
8
                 << "as Result Type: " << spvOpcodeString(opcode)
131
8
                 << " operand index " << operand_index;
132
133
30.2k
        if (_.GetBitWidth(type_id) != result_bit_width)
134
0
          return _.diag(SPV_ERROR_INVALID_DATA, inst)
135
0
                 << "Expected operands to have the same bit width "
136
0
                 << "as Result Type: " << spvOpcodeString(opcode)
137
0
                 << " operand index " << operand_index;
138
30.2k
      }
139
16.2k
      break;
140
16.2k
    }
141
142
16.2k
    case spv::Op::OpBitFieldInsert: {
143
483
      const uint32_t base_type = _.GetOperandTypeId(inst, 2);
144
483
      const uint32_t insert_type = _.GetOperandTypeId(inst, 3);
145
483
      const uint32_t offset_type = _.GetOperandTypeId(inst, 4);
146
483
      const uint32_t count_type = _.GetOperandTypeId(inst, 5);
147
148
483
      if (spv_result_t error = ValidateBaseType(_, inst, base_type)) {
149
3
        return error;
150
3
      }
151
152
480
      if (insert_type != result_type)
153
1
        return _.diag(SPV_ERROR_INVALID_DATA, inst)
154
1
               << "Expected Insert Type to be equal to Result Type: "
155
1
               << spvOpcodeString(opcode);
156
157
479
      if (!offset_type || !_.IsIntScalarType(offset_type))
158
4
        return _.diag(SPV_ERROR_INVALID_DATA, inst)
159
4
               << "Expected Offset Type to be int scalar: "
160
4
               << spvOpcodeString(opcode);
161
162
475
      if (!count_type || !_.IsIntScalarType(count_type))
163
4
        return _.diag(SPV_ERROR_INVALID_DATA, inst)
164
4
               << "Expected Count Type to be int scalar: "
165
4
               << spvOpcodeString(opcode);
166
471
      break;
167
475
    }
168
169
486
    case spv::Op::OpBitFieldSExtract:
170
714
    case spv::Op::OpBitFieldUExtract: {
171
714
      const uint32_t base_type = _.GetOperandTypeId(inst, 2);
172
714
      const uint32_t offset_type = _.GetOperandTypeId(inst, 3);
173
714
      const uint32_t count_type = _.GetOperandTypeId(inst, 4);
174
175
714
      if (spv_result_t error = ValidateBaseType(_, inst, base_type)) {
176
12
        return error;
177
12
      }
178
179
702
      if (!offset_type || !_.IsIntScalarType(offset_type))
180
3
        return _.diag(SPV_ERROR_INVALID_DATA, inst)
181
3
               << "Expected Offset Type to be int scalar: "
182
3
               << spvOpcodeString(opcode);
183
184
699
      if (!count_type || !_.IsIntScalarType(count_type))
185
2
        return _.diag(SPV_ERROR_INVALID_DATA, inst)
186
2
               << "Expected Count Type to be int scalar: "
187
2
               << spvOpcodeString(opcode);
188
697
      break;
189
699
    }
190
191
697
    case spv::Op::OpBitReverse: {
192
312
      const uint32_t base_type = _.GetOperandTypeId(inst, 2);
193
194
312
      if (spv_result_t error = ValidateBaseType(_, inst, base_type)) {
195
4
        return error;
196
4
      }
197
198
308
      break;
199
312
    }
200
201
552
    case spv::Op::OpBitCount: {
202
552
      if (!_.IsIntScalarType(result_type) && !_.IsIntVectorType(result_type))
203
4
        return _.diag(SPV_ERROR_INVALID_DATA, inst)
204
4
               << "Expected int scalar or vector type as Result Type: "
205
4
               << spvOpcodeString(opcode);
206
207
548
      const uint32_t base_type = _.GetOperandTypeId(inst, 2);
208
209
548
      if (spv_result_t error = ValidateBaseType(_, inst, base_type)) {
210
3
        return error;
211
3
      }
212
213
545
      const uint32_t base_dimension = _.GetDimension(base_type);
214
545
      const uint32_t result_dimension = _.GetDimension(result_type);
215
216
545
      if (base_dimension != result_dimension)
217
2
        return _.diag(SPV_ERROR_INVALID_DATA, inst)
218
2
               << "Expected Base dimension to be equal to Result Type "
219
2
                  "dimension: "
220
2
               << spvOpcodeString(opcode);
221
543
      break;
222
545
    }
223
224
10.1M
    default:
225
10.1M
      break;
226
10.1M
  }
227
228
10.1M
  return SPV_SUCCESS;
229
10.1M
}
230
231
}  // namespace val
232
}  // namespace spvtools