/src/spirv-tools/source/val/validate_derivatives.cpp
Line | Count | Source (jump to first uncovered line) |
1 | | // Copyright (c) 2017 Google Inc. |
2 | | // |
3 | | // Licensed under the Apache License, Version 2.0 (the "License"); |
4 | | // you may not use this file except in compliance with the License. |
5 | | // You may obtain a copy of the License at |
6 | | // |
7 | | // http://www.apache.org/licenses/LICENSE-2.0 |
8 | | // |
9 | | // Unless required by applicable law or agreed to in writing, software |
10 | | // distributed under the License is distributed on an "AS IS" BASIS, |
11 | | // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. |
12 | | // See the License for the specific language governing permissions and |
13 | | // limitations under the License. |
14 | | |
15 | | // Validates correctness of derivative SPIR-V instructions. |
16 | | |
17 | | #include <string> |
18 | | |
19 | | #include "source/opcode.h" |
20 | | #include "source/val/instruction.h" |
21 | | #include "source/val/validate.h" |
22 | | #include "source/val/validation_state.h" |
23 | | |
24 | | namespace spvtools { |
25 | | namespace val { |
26 | | |
27 | | // Validates correctness of derivative instructions. |
28 | 10.9M | spv_result_t DerivativesPass(ValidationState_t& _, const Instruction* inst) { |
29 | 10.9M | const spv::Op opcode = inst->opcode(); |
30 | 10.9M | const uint32_t result_type = inst->type_id(); |
31 | | |
32 | 10.9M | switch (opcode) { |
33 | 324 | case spv::Op::OpDPdx: |
34 | 685 | case spv::Op::OpDPdy: |
35 | 738 | case spv::Op::OpFwidth: |
36 | 743 | case spv::Op::OpDPdxFine: |
37 | 756 | case spv::Op::OpDPdyFine: |
38 | 760 | case spv::Op::OpFwidthFine: |
39 | 766 | case spv::Op::OpDPdxCoarse: |
40 | 773 | case spv::Op::OpDPdyCoarse: |
41 | 777 | case spv::Op::OpFwidthCoarse: { |
42 | 777 | if (!_.IsFloatScalarOrVectorType(result_type)) { |
43 | 7 | return _.diag(SPV_ERROR_INVALID_DATA, inst) |
44 | 7 | << "Expected Result Type to be float scalar or vector type: " |
45 | 7 | << spvOpcodeString(opcode); |
46 | 7 | } |
47 | 770 | if (!_.ContainsSizedIntOrFloatType(result_type, spv::Op::OpTypeFloat, |
48 | 770 | 32)) { |
49 | 0 | return _.diag(SPV_ERROR_INVALID_DATA, inst) |
50 | 0 | << "Result type component width must be 32 bits"; |
51 | 0 | } |
52 | | |
53 | 770 | const uint32_t p_type = _.GetOperandTypeId(inst, 2); |
54 | 770 | if (p_type != result_type) { |
55 | 5 | return _.diag(SPV_ERROR_INVALID_DATA, inst) |
56 | 5 | << "Expected P type and Result Type to be the same: " |
57 | 5 | << spvOpcodeString(opcode); |
58 | 5 | } |
59 | 765 | _.function(inst->function()->id()) |
60 | 765 | ->RegisterExecutionModelLimitation([opcode](spv::ExecutionModel model, |
61 | 765 | std::string* message) { |
62 | 489 | if (model != spv::ExecutionModel::Fragment && |
63 | 489 | model != spv::ExecutionModel::GLCompute && |
64 | 489 | model != spv::ExecutionModel::MeshEXT && |
65 | 489 | model != spv::ExecutionModel::TaskEXT) { |
66 | 2 | if (message) { |
67 | 2 | *message = |
68 | 2 | std::string( |
69 | 2 | "Derivative instructions require Fragment, GLCompute, " |
70 | 2 | "MeshEXT or TaskEXT execution model: ") + |
71 | 2 | spvOpcodeString(opcode); |
72 | 2 | } |
73 | 2 | return false; |
74 | 2 | } |
75 | 487 | return true; |
76 | 489 | }); |
77 | 765 | _.function(inst->function()->id()) |
78 | 765 | ->RegisterLimitation([opcode](const ValidationState_t& state, |
79 | 765 | const Function* entry_point, |
80 | 765 | std::string* message) { |
81 | 487 | const auto* models = state.GetExecutionModels(entry_point->id()); |
82 | 487 | const auto* modes = state.GetExecutionModes(entry_point->id()); |
83 | 487 | if (models && |
84 | 487 | (models->find(spv::ExecutionModel::GLCompute) != |
85 | 487 | models->end() || |
86 | 487 | models->find(spv::ExecutionModel::MeshEXT) != models->end() || |
87 | 487 | models->find(spv::ExecutionModel::TaskEXT) != models->end()) && |
88 | 487 | (!modes || |
89 | 8 | (modes->find(spv::ExecutionMode::DerivativeGroupLinearKHR) == |
90 | 3 | modes->end() && |
91 | 3 | modes->find(spv::ExecutionMode::DerivativeGroupQuadsKHR) == |
92 | 8 | modes->end()))) { |
93 | 8 | if (message) { |
94 | 8 | *message = |
95 | 8 | std::string( |
96 | 8 | "Derivative instructions require " |
97 | 8 | "DerivativeGroupQuadsKHR " |
98 | 8 | "or DerivativeGroupLinearKHR execution mode for " |
99 | 8 | "GLCompute, MeshEXT or TaskEXT execution model: ") + |
100 | 8 | spvOpcodeString(opcode); |
101 | 8 | } |
102 | 8 | return false; |
103 | 8 | } |
104 | 479 | return true; |
105 | 487 | }); |
106 | 765 | break; |
107 | 770 | } |
108 | | |
109 | 10.9M | default: |
110 | 10.9M | break; |
111 | 10.9M | } |
112 | | |
113 | 10.9M | return SPV_SUCCESS; |
114 | 10.9M | } |
115 | | |
116 | | } // namespace val |
117 | | } // namespace spvtools |