Coverage Report

Created: 2025-07-23 06:18

/src/spirv-tools/source/val/validate_derivatives.cpp
Line
Count
Source (jump to first uncovered line)
1
// Copyright (c) 2017 Google Inc.
2
//
3
// Licensed under the Apache License, Version 2.0 (the "License");
4
// you may not use this file except in compliance with the License.
5
// You may obtain a copy of the License at
6
//
7
//     http://www.apache.org/licenses/LICENSE-2.0
8
//
9
// Unless required by applicable law or agreed to in writing, software
10
// distributed under the License is distributed on an "AS IS" BASIS,
11
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
// See the License for the specific language governing permissions and
13
// limitations under the License.
14
15
// Validates correctness of derivative SPIR-V instructions.
16
17
#include <string>
18
19
#include "source/opcode.h"
20
#include "source/val/instruction.h"
21
#include "source/val/validate.h"
22
#include "source/val/validation_state.h"
23
24
namespace spvtools {
25
namespace val {
26
27
// Validates correctness of derivative instructions.
28
10.9M
spv_result_t DerivativesPass(ValidationState_t& _, const Instruction* inst) {
29
10.9M
  const spv::Op opcode = inst->opcode();
30
10.9M
  const uint32_t result_type = inst->type_id();
31
32
10.9M
  switch (opcode) {
33
324
    case spv::Op::OpDPdx:
34
685
    case spv::Op::OpDPdy:
35
738
    case spv::Op::OpFwidth:
36
743
    case spv::Op::OpDPdxFine:
37
756
    case spv::Op::OpDPdyFine:
38
760
    case spv::Op::OpFwidthFine:
39
766
    case spv::Op::OpDPdxCoarse:
40
773
    case spv::Op::OpDPdyCoarse:
41
777
    case spv::Op::OpFwidthCoarse: {
42
777
      if (!_.IsFloatScalarOrVectorType(result_type)) {
43
7
        return _.diag(SPV_ERROR_INVALID_DATA, inst)
44
7
               << "Expected Result Type to be float scalar or vector type: "
45
7
               << spvOpcodeString(opcode);
46
7
      }
47
770
      if (!_.ContainsSizedIntOrFloatType(result_type, spv::Op::OpTypeFloat,
48
770
                                         32)) {
49
0
        return _.diag(SPV_ERROR_INVALID_DATA, inst)
50
0
               << "Result type component width must be 32 bits";
51
0
      }
52
53
770
      const uint32_t p_type = _.GetOperandTypeId(inst, 2);
54
770
      if (p_type != result_type) {
55
5
        return _.diag(SPV_ERROR_INVALID_DATA, inst)
56
5
               << "Expected P type and Result Type to be the same: "
57
5
               << spvOpcodeString(opcode);
58
5
      }
59
765
      _.function(inst->function()->id())
60
765
          ->RegisterExecutionModelLimitation([opcode](spv::ExecutionModel model,
61
765
                                                      std::string* message) {
62
489
            if (model != spv::ExecutionModel::Fragment &&
63
489
                model != spv::ExecutionModel::GLCompute &&
64
489
                model != spv::ExecutionModel::MeshEXT &&
65
489
                model != spv::ExecutionModel::TaskEXT) {
66
2
              if (message) {
67
2
                *message =
68
2
                    std::string(
69
2
                        "Derivative instructions require Fragment, GLCompute, "
70
2
                        "MeshEXT or TaskEXT execution model: ") +
71
2
                    spvOpcodeString(opcode);
72
2
              }
73
2
              return false;
74
2
            }
75
487
            return true;
76
489
          });
77
765
      _.function(inst->function()->id())
78
765
          ->RegisterLimitation([opcode](const ValidationState_t& state,
79
765
                                        const Function* entry_point,
80
765
                                        std::string* message) {
81
487
            const auto* models = state.GetExecutionModels(entry_point->id());
82
487
            const auto* modes = state.GetExecutionModes(entry_point->id());
83
487
            if (models &&
84
487
                (models->find(spv::ExecutionModel::GLCompute) !=
85
487
                     models->end() ||
86
487
                 models->find(spv::ExecutionModel::MeshEXT) != models->end() ||
87
487
                 models->find(spv::ExecutionModel::TaskEXT) != models->end()) &&
88
487
                (!modes ||
89
8
                 (modes->find(spv::ExecutionMode::DerivativeGroupLinearKHR) ==
90
3
                      modes->end() &&
91
3
                  modes->find(spv::ExecutionMode::DerivativeGroupQuadsKHR) ==
92
8
                      modes->end()))) {
93
8
              if (message) {
94
8
                *message =
95
8
                    std::string(
96
8
                        "Derivative instructions require "
97
8
                        "DerivativeGroupQuadsKHR "
98
8
                        "or DerivativeGroupLinearKHR execution mode for "
99
8
                        "GLCompute, MeshEXT or TaskEXT execution model: ") +
100
8
                    spvOpcodeString(opcode);
101
8
              }
102
8
              return false;
103
8
            }
104
479
            return true;
105
487
          });
106
765
      break;
107
770
    }
108
109
10.9M
    default:
110
10.9M
      break;
111
10.9M
  }
112
113
10.9M
  return SPV_SUCCESS;
114
10.9M
}
115
116
}  // namespace val
117
}  // namespace spvtools