Coverage Report

Created: 2023-03-01 07:33

/src/spirv-tools/source/val/validate_bitwise.cpp
Line
Count
Source (jump to first uncovered line)
1
// Copyright (c) 2017 Google Inc.
2
//
3
// Licensed under the Apache License, Version 2.0 (the "License");
4
// you may not use this file except in compliance with the License.
5
// You may obtain a copy of the License at
6
//
7
//     http://www.apache.org/licenses/LICENSE-2.0
8
//
9
// Unless required by applicable law or agreed to in writing, software
10
// distributed under the License is distributed on an "AS IS" BASIS,
11
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
// See the License for the specific language governing permissions and
13
// limitations under the License.
14
15
// Validates correctness of bitwise instructions.
16
17
#include "source/diagnostic.h"
18
#include "source/opcode.h"
19
#include "source/spirv_target_env.h"
20
#include "source/val/instruction.h"
21
#include "source/val/validate.h"
22
#include "source/val/validation_state.h"
23
24
namespace spvtools {
25
namespace val {
26
27
// Validates when base and result need to be the same type
28
spv_result_t ValidateBaseType(ValidationState_t& _, const Instruction* inst,
29
2.50k
                              const uint32_t base_type) {
30
2.50k
  const spv::Op opcode = inst->opcode();
31
32
2.50k
  if (!_.IsIntScalarType(base_type) && !_.IsIntVectorType(base_type)) {
33
14
    return _.diag(SPV_ERROR_INVALID_DATA, inst)
34
14
           << _.VkErrorID(4781)
35
14
           << "Expected int scalar or vector type for Base operand: "
36
14
           << spvOpcodeString(opcode);
37
14
  }
38
39
  // Vulkan has a restriction to 32 bit for base
40
2.49k
  if (spvIsVulkanEnv(_.context()->target_env)) {
41
0
    if (_.GetBitWidth(base_type) != 32) {
42
0
      return _.diag(SPV_ERROR_INVALID_DATA, inst)
43
0
             << _.VkErrorID(4781)
44
0
             << "Expected 32-bit int type for Base operand: "
45
0
             << spvOpcodeString(opcode);
46
0
    }
47
0
  }
48
49
  // OpBitCount just needs same number of components
50
2.49k
  if (base_type != inst->type_id() && opcode != spv::Op::OpBitCount) {
51
7
    return _.diag(SPV_ERROR_INVALID_DATA, inst)
52
7
           << "Expected Base Type to be equal to Result Type: "
53
7
           << spvOpcodeString(opcode);
54
7
  }
55
56
2.48k
  return SPV_SUCCESS;
57
2.49k
}
58
59
// Validates correctness of bitwise instructions.
60
2.69M
spv_result_t BitwisePass(ValidationState_t& _, const Instruction* inst) {
61
2.69M
  const spv::Op opcode = inst->opcode();
62
2.69M
  const uint32_t result_type = inst->type_id();
63
64
2.69M
  switch (opcode) {
65
220
    case spv::Op::OpShiftRightLogical:
66
717
    case spv::Op::OpShiftRightArithmetic:
67
949
    case spv::Op::OpShiftLeftLogical: {
68
949
      if (!_.IsIntScalarType(result_type) && !_.IsIntVectorType(result_type))
69
6
        return _.diag(SPV_ERROR_INVALID_DATA, inst)
70
6
               << "Expected int scalar or vector type as Result Type: "
71
6
               << spvOpcodeString(opcode);
72
73
943
      const uint32_t result_dimension = _.GetDimension(result_type);
74
943
      const uint32_t base_type = _.GetOperandTypeId(inst, 2);
75
943
      const uint32_t shift_type = _.GetOperandTypeId(inst, 3);
76
77
943
      if (!base_type ||
78
943
          (!_.IsIntScalarType(base_type) && !_.IsIntVectorType(base_type)))
79
1
        return _.diag(SPV_ERROR_INVALID_DATA, inst)
80
1
               << "Expected Base to be int scalar or vector: "
81
1
               << spvOpcodeString(opcode);
82
83
942
      if (_.GetDimension(base_type) != result_dimension)
84
5
        return _.diag(SPV_ERROR_INVALID_DATA, inst)
85
5
               << "Expected Base to have the same dimension "
86
5
               << "as Result Type: " << spvOpcodeString(opcode);
87
88
937
      if (_.GetBitWidth(base_type) != _.GetBitWidth(result_type))
89
0
        return _.diag(SPV_ERROR_INVALID_DATA, inst)
90
0
               << "Expected Base to have the same bit width "
91
0
               << "as Result Type: " << spvOpcodeString(opcode);
92
93
937
      if (!shift_type ||
94
937
          (!_.IsIntScalarType(shift_type) && !_.IsIntVectorType(shift_type)))
95
4
        return _.diag(SPV_ERROR_INVALID_DATA, inst)
96
4
               << "Expected Shift to be int scalar or vector: "
97
4
               << spvOpcodeString(opcode);
98
99
933
      if (_.GetDimension(shift_type) != result_dimension)
100
3
        return _.diag(SPV_ERROR_INVALID_DATA, inst)
101
3
               << "Expected Shift to have the same dimension "
102
3
               << "as Result Type: " << spvOpcodeString(opcode);
103
930
      break;
104
933
    }
105
106
19.1k
    case spv::Op::OpBitwiseOr:
107
19.4k
    case spv::Op::OpBitwiseXor:
108
24.6k
    case spv::Op::OpBitwiseAnd:
109
27.9k
    case spv::Op::OpNot: {
110
27.9k
      if (!_.IsIntScalarType(result_type) && !_.IsIntVectorType(result_type))
111
8
        return _.diag(SPV_ERROR_INVALID_DATA, inst)
112
8
               << "Expected int scalar or vector type as Result Type: "
113
8
               << spvOpcodeString(opcode);
114
115
27.9k
      const uint32_t result_dimension = _.GetDimension(result_type);
116
27.9k
      const uint32_t result_bit_width = _.GetBitWidth(result_type);
117
118
80.4k
      for (size_t operand_index = 2; operand_index < inst->operands().size();
119
52.5k
           ++operand_index) {
120
52.5k
        const uint32_t type_id = _.GetOperandTypeId(inst, operand_index);
121
52.5k
        if (!type_id ||
122
52.5k
            (!_.IsIntScalarType(type_id) && !_.IsIntVectorType(type_id)))
123
14
          return _.diag(SPV_ERROR_INVALID_DATA, inst)
124
14
                 << "Expected int scalar or vector as operand: "
125
14
                 << spvOpcodeString(opcode) << " operand index "
126
14
                 << operand_index;
127
128
52.5k
        if (_.GetDimension(type_id) != result_dimension)
129
6
          return _.diag(SPV_ERROR_INVALID_DATA, inst)
130
6
                 << "Expected operands to have the same dimension "
131
6
                 << "as Result Type: " << spvOpcodeString(opcode)
132
6
                 << " operand index " << operand_index;
133
134
52.5k
        if (_.GetBitWidth(type_id) != result_bit_width)
135
0
          return _.diag(SPV_ERROR_INVALID_DATA, inst)
136
0
                 << "Expected operands to have the same bit width "
137
0
                 << "as Result Type: " << spvOpcodeString(opcode)
138
0
                 << " operand index " << operand_index;
139
52.5k
      }
140
27.8k
      break;
141
27.9k
    }
142
143
27.8k
    case spv::Op::OpBitFieldInsert: {
144
554
      const uint32_t base_type = _.GetOperandTypeId(inst, 2);
145
554
      const uint32_t insert_type = _.GetOperandTypeId(inst, 3);
146
554
      const uint32_t offset_type = _.GetOperandTypeId(inst, 4);
147
554
      const uint32_t count_type = _.GetOperandTypeId(inst, 5);
148
149
554
      if (spv_result_t error = ValidateBaseType(_, inst, base_type)) {
150
5
        return error;
151
5
      }
152
153
549
      if (insert_type != result_type)
154
4
        return _.diag(SPV_ERROR_INVALID_DATA, inst)
155
4
               << "Expected Insert Type to be equal to Result Type: "
156
4
               << spvOpcodeString(opcode);
157
158
545
      if (!offset_type || !_.IsIntScalarType(offset_type))
159
3
        return _.diag(SPV_ERROR_INVALID_DATA, inst)
160
3
               << "Expected Offset Type to be int scalar: "
161
3
               << spvOpcodeString(opcode);
162
163
542
      if (!count_type || !_.IsIntScalarType(count_type))
164
4
        return _.diag(SPV_ERROR_INVALID_DATA, inst)
165
4
               << "Expected Count Type to be int scalar: "
166
4
               << spvOpcodeString(opcode);
167
538
      break;
168
542
    }
169
170
587
    case spv::Op::OpBitFieldSExtract:
171
827
    case spv::Op::OpBitFieldUExtract: {
172
827
      const uint32_t base_type = _.GetOperandTypeId(inst, 2);
173
827
      const uint32_t offset_type = _.GetOperandTypeId(inst, 3);
174
827
      const uint32_t count_type = _.GetOperandTypeId(inst, 4);
175
176
827
      if (spv_result_t error = ValidateBaseType(_, inst, base_type)) {
177
7
        return error;
178
7
      }
179
180
820
      if (!offset_type || !_.IsIntScalarType(offset_type))
181
3
        return _.diag(SPV_ERROR_INVALID_DATA, inst)
182
3
               << "Expected Offset Type to be int scalar: "
183
3
               << spvOpcodeString(opcode);
184
185
817
      if (!count_type || !_.IsIntScalarType(count_type))
186
3
        return _.diag(SPV_ERROR_INVALID_DATA, inst)
187
3
               << "Expected Count Type to be int scalar: "
188
3
               << spvOpcodeString(opcode);
189
814
      break;
190
817
    }
191
192
814
    case spv::Op::OpBitReverse: {
193
318
      const uint32_t base_type = _.GetOperandTypeId(inst, 2);
194
195
318
      if (spv_result_t error = ValidateBaseType(_, inst, base_type)) {
196
5
        return error;
197
5
      }
198
199
313
      break;
200
318
    }
201
202
812
    case spv::Op::OpBitCount: {
203
812
      if (!_.IsIntScalarType(result_type) && !_.IsIntVectorType(result_type))
204
3
        return _.diag(SPV_ERROR_INVALID_DATA, inst)
205
3
               << "Expected int scalar or vector type as Result Type: "
206
3
               << spvOpcodeString(opcode);
207
208
809
      const uint32_t base_type = _.GetOperandTypeId(inst, 2);
209
210
809
      if (spv_result_t error = ValidateBaseType(_, inst, base_type)) {
211
4
        return error;
212
4
      }
213
214
805
      const uint32_t base_dimension = _.GetDimension(base_type);
215
805
      const uint32_t result_dimension = _.GetDimension(result_type);
216
217
805
      if (base_dimension != result_dimension)
218
0
        return _.diag(SPV_ERROR_INVALID_DATA, inst)
219
0
               << "Expected Base dimension to be equal to Result Type "
220
0
                  "dimension: "
221
0
               << spvOpcodeString(opcode);
222
805
      break;
223
805
    }
224
225
2.66M
    default:
226
2.66M
      break;
227
2.69M
  }
228
229
2.69M
  return SPV_SUCCESS;
230
2.69M
}
231
232
}  // namespace val
233
}  // namespace spvtools