Coverage Report

Created: 2026-04-12 06:23

next uncovered line (L), next uncovered region (R), next uncovered branch (B)
/src/shaderc/third_party/glslang/SPIRV/SpvPostProcess.cpp
Line
Count
Source
1
//
2
// Copyright (C) 2018 Google, Inc.
3
//
4
// All rights reserved.
5
//
6
// Redistribution and use in source and binary forms, with or without
7
// modification, are permitted provided that the following conditions
8
// are met:
9
//
10
//    Redistributions of source code must retain the above copyright
11
//    notice, this list of conditions and the following disclaimer.
12
//
13
//    Redistributions in binary form must reproduce the above
14
//    copyright notice, this list of conditions and the following
15
//    disclaimer in the documentation and/or other materials provided
16
//    with the distribution.
17
//
18
//    Neither the name of 3Dlabs Inc. Ltd. nor the names of its
19
//    contributors may be used to endorse or promote products derived
20
//    from this software without specific prior written permission.
21
//
22
// THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS
23
// "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT
24
// LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS
25
// FOR A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE
26
// COPYRIGHT HOLDERS OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT,
27
// INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING,
28
// BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES;
29
// LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
30
// CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT
31
// LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN
32
// ANY WAY OUT OF THE USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE
33
// POSSIBILITY OF SUCH DAMAGE.
34
35
//
36
// Post-processing for SPIR-V IR, in internal form, not standard binary form.
37
//
38
39
#include <cassert>
40
#include <cstdlib>
41
42
#include <unordered_map>
43
#include <unordered_set>
44
#include <algorithm>
45
46
#include "SPIRV/spvIR.h"
47
#include "SpvBuilder.h"
48
#include "spirv.hpp11"
49
#include "spvUtil.h"
50
51
namespace spv {
52
    #include "GLSL.std.450.h"
53
    #include "GLSL.ext.KHR.h"
54
    #include "GLSL.ext.EXT.h"
55
    #include "GLSL.ext.AMD.h"
56
    #include "GLSL.ext.NV.h"
57
    #include "GLSL.ext.ARM.h"
58
    #include "GLSL.ext.QCOM.h"
59
}
60
61
namespace spv {
62
63
// Hook to visit each operand type and result type of an instruction.
64
// Will be called multiple times for one instruction, once for each typed
65
// operand and the result.
66
void Builder::postProcessType(const Instruction& inst, Id typeId)
67
128k
{
68
    // Characterize the type being questioned
69
128k
    Op basicTypeOp = getMostBasicTypeClass(typeId);
70
128k
    int width = 0;
71
128k
    if (basicTypeOp == Op::OpTypeFloat || basicTypeOp == Op::OpTypeInt)
72
106k
        width = getScalarTypeWidth(typeId);
73
74
    // Do opcode-specific checks
75
128k
    switch (inst.getOpCode()) {
76
30.6k
    case Op::OpLoad:
77
46.4k
    case Op::OpStore:
78
46.4k
        if (basicTypeOp == Op::OpTypeStruct) {
79
926
            if (containsType(typeId, Op::OpTypeInt, 8))
80
0
                addCapability(Capability::Int8);
81
926
            if (containsType(typeId, Op::OpTypeInt, 16))
82
0
                addCapability(Capability::Int16);
83
926
            if (containsType(typeId, Op::OpTypeFloat, 16))
84
0
                addCapability(Capability::Float16);
85
45.4k
        } else {
86
45.4k
            StorageClass storageClass = StorageClass::Max;
87
45.4k
            if (module.getInstruction(inst.getIdOperand(0))->getOpCode() != Op::OpUntypedAccessChainKHR) {
88
45.4k
                storageClass = getStorageClass(inst.getIdOperand(0));
89
45.4k
            }
90
45.4k
            if (width == 8) {
91
562
                switch (storageClass) {
92
0
                case StorageClass::PhysicalStorageBufferEXT:
93
328
                case StorageClass::Uniform:
94
328
                case StorageClass::StorageBuffer:
95
328
                case StorageClass::PushConstant:
96
328
                    break;
97
234
                default:
98
234
                    addCapability(Capability::Int8);
99
234
                    break;
100
562
                }
101
44.9k
            } else if (width == 16) {
102
15.3k
                switch (storageClass) {
103
0
                case StorageClass::PhysicalStorageBufferEXT:
104
5.28k
                case StorageClass::Uniform:
105
5.28k
                case StorageClass::StorageBuffer:
106
5.28k
                case StorageClass::PushConstant:
107
6.40k
                case StorageClass::Input:
108
6.40k
                case StorageClass::Output:
109
6.40k
                    break;
110
8.95k
                default:
111
8.95k
                    if (basicTypeOp == Op::OpTypeInt)
112
3.95k
                        addCapability(Capability::Int16);
113
8.95k
                    if (basicTypeOp == Op::OpTypeFloat)
114
4.99k
                        addCapability(Capability::Float16);
115
8.95k
                    break;
116
15.3k
                }
117
15.3k
            }
118
45.4k
        }
119
46.4k
        break;
120
46.4k
    case Op::OpCopyObject:
121
0
        break;
122
144
    case Op::OpFConvert:
123
586
    case Op::OpSConvert:
124
882
    case Op::OpUConvert:
125
        // Look for any 8/16-bit storage capabilities. If there are none, assume that
126
        // the convert instruction requires the Float16/Int8/16 capability.
127
882
        if (containsType(typeId, Op::OpTypeFloat, 16) || containsType(typeId, Op::OpTypeInt, 16)) {
128
283
            bool foundStorage = false;
129
1.33k
            for (auto it = capabilities.begin(); it != capabilities.end(); ++it) {
130
1.31k
                spv::Capability cap = *it;
131
1.31k
                if (cap == spv::Capability::StorageInputOutput16 ||
132
1.31k
                    cap == spv::Capability::StoragePushConstant16 ||
133
1.31k
                    cap == spv::Capability::StorageUniformBufferBlock16 ||
134
1.21k
                    cap == spv::Capability::StorageUniform16) {
135
265
                    foundStorage = true;
136
265
                    break;
137
265
                }
138
1.31k
            }
139
283
            if (!foundStorage) {
140
18
                if (containsType(typeId, Op::OpTypeFloat, 16))
141
18
                    addCapability(Capability::Float16);
142
18
                if (containsType(typeId, Op::OpTypeInt, 16))
143
0
                    addCapability(Capability::Int16);
144
18
            }
145
283
        }
146
882
        if (containsType(typeId, Op::OpTypeInt, 8)) {
147
77
            bool foundStorage = false;
148
199
            for (auto it = capabilities.begin(); it != capabilities.end(); ++it) {
149
190
                spv::Capability cap = *it;
150
190
                if (cap == spv::Capability::StoragePushConstant8 ||
151
190
                    cap == spv::Capability::UniformAndStorageBuffer8BitAccess ||
152
122
                    cap == spv::Capability::StorageBuffer8BitAccess) {
153
68
                    foundStorage = true;
154
68
                    break;
155
68
                }
156
190
            }
157
77
            if (!foundStorage) {
158
9
                addCapability(Capability::Int8);
159
9
            }
160
77
        }
161
882
        break;
162
911
    case Op::OpExtInst:
163
911
        switch (inst.getImmediateOperand(1)) {
164
0
        case GLSLstd450Frexp:
165
18
        case GLSLstd450FrexpStruct:
166
18
            if (getSpvVersion() < spv::Spv_1_3 && containsType(typeId, Op::OpTypeInt, 16))
167
9
                addExtension(spv::E_SPV_AMD_gpu_shader_int16);
168
18
            break;
169
0
        case GLSLstd450InterpolateAtCentroid:
170
0
        case GLSLstd450InterpolateAtSample:
171
0
        case GLSLstd450InterpolateAtOffset:
172
0
            if (getSpvVersion() < spv::Spv_1_3 && containsType(typeId, Op::OpTypeFloat, 16))
173
0
                addExtension(spv::E_SPV_AMD_gpu_shader_half_float);
174
0
            break;
175
893
        default:
176
893
            break;
177
911
        }
178
911
        break;
179
22.1k
    case Op::OpAccessChain:
180
22.1k
    case Op::OpPtrAccessChain:
181
22.1k
        if (isPointerType(typeId))
182
14.4k
            break;
183
7.75k
        if (basicTypeOp == Op::OpTypeInt) {
184
7.75k
            if (width == 16)
185
0
                addCapability(Capability::Int16);
186
7.75k
            else if (width == 8)
187
0
                addCapability(Capability::Int8);
188
7.75k
        }
189
7.75k
        break;
190
57.7k
    default:
191
57.7k
        if (basicTypeOp == Op::OpTypeInt) {
192
23.4k
            if (width == 16)
193
3.53k
                addCapability(Capability::Int16);
194
19.9k
            else if (width == 8)
195
552
                addCapability(Capability::Int8);
196
19.4k
            else if (width == 64)
197
4.38k
                addCapability(Capability::Int64);
198
34.2k
        } else if (basicTypeOp == Op::OpTypeFloat) {
199
24.4k
            if (width == 16)
200
11.3k
                addCapability(Capability::Float16);
201
13.0k
            else if (width == 64)
202
63
                addCapability(Capability::Float64);
203
24.4k
        }
204
57.7k
        break;
205
128k
    }
206
128k
}
207
208
unsigned int Builder::postProcessGetLargestScalarSize(const Instruction& type)
209
0
{
210
0
    switch (type.getOpCode()) {
211
0
    case Op::OpTypeBool:
212
0
        return 1;
213
0
    case Op::OpTypeInt:
214
0
    case Op::OpTypeFloat:
215
0
        return type.getImmediateOperand(0) / 8;
216
0
    case Op::OpTypePointer:
217
0
        return 8;
218
0
    case Op::OpTypeVector:
219
0
    case Op::OpTypeMatrix:
220
0
    case Op::OpTypeArray:
221
0
    case Op::OpTypeRuntimeArray: {
222
0
        const Instruction* elem_type = module.getInstruction(type.getIdOperand(0));
223
0
        return postProcessGetLargestScalarSize(*elem_type);
224
0
    }
225
0
    case Op::OpTypeStruct: {
226
0
        unsigned int largest = 0;
227
0
        for (int i = 0; i < type.getNumOperands(); ++i) {
228
0
            const Instruction* elem_type = module.getInstruction(type.getIdOperand(i));
229
0
            unsigned int elem_size = postProcessGetLargestScalarSize(*elem_type);
230
0
            largest = std::max(largest, elem_size);
231
0
        }
232
0
        return largest;
233
0
    }
234
0
    default:
235
0
        return 0;
236
0
    }
237
0
}
238
239
// Called for each instruction that resides in a block.
240
void Builder::postProcess(Instruction& inst)
241
56.6k
{
242
    // Add capabilities based simply on the opcode.
243
56.6k
    switch (inst.getOpCode()) {
244
318
    case Op::OpExtInst:
245
318
        switch (inst.getImmediateOperand(1)) {
246
0
        case GLSLstd450InterpolateAtCentroid:
247
0
        case GLSLstd450InterpolateAtSample:
248
0
        case GLSLstd450InterpolateAtOffset:
249
0
            addCapability(Capability::InterpolationFunction);
250
0
            break;
251
318
        default:
252
318
            break;
253
318
        }
254
318
        break;
255
318
    case Op::OpDPdxFine:
256
0
    case Op::OpDPdyFine:
257
0
    case Op::OpFwidthFine:
258
0
    case Op::OpDPdxCoarse:
259
0
    case Op::OpDPdyCoarse:
260
0
    case Op::OpFwidthCoarse:
261
0
        addCapability(Capability::DerivativeControl);
262
0
        break;
263
264
52
    case Op::OpImageQueryLod:
265
67
    case Op::OpImageQuerySize:
266
118
    case Op::OpImageQuerySizeLod:
267
124
    case Op::OpImageQuerySamples:
268
163
    case Op::OpImageQueryLevels:
269
163
        addCapability(Capability::ImageQuery);
270
163
        break;
271
272
0
    case Op::OpGroupNonUniformPartitionNV:
273
0
        addExtension(E_SPV_NV_shader_subgroup_partitioned);
274
0
        addCapability(Capability::GroupNonUniformPartitionedNV);
275
0
        break;
276
277
15.3k
    case Op::OpLoad:
278
23.2k
    case Op::OpStore:
279
23.2k
        {
280
            // For any load/store to a PhysicalStorageBufferEXT, walk the accesschain
281
            // index list to compute the misalignment. The pre-existing alignment value
282
            // (set via Builder::AccessChain::alignment) only accounts for the base of
283
            // the reference type and any scalar component selection in the accesschain,
284
            // and this function computes the rest from the SPIR-V Offset decorations.
285
23.2k
            Instruction *accessChain = module.getInstruction(inst.getIdOperand(0));
286
23.2k
            if (accessChain->getOpCode() == Op::OpAccessChain) {
287
7.06k
                const Instruction* base = module.getInstruction(accessChain->getIdOperand(0));
288
                // Get the type of the base of the access chain. It must be a pointer type.
289
7.06k
                Id typeId = base->getTypeId();
290
7.06k
                Instruction *type = module.getInstruction(typeId);
291
7.06k
                assert(type->getOpCode() == Op::OpTypePointer);
292
7.06k
                if (type->getImmediateOperand(0) != StorageClass::PhysicalStorageBuffer) {
293
6.98k
                    break;
294
6.98k
                }
295
                // Get the pointee type.
296
72
                typeId = type->getIdOperand(1);
297
72
                type = module.getInstruction(typeId);
298
                // Walk the index list for the access chain. For each index, find any
299
                // misalignment that can apply when accessing the member/element via
300
                // Offset/ArrayStride/MatrixStride decorations, and bitwise OR them all
301
                // together.
302
72
                int alignment = 0;
303
72
                bool first_struct_elem = false;
304
144
                for (int i = 1; i < accessChain->getNumOperands(); ++i) {
305
72
                    Instruction *idx = module.getInstruction(accessChain->getIdOperand(i));
306
72
                    if (type->getOpCode() == Op::OpTypeStruct) {
307
72
                        assert(idx->getOpCode() == Op::OpConstant);
308
72
                        unsigned int c = idx->getImmediateOperand(0);
309
310
1.17k
                        const auto function = [&](const std::unique_ptr<Instruction>& decoration) {
311
1.17k
                            if (decoration.get()->getOpCode() == Op::OpMemberDecorate &&
312
378
                                decoration.get()->getIdOperand(0) == typeId &&
313
162
                                decoration.get()->getImmediateOperand(1) == c &&
314
108
                                (decoration.get()->getImmediateOperand(2) == Decoration::Offset ||
315
81
                                 decoration.get()->getImmediateOperand(2) == Decoration::MatrixStride)) {
316
81
                                unsigned int opernad_value = decoration.get()->getImmediateOperand(3);
317
81
                                alignment |= opernad_value;
318
81
                                if (opernad_value == 0 &&
319
63
                                    decoration.get()->getImmediateOperand(2) == Decoration::Offset) {
320
63
                                    first_struct_elem = true;
321
63
                                }
322
81
                            }
323
1.17k
                        };
324
72
                        std::for_each(decorations.begin(), decorations.end(), function);
325
                        // get the next member type
326
72
                        typeId = type->getIdOperand(c);
327
72
                        type = module.getInstruction(typeId);
328
72
                    } else if (type->getOpCode() == Op::OpTypeArray ||
329
0
                               type->getOpCode() == Op::OpTypeRuntimeArray) {
330
0
                        const auto function = [&](const std::unique_ptr<Instruction>& decoration) {
331
0
                            if (decoration.get()->getOpCode() == Op::OpDecorate &&
332
0
                                decoration.get()->getIdOperand(0) == typeId &&
333
0
                                decoration.get()->getImmediateOperand(1) == Decoration::ArrayStride) {
334
0
                                alignment |= decoration.get()->getImmediateOperand(2);
335
0
                            }
336
0
                        };
337
0
                        std::for_each(decorations.begin(), decorations.end(), function);
338
                        // Get the element type
339
0
                        typeId = type->getIdOperand(0);
340
0
                        type = module.getInstruction(typeId);
341
0
                    } else {
342
                        // Once we get to any non-aggregate type, we're done.
343
0
                        break;
344
0
                    }
345
72
                }
346
72
                assert(inst.getNumOperands() >= 3);
347
72
                const bool is_store = inst.getOpCode() == Op::OpStore;
348
72
                auto const memoryAccess = (MemoryAccessMask)inst.getImmediateOperand(is_store ? 2 : 1);
349
72
                assert(anySet(memoryAccess, MemoryAccessMask::Aligned));
350
72
                static_cast<void>(memoryAccess);
351
352
                // Compute the index of the alignment operand.
353
72
                int alignmentIdx = 2;
354
72
                if (is_store)
355
54
                    alignmentIdx++;
356
                // Merge new and old (mis)alignment
357
72
                alignment |= inst.getImmediateOperand(alignmentIdx);
358
359
72
                if (!is_store) {
360
18
                    Instruction* inst_type = module.getInstruction(inst.getTypeId());
361
18
                    if (inst_type->getOpCode() == Op::OpTypePointer &&
362
0
                        inst_type->getImmediateOperand(0) == StorageClass::PhysicalStorageBuffer) {
363
                        // This means we are loading a pointer which means need to ensure it is at least 8-byte aligned
364
                        // See https://github.com/KhronosGroup/glslang/issues/4084
365
                        // In case the alignment is currently 4, need to ensure it is 8 before grabbing the LSB
366
0
                        alignment |= 8;
367
0
                        alignment &= 8;
368
0
                    }
369
18
                }
370
371
                // Pick the LSB
372
72
                alignment = alignment & ~(alignment & (alignment-1));
373
374
                // The edge case we find is when copying a struct to another struct, we never find the alignment anywhere,
375
                // so in this case, fallback to doing a full size lookup on the type
376
72
                if (alignment == 0 && first_struct_elem) {
377
                    // Quick get the struct type back
378
0
                    const Instruction* pointer_type = module.getInstruction(base->getTypeId());
379
0
                    const Instruction* struct_type = module.getInstruction(pointer_type->getIdOperand(1));
380
0
                    assert(struct_type->getOpCode() == Op::OpTypeStruct);
381
382
0
                    const Instruction* elem_type = module.getInstruction(struct_type->getIdOperand(0));
383
0
                    unsigned int largest_scalar = postProcessGetLargestScalarSize(*elem_type);
384
0
                    if (largest_scalar != 0) {
385
0
                        alignment = largest_scalar;
386
0
                    } else {
387
0
                        alignment = 16; // fallback if can't determine a godo alignment
388
0
                    }
389
0
                }
390
                // update the Aligned operand
391
72
                assert(alignment != 0);
392
72
                inst.setImmediateOperand(alignmentIdx, alignment);
393
72
            }
394
16.2k
            break;
395
23.2k
        }
396
397
32.9k
    default:
398
32.9k
        break;
399
56.6k
    }
400
401
    // Checks based on type
402
56.6k
    if (inst.getTypeId() != NoType)
403
40.8k
        postProcessType(inst, inst.getTypeId());
404
156k
    for (int op = 0; op < inst.getNumOperands(); ++op) {
405
99.4k
        if (inst.isIdOperand(op)) {
406
            // In blocks, these are always result ids, but we are relying on
407
            // getTypeId() to return NoType for things like OpLabel.
408
92.2k
            if (getTypeId(inst.getIdOperand(op)) != NoType)
409
87.3k
                postProcessType(inst, getTypeId(inst.getIdOperand(op)));
410
92.2k
        }
411
99.4k
    }
412
56.6k
}
413
414
// comment in header
415
void Builder::postProcessCFG()
416
364
{
417
    // reachableBlocks is the set of blockss reached via control flow, or which are
418
    // unreachable continue targert or unreachable merge.
419
364
    std::unordered_set<const Block*> reachableBlocks;
420
364
    std::unordered_map<Block*, Block*> headerForUnreachableContinue;
421
364
    std::unordered_set<Block*> unreachableMerges;
422
364
    std::unordered_set<Id> unreachableDefinitions;
423
    // Collect IDs defined in unreachable blocks. For each function, label the
424
    // reachable blocks first. Then for each unreachable block, collect the
425
    // result IDs of the instructions in it.
426
1.05k
    for (auto fi = module.getFunctions().cbegin(); fi != module.getFunctions().cend(); fi++) {
427
692
        Function* f = *fi;
428
692
        Block* entry = f->getEntryBlock();
429
692
        inReadableOrder(entry,
430
692
            [&reachableBlocks, &unreachableMerges, &headerForUnreachableContinue]
431
3.04k
            (Block* b, ReachReason why, Block* header) {
432
3.04k
               reachableBlocks.insert(b);
433
3.04k
               if (why == ReachDeadContinue) headerForUnreachableContinue[b] = header;
434
3.04k
               if (why == ReachDeadMerge) unreachableMerges.insert(b);
435
3.04k
            });
436
4.24k
        for (auto bi = f->getBlocks().cbegin(); bi != f->getBlocks().cend(); bi++) {
437
3.55k
            Block* b = *bi;
438
3.55k
            if (unreachableMerges.count(b) != 0 || headerForUnreachableContinue.count(b) != 0) {
439
18
                auto ii = b->getInstructions().cbegin();
440
18
                ++ii; // Keep potential decorations on the label.
441
36
                for (; ii != b->getInstructions().cend(); ++ii)
442
18
                    unreachableDefinitions.insert(ii->get()->getResultId());
443
3.53k
            } else if (reachableBlocks.count(b) == 0) {
444
                // The normal case for unreachable code.  All definitions are considered dead.
445
1.67k
                for (auto ii = b->getInstructions().cbegin(); ii != b->getInstructions().cend(); ++ii)
446
1.16k
                    unreachableDefinitions.insert(ii->get()->getResultId());
447
510
            }
448
3.55k
        }
449
692
    }
450
451
    // Modify unreachable merge blocks and unreachable continue targets.
452
    // Delete their contents.
453
382
    for (auto mergeIter = unreachableMerges.begin(); mergeIter != unreachableMerges.end(); ++mergeIter) {
454
18
        (*mergeIter)->rewriteAsCanonicalUnreachableMerge();
455
18
    }
456
364
    for (auto continueIter = headerForUnreachableContinue.begin();
457
364
         continueIter != headerForUnreachableContinue.end();
458
364
         ++continueIter) {
459
0
        Block* continue_target = continueIter->first;
460
0
        Block* header = continueIter->second;
461
0
        continue_target->rewriteAsCanonicalUnreachableContinue(header);
462
0
    }
463
464
    // Remove unneeded decorations, for unreachable instructions
465
7.44k
    for (auto decorationIter = decorations.begin(); decorationIter != decorations.end();) {
466
7.08k
        Id decorationId = (*decorationIter)->getIdOperand(0);
467
7.08k
        if (unreachableDefinitions.count(decorationId) != 0) {
468
0
            decorationIter = decorations.erase(decorationIter);
469
7.08k
        } else {
470
7.08k
            ++decorationIter;
471
7.08k
        }
472
7.08k
    }
473
364
}
474
475
// comment in header
476
364
void Builder::postProcessFeatures() {
477
    // Add per-instruction capabilities, extensions, etc.,
478
479
    // Look for any 8/16 bit type in physical storage buffer class, and set the
480
    // appropriate capability. This happens in createSpvVariable for other storage
481
    // classes, but there isn't always a variable for physical storage buffer.
482
3.87k
    for (int t = 0; t < (int)groupedTypes[enumCast(Op::OpTypePointer)].size(); ++t) {
483
3.50k
        Instruction* type = groupedTypes[enumCast(Op::OpTypePointer)][t];
484
3.50k
        if (type->getImmediateOperand(0) == (unsigned)StorageClass::PhysicalStorageBufferEXT) {
485
90
            if (containsType(type->getIdOperand(1), Op::OpTypeInt, 8)) {
486
0
                addIncorporatedExtension(spv::E_SPV_KHR_8bit_storage, spv::Spv_1_5);
487
0
                addCapability(spv::Capability::StorageBuffer8BitAccess);
488
0
            }
489
90
            if (containsType(type->getIdOperand(1), Op::OpTypeInt, 16) ||
490
90
                containsType(type->getIdOperand(1), Op::OpTypeFloat, 16)) {
491
9
                addIncorporatedExtension(spv::E_SPV_KHR_16bit_storage, spv::Spv_1_3);
492
9
                addCapability(spv::Capability::StorageBuffer16BitAccess);
493
9
            }
494
90
        }
495
3.50k
    }
496
497
    // process all block-contained instructions
498
1.05k
    for (auto fi = module.getFunctions().cbegin(); fi != module.getFunctions().cend(); fi++) {
499
692
        Function* f = *fi;
500
4.24k
        for (auto bi = f->getBlocks().cbegin(); bi != f->getBlocks().cend(); bi++) {
501
3.55k
            Block* b = *bi;
502
60.2k
            for (auto ii = b->getInstructions().cbegin(); ii != b->getInstructions().cend(); ii++)
503
56.6k
                postProcess(*ii->get());
504
505
            // For all local variables that contain pointers to PhysicalStorageBufferEXT, check whether
506
            // there is an existing restrict/aliased decoration. If we don't find one, add Aliased as the
507
            // default.
508
5.55k
            for (auto vi = b->getLocalVariables().cbegin(); vi != b->getLocalVariables().cend(); vi++) {
509
1.99k
                const Instruction& inst = *vi->get();
510
1.99k
                Id resultId = inst.getResultId();
511
1.99k
                if (containsPhysicalStorageBufferOrArray(getDerefTypeId(resultId))) {
512
108
                    bool foundDecoration = false;
513
1.44k
                    const auto function = [&](const std::unique_ptr<Instruction>& decoration) {
514
1.44k
                        if (decoration.get()->getIdOperand(0) == resultId &&
515
108
                            decoration.get()->getOpCode() == Op::OpDecorate &&
516
108
                            (decoration.get()->getImmediateOperand(1) == spv::Decoration::AliasedPointerEXT ||
517
108
                             decoration.get()->getImmediateOperand(1) == spv::Decoration::RestrictPointerEXT)) {
518
108
                            foundDecoration = true;
519
108
                        }
520
1.44k
                    };
521
108
                    std::for_each(decorations.begin(), decorations.end(), function);
522
108
                    if (!foundDecoration) {
523
0
                        addDecoration(resultId, spv::Decoration::AliasedPointerEXT);
524
0
                    }
525
108
                }
526
1.99k
            }
527
3.55k
        }
528
692
    }
529
530
    // If any Vulkan memory model-specific functionality is used, update the
531
    // OpMemoryModel to match.
532
364
    if (capabilities.find(spv::Capability::VulkanMemoryModelKHR) != capabilities.end()) {
533
0
        memoryModel = spv::MemoryModel::VulkanKHR;
534
0
        addIncorporatedExtension(spv::E_SPV_KHR_vulkan_memory_model, spv::Spv_1_5);
535
0
    }
536
537
    // Add Aliased decoration if there's more than one Workgroup Block variable.
538
364
    if (capabilities.find(spv::Capability::WorkgroupMemoryExplicitLayoutKHR) != capabilities.end()) {
539
0
        assert(entryPoints.size() == 1);
540
0
        auto &ep = entryPoints[0];
541
542
0
        std::vector<Id> workgroup_variables;
543
0
        for (int i = 0; i < (int)ep->getNumOperands(); i++) {
544
0
            if (!ep->isIdOperand(i))
545
0
                continue;
546
547
0
            const Id id = ep->getIdOperand(i);
548
0
            const Instruction *instr = module.getInstruction(id);
549
0
            if (instr->getOpCode() != spv::Op::OpVariable)
550
0
                continue;
551
552
0
            if (instr->getImmediateOperand(0) == spv::StorageClass::Workgroup)
553
0
                workgroup_variables.push_back(id);
554
0
        }
555
556
0
        if (workgroup_variables.size() > 1) {
557
0
            for (size_t i = 0; i < workgroup_variables.size(); i++)
558
0
                addDecoration(workgroup_variables[i], spv::Decoration::Aliased);
559
0
        }
560
0
    }
561
364
}
562
563
// SPIR-V requires that any instruction consuming the result of an OpSampledImage
564
// be in the same block as the OpSampledImage instruction. This pass goes finds
565
// uses of OpSampledImage where that is not the case and duplicates the
566
// OpSampledImage to be immediately before the instruction that consumes it.
567
// The old OpSampledImage is left in place, potentially with no users.
568
void Builder::postProcessSamplers()
569
364
{
570
    // first, find all OpSampledImage instructions and store them in a map.
571
364
    std::map<Id, Instruction*> sampledImageInstrs;
572
692
    for (auto f: module.getFunctions()) {
573
3.55k
  for (auto b: f->getBlocks()) {
574
56.6k
      for (auto &i: b->getInstructions()) {
575
56.6k
        if (i->getOpCode() == spv::Op::OpSampledImage) {
576
249
        sampledImageInstrs[i->getResultId()] = i.get();
577
249
    }
578
56.6k
      }
579
3.55k
  }
580
692
    }
581
    // next find all uses of the given ids and rewrite them if needed.
582
692
    for (auto f: module.getFunctions()) {
583
3.55k
  for (auto b: f->getBlocks()) {
584
3.55k
            auto &instrs = b->getInstructions();
585
60.2k
            for (size_t idx = 0; idx < instrs.size(); idx++) {
586
56.6k
                Instruction *i = instrs[idx].get();
587
156k
                for (int opnum = 0; opnum < i->getNumOperands(); opnum++) {
588
                    // Is this operand of the current instruction the result of an OpSampledImage?
589
99.4k
                    if (i->isIdOperand(opnum) &&
590
92.2k
                        sampledImageInstrs.count(i->getIdOperand(opnum)))
591
249
                    {
592
249
                        Instruction *opSampImg = sampledImageInstrs[i->getIdOperand(opnum)];
593
249
                        if (i->getBlock() != opSampImg->getBlock()) {
594
0
                            Instruction *newInstr = new Instruction(getUniqueId(),
595
0
                                                                    opSampImg->getTypeId(),
596
0
                                                                    spv::Op::OpSampledImage);
597
0
                            newInstr->addIdOperand(opSampImg->getIdOperand(0));
598
0
                            newInstr->addIdOperand(opSampImg->getIdOperand(1));
599
0
                            newInstr->setBlock(b);
600
601
                            // rewrite the user of the OpSampledImage to use the new instruction.
602
0
                            i->setIdOperand(opnum, newInstr->getResultId());
603
                            // insert the new OpSampledImage right before the current instruction.
604
0
                            instrs.insert(instrs.begin() + idx,
605
0
                                    std::unique_ptr<Instruction>(newInstr));
606
0
                            idx++;
607
0
                        }
608
249
                    }
609
99.4k
                }
610
56.6k
            }
611
3.55k
  }
612
692
    }
613
364
}
614
615
// comment in header
616
void Builder::postProcess(bool compileOnly)
617
364
{
618
    // postProcessCFG needs an entrypoint to determine what is reachable, but if we are not creating an "executable" shader, we don't have an entrypoint
619
364
    if (!compileOnly)
620
364
        postProcessCFG();
621
622
364
    postProcessFeatures();
623
364
    postProcessSamplers();
624
364
}
625
626
} // end spv namespace