Coverage Report

Created: 2025-07-18 06:38

/src/shaderc/third_party/glslang/SPIRV/SpvPostProcess.cpp
Line
Count
Source (jump to first uncovered line)
1
//
2
// Copyright (C) 2018 Google, Inc.
3
//
4
// All rights reserved.
5
//
6
// Redistribution and use in source and binary forms, with or without
7
// modification, are permitted provided that the following conditions
8
// are met:
9
//
10
//    Redistributions of source code must retain the above copyright
11
//    notice, this list of conditions and the following disclaimer.
12
//
13
//    Redistributions in binary form must reproduce the above
14
//    copyright notice, this list of conditions and the following
15
//    disclaimer in the documentation and/or other materials provided
16
//    with the distribution.
17
//
18
//    Neither the name of 3Dlabs Inc. Ltd. nor the names of its
19
//    contributors may be used to endorse or promote products derived
20
//    from this software without specific prior written permission.
21
//
22
// THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS
23
// "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT
24
// LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS
25
// FOR A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE
26
// COPYRIGHT HOLDERS OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT,
27
// INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING,
28
// BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES;
29
// LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
30
// CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT
31
// LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN
32
// ANY WAY OUT OF THE USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE
33
// POSSIBILITY OF SUCH DAMAGE.
34
35
//
36
// Post-processing for SPIR-V IR, in internal form, not standard binary form.
37
//
38
39
#include <cassert>
40
#include <cstdlib>
41
42
#include <unordered_map>
43
#include <unordered_set>
44
#include <algorithm>
45
46
#include "SpvBuilder.h"
47
#include "spirv.hpp11"
48
#include "spvUtil.h"
49
50
namespace spv {
51
    #include "GLSL.std.450.h"
52
    #include "GLSL.ext.KHR.h"
53
    #include "GLSL.ext.EXT.h"
54
    #include "GLSL.ext.AMD.h"
55
    #include "GLSL.ext.NV.h"
56
    #include "GLSL.ext.ARM.h"
57
    #include "GLSL.ext.QCOM.h"
58
}
59
60
namespace spv {
61
62
// Hook to visit each operand type and result type of an instruction.
63
// Will be called multiple times for one instruction, once for each typed
64
// operand and the result.
65
void Builder::postProcessType(const Instruction& inst, Id typeId)
66
483k
{
67
    // Characterize the type being questioned
68
483k
    Op basicTypeOp = getMostBasicTypeClass(typeId);
69
483k
    int width = 0;
70
483k
    if (basicTypeOp == Op::OpTypeFloat || basicTypeOp == Op::OpTypeInt)
71
413k
        width = getScalarTypeWidth(typeId);
72
73
    // Do opcode-specific checks
74
483k
    switch (inst.getOpCode()) {
75
117k
    case Op::OpLoad:
76
179k
    case Op::OpStore:
77
179k
        if (basicTypeOp == Op::OpTypeStruct) {
78
3.17k
            if (containsType(typeId, Op::OpTypeInt, 8))
79
0
                addCapability(Capability::Int8);
80
3.17k
            if (containsType(typeId, Op::OpTypeInt, 16))
81
0
                addCapability(Capability::Int16);
82
3.17k
            if (containsType(typeId, Op::OpTypeFloat, 16))
83
0
                addCapability(Capability::Float16);
84
176k
        } else {
85
176k
            StorageClass storageClass = getStorageClass(inst.getIdOperand(0));
86
176k
            if (width == 8) {
87
556
                switch (storageClass) {
88
84
                case StorageClass::PhysicalStorageBufferEXT:
89
420
                case StorageClass::Uniform:
90
420
                case StorageClass::StorageBuffer:
91
420
                case StorageClass::PushConstant:
92
420
                    break;
93
136
                default:
94
136
                    addCapability(Capability::Int8);
95
136
                    break;
96
556
                }
97
176k
            } else if (width == 16) {
98
18.0k
                switch (storageClass) {
99
0
                case StorageClass::PhysicalStorageBufferEXT:
100
6.80k
                case StorageClass::Uniform:
101
6.80k
                case StorageClass::StorageBuffer:
102
6.80k
                case StorageClass::PushConstant:
103
7.92k
                case StorageClass::Input:
104
7.92k
                case StorageClass::Output:
105
7.92k
                    break;
106
10.0k
                default:
107
10.0k
                    if (basicTypeOp == Op::OpTypeInt)
108
4.50k
                        addCapability(Capability::Int16);
109
10.0k
                    if (basicTypeOp == Op::OpTypeFloat)
110
5.58k
                        addCapability(Capability::Float16);
111
10.0k
                    break;
112
18.0k
                }
113
18.0k
            }
114
176k
        }
115
179k
        break;
116
179k
    case Op::OpCopyObject:
117
24
        break;
118
100
    case Op::OpFConvert:
119
694
    case Op::OpSConvert:
120
1.08k
    case Op::OpUConvert:
121
        // Look for any 8/16-bit storage capabilities. If there are none, assume that
122
        // the convert instruction requires the Float16/Int8/16 capability.
123
1.08k
        if (containsType(typeId, Op::OpTypeFloat, 16) || containsType(typeId, Op::OpTypeInt, 16)) {
124
297
            bool foundStorage = false;
125
1.34k
            for (auto it = capabilities.begin(); it != capabilities.end(); ++it) {
126
1.34k
                spv::Capability cap = *it;
127
1.34k
                if (cap == spv::Capability::StorageInputOutput16 ||
128
1.34k
                    cap == spv::Capability::StoragePushConstant16 ||
129
1.34k
                    cap == spv::Capability::StorageUniformBufferBlock16 ||
130
1.34k
                    cap == spv::Capability::StorageUniform16) {
131
297
                    foundStorage = true;
132
297
                    break;
133
297
                }
134
1.34k
            }
135
297
            if (!foundStorage) {
136
0
                if (containsType(typeId, Op::OpTypeFloat, 16))
137
0
                    addCapability(Capability::Float16);
138
0
                if (containsType(typeId, Op::OpTypeInt, 16))
139
0
                    addCapability(Capability::Int16);
140
0
            }
141
297
        }
142
1.08k
        if (containsType(typeId, Op::OpTypeInt, 8)) {
143
126
            bool foundStorage = false;
144
268
            for (auto it = capabilities.begin(); it != capabilities.end(); ++it) {
145
268
                spv::Capability cap = *it;
146
268
                if (cap == spv::Capability::StoragePushConstant8 ||
147
268
                    cap == spv::Capability::UniformAndStorageBuffer8BitAccess ||
148
268
                    cap == spv::Capability::StorageBuffer8BitAccess) {
149
126
                    foundStorage = true;
150
126
                    break;
151
126
                }
152
268
            }
153
126
            if (!foundStorage) {
154
0
                addCapability(Capability::Int8);
155
0
            }
156
126
        }
157
1.08k
        break;
158
7.59k
    case Op::OpExtInst:
159
7.59k
        switch (inst.getImmediateOperand(1)) {
160
0
        case GLSLstd450Frexp:
161
20
        case GLSLstd450FrexpStruct:
162
20
            if (getSpvVersion() < spv::Spv_1_3 && containsType(typeId, Op::OpTypeInt, 16))
163
10
                addExtension(spv::E_SPV_AMD_gpu_shader_int16);
164
20
            break;
165
0
        case GLSLstd450InterpolateAtCentroid:
166
0
        case GLSLstd450InterpolateAtSample:
167
0
        case GLSLstd450InterpolateAtOffset:
168
0
            if (getSpvVersion() < spv::Spv_1_3 && containsType(typeId, Op::OpTypeFloat, 16))
169
0
                addExtension(spv::E_SPV_AMD_gpu_shader_half_float);
170
0
            break;
171
7.57k
        default:
172
7.57k
            break;
173
7.59k
        }
174
7.59k
        break;
175
67.4k
    case Op::OpAccessChain:
176
67.4k
    case Op::OpPtrAccessChain:
177
67.4k
        if (isPointerType(typeId))
178
42.5k
            break;
179
24.8k
        if (basicTypeOp == Op::OpTypeInt) {
180
24.8k
            if (width == 16)
181
0
                addCapability(Capability::Int16);
182
24.8k
            else if (width == 8)
183
0
                addCapability(Capability::Int8);
184
24.8k
        }
185
24.8k
        break;
186
227k
    default:
187
227k
        if (basicTypeOp == Op::OpTypeInt) {
188
76.3k
            if (width == 16)
189
4.45k
                addCapability(Capability::Int16);
190
71.8k
            else if (width == 8)
191
1.07k
                addCapability(Capability::Int8);
192
70.7k
            else if (width == 64)
193
10.6k
                addCapability(Capability::Int64);
194
151k
        } else if (basicTypeOp == Op::OpTypeFloat) {
195
111k
            if (width == 16)
196
13.8k
                addCapability(Capability::Float16);
197
97.5k
            else if (width == 64)
198
60
                addCapability(Capability::Float64);
199
111k
        }
200
227k
        break;
201
483k
    }
202
483k
}
203
204
// Called for each instruction that resides in a block.
205
void Builder::postProcess(Instruction& inst)
206
237k
{
207
    // Add capabilities based simply on the opcode.
208
237k
    switch (inst.getOpCode()) {
209
2.79k
    case Op::OpExtInst:
210
2.79k
        switch (inst.getImmediateOperand(1)) {
211
0
        case GLSLstd450InterpolateAtCentroid:
212
0
        case GLSLstd450InterpolateAtSample:
213
0
        case GLSLstd450InterpolateAtOffset:
214
0
            addCapability(Capability::InterpolationFunction);
215
0
            break;
216
2.79k
        default:
217
2.79k
            break;
218
2.79k
        }
219
2.79k
        break;
220
2.79k
    case Op::OpDPdxFine:
221
0
    case Op::OpDPdyFine:
222
0
    case Op::OpFwidthFine:
223
0
    case Op::OpDPdxCoarse:
224
0
    case Op::OpDPdyCoarse:
225
0
    case Op::OpFwidthCoarse:
226
0
        addCapability(Capability::DerivativeControl);
227
0
        break;
228
229
52
    case Op::OpImageQueryLod:
230
103
    case Op::OpImageQuerySize:
231
340
    case Op::OpImageQuerySizeLod:
232
357
    case Op::OpImageQuerySamples:
233
407
    case Op::OpImageQueryLevels:
234
407
        addCapability(Capability::ImageQuery);
235
407
        break;
236
237
0
    case Op::OpGroupNonUniformPartitionNV:
238
0
        addExtension(E_SPV_NV_shader_subgroup_partitioned);
239
0
        addCapability(Capability::GroupNonUniformPartitionedNV);
240
0
        break;
241
242
58.5k
    case Op::OpLoad:
243
89.9k
    case Op::OpStore:
244
89.9k
        {
245
            // For any load/store to a PhysicalStorageBufferEXT, walk the accesschain
246
            // index list to compute the misalignment. The pre-existing alignment value
247
            // (set via Builder::AccessChain::alignment) only accounts for the base of
248
            // the reference type and any scalar component selection in the accesschain,
249
            // and this function computes the rest from the SPIR-V Offset decorations.
250
89.9k
            Instruction *accessChain = module.getInstruction(inst.getIdOperand(0));
251
89.9k
            if (accessChain->getOpCode() == Op::OpAccessChain) {
252
20.8k
                Instruction *base = module.getInstruction(accessChain->getIdOperand(0));
253
                // Get the type of the base of the access chain. It must be a pointer type.
254
20.8k
                Id typeId = base->getTypeId();
255
20.8k
                Instruction *type = module.getInstruction(typeId);
256
20.8k
                assert(type->getOpCode() == Op::OpTypePointer);
257
20.8k
                if (type->getImmediateOperand(0) != StorageClass::PhysicalStorageBufferEXT) {
258
20.4k
                    break;
259
20.4k
                }
260
                // Get the pointee type.
261
356
                typeId = type->getIdOperand(1);
262
356
                type = module.getInstruction(typeId);
263
                // Walk the index list for the access chain. For each index, find any
264
                // misalignment that can apply when accessing the member/element via
265
                // Offset/ArrayStride/MatrixStride decorations, and bitwise OR them all
266
                // together.
267
356
                int alignment = 0;
268
735
                for (int i = 1; i < accessChain->getNumOperands(); ++i) {
269
381
                    Instruction *idx = module.getInstruction(accessChain->getIdOperand(i));
270
381
                    if (type->getOpCode() == Op::OpTypeStruct) {
271
356
                        assert(idx->getOpCode() == Op::OpConstant);
272
356
                        unsigned int c = idx->getImmediateOperand(0);
273
274
5.66k
                        const auto function = [&](const std::unique_ptr<Instruction>& decoration) {
275
5.66k
                            if (decoration.get()->getOpCode() == Op::OpMemberDecorate &&
276
5.66k
                                decoration.get()->getIdOperand(0) == typeId &&
277
5.66k
                                decoration.get()->getImmediateOperand(1) == c &&
278
5.66k
                                (decoration.get()->getImmediateOperand(2) == Decoration::Offset ||
279
420
                                 decoration.get()->getImmediateOperand(2) == Decoration::MatrixStride)) {
280
372
                                alignment |= decoration.get()->getImmediateOperand(3);
281
372
                            }
282
5.66k
                        };
283
356
                        std::for_each(decorations.begin(), decorations.end(), function);
284
                        // get the next member type
285
356
                        typeId = type->getIdOperand(c);
286
356
                        type = module.getInstruction(typeId);
287
356
                    } else if (type->getOpCode() == Op::OpTypeArray ||
288
25
                               type->getOpCode() == Op::OpTypeRuntimeArray) {
289
486
                        const auto function = [&](const std::unique_ptr<Instruction>& decoration) {
290
486
                            if (decoration.get()->getOpCode() == Op::OpDecorate &&
291
486
                                decoration.get()->getIdOperand(0) == typeId &&
292
486
                                decoration.get()->getImmediateOperand(1) == Decoration::ArrayStride) {
293
23
                                alignment |= decoration.get()->getImmediateOperand(2);
294
23
                            }
295
486
                        };
296
23
                        std::for_each(decorations.begin(), decorations.end(), function);
297
                        // Get the element type
298
23
                        typeId = type->getIdOperand(0);
299
23
                        type = module.getInstruction(typeId);
300
23
                    } else {
301
                        // Once we get to any non-aggregate type, we're done.
302
2
                        break;
303
2
                    }
304
381
                }
305
356
                assert(inst.getNumOperands() >= 3);
306
356
                auto const memoryAccess = (MemoryAccessMask)inst.getImmediateOperand((inst.getOpCode() == Op::OpStore) ? 2 : 1);
307
356
                assert(anySet(memoryAccess, MemoryAccessMask::Aligned));
308
356
                static_cast<void>(memoryAccess);
309
                // Compute the index of the alignment operand.
310
356
                int alignmentIdx = 2;
311
356
                if (inst.getOpCode() == Op::OpStore)
312
190
                    alignmentIdx++;
313
                // Merge new and old (mis)alignment
314
356
                alignment |= inst.getImmediateOperand(alignmentIdx);
315
                // Pick the LSB
316
356
                alignment = alignment & ~(alignment & (alignment-1));
317
                // update the Aligned operand
318
356
                inst.setImmediateOperand(alignmentIdx, alignment);
319
356
            }
320
69.4k
            break;
321
89.9k
        }
322
323
144k
    default:
324
144k
        break;
325
237k
    }
326
327
    // Checks based on type
328
237k
    if (inst.getTypeId() != NoType)
329
157k
        postProcessType(inst, inst.getTypeId());
330
635k
    for (int op = 0; op < inst.getNumOperands(); ++op) {
331
397k
        if (inst.isIdOperand(op)) {
332
            // In blocks, these are always result ids, but we are relying on
333
            // getTypeId() to return NoType for things like OpLabel.
334
357k
            if (getTypeId(inst.getIdOperand(op)) != NoType)
335
325k
                postProcessType(inst, getTypeId(inst.getIdOperand(op)));
336
357k
        }
337
397k
    }
338
237k
}
339
340
// comment in header
341
void Builder::postProcessCFG()
342
2.63k
{
343
    // reachableBlocks is the set of blockss reached via control flow, or which are
344
    // unreachable continue targert or unreachable merge.
345
2.63k
    std::unordered_set<const Block*> reachableBlocks;
346
2.63k
    std::unordered_map<Block*, Block*> headerForUnreachableContinue;
347
2.63k
    std::unordered_set<Block*> unreachableMerges;
348
2.63k
    std::unordered_set<Id> unreachableDefinitions;
349
    // Collect IDs defined in unreachable blocks. For each function, label the
350
    // reachable blocks first. Then for each unreachable block, collect the
351
    // result IDs of the instructions in it.
352
6.48k
    for (auto fi = module.getFunctions().cbegin(); fi != module.getFunctions().cend(); fi++) {
353
3.85k
        Function* f = *fi;
354
3.85k
        Block* entry = f->getEntryBlock();
355
3.85k
        inReadableOrder(entry,
356
3.85k
            [&reachableBlocks, &unreachableMerges, &headerForUnreachableContinue]
357
18.7k
            (Block* b, ReachReason why, Block* header) {
358
18.7k
               reachableBlocks.insert(b);
359
18.7k
               if (why == ReachDeadContinue) headerForUnreachableContinue[b] = header;
360
18.7k
               if (why == ReachDeadMerge) unreachableMerges.insert(b);
361
18.7k
            });
362
25.6k
        for (auto bi = f->getBlocks().cbegin(); bi != f->getBlocks().cend(); bi++) {
363
21.7k
            Block* b = *bi;
364
21.7k
            if (unreachableMerges.count(b) != 0 || headerForUnreachableContinue.count(b) != 0) {
365
36
                auto ii = b->getInstructions().cbegin();
366
36
                ++ii; // Keep potential decorations on the label.
367
72
                for (; ii != b->getInstructions().cend(); ++ii)
368
36
                    unreachableDefinitions.insert(ii->get()->getResultId());
369
21.7k
            } else if (reachableBlocks.count(b) == 0) {
370
                // The normal case for unreachable code.  All definitions are considered dead.
371
9.58k
                for (auto ii = b->getInstructions().cbegin(); ii != b->getInstructions().cend(); ++ii)
372
6.59k
                    unreachableDefinitions.insert(ii->get()->getResultId());
373
2.98k
            }
374
21.7k
        }
375
3.85k
    }
376
377
    // Modify unreachable merge blocks and unreachable continue targets.
378
    // Delete their contents.
379
2.66k
    for (auto mergeIter = unreachableMerges.begin(); mergeIter != unreachableMerges.end(); ++mergeIter) {
380
36
        (*mergeIter)->rewriteAsCanonicalUnreachableMerge();
381
36
    }
382
2.63k
    for (auto continueIter = headerForUnreachableContinue.begin();
383
2.63k
         continueIter != headerForUnreachableContinue.end();
384
2.63k
         ++continueIter) {
385
0
        Block* continue_target = continueIter->first;
386
0
        Block* header = continueIter->second;
387
0
        continue_target->rewriteAsCanonicalUnreachableContinue(header);
388
0
    }
389
390
    // Remove unneeded decorations, for unreachable instructions
391
28.6k
    for (auto decorationIter = decorations.begin(); decorationIter != decorations.end();) {
392
26.0k
        Id decorationId = (*decorationIter)->getIdOperand(0);
393
26.0k
        if (unreachableDefinitions.count(decorationId) != 0) {
394
0
            decorationIter = decorations.erase(decorationIter);
395
26.0k
        } else {
396
26.0k
            ++decorationIter;
397
26.0k
        }
398
26.0k
    }
399
2.63k
}
400
401
// comment in header
402
2.63k
void Builder::postProcessFeatures() {
403
    // Add per-instruction capabilities, extensions, etc.,
404
405
    // Look for any 8/16 bit type in physical storage buffer class, and set the
406
    // appropriate capability. This happens in createSpvVariable for other storage
407
    // classes, but there isn't always a variable for physical storage buffer.
408
18.7k
    for (int t = 0; t < (int)groupedTypes[enumCast(Op::OpTypePointer)].size(); ++t) {
409
16.1k
        Instruction* type = groupedTypes[enumCast(Op::OpTypePointer)][t];
410
16.1k
        if (type->getImmediateOperand(0) == (unsigned)StorageClass::PhysicalStorageBufferEXT) {
411
322
            if (containsType(type->getIdOperand(1), Op::OpTypeInt, 8)) {
412
42
                addIncorporatedExtension(spv::E_SPV_KHR_8bit_storage, spv::Spv_1_5);
413
42
                addCapability(spv::Capability::StorageBuffer8BitAccess);
414
42
            }
415
322
            if (containsType(type->getIdOperand(1), Op::OpTypeInt, 16) ||
416
322
                containsType(type->getIdOperand(1), Op::OpTypeFloat, 16)) {
417
16
                addIncorporatedExtension(spv::E_SPV_KHR_16bit_storage, spv::Spv_1_3);
418
16
                addCapability(spv::Capability::StorageBuffer16BitAccess);
419
16
            }
420
322
        }
421
16.1k
    }
422
423
    // process all block-contained instructions
424
6.48k
    for (auto fi = module.getFunctions().cbegin(); fi != module.getFunctions().cend(); fi++) {
425
3.85k
        Function* f = *fi;
426
25.6k
        for (auto bi = f->getBlocks().cbegin(); bi != f->getBlocks().cend(); bi++) {
427
21.7k
            Block* b = *bi;
428
259k
            for (auto ii = b->getInstructions().cbegin(); ii != b->getInstructions().cend(); ii++)
429
237k
                postProcess(*ii->get());
430
431
            // For all local variables that contain pointers to PhysicalStorageBufferEXT, check whether
432
            // there is an existing restrict/aliased decoration. If we don't find one, add Aliased as the
433
            // default.
434
37.4k
            for (auto vi = b->getLocalVariables().cbegin(); vi != b->getLocalVariables().cend(); vi++) {
435
15.7k
                const Instruction& inst = *vi->get();
436
15.7k
                Id resultId = inst.getResultId();
437
15.7k
                if (containsPhysicalStorageBufferOrArray(getDerefTypeId(resultId))) {
438
299
                    bool foundDecoration = false;
439
4.35k
                    const auto function = [&](const std::unique_ptr<Instruction>& decoration) {
440
4.35k
                        if (decoration.get()->getIdOperand(0) == resultId &&
441
4.35k
                            decoration.get()->getOpCode() == Op::OpDecorate &&
442
4.35k
                            (decoration.get()->getImmediateOperand(1) == spv::Decoration::AliasedPointerEXT ||
443
294
                             decoration.get()->getImmediateOperand(1) == spv::Decoration::RestrictPointerEXT)) {
444
290
                            foundDecoration = true;
445
290
                        }
446
4.35k
                    };
447
299
                    std::for_each(decorations.begin(), decorations.end(), function);
448
299
                    if (!foundDecoration) {
449
9
                        addDecoration(resultId, spv::Decoration::AliasedPointerEXT);
450
9
                    }
451
299
                }
452
15.7k
            }
453
21.7k
        }
454
3.85k
    }
455
456
    // If any Vulkan memory model-specific functionality is used, update the
457
    // OpMemoryModel to match.
458
2.63k
    if (capabilities.find(spv::Capability::VulkanMemoryModelKHR) != capabilities.end()) {
459
43
        memoryModel = spv::MemoryModel::VulkanKHR;
460
43
        addIncorporatedExtension(spv::E_SPV_KHR_vulkan_memory_model, spv::Spv_1_5);
461
43
    }
462
463
    // Add Aliased decoration if there's more than one Workgroup Block variable.
464
2.63k
    if (capabilities.find(spv::Capability::WorkgroupMemoryExplicitLayoutKHR) != capabilities.end()) {
465
0
        assert(entryPoints.size() == 1);
466
0
        auto &ep = entryPoints[0];
467
468
0
        std::vector<Id> workgroup_variables;
469
0
        for (int i = 0; i < (int)ep->getNumOperands(); i++) {
470
0
            if (!ep->isIdOperand(i))
471
0
                continue;
472
473
0
            const Id id = ep->getIdOperand(i);
474
0
            const Instruction *instr = module.getInstruction(id);
475
0
            if (instr->getOpCode() != spv::Op::OpVariable)
476
0
                continue;
477
478
0
            if (instr->getImmediateOperand(0) == spv::StorageClass::Workgroup)
479
0
                workgroup_variables.push_back(id);
480
0
        }
481
482
0
        if (workgroup_variables.size() > 1) {
483
0
            for (size_t i = 0; i < workgroup_variables.size(); i++)
484
0
                addDecoration(workgroup_variables[i], spv::Decoration::Aliased);
485
0
        }
486
0
    }
487
2.63k
}
488
489
// SPIR-V requires that any instruction consuming the result of an OpSampledImage
490
// be in the same block as the OpSampledImage instruction. This pass goes finds
491
// uses of OpSampledImage where that is not the case and duplicates the
492
// OpSampledImage to be immediately before the instruction that consumes it.
493
// The old OpSampledImage is left in place, potentially with no users.
494
void Builder::postProcessSamplers()
495
2.63k
{
496
    // first, find all OpSampledImage instructions and store them in a map.
497
2.63k
    std::map<Id, Instruction*> sampledImageInstrs;
498
3.85k
    for (auto f: module.getFunctions()) {
499
21.7k
  for (auto b: f->getBlocks()) {
500
237k
      for (auto &i: b->getInstructions()) {
501
237k
        if (i->getOpCode() == spv::Op::OpSampledImage) {
502
359
        sampledImageInstrs[i->getResultId()] = i.get();
503
359
    }
504
237k
      }
505
21.7k
  }
506
3.85k
    }
507
    // next find all uses of the given ids and rewrite them if needed.
508
3.85k
    for (auto f: module.getFunctions()) {
509
21.7k
  for (auto b: f->getBlocks()) {
510
21.7k
            auto &instrs = b->getInstructions();
511
259k
            for (size_t idx = 0; idx < instrs.size(); idx++) {
512
237k
                Instruction *i = instrs[idx].get();
513
635k
                for (int opnum = 0; opnum < i->getNumOperands(); opnum++) {
514
                    // Is this operand of the current instruction the result of an OpSampledImage?
515
397k
                    if (i->isIdOperand(opnum) &&
516
397k
                        sampledImageInstrs.count(i->getIdOperand(opnum)))
517
353
                    {
518
353
                        Instruction *opSampImg = sampledImageInstrs[i->getIdOperand(opnum)];
519
353
                        if (i->getBlock() != opSampImg->getBlock()) {
520
1
                            Instruction *newInstr = new Instruction(getUniqueId(),
521
1
                                                                    opSampImg->getTypeId(),
522
1
                                                                    spv::Op::OpSampledImage);
523
1
                            newInstr->addIdOperand(opSampImg->getIdOperand(0));
524
1
                            newInstr->addIdOperand(opSampImg->getIdOperand(1));
525
1
                            newInstr->setBlock(b);
526
527
                            // rewrite the user of the OpSampledImage to use the new instruction.
528
1
                            i->setIdOperand(opnum, newInstr->getResultId());
529
                            // insert the new OpSampledImage right before the current instruction.
530
1
                            instrs.insert(instrs.begin() + idx,
531
1
                                    std::unique_ptr<Instruction>(newInstr));
532
1
                            idx++;
533
1
                        }
534
353
                    }
535
397k
                }
536
237k
            }
537
21.7k
  }
538
3.85k
    }
539
2.63k
}
540
541
// comment in header
542
void Builder::postProcess(bool compileOnly)
543
2.63k
{
544
    // postProcessCFG needs an entrypoint to determine what is reachable, but if we are not creating an "executable" shader, we don't have an entrypoint
545
2.63k
    if (!compileOnly)
546
2.63k
        postProcessCFG();
547
548
2.63k
    postProcessFeatures();
549
2.63k
    postProcessSamplers();
550
2.63k
}
551
552
} // end spv namespace