Coverage Report

Created: 2024-12-11 06:33

/src/spirv-tools/source/val/validate_tensor_layout.cpp
Line
Count
Source (jump to first uncovered line)
1
// Copyright (c) 2024 NVIDIA Corporation
2
//
3
// Licensed under the Apache License, Version 2.0 (the "License");
4
// you may not use this file except in compliance with the License.
5
// You may obtain a copy of the License at
6
//
7
//     http://www.apache.org/licenses/LICENSE-2.0
8
//
9
// Unless required by applicable law or agreed to in writing, software
10
// distributed under the License is distributed on an "AS IS" BASIS,
11
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
// See the License for the specific language governing permissions and
13
// limitations under the License.
14
15
// Validate instructions that manipulate tensor layout and view objects
16
17
#include "source/opcode.h"
18
#include "source/spirv_target_env.h"
19
#include "source/val/instruction.h"
20
#include "source/val/validate.h"
21
#include "source/val/validation_state.h"
22
23
namespace spvtools {
24
namespace val {
25
namespace {
26
27
spv_result_t ValidateTensorLayoutResultTypeNV(ValidationState_t& _,
28
0
                                              const Instruction* inst) {
29
0
  const auto result_type_index = 0;
30
0
  const auto result_type_id = inst->GetOperandAs<uint32_t>(result_type_index);
31
0
  const auto result_type = _.FindDef(result_type_id);
32
33
0
  if (!result_type || spv::Op::OpTypeTensorLayoutNV != result_type->opcode()) {
34
0
    return _.diag(SPV_ERROR_INVALID_ID, inst)
35
0
           << spvOpcodeString(inst->opcode()) << " Result Type <id> "
36
0
           << _.getIdName(result_type_id) << " is not a tensor layout type.";
37
0
  }
38
0
  return SPV_SUCCESS;
39
0
}
40
41
spv_result_t ValidateTensorViewResultTypeNV(ValidationState_t& _,
42
0
                                            const Instruction* inst) {
43
0
  const auto result_type_index = 0;
44
0
  const auto result_type_id = inst->GetOperandAs<uint32_t>(result_type_index);
45
0
  const auto result_type = _.FindDef(result_type_id);
46
47
0
  if (!result_type || spv::Op::OpTypeTensorViewNV != result_type->opcode()) {
48
0
    return _.diag(SPV_ERROR_INVALID_ID, inst)
49
0
           << spvOpcodeString(inst->opcode()) << " Result Type <id> "
50
0
           << _.getIdName(result_type_id) << " is not a tensor view type.";
51
0
  }
52
0
  return SPV_SUCCESS;
53
0
}
54
55
spv_result_t ValidateCreateTensorLayoutNV(ValidationState_t& _,
56
0
                                          const Instruction* inst) {
57
0
  if (auto error = ValidateTensorLayoutResultTypeNV(_, inst)) return error;
58
59
0
  return SPV_SUCCESS;
60
0
}
61
62
spv_result_t ValidateCreateTensorViewNV(ValidationState_t& _,
63
0
                                        const Instruction* inst) {
64
0
  if (auto error = ValidateTensorViewResultTypeNV(_, inst)) return error;
65
66
0
  return SPV_SUCCESS;
67
0
}
68
69
enum ExpectedNumValues {
70
  DIM,
71
  DIMx2,
72
  ONE,
73
  FOUR,
74
};
75
76
spv_result_t ValidateTensorTypeWithDimValuesNV(ValidationState_t& _,
77
                                               const Instruction* inst,
78
                                               ExpectedNumValues expected,
79
0
                                               bool is_view) {
80
0
  std::string type_str;
81
0
  if (is_view) {
82
0
    if (auto error = ValidateTensorViewResultTypeNV(_, inst)) return error;
83
0
    type_str = "TensorView";
84
0
  } else {
85
0
    if (auto error = ValidateTensorLayoutResultTypeNV(_, inst)) return error;
86
0
    type_str = "TensorLayout";
87
0
  }
88
89
0
  const auto result_type_id = inst->GetOperandAs<uint32_t>(0);
90
0
  const auto tensor_id = inst->GetOperandAs<uint32_t>(2);
91
0
  const auto tensor = _.FindDef(tensor_id);
92
0
  if (!tensor || result_type_id != tensor->type_id()) {
93
0
    return _.diag(SPV_ERROR_INVALID_ID, inst)
94
0
           << spvOpcodeString(inst->opcode()) << " Result Type <id> "
95
0
           << _.getIdName(result_type_id) << " does not match " << type_str
96
0
           << " type.";
97
0
  }
98
99
0
  const auto num_values = inst->operands().size() - 3;
100
101
0
  const auto result_type = _.FindDef(result_type_id);
102
0
  const auto dim_index = 1;
103
0
  const auto dim_id = result_type->GetOperandAs<uint32_t>(dim_index);
104
0
  uint64_t dim_value;
105
0
  if (_.EvalConstantValUint64(dim_id, &dim_value)) {
106
0
    uint64_t expected_num_values = 0;
107
0
    switch (expected) {
108
0
      case DIM:
109
0
        expected_num_values = dim_value;
110
0
        break;
111
0
      case DIMx2:
112
0
        expected_num_values = dim_value * 2;
113
0
        break;
114
0
      case ONE:
115
0
        expected_num_values = 1;
116
0
        break;
117
0
      case FOUR:
118
0
        expected_num_values = 4;
119
0
        break;
120
0
    }
121
122
0
    if (num_values != expected_num_values) {
123
0
      return _.diag(SPV_ERROR_INVALID_ID, inst)
124
0
             << spvOpcodeString(inst->opcode())
125
0
             << " unexpected number of operands.";
126
0
    }
127
0
  }
128
129
0
  for (uint32_t i = 0; i < num_values; ++i) {
130
0
    const auto val_id = inst->GetOperandAs<uint32_t>(i + 3);
131
0
    const auto val = _.FindDef(val_id);
132
0
    if (!val || !_.IsIntScalarType(val->type_id()) ||
133
0
        _.GetBitWidth(val->type_id()) != 32) {
134
0
      return _.diag(SPV_ERROR_INVALID_ID, inst)
135
0
             << spvOpcodeString(inst->opcode()) << " operand <id> "
136
0
             << _.getIdName(val_id) << " is not a 32-bit integer.";
137
0
    }
138
0
  }
139
140
0
  return SPV_SUCCESS;
141
0
}
142
143
}  // namespace
144
145
2.95M
spv_result_t TensorLayoutPass(ValidationState_t& _, const Instruction* inst) {
146
2.95M
  switch (inst->opcode()) {
147
0
    case spv::Op::OpCreateTensorLayoutNV:
148
0
      if (auto error = ValidateCreateTensorLayoutNV(_, inst)) return error;
149
0
      break;
150
0
    case spv::Op::OpCreateTensorViewNV:
151
0
      if (auto error = ValidateCreateTensorViewNV(_, inst)) return error;
152
0
      break;
153
0
    case spv::Op::OpTensorLayoutSetBlockSizeNV:
154
0
    case spv::Op::OpTensorLayoutSetDimensionNV:
155
0
    case spv::Op::OpTensorLayoutSetStrideNV:
156
0
      if (auto error = ValidateTensorTypeWithDimValuesNV(_, inst, DIM, false))
157
0
        return error;
158
0
      break;
159
0
    case spv::Op::OpTensorLayoutSliceNV:
160
0
      if (auto error = ValidateTensorTypeWithDimValuesNV(_, inst, DIMx2, false))
161
0
        return error;
162
0
      break;
163
0
    case spv::Op::OpTensorLayoutSetClampValueNV:
164
0
      if (auto error = ValidateTensorTypeWithDimValuesNV(_, inst, ONE, false))
165
0
        return error;
166
0
      break;
167
0
    case spv::Op::OpTensorViewSetDimensionNV:
168
0
    case spv::Op::OpTensorViewSetStrideNV:
169
0
      if (auto error = ValidateTensorTypeWithDimValuesNV(_, inst, DIM, true))
170
0
        return error;
171
0
      break;
172
0
    case spv::Op::OpTensorViewSetClipNV:
173
0
      if (auto error = ValidateTensorTypeWithDimValuesNV(_, inst, FOUR, true))
174
0
        return error;
175
0
      break;
176
2.95M
    default:
177
2.95M
      break;
178
2.95M
  }
179
180
2.95M
  return SPV_SUCCESS;
181
2.95M
}
182
183
}  // namespace val
184
}  // namespace spvtools