/src/llvm-project/llvm/lib/Target/NVPTX/NVVMReflect.cpp
Line | Count | Source (jump to first uncovered line) |
1 | | //===- NVVMReflect.cpp - NVVM Emulate conditional compilation -------------===// |
2 | | // |
3 | | // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. |
4 | | // See https://llvm.org/LICENSE.txt for license information. |
5 | | // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception |
6 | | // |
7 | | //===----------------------------------------------------------------------===// |
8 | | // |
9 | | // This pass replaces occurrences of __nvvm_reflect("foo") and llvm.nvvm.reflect |
10 | | // with an integer. |
11 | | // |
12 | | // We choose the value we use by looking at metadata in the module itself. Note |
13 | | // that we intentionally only have one way to choose these values, because other |
14 | | // parts of LLVM (particularly, InstCombineCall) rely on being able to predict |
15 | | // the values chosen by this pass. |
16 | | // |
17 | | // If we see an unknown string, we replace its call with 0. |
18 | | // |
19 | | //===----------------------------------------------------------------------===// |
20 | | |
21 | | #include "NVPTX.h" |
22 | | #include "llvm/ADT/SmallVector.h" |
23 | | #include "llvm/IR/Constants.h" |
24 | | #include "llvm/IR/DerivedTypes.h" |
25 | | #include "llvm/IR/Function.h" |
26 | | #include "llvm/IR/InstIterator.h" |
27 | | #include "llvm/IR/Instructions.h" |
28 | | #include "llvm/IR/Intrinsics.h" |
29 | | #include "llvm/IR/IntrinsicsNVPTX.h" |
30 | | #include "llvm/IR/Module.h" |
31 | | #include "llvm/IR/PassManager.h" |
32 | | #include "llvm/IR/Type.h" |
33 | | #include "llvm/Pass.h" |
34 | | #include "llvm/Support/CommandLine.h" |
35 | | #include "llvm/Support/Debug.h" |
36 | | #include "llvm/Support/raw_os_ostream.h" |
37 | | #include "llvm/Support/raw_ostream.h" |
38 | | #include "llvm/Transforms/Scalar.h" |
39 | | #include <sstream> |
40 | | #include <string> |
41 | 1.47k | #define NVVM_REFLECT_FUNCTION "__nvvm_reflect" |
42 | 738 | #define NVVM_REFLECT_OCL_FUNCTION "__nvvm_reflect_ocl" |
43 | | |
44 | | using namespace llvm; |
45 | | |
46 | | #define DEBUG_TYPE "nvptx-reflect" |
47 | | |
48 | | namespace llvm { void initializeNVVMReflectPass(PassRegistry &); } |
49 | | |
50 | | namespace { |
51 | | class NVVMReflect : public FunctionPass { |
52 | | public: |
53 | | static char ID; |
54 | | unsigned int SmVersion; |
55 | 0 | NVVMReflect() : NVVMReflect(0) {} |
56 | 740 | explicit NVVMReflect(unsigned int Sm) : FunctionPass(ID), SmVersion(Sm) { |
57 | 740 | initializeNVVMReflectPass(*PassRegistry::getPassRegistry()); |
58 | 740 | } |
59 | | |
60 | | bool runOnFunction(Function &) override; |
61 | | }; |
62 | | } |
63 | | |
64 | 740 | FunctionPass *llvm::createNVVMReflectPass(unsigned int SmVersion) { |
65 | 740 | return new NVVMReflect(SmVersion); |
66 | 740 | } |
67 | | |
68 | | static cl::opt<bool> |
69 | | NVVMReflectEnabled("nvvm-reflect-enable", cl::init(true), cl::Hidden, |
70 | | cl::desc("NVVM reflection, enabled by default")); |
71 | | |
72 | | char NVVMReflect::ID = 0; |
73 | | INITIALIZE_PASS(NVVMReflect, "nvvm-reflect", |
74 | | "Replace occurrences of __nvvm_reflect() calls with 0/1", false, |
75 | | false) |
76 | | |
77 | 738 | static bool runNVVMReflect(Function &F, unsigned SmVersion) { |
78 | 738 | if (!NVVMReflectEnabled) |
79 | 0 | return false; |
80 | | |
81 | 738 | if (F.getName() == NVVM_REFLECT_FUNCTION || |
82 | 738 | F.getName() == NVVM_REFLECT_OCL_FUNCTION) { |
83 | 0 | assert(F.isDeclaration() && "_reflect function should not have a body"); |
84 | 0 | assert(F.getReturnType()->isIntegerTy() && |
85 | 0 | "_reflect's return type should be integer"); |
86 | 0 | return false; |
87 | 0 | } |
88 | | |
89 | 738 | SmallVector<Instruction *, 4> ToRemove; |
90 | | |
91 | | // Go through the calls in this function. Each call to __nvvm_reflect or |
92 | | // llvm.nvvm.reflect should be a CallInst with a ConstantArray argument. |
93 | | // First validate that. If the c-string corresponding to the ConstantArray can |
94 | | // be found successfully, see if it can be found in VarMap. If so, replace the |
95 | | // uses of CallInst with the value found in VarMap. If not, replace the use |
96 | | // with value 0. |
97 | | |
98 | | // The IR for __nvvm_reflect calls differs between CUDA versions. |
99 | | // |
100 | | // CUDA 6.5 and earlier uses this sequence: |
101 | | // %ptr = tail call i8* @llvm.nvvm.ptr.constant.to.gen.p0i8.p4i8 |
102 | | // (i8 addrspace(4)* getelementptr inbounds |
103 | | // ([8 x i8], [8 x i8] addrspace(4)* @str, i32 0, i32 0)) |
104 | | // %reflect = tail call i32 @__nvvm_reflect(i8* %ptr) |
105 | | // |
106 | | // The value returned by Sym->getOperand(0) is a Constant with a |
107 | | // ConstantDataSequential operand which can be converted to string and used |
108 | | // for lookup. |
109 | | // |
110 | | // CUDA 7.0 does it slightly differently: |
111 | | // %reflect = call i32 @__nvvm_reflect(i8* addrspacecast |
112 | | // (i8 addrspace(1)* getelementptr inbounds |
113 | | // ([8 x i8], [8 x i8] addrspace(1)* @str, i32 0, i32 0) to i8*)) |
114 | | // |
115 | | // In this case, we get a Constant with a GlobalVariable operand and we need |
116 | | // to dig deeper to find its initializer with the string we'll use for lookup. |
117 | 17.9k | for (Instruction &I : instructions(F)) { |
118 | 17.9k | CallInst *Call = dyn_cast<CallInst>(&I); |
119 | 17.9k | if (!Call) |
120 | 17.9k | continue; |
121 | 0 | Function *Callee = Call->getCalledFunction(); |
122 | 0 | if (!Callee || (Callee->getName() != NVVM_REFLECT_FUNCTION && |
123 | 0 | Callee->getName() != NVVM_REFLECT_OCL_FUNCTION && |
124 | 0 | Callee->getIntrinsicID() != Intrinsic::nvvm_reflect)) |
125 | 0 | continue; |
126 | | |
127 | | // FIXME: Improve error handling here and elsewhere in this pass. |
128 | 0 | assert(Call->getNumOperands() == 2 && |
129 | 0 | "Wrong number of operands to __nvvm_reflect function"); |
130 | | |
131 | | // In cuda 6.5 and earlier, we will have an extra constant-to-generic |
132 | | // conversion of the string. |
133 | 0 | const Value *Str = Call->getArgOperand(0); |
134 | 0 | if (const CallInst *ConvCall = dyn_cast<CallInst>(Str)) { |
135 | | // FIXME: Add assertions about ConvCall. |
136 | 0 | Str = ConvCall->getArgOperand(0); |
137 | 0 | } |
138 | | // Pre opaque pointers we have a constant expression wrapping the constant |
139 | | // string. |
140 | 0 | Str = Str->stripPointerCasts(); |
141 | 0 | assert(isa<Constant>(Str) && |
142 | 0 | "Format of __nvvm_reflect function not recognized"); |
143 | | |
144 | 0 | const Value *Operand = cast<Constant>(Str)->getOperand(0); |
145 | 0 | if (const GlobalVariable *GV = dyn_cast<GlobalVariable>(Operand)) { |
146 | | // For CUDA-7.0 style __nvvm_reflect calls, we need to find the operand's |
147 | | // initializer. |
148 | 0 | assert(GV->hasInitializer() && |
149 | 0 | "Format of _reflect function not recognized"); |
150 | 0 | const Constant *Initializer = GV->getInitializer(); |
151 | 0 | Operand = Initializer; |
152 | 0 | } |
153 | | |
154 | 0 | assert(isa<ConstantDataSequential>(Operand) && |
155 | 0 | "Format of _reflect function not recognized"); |
156 | 0 | assert(cast<ConstantDataSequential>(Operand)->isCString() && |
157 | 0 | "Format of _reflect function not recognized"); |
158 | | |
159 | 0 | StringRef ReflectArg = cast<ConstantDataSequential>(Operand)->getAsString(); |
160 | 0 | ReflectArg = ReflectArg.substr(0, ReflectArg.size() - 1); |
161 | 0 | LLVM_DEBUG(dbgs() << "Arg of _reflect : " << ReflectArg << "\n"); |
162 | |
|
163 | 0 | int ReflectVal = 0; // The default value is 0 |
164 | 0 | if (ReflectArg == "__CUDA_FTZ") { |
165 | | // Try to pull __CUDA_FTZ from the nvvm-reflect-ftz module flag. Our |
166 | | // choice here must be kept in sync with AutoUpgrade, which uses the same |
167 | | // technique to detect whether ftz is enabled. |
168 | 0 | if (auto *Flag = mdconst::extract_or_null<ConstantInt>( |
169 | 0 | F.getParent()->getModuleFlag("nvvm-reflect-ftz"))) |
170 | 0 | ReflectVal = Flag->getSExtValue(); |
171 | 0 | } else if (ReflectArg == "__CUDA_ARCH") { |
172 | 0 | ReflectVal = SmVersion * 10; |
173 | 0 | } |
174 | 0 | Call->replaceAllUsesWith(ConstantInt::get(Call->getType(), ReflectVal)); |
175 | 0 | ToRemove.push_back(Call); |
176 | 0 | } |
177 | | |
178 | 738 | for (Instruction *I : ToRemove) |
179 | 0 | I->eraseFromParent(); |
180 | | |
181 | 738 | return ToRemove.size() > 0; |
182 | 738 | } |
183 | | |
184 | 738 | bool NVVMReflect::runOnFunction(Function &F) { |
185 | 738 | return runNVVMReflect(F, SmVersion); |
186 | 738 | } |
187 | | |
188 | 0 | NVVMReflectPass::NVVMReflectPass() : NVVMReflectPass(0) {} |
189 | | |
190 | | PreservedAnalyses NVVMReflectPass::run(Function &F, |
191 | 0 | FunctionAnalysisManager &AM) { |
192 | 0 | return runNVVMReflect(F, SmVersion) ? PreservedAnalyses::none() |
193 | 0 | : PreservedAnalyses::all(); |
194 | 0 | } |