/src/spirv-tools/source/val/validate_ray_tracing.cpp
Line | Count | Source (jump to first uncovered line) |
1 | | // Copyright (c) 2022 The Khronos Group Inc. |
2 | | // |
3 | | // Licensed under the Apache License, Version 2.0 (the "License"); |
4 | | // you may not use this file except in compliance with the License. |
5 | | // You may obtain a copy of the License at |
6 | | // |
7 | | // http://www.apache.org/licenses/LICENSE-2.0 |
8 | | // |
9 | | // Unless required by applicable law or agreed to in writing, software |
10 | | // distributed under the License is distributed on an "AS IS" BASIS, |
11 | | // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. |
12 | | // See the License for the specific language governing permissions and |
13 | | // limitations under the License. |
14 | | |
15 | | // Validates ray tracing instructions from SPV_KHR_ray_tracing |
16 | | |
17 | | #include "source/opcode.h" |
18 | | #include "source/val/instruction.h" |
19 | | #include "source/val/validate.h" |
20 | | #include "source/val/validation_state.h" |
21 | | |
22 | | namespace spvtools { |
23 | | namespace val { |
24 | | |
25 | 11.3M | spv_result_t RayTracingPass(ValidationState_t& _, const Instruction* inst) { |
26 | 11.3M | const spv::Op opcode = inst->opcode(); |
27 | 11.3M | const uint32_t result_type = inst->type_id(); |
28 | | |
29 | 11.3M | switch (opcode) { |
30 | 0 | case spv::Op::OpTraceRayKHR: { |
31 | 0 | _.function(inst->function()->id()) |
32 | 0 | ->RegisterExecutionModelLimitation( |
33 | 0 | [](spv::ExecutionModel model, std::string* message) { |
34 | 0 | if (model != spv::ExecutionModel::RayGenerationKHR && |
35 | 0 | model != spv::ExecutionModel::ClosestHitKHR && |
36 | 0 | model != spv::ExecutionModel::MissKHR) { |
37 | 0 | if (message) { |
38 | 0 | *message = |
39 | 0 | "OpTraceRayKHR requires RayGenerationKHR, " |
40 | 0 | "ClosestHitKHR and MissKHR execution models"; |
41 | 0 | } |
42 | 0 | return false; |
43 | 0 | } |
44 | 0 | return true; |
45 | 0 | }); |
46 | |
|
47 | 0 | if (_.GetIdOpcode(_.GetOperandTypeId(inst, 0)) != |
48 | 0 | spv::Op::OpTypeAccelerationStructureKHR) { |
49 | 0 | return _.diag(SPV_ERROR_INVALID_DATA, inst) |
50 | 0 | << "Expected Acceleration Structure to be of type " |
51 | 0 | "OpTypeAccelerationStructureKHR"; |
52 | 0 | } |
53 | | |
54 | 0 | const uint32_t ray_flags = _.GetOperandTypeId(inst, 1); |
55 | 0 | if (!_.IsIntScalarType(ray_flags) || _.GetBitWidth(ray_flags) != 32) { |
56 | 0 | return _.diag(SPV_ERROR_INVALID_DATA, inst) |
57 | 0 | << "Ray Flags must be a 32-bit int scalar"; |
58 | 0 | } |
59 | | |
60 | 0 | const uint32_t cull_mask = _.GetOperandTypeId(inst, 2); |
61 | 0 | if (!_.IsIntScalarType(cull_mask) || _.GetBitWidth(cull_mask) != 32) { |
62 | 0 | return _.diag(SPV_ERROR_INVALID_DATA, inst) |
63 | 0 | << "Cull Mask must be a 32-bit int scalar"; |
64 | 0 | } |
65 | | |
66 | 0 | const uint32_t sbt_offset = _.GetOperandTypeId(inst, 3); |
67 | 0 | if (!_.IsIntScalarType(sbt_offset) || _.GetBitWidth(sbt_offset) != 32) { |
68 | 0 | return _.diag(SPV_ERROR_INVALID_DATA, inst) |
69 | 0 | << "SBT Offset must be a 32-bit int scalar"; |
70 | 0 | } |
71 | | |
72 | 0 | const uint32_t sbt_stride = _.GetOperandTypeId(inst, 4); |
73 | 0 | if (!_.IsIntScalarType(sbt_stride) || _.GetBitWidth(sbt_stride) != 32) { |
74 | 0 | return _.diag(SPV_ERROR_INVALID_DATA, inst) |
75 | 0 | << "SBT Stride must be a 32-bit int scalar"; |
76 | 0 | } |
77 | | |
78 | 0 | const uint32_t miss_index = _.GetOperandTypeId(inst, 5); |
79 | 0 | if (!_.IsIntScalarType(miss_index) || _.GetBitWidth(miss_index) != 32) { |
80 | 0 | return _.diag(SPV_ERROR_INVALID_DATA, inst) |
81 | 0 | << "Miss Index must be a 32-bit int scalar"; |
82 | 0 | } |
83 | | |
84 | 0 | const uint32_t ray_origin = _.GetOperandTypeId(inst, 6); |
85 | 0 | if (!_.IsFloatVectorType(ray_origin) || _.GetDimension(ray_origin) != 3 || |
86 | 0 | _.GetBitWidth(ray_origin) != 32) { |
87 | 0 | return _.diag(SPV_ERROR_INVALID_DATA, inst) |
88 | 0 | << "Ray Origin must be a 32-bit float 3-component vector"; |
89 | 0 | } |
90 | | |
91 | 0 | const uint32_t ray_tmin = _.GetOperandTypeId(inst, 7); |
92 | 0 | if (!_.IsFloatScalarType(ray_tmin) || _.GetBitWidth(ray_tmin) != 32) { |
93 | 0 | return _.diag(SPV_ERROR_INVALID_DATA, inst) |
94 | 0 | << "Ray TMin must be a 32-bit float scalar"; |
95 | 0 | } |
96 | | |
97 | 0 | const uint32_t ray_direction = _.GetOperandTypeId(inst, 8); |
98 | 0 | if (!_.IsFloatVectorType(ray_direction) || |
99 | 0 | _.GetDimension(ray_direction) != 3 || |
100 | 0 | _.GetBitWidth(ray_direction) != 32) { |
101 | 0 | return _.diag(SPV_ERROR_INVALID_DATA, inst) |
102 | 0 | << "Ray Direction must be a 32-bit float 3-component vector"; |
103 | 0 | } |
104 | | |
105 | 0 | const uint32_t ray_tmax = _.GetOperandTypeId(inst, 9); |
106 | 0 | if (!_.IsFloatScalarType(ray_tmax) || _.GetBitWidth(ray_tmax) != 32) { |
107 | 0 | return _.diag(SPV_ERROR_INVALID_DATA, inst) |
108 | 0 | << "Ray TMax must be a 32-bit float scalar"; |
109 | 0 | } |
110 | | |
111 | 0 | const Instruction* payload = _.FindDef(inst->GetOperandAs<uint32_t>(10)); |
112 | 0 | if (payload->opcode() != spv::Op::OpVariable) { |
113 | 0 | return _.diag(SPV_ERROR_INVALID_DATA, inst) |
114 | 0 | << "Payload must be the result of a OpVariable"; |
115 | 0 | } else if (payload->GetOperandAs<spv::StorageClass>(2) != |
116 | 0 | spv::StorageClass::RayPayloadKHR && |
117 | 0 | payload->GetOperandAs<spv::StorageClass>(2) != |
118 | 0 | spv::StorageClass::IncomingRayPayloadKHR) { |
119 | 0 | return _.diag(SPV_ERROR_INVALID_DATA, inst) |
120 | 0 | << "Payload must have storage class RayPayloadKHR or " |
121 | 0 | "IncomingRayPayloadKHR"; |
122 | 0 | } |
123 | 0 | break; |
124 | 0 | } |
125 | | |
126 | 0 | case spv::Op::OpReportIntersectionKHR: { |
127 | 0 | _.function(inst->function()->id()) |
128 | 0 | ->RegisterExecutionModelLimitation( |
129 | 0 | [](spv::ExecutionModel model, std::string* message) { |
130 | 0 | if (model != spv::ExecutionModel::IntersectionKHR) { |
131 | 0 | if (message) { |
132 | 0 | *message = |
133 | 0 | "OpReportIntersectionKHR requires IntersectionKHR " |
134 | 0 | "execution model"; |
135 | 0 | } |
136 | 0 | return false; |
137 | 0 | } |
138 | 0 | return true; |
139 | 0 | }); |
140 | |
|
141 | 0 | if (!_.IsBoolScalarType(result_type)) { |
142 | 0 | return _.diag(SPV_ERROR_INVALID_DATA, inst) |
143 | 0 | << "expected Result Type to be bool scalar type"; |
144 | 0 | } |
145 | | |
146 | 0 | const uint32_t hit = _.GetOperandTypeId(inst, 2); |
147 | 0 | if (!_.IsFloatScalarType(hit) || _.GetBitWidth(hit) != 32) { |
148 | 0 | return _.diag(SPV_ERROR_INVALID_DATA, inst) |
149 | 0 | << "Hit must be a 32-bit int scalar"; |
150 | 0 | } |
151 | | |
152 | 0 | const uint32_t hit_kind = _.GetOperandTypeId(inst, 3); |
153 | 0 | if (!_.IsUnsignedIntScalarType(hit_kind) || |
154 | 0 | _.GetBitWidth(hit_kind) != 32) { |
155 | 0 | return _.diag(SPV_ERROR_INVALID_DATA, inst) |
156 | 0 | << "Hit Kind must be a 32-bit unsigned int scalar"; |
157 | 0 | } |
158 | 0 | break; |
159 | 0 | } |
160 | | |
161 | 0 | case spv::Op::OpExecuteCallableKHR: { |
162 | 0 | _.function(inst->function()->id()) |
163 | 0 | ->RegisterExecutionModelLimitation([](spv::ExecutionModel model, |
164 | 0 | std::string* message) { |
165 | 0 | if (model != spv::ExecutionModel::RayGenerationKHR && |
166 | 0 | model != spv::ExecutionModel::ClosestHitKHR && |
167 | 0 | model != spv::ExecutionModel::MissKHR && |
168 | 0 | model != spv::ExecutionModel::CallableKHR) { |
169 | 0 | if (message) { |
170 | 0 | *message = |
171 | 0 | "OpExecuteCallableKHR requires RayGenerationKHR, " |
172 | 0 | "ClosestHitKHR, MissKHR and CallableKHR execution models"; |
173 | 0 | } |
174 | 0 | return false; |
175 | 0 | } |
176 | 0 | return true; |
177 | 0 | }); |
178 | |
|
179 | 0 | const uint32_t sbt_index = _.GetOperandTypeId(inst, 0); |
180 | 0 | if (!_.IsUnsignedIntScalarType(sbt_index) || |
181 | 0 | _.GetBitWidth(sbt_index) != 32) { |
182 | 0 | return _.diag(SPV_ERROR_INVALID_DATA, inst) |
183 | 0 | << "SBT Index must be a 32-bit unsigned int scalar"; |
184 | 0 | } |
185 | | |
186 | 0 | const auto callable_data = _.FindDef(inst->GetOperandAs<uint32_t>(1)); |
187 | 0 | if (callable_data->opcode() != spv::Op::OpVariable) { |
188 | 0 | return _.diag(SPV_ERROR_INVALID_DATA, inst) |
189 | 0 | << "Callable Data must be the result of a OpVariable"; |
190 | 0 | } else if (callable_data->GetOperandAs<spv::StorageClass>(2) != |
191 | 0 | spv::StorageClass::CallableDataKHR && |
192 | 0 | callable_data->GetOperandAs<spv::StorageClass>(2) != |
193 | 0 | spv::StorageClass::IncomingCallableDataKHR) { |
194 | 0 | return _.diag(SPV_ERROR_INVALID_DATA, inst) |
195 | 0 | << "Callable Data must have storage class CallableDataKHR or " |
196 | 0 | "IncomingCallableDataKHR"; |
197 | 0 | } |
198 | | |
199 | 0 | break; |
200 | 0 | } |
201 | | |
202 | 11.3M | default: |
203 | 11.3M | break; |
204 | 11.3M | } |
205 | | |
206 | 11.3M | return SPV_SUCCESS; |
207 | 11.3M | } |
208 | | } // namespace val |
209 | | } // namespace spvtools |