/src/spirv-tools/source/opt/fix_func_call_arguments.cpp
Line | Count | Source (jump to first uncovered line) |
1 | | // Copyright (c) 2022 Advanced Micro Devices, Inc. |
2 | | // |
3 | | // Licensed under the Apache License, Version 2.0 (the "License"); |
4 | | // you may not use this file except in compliance with the License. |
5 | | // You may obtain a copy of the License at |
6 | | // |
7 | | // http://www.apache.org/licenses/LICENSE-2.0 |
8 | | // |
9 | | // Unless required by applicable law or agreed to in writing, software |
10 | | // distributed under the License is distributed on an "AS IS" BASIS, |
11 | | // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. |
12 | | // See the License for the specific language governing permissions and |
13 | | // limitations under the License. |
14 | | |
15 | | #include "fix_func_call_arguments.h" |
16 | | |
17 | | #include "ir_builder.h" |
18 | | |
19 | | using namespace spvtools; |
20 | | using namespace opt; |
21 | | |
22 | 0 | bool FixFuncCallArgumentsPass::ModuleHasASingleFunction() { |
23 | 0 | auto funcsNum = get_module()->end() - get_module()->begin(); |
24 | 0 | return funcsNum == 1; |
25 | 0 | } |
26 | | |
27 | 0 | Pass::Status FixFuncCallArgumentsPass::Process() { |
28 | 0 | bool modified = false; |
29 | 0 | if (ModuleHasASingleFunction()) return Status::SuccessWithoutChange; |
30 | 0 | for (auto& func : *get_module()) { |
31 | 0 | func.ForEachInst([this, &modified](Instruction* inst) { |
32 | 0 | if (inst->opcode() == spv::Op::OpFunctionCall) { |
33 | 0 | modified |= FixFuncCallArguments(inst); |
34 | 0 | } |
35 | 0 | }); |
36 | 0 | } |
37 | 0 | return modified ? Status::SuccessWithChange : Status::SuccessWithoutChange; |
38 | 0 | } |
39 | | |
40 | | bool FixFuncCallArgumentsPass::FixFuncCallArguments( |
41 | 0 | Instruction* func_call_inst) { |
42 | 0 | bool modified = false; |
43 | 0 | for (uint32_t i = 0; i < func_call_inst->NumInOperands(); ++i) { |
44 | 0 | Operand& op = func_call_inst->GetInOperand(i); |
45 | 0 | if (op.type != SPV_OPERAND_TYPE_ID) continue; |
46 | 0 | Instruction* operand_inst = get_def_use_mgr()->GetDef(op.AsId()); |
47 | 0 | if (operand_inst->opcode() == spv::Op::OpAccessChain) { |
48 | 0 | uint32_t var_id = |
49 | 0 | ReplaceAccessChainFuncCallArguments(func_call_inst, operand_inst); |
50 | 0 | func_call_inst->SetInOperand(i, {var_id}); |
51 | 0 | modified = true; |
52 | 0 | } |
53 | 0 | } |
54 | 0 | if (modified) { |
55 | 0 | context()->UpdateDefUse(func_call_inst); |
56 | 0 | } |
57 | 0 | return modified; |
58 | 0 | } |
59 | | |
60 | | uint32_t FixFuncCallArgumentsPass::ReplaceAccessChainFuncCallArguments( |
61 | 0 | Instruction* func_call_inst, Instruction* operand_inst) { |
62 | 0 | InstructionBuilder builder( |
63 | 0 | context(), func_call_inst, |
64 | 0 | IRContext::kAnalysisDefUse | IRContext::kAnalysisInstrToBlockMapping); |
65 | |
|
66 | 0 | Instruction* next_insert_point = func_call_inst->NextNode(); |
67 | | // Get Variable insertion point |
68 | 0 | Function* func = context()->get_instr_block(func_call_inst)->GetParent(); |
69 | 0 | Instruction* variable_insertion_point = &*(func->begin()->begin()); |
70 | 0 | Instruction* op_ptr_type = get_def_use_mgr()->GetDef(operand_inst->type_id()); |
71 | 0 | Instruction* op_type = |
72 | 0 | get_def_use_mgr()->GetDef(op_ptr_type->GetSingleWordInOperand(1)); |
73 | 0 | uint32_t varType = context()->get_type_mgr()->FindPointerToType( |
74 | 0 | op_type->result_id(), spv::StorageClass::Function); |
75 | | // Create new variable |
76 | 0 | builder.SetInsertPoint(variable_insertion_point); |
77 | 0 | Instruction* var = |
78 | 0 | builder.AddVariable(varType, uint32_t(spv::StorageClass::Function)); |
79 | | // Load access chain to the new variable before function call |
80 | 0 | builder.SetInsertPoint(func_call_inst); |
81 | |
|
82 | 0 | uint32_t operand_id = operand_inst->result_id(); |
83 | 0 | Instruction* load = builder.AddLoad(op_type->result_id(), operand_id); |
84 | 0 | builder.AddStore(var->result_id(), load->result_id()); |
85 | | // Load return value to the acesschain after function call |
86 | 0 | builder.SetInsertPoint(next_insert_point); |
87 | 0 | load = builder.AddLoad(op_type->result_id(), var->result_id()); |
88 | 0 | builder.AddStore(operand_id, load->result_id()); |
89 | |
|
90 | 0 | return var->result_id(); |
91 | 0 | } |