/src/llvm-project/llvm/lib/CodeGen/ExpandMemCmp.cpp
Line | Count | Source (jump to first uncovered line) |
1 | | //===--- ExpandMemCmp.cpp - Expand memcmp() to load/stores ----------------===// |
2 | | // |
3 | | // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. |
4 | | // See https://llvm.org/LICENSE.txt for license information. |
5 | | // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception |
6 | | // |
7 | | //===----------------------------------------------------------------------===// |
8 | | // |
9 | | // This pass tries to expand memcmp() calls into optimally-sized loads and |
10 | | // compares for the target. |
11 | | // |
12 | | //===----------------------------------------------------------------------===// |
13 | | |
14 | | #include "llvm/CodeGen/ExpandMemCmp.h" |
15 | | #include "llvm/ADT/Statistic.h" |
16 | | #include "llvm/Analysis/ConstantFolding.h" |
17 | | #include "llvm/Analysis/DomTreeUpdater.h" |
18 | | #include "llvm/Analysis/LazyBlockFrequencyInfo.h" |
19 | | #include "llvm/Analysis/ProfileSummaryInfo.h" |
20 | | #include "llvm/Analysis/TargetLibraryInfo.h" |
21 | | #include "llvm/Analysis/TargetTransformInfo.h" |
22 | | #include "llvm/Analysis/ValueTracking.h" |
23 | | #include "llvm/CodeGen/TargetPassConfig.h" |
24 | | #include "llvm/CodeGen/TargetSubtargetInfo.h" |
25 | | #include "llvm/IR/Dominators.h" |
26 | | #include "llvm/IR/IRBuilder.h" |
27 | | #include "llvm/IR/PatternMatch.h" |
28 | | #include "llvm/InitializePasses.h" |
29 | | #include "llvm/Target/TargetMachine.h" |
30 | | #include "llvm/Transforms/Utils/BasicBlockUtils.h" |
31 | | #include "llvm/Transforms/Utils/Local.h" |
32 | | #include "llvm/Transforms/Utils/SizeOpts.h" |
33 | | #include <optional> |
34 | | |
35 | | using namespace llvm; |
36 | | using namespace llvm::PatternMatch; |
37 | | |
38 | | namespace llvm { |
39 | | class TargetLowering; |
40 | | } |
41 | | |
42 | | #define DEBUG_TYPE "expand-memcmp" |
43 | | |
44 | | STATISTIC(NumMemCmpCalls, "Number of memcmp calls"); |
45 | | STATISTIC(NumMemCmpNotConstant, "Number of memcmp calls without constant size"); |
46 | | STATISTIC(NumMemCmpGreaterThanMax, |
47 | | "Number of memcmp calls with size greater than max size"); |
48 | | STATISTIC(NumMemCmpInlined, "Number of inlined memcmp calls"); |
49 | | |
50 | | static cl::opt<unsigned> MemCmpEqZeroNumLoadsPerBlock( |
51 | | "memcmp-num-loads-per-block", cl::Hidden, cl::init(1), |
52 | | cl::desc("The number of loads per basic block for inline expansion of " |
53 | | "memcmp that is only being compared against zero.")); |
54 | | |
55 | | static cl::opt<unsigned> MaxLoadsPerMemcmp( |
56 | | "max-loads-per-memcmp", cl::Hidden, |
57 | | cl::desc("Set maximum number of loads used in expanded memcmp")); |
58 | | |
59 | | static cl::opt<unsigned> MaxLoadsPerMemcmpOptSize( |
60 | | "max-loads-per-memcmp-opt-size", cl::Hidden, |
61 | | cl::desc("Set maximum number of loads used in expanded memcmp for -Os/Oz")); |
62 | | |
63 | | namespace { |
64 | | |
65 | | |
66 | | // This class provides helper functions to expand a memcmp library call into an |
67 | | // inline expansion. |
68 | | class MemCmpExpansion { |
69 | | struct ResultBlock { |
70 | | BasicBlock *BB = nullptr; |
71 | | PHINode *PhiSrc1 = nullptr; |
72 | | PHINode *PhiSrc2 = nullptr; |
73 | | |
74 | 187 | ResultBlock() = default; |
75 | | }; |
76 | | |
77 | | CallInst *const CI = nullptr; |
78 | | ResultBlock ResBlock; |
79 | | const uint64_t Size; |
80 | | unsigned MaxLoadSize = 0; |
81 | | uint64_t NumLoadsNonOneByte = 0; |
82 | | const uint64_t NumLoadsPerBlockForZeroCmp; |
83 | | std::vector<BasicBlock *> LoadCmpBlocks; |
84 | | BasicBlock *EndBlock = nullptr; |
85 | | PHINode *PhiRes = nullptr; |
86 | | const bool IsUsedForZeroCmp; |
87 | | const DataLayout &DL; |
88 | | DomTreeUpdater *DTU = nullptr; |
89 | | IRBuilder<> Builder; |
90 | | // Represents the decomposition in blocks of the expansion. For example, |
91 | | // comparing 33 bytes on X86+sse can be done with 2x16-byte loads and |
92 | | // 1x1-byte load, which would be represented as [{16, 0}, {16, 16}, {1, 32}. |
93 | | struct LoadEntry { |
94 | | LoadEntry(unsigned LoadSize, uint64_t Offset) |
95 | 228 | : LoadSize(LoadSize), Offset(Offset) { |
96 | 228 | } |
97 | | |
98 | | // The size of the load for this block, in bytes. |
99 | | unsigned LoadSize; |
100 | | // The offset of this load from the base pointer, in bytes. |
101 | | uint64_t Offset; |
102 | | }; |
103 | | using LoadEntryVector = SmallVector<LoadEntry, 8>; |
104 | | LoadEntryVector LoadSequence; |
105 | | |
106 | | void createLoadCmpBlocks(); |
107 | | void createResultBlock(); |
108 | | void setupResultBlockPHINodes(); |
109 | | void setupEndBlockPHINodes(); |
110 | | Value *getCompareLoadPairs(unsigned BlockIndex, unsigned &LoadIndex); |
111 | | void emitLoadCompareBlock(unsigned BlockIndex); |
112 | | void emitLoadCompareBlockMultipleLoads(unsigned BlockIndex, |
113 | | unsigned &LoadIndex); |
114 | | void emitLoadCompareByteBlock(unsigned BlockIndex, unsigned OffsetBytes); |
115 | | void emitMemCmpResultBlock(); |
116 | | Value *getMemCmpExpansionZeroCase(); |
117 | | Value *getMemCmpEqZeroOneBlock(); |
118 | | Value *getMemCmpOneBlock(); |
119 | | struct LoadPair { |
120 | | Value *Lhs = nullptr; |
121 | | Value *Rhs = nullptr; |
122 | | }; |
123 | | LoadPair getLoadPair(Type *LoadSizeType, Type *BSwapSizeType, |
124 | | Type *CmpSizeType, unsigned OffsetBytes); |
125 | | |
126 | | static LoadEntryVector |
127 | | computeGreedyLoadSequence(uint64_t Size, llvm::ArrayRef<unsigned> LoadSizes, |
128 | | unsigned MaxNumLoads, unsigned &NumLoadsNonOneByte); |
129 | | static LoadEntryVector |
130 | | computeOverlappingLoadSequence(uint64_t Size, unsigned MaxLoadSize, |
131 | | unsigned MaxNumLoads, |
132 | | unsigned &NumLoadsNonOneByte); |
133 | | |
134 | | static void optimiseLoadSequence( |
135 | | LoadEntryVector &LoadSequence, |
136 | | const TargetTransformInfo::MemCmpExpansionOptions &Options, |
137 | | bool IsUsedForZeroCmp); |
138 | | |
139 | | public: |
140 | | MemCmpExpansion(CallInst *CI, uint64_t Size, |
141 | | const TargetTransformInfo::MemCmpExpansionOptions &Options, |
142 | | const bool IsUsedForZeroCmp, const DataLayout &TheDataLayout, |
143 | | DomTreeUpdater *DTU); |
144 | | |
145 | | unsigned getNumBlocks(); |
146 | 662 | uint64_t getNumLoads() const { return LoadSequence.size(); } |
147 | | |
148 | | Value *getMemCmpExpansion(); |
149 | | }; |
150 | | |
151 | | MemCmpExpansion::LoadEntryVector MemCmpExpansion::computeGreedyLoadSequence( |
152 | | uint64_t Size, llvm::ArrayRef<unsigned> LoadSizes, |
153 | 187 | const unsigned MaxNumLoads, unsigned &NumLoadsNonOneByte) { |
154 | 187 | NumLoadsNonOneByte = 0; |
155 | 187 | LoadEntryVector LoadSequence; |
156 | 187 | uint64_t Offset = 0; |
157 | 394 | while (Size && !LoadSizes.empty()) { |
158 | 207 | const unsigned LoadSize = LoadSizes.front(); |
159 | 207 | const uint64_t NumLoadsForThisSize = Size / LoadSize; |
160 | 207 | if (LoadSequence.size() + NumLoadsForThisSize > MaxNumLoads) { |
161 | | // Do not expand if the total number of loads is larger than what the |
162 | | // target allows. Note that it's important that we exit before completing |
163 | | // the expansion to avoid using a ton of memory to store the expansion for |
164 | | // large sizes. |
165 | 0 | return {}; |
166 | 0 | } |
167 | 207 | if (NumLoadsForThisSize > 0) { |
168 | 415 | for (uint64_t I = 0; I < NumLoadsForThisSize; ++I) { |
169 | 208 | LoadSequence.push_back({LoadSize, Offset}); |
170 | 208 | Offset += LoadSize; |
171 | 208 | } |
172 | 207 | if (LoadSize > 1) |
173 | 187 | ++NumLoadsNonOneByte; |
174 | 207 | Size = Size % LoadSize; |
175 | 207 | } |
176 | 207 | LoadSizes = LoadSizes.drop_front(); |
177 | 207 | } |
178 | 187 | return LoadSequence; |
179 | 187 | } |
180 | | |
181 | | MemCmpExpansion::LoadEntryVector |
182 | | MemCmpExpansion::computeOverlappingLoadSequence(uint64_t Size, |
183 | | const unsigned MaxLoadSize, |
184 | | const unsigned MaxNumLoads, |
185 | 0 | unsigned &NumLoadsNonOneByte) { |
186 | | // These are already handled by the greedy approach. |
187 | 0 | if (Size < 2 || MaxLoadSize < 2) |
188 | 0 | return {}; |
189 | | |
190 | | // We try to do as many non-overlapping loads as possible starting from the |
191 | | // beginning. |
192 | 0 | const uint64_t NumNonOverlappingLoads = Size / MaxLoadSize; |
193 | 0 | assert(NumNonOverlappingLoads && "there must be at least one load"); |
194 | | // There remain 0 to (MaxLoadSize - 1) bytes to load, this will be done with |
195 | | // an overlapping load. |
196 | 0 | Size = Size - NumNonOverlappingLoads * MaxLoadSize; |
197 | | // Bail if we do not need an overloapping store, this is already handled by |
198 | | // the greedy approach. |
199 | 0 | if (Size == 0) |
200 | 0 | return {}; |
201 | | // Bail if the number of loads (non-overlapping + potential overlapping one) |
202 | | // is larger than the max allowed. |
203 | 0 | if ((NumNonOverlappingLoads + 1) > MaxNumLoads) |
204 | 0 | return {}; |
205 | | |
206 | | // Add non-overlapping loads. |
207 | 0 | LoadEntryVector LoadSequence; |
208 | 0 | uint64_t Offset = 0; |
209 | 0 | for (uint64_t I = 0; I < NumNonOverlappingLoads; ++I) { |
210 | 0 | LoadSequence.push_back({MaxLoadSize, Offset}); |
211 | 0 | Offset += MaxLoadSize; |
212 | 0 | } |
213 | | |
214 | | // Add the last overlapping load. |
215 | 0 | assert(Size > 0 && Size < MaxLoadSize && "broken invariant"); |
216 | 0 | LoadSequence.push_back({MaxLoadSize, Offset - (MaxLoadSize - Size)}); |
217 | 0 | NumLoadsNonOneByte = 1; |
218 | 0 | return LoadSequence; |
219 | 0 | } |
220 | | |
221 | | void MemCmpExpansion::optimiseLoadSequence( |
222 | | LoadEntryVector &LoadSequence, |
223 | | const TargetTransformInfo::MemCmpExpansionOptions &Options, |
224 | 187 | bool IsUsedForZeroCmp) { |
225 | | // This part of code attempts to optimize the LoadSequence by merging allowed |
226 | | // subsequences into single loads of allowed sizes from |
227 | | // `MemCmpExpansionOptions::AllowedTailExpansions`. If it is for zero |
228 | | // comparison or if no allowed tail expansions are specified, we exit early. |
229 | 187 | if (IsUsedForZeroCmp || Options.AllowedTailExpansions.empty()) |
230 | 20 | return; |
231 | | |
232 | 187 | while (LoadSequence.size() >= 2) { |
233 | 20 | auto Last = LoadSequence[LoadSequence.size() - 1]; |
234 | 20 | auto PreLast = LoadSequence[LoadSequence.size() - 2]; |
235 | | |
236 | | // Exit the loop if the two sequences are not contiguous |
237 | 20 | if (PreLast.Offset + PreLast.LoadSize != Last.Offset) |
238 | 0 | break; |
239 | | |
240 | 20 | auto LoadSize = Last.LoadSize + PreLast.LoadSize; |
241 | 20 | if (find(Options.AllowedTailExpansions, LoadSize) == |
242 | 20 | Options.AllowedTailExpansions.end()) |
243 | 0 | break; |
244 | | |
245 | | // Remove the last two sequences and replace with the combined sequence |
246 | 20 | LoadSequence.pop_back(); |
247 | 20 | LoadSequence.pop_back(); |
248 | 20 | LoadSequence.emplace_back(PreLast.Offset, LoadSize); |
249 | 20 | } |
250 | 167 | } |
251 | | |
252 | | // Initialize the basic block structure required for expansion of memcmp call |
253 | | // with given maximum load size and memcmp size parameter. |
254 | | // This structure includes: |
255 | | // 1. A list of load compare blocks - LoadCmpBlocks. |
256 | | // 2. An EndBlock, split from original instruction point, which is the block to |
257 | | // return from. |
258 | | // 3. ResultBlock, block to branch to for early exit when a |
259 | | // LoadCmpBlock finds a difference. |
260 | | MemCmpExpansion::MemCmpExpansion( |
261 | | CallInst *const CI, uint64_t Size, |
262 | | const TargetTransformInfo::MemCmpExpansionOptions &Options, |
263 | | const bool IsUsedForZeroCmp, const DataLayout &TheDataLayout, |
264 | | DomTreeUpdater *DTU) |
265 | | : CI(CI), Size(Size), NumLoadsPerBlockForZeroCmp(Options.NumLoadsPerBlock), |
266 | | IsUsedForZeroCmp(IsUsedForZeroCmp), DL(TheDataLayout), DTU(DTU), |
267 | 187 | Builder(CI) { |
268 | 187 | assert(Size > 0 && "zero blocks"); |
269 | | // Scale the max size down if the target can load more bytes than we need. |
270 | 0 | llvm::ArrayRef<unsigned> LoadSizes(Options.LoadSizes); |
271 | 393 | while (!LoadSizes.empty() && LoadSizes.front() > Size) { |
272 | 206 | LoadSizes = LoadSizes.drop_front(); |
273 | 206 | } |
274 | 187 | assert(!LoadSizes.empty() && "cannot load Size bytes"); |
275 | 0 | MaxLoadSize = LoadSizes.front(); |
276 | | // Compute the decomposition. |
277 | 187 | unsigned GreedyNumLoadsNonOneByte = 0; |
278 | 187 | LoadSequence = computeGreedyLoadSequence(Size, LoadSizes, Options.MaxNumLoads, |
279 | 187 | GreedyNumLoadsNonOneByte); |
280 | 187 | NumLoadsNonOneByte = GreedyNumLoadsNonOneByte; |
281 | 187 | assert(LoadSequence.size() <= Options.MaxNumLoads && "broken invariant"); |
282 | | // If we allow overlapping loads and the load sequence is not already optimal, |
283 | | // use overlapping loads. |
284 | 187 | if (Options.AllowOverlappingLoads && |
285 | 187 | (LoadSequence.empty() || LoadSequence.size() > 2)) { |
286 | 0 | unsigned OverlappingNumLoadsNonOneByte = 0; |
287 | 0 | auto OverlappingLoads = computeOverlappingLoadSequence( |
288 | 0 | Size, MaxLoadSize, Options.MaxNumLoads, OverlappingNumLoadsNonOneByte); |
289 | 0 | if (!OverlappingLoads.empty() && |
290 | 0 | (LoadSequence.empty() || |
291 | 0 | OverlappingLoads.size() < LoadSequence.size())) { |
292 | 0 | LoadSequence = OverlappingLoads; |
293 | 0 | NumLoadsNonOneByte = OverlappingNumLoadsNonOneByte; |
294 | 0 | } |
295 | 0 | } |
296 | 187 | assert(LoadSequence.size() <= Options.MaxNumLoads && "broken invariant"); |
297 | 0 | optimiseLoadSequence(LoadSequence, Options, IsUsedForZeroCmp); |
298 | 187 | } |
299 | | |
300 | 380 | unsigned MemCmpExpansion::getNumBlocks() { |
301 | 380 | if (IsUsedForZeroCmp) |
302 | 38 | return getNumLoads() / NumLoadsPerBlockForZeroCmp + |
303 | 38 | (getNumLoads() % NumLoadsPerBlockForZeroCmp != 0 ? 1 : 0); |
304 | 342 | return getNumLoads(); |
305 | 380 | } |
306 | | |
307 | 1 | void MemCmpExpansion::createLoadCmpBlocks() { |
308 | 3 | for (unsigned i = 0; i < getNumBlocks(); i++) { |
309 | 2 | BasicBlock *BB = BasicBlock::Create(CI->getContext(), "loadbb", |
310 | 2 | EndBlock->getParent(), EndBlock); |
311 | 2 | LoadCmpBlocks.push_back(BB); |
312 | 2 | } |
313 | 1 | } |
314 | | |
315 | 1 | void MemCmpExpansion::createResultBlock() { |
316 | 1 | ResBlock.BB = BasicBlock::Create(CI->getContext(), "res_block", |
317 | 1 | EndBlock->getParent(), EndBlock); |
318 | 1 | } |
319 | | |
320 | | MemCmpExpansion::LoadPair MemCmpExpansion::getLoadPair(Type *LoadSizeType, |
321 | | Type *BSwapSizeType, |
322 | | Type *CmpSizeType, |
323 | 188 | unsigned OffsetBytes) { |
324 | | // Get the memory source at offset `OffsetBytes`. |
325 | 188 | Value *LhsSource = CI->getArgOperand(0); |
326 | 188 | Value *RhsSource = CI->getArgOperand(1); |
327 | 188 | Align LhsAlign = LhsSource->getPointerAlignment(DL); |
328 | 188 | Align RhsAlign = RhsSource->getPointerAlignment(DL); |
329 | 188 | if (OffsetBytes > 0) { |
330 | 1 | auto *ByteType = Type::getInt8Ty(CI->getContext()); |
331 | 1 | LhsSource = Builder.CreateConstGEP1_64(ByteType, LhsSource, OffsetBytes); |
332 | 1 | RhsSource = Builder.CreateConstGEP1_64(ByteType, RhsSource, OffsetBytes); |
333 | 1 | LhsAlign = commonAlignment(LhsAlign, OffsetBytes); |
334 | 1 | RhsAlign = commonAlignment(RhsAlign, OffsetBytes); |
335 | 1 | } |
336 | | |
337 | | // Create a constant or a load from the source. |
338 | 188 | Value *Lhs = nullptr; |
339 | 188 | if (auto *C = dyn_cast<Constant>(LhsSource)) |
340 | 77 | Lhs = ConstantFoldLoadFromConstPtr(C, LoadSizeType, DL); |
341 | 188 | if (!Lhs) |
342 | 111 | Lhs = Builder.CreateAlignedLoad(LoadSizeType, LhsSource, LhsAlign); |
343 | | |
344 | 188 | Value *Rhs = nullptr; |
345 | 188 | if (auto *C = dyn_cast<Constant>(RhsSource)) |
346 | 23 | Rhs = ConstantFoldLoadFromConstPtr(C, LoadSizeType, DL); |
347 | 188 | if (!Rhs) |
348 | 165 | Rhs = Builder.CreateAlignedLoad(LoadSizeType, RhsSource, RhsAlign); |
349 | | |
350 | | // Zero extend if Byte Swap intrinsic has different type |
351 | 188 | if (BSwapSizeType && LoadSizeType != BSwapSizeType) { |
352 | 20 | Lhs = Builder.CreateZExt(Lhs, BSwapSizeType); |
353 | 20 | Rhs = Builder.CreateZExt(Rhs, BSwapSizeType); |
354 | 20 | } |
355 | | |
356 | | // Swap bytes if required. |
357 | 188 | if (BSwapSizeType) { |
358 | 167 | Function *Bswap = Intrinsic::getDeclaration( |
359 | 167 | CI->getModule(), Intrinsic::bswap, BSwapSizeType); |
360 | 167 | Lhs = Builder.CreateCall(Bswap, Lhs); |
361 | 167 | Rhs = Builder.CreateCall(Bswap, Rhs); |
362 | 167 | } |
363 | | |
364 | | // Zero extend if required. |
365 | 188 | if (CmpSizeType != nullptr && CmpSizeType != Lhs->getType()) { |
366 | 0 | Lhs = Builder.CreateZExt(Lhs, CmpSizeType); |
367 | 0 | Rhs = Builder.CreateZExt(Rhs, CmpSizeType); |
368 | 0 | } |
369 | 188 | return {Lhs, Rhs}; |
370 | 188 | } |
371 | | |
372 | | // This function creates the IR instructions for loading and comparing 1 byte. |
373 | | // It loads 1 byte from each source of the memcmp parameters with the given |
374 | | // GEPIndex. It then subtracts the two loaded values and adds this result to the |
375 | | // final phi node for selecting the memcmp result. |
376 | | void MemCmpExpansion::emitLoadCompareByteBlock(unsigned BlockIndex, |
377 | 0 | unsigned OffsetBytes) { |
378 | 0 | BasicBlock *BB = LoadCmpBlocks[BlockIndex]; |
379 | 0 | Builder.SetInsertPoint(BB); |
380 | 0 | const LoadPair Loads = |
381 | 0 | getLoadPair(Type::getInt8Ty(CI->getContext()), nullptr, |
382 | 0 | Type::getInt32Ty(CI->getContext()), OffsetBytes); |
383 | 0 | Value *Diff = Builder.CreateSub(Loads.Lhs, Loads.Rhs); |
384 | |
|
385 | 0 | PhiRes->addIncoming(Diff, BB); |
386 | |
|
387 | 0 | if (BlockIndex < (LoadCmpBlocks.size() - 1)) { |
388 | | // Early exit branch if difference found to EndBlock. Otherwise, continue to |
389 | | // next LoadCmpBlock, |
390 | 0 | Value *Cmp = Builder.CreateICmp(ICmpInst::ICMP_NE, Diff, |
391 | 0 | ConstantInt::get(Diff->getType(), 0)); |
392 | 0 | BranchInst *CmpBr = |
393 | 0 | BranchInst::Create(EndBlock, LoadCmpBlocks[BlockIndex + 1], Cmp); |
394 | 0 | Builder.Insert(CmpBr); |
395 | 0 | if (DTU) |
396 | 0 | DTU->applyUpdates( |
397 | 0 | {{DominatorTree::Insert, BB, EndBlock}, |
398 | 0 | {DominatorTree::Insert, BB, LoadCmpBlocks[BlockIndex + 1]}}); |
399 | 0 | } else { |
400 | | // The last block has an unconditional branch to EndBlock. |
401 | 0 | BranchInst *CmpBr = BranchInst::Create(EndBlock); |
402 | 0 | Builder.Insert(CmpBr); |
403 | 0 | if (DTU) |
404 | 0 | DTU->applyUpdates({{DominatorTree::Insert, BB, EndBlock}}); |
405 | 0 | } |
406 | 0 | } |
407 | | |
408 | | /// Generate an equality comparison for one or more pairs of loaded values. |
409 | | /// This is used in the case where the memcmp() call is compared equal or not |
410 | | /// equal to zero. |
411 | | Value *MemCmpExpansion::getCompareLoadPairs(unsigned BlockIndex, |
412 | 19 | unsigned &LoadIndex) { |
413 | 19 | assert(LoadIndex < getNumLoads() && |
414 | 19 | "getCompareLoadPairs() called with no remaining loads"); |
415 | 0 | std::vector<Value *> XorList, OrList; |
416 | 19 | Value *Diff = nullptr; |
417 | | |
418 | 19 | const unsigned NumLoads = |
419 | 19 | std::min(getNumLoads() - LoadIndex, NumLoadsPerBlockForZeroCmp); |
420 | | |
421 | | // For a single-block expansion, start inserting before the memcmp call. |
422 | 19 | if (LoadCmpBlocks.empty()) |
423 | 19 | Builder.SetInsertPoint(CI); |
424 | 0 | else |
425 | 0 | Builder.SetInsertPoint(LoadCmpBlocks[BlockIndex]); |
426 | | |
427 | 19 | Value *Cmp = nullptr; |
428 | | // If we have multiple loads per block, we need to generate a composite |
429 | | // comparison using xor+or. The type for the combinations is the largest load |
430 | | // type. |
431 | 19 | IntegerType *const MaxLoadType = |
432 | 19 | NumLoads == 1 ? nullptr |
433 | 19 | : IntegerType::get(CI->getContext(), MaxLoadSize * 8); |
434 | | |
435 | 38 | for (unsigned i = 0; i < NumLoads; ++i, ++LoadIndex) { |
436 | 19 | const LoadEntry &CurLoadEntry = LoadSequence[LoadIndex]; |
437 | 19 | const LoadPair Loads = getLoadPair( |
438 | 19 | IntegerType::get(CI->getContext(), CurLoadEntry.LoadSize * 8), nullptr, |
439 | 19 | MaxLoadType, CurLoadEntry.Offset); |
440 | | |
441 | 19 | if (NumLoads != 1) { |
442 | | // If we have multiple loads per block, we need to generate a composite |
443 | | // comparison using xor+or. |
444 | 0 | Diff = Builder.CreateXor(Loads.Lhs, Loads.Rhs); |
445 | 0 | Diff = Builder.CreateZExt(Diff, MaxLoadType); |
446 | 0 | XorList.push_back(Diff); |
447 | 19 | } else { |
448 | | // If there's only one load per block, we just compare the loaded values. |
449 | 19 | Cmp = Builder.CreateICmpNE(Loads.Lhs, Loads.Rhs); |
450 | 19 | } |
451 | 19 | } |
452 | | |
453 | 19 | auto pairWiseOr = [&](std::vector<Value *> &InList) -> std::vector<Value *> { |
454 | 0 | std::vector<Value *> OutList; |
455 | 0 | for (unsigned i = 0; i < InList.size() - 1; i = i + 2) { |
456 | 0 | Value *Or = Builder.CreateOr(InList[i], InList[i + 1]); |
457 | 0 | OutList.push_back(Or); |
458 | 0 | } |
459 | 0 | if (InList.size() % 2 != 0) |
460 | 0 | OutList.push_back(InList.back()); |
461 | 0 | return OutList; |
462 | 0 | }; |
463 | | |
464 | 19 | if (!Cmp) { |
465 | | // Pairwise OR the XOR results. |
466 | 0 | OrList = pairWiseOr(XorList); |
467 | | |
468 | | // Pairwise OR the OR results until one result left. |
469 | 0 | while (OrList.size() != 1) { |
470 | 0 | OrList = pairWiseOr(OrList); |
471 | 0 | } |
472 | |
|
473 | 0 | assert(Diff && "Failed to find comparison diff"); |
474 | 0 | Cmp = Builder.CreateICmpNE(OrList[0], ConstantInt::get(Diff->getType(), 0)); |
475 | 0 | } |
476 | | |
477 | 0 | return Cmp; |
478 | 19 | } |
479 | | |
480 | | void MemCmpExpansion::emitLoadCompareBlockMultipleLoads(unsigned BlockIndex, |
481 | 0 | unsigned &LoadIndex) { |
482 | 0 | Value *Cmp = getCompareLoadPairs(BlockIndex, LoadIndex); |
483 | |
|
484 | 0 | BasicBlock *NextBB = (BlockIndex == (LoadCmpBlocks.size() - 1)) |
485 | 0 | ? EndBlock |
486 | 0 | : LoadCmpBlocks[BlockIndex + 1]; |
487 | | // Early exit branch if difference found to ResultBlock. Otherwise, |
488 | | // continue to next LoadCmpBlock or EndBlock. |
489 | 0 | BasicBlock *BB = Builder.GetInsertBlock(); |
490 | 0 | BranchInst *CmpBr = BranchInst::Create(ResBlock.BB, NextBB, Cmp); |
491 | 0 | Builder.Insert(CmpBr); |
492 | 0 | if (DTU) |
493 | 0 | DTU->applyUpdates({{DominatorTree::Insert, BB, ResBlock.BB}, |
494 | 0 | {DominatorTree::Insert, BB, NextBB}}); |
495 | | |
496 | | // Add a phi edge for the last LoadCmpBlock to Endblock with a value of 0 |
497 | | // since early exit to ResultBlock was not taken (no difference was found in |
498 | | // any of the bytes). |
499 | 0 | if (BlockIndex == LoadCmpBlocks.size() - 1) { |
500 | 0 | Value *Zero = ConstantInt::get(Type::getInt32Ty(CI->getContext()), 0); |
501 | 0 | PhiRes->addIncoming(Zero, LoadCmpBlocks[BlockIndex]); |
502 | 0 | } |
503 | 0 | } |
504 | | |
505 | | // This function creates the IR intructions for loading and comparing using the |
506 | | // given LoadSize. It loads the number of bytes specified by LoadSize from each |
507 | | // source of the memcmp parameters. It then does a subtract to see if there was |
508 | | // a difference in the loaded values. If a difference is found, it branches |
509 | | // with an early exit to the ResultBlock for calculating which source was |
510 | | // larger. Otherwise, it falls through to the either the next LoadCmpBlock or |
511 | | // the EndBlock if this is the last LoadCmpBlock. Loading 1 byte is handled with |
512 | | // a special case through emitLoadCompareByteBlock. The special handling can |
513 | | // simply subtract the loaded values and add it to the result phi node. |
514 | 2 | void MemCmpExpansion::emitLoadCompareBlock(unsigned BlockIndex) { |
515 | | // There is one load per block in this case, BlockIndex == LoadIndex. |
516 | 2 | const LoadEntry &CurLoadEntry = LoadSequence[BlockIndex]; |
517 | | |
518 | 2 | if (CurLoadEntry.LoadSize == 1) { |
519 | 0 | MemCmpExpansion::emitLoadCompareByteBlock(BlockIndex, CurLoadEntry.Offset); |
520 | 0 | return; |
521 | 0 | } |
522 | | |
523 | 2 | Type *LoadSizeType = |
524 | 2 | IntegerType::get(CI->getContext(), CurLoadEntry.LoadSize * 8); |
525 | 2 | Type *BSwapSizeType = |
526 | 2 | DL.isLittleEndian() |
527 | 2 | ? IntegerType::get(CI->getContext(), |
528 | 0 | PowerOf2Ceil(CurLoadEntry.LoadSize * 8)) |
529 | 2 | : nullptr; |
530 | 2 | Type *MaxLoadType = IntegerType::get( |
531 | 2 | CI->getContext(), |
532 | 2 | std::max(MaxLoadSize, (unsigned)PowerOf2Ceil(CurLoadEntry.LoadSize)) * 8); |
533 | 2 | assert(CurLoadEntry.LoadSize <= MaxLoadSize && "Unexpected load type"); |
534 | | |
535 | 0 | Builder.SetInsertPoint(LoadCmpBlocks[BlockIndex]); |
536 | | |
537 | 2 | const LoadPair Loads = getLoadPair(LoadSizeType, BSwapSizeType, MaxLoadType, |
538 | 2 | CurLoadEntry.Offset); |
539 | | |
540 | | // Add the loaded values to the phi nodes for calculating memcmp result only |
541 | | // if result is not used in a zero equality. |
542 | 2 | if (!IsUsedForZeroCmp) { |
543 | 2 | ResBlock.PhiSrc1->addIncoming(Loads.Lhs, LoadCmpBlocks[BlockIndex]); |
544 | 2 | ResBlock.PhiSrc2->addIncoming(Loads.Rhs, LoadCmpBlocks[BlockIndex]); |
545 | 2 | } |
546 | | |
547 | 2 | Value *Cmp = Builder.CreateICmp(ICmpInst::ICMP_EQ, Loads.Lhs, Loads.Rhs); |
548 | 2 | BasicBlock *NextBB = (BlockIndex == (LoadCmpBlocks.size() - 1)) |
549 | 2 | ? EndBlock |
550 | 2 | : LoadCmpBlocks[BlockIndex + 1]; |
551 | | // Early exit branch if difference found to ResultBlock. Otherwise, continue |
552 | | // to next LoadCmpBlock or EndBlock. |
553 | 2 | BasicBlock *BB = Builder.GetInsertBlock(); |
554 | 2 | BranchInst *CmpBr = BranchInst::Create(NextBB, ResBlock.BB, Cmp); |
555 | 2 | Builder.Insert(CmpBr); |
556 | 2 | if (DTU) |
557 | 2 | DTU->applyUpdates({{DominatorTree::Insert, BB, NextBB}, |
558 | 2 | {DominatorTree::Insert, BB, ResBlock.BB}}); |
559 | | |
560 | | // Add a phi edge for the last LoadCmpBlock to Endblock with a value of 0 |
561 | | // since early exit to ResultBlock was not taken (no difference was found in |
562 | | // any of the bytes). |
563 | 2 | if (BlockIndex == LoadCmpBlocks.size() - 1) { |
564 | 1 | Value *Zero = ConstantInt::get(Type::getInt32Ty(CI->getContext()), 0); |
565 | 1 | PhiRes->addIncoming(Zero, LoadCmpBlocks[BlockIndex]); |
566 | 1 | } |
567 | 2 | } |
568 | | |
569 | | // This function populates the ResultBlock with a sequence to calculate the |
570 | | // memcmp result. It compares the two loaded source values and returns -1 if |
571 | | // src1 < src2 and 1 if src1 > src2. |
572 | 1 | void MemCmpExpansion::emitMemCmpResultBlock() { |
573 | | // Special case: if memcmp result is used in a zero equality, result does not |
574 | | // need to be calculated and can simply return 1. |
575 | 1 | if (IsUsedForZeroCmp) { |
576 | 0 | BasicBlock::iterator InsertPt = ResBlock.BB->getFirstInsertionPt(); |
577 | 0 | Builder.SetInsertPoint(ResBlock.BB, InsertPt); |
578 | 0 | Value *Res = ConstantInt::get(Type::getInt32Ty(CI->getContext()), 1); |
579 | 0 | PhiRes->addIncoming(Res, ResBlock.BB); |
580 | 0 | BranchInst *NewBr = BranchInst::Create(EndBlock); |
581 | 0 | Builder.Insert(NewBr); |
582 | 0 | if (DTU) |
583 | 0 | DTU->applyUpdates({{DominatorTree::Insert, ResBlock.BB, EndBlock}}); |
584 | 0 | return; |
585 | 0 | } |
586 | 1 | BasicBlock::iterator InsertPt = ResBlock.BB->getFirstInsertionPt(); |
587 | 1 | Builder.SetInsertPoint(ResBlock.BB, InsertPt); |
588 | | |
589 | 1 | Value *Cmp = Builder.CreateICmp(ICmpInst::ICMP_ULT, ResBlock.PhiSrc1, |
590 | 1 | ResBlock.PhiSrc2); |
591 | | |
592 | 1 | Value *Res = |
593 | 1 | Builder.CreateSelect(Cmp, ConstantInt::get(Builder.getInt32Ty(), -1), |
594 | 1 | ConstantInt::get(Builder.getInt32Ty(), 1)); |
595 | | |
596 | 1 | PhiRes->addIncoming(Res, ResBlock.BB); |
597 | 1 | BranchInst *NewBr = BranchInst::Create(EndBlock); |
598 | 1 | Builder.Insert(NewBr); |
599 | 1 | if (DTU) |
600 | 1 | DTU->applyUpdates({{DominatorTree::Insert, ResBlock.BB, EndBlock}}); |
601 | 1 | } |
602 | | |
603 | 1 | void MemCmpExpansion::setupResultBlockPHINodes() { |
604 | 1 | Type *MaxLoadType = IntegerType::get(CI->getContext(), MaxLoadSize * 8); |
605 | 1 | Builder.SetInsertPoint(ResBlock.BB); |
606 | | // Note: this assumes one load per block. |
607 | 1 | ResBlock.PhiSrc1 = |
608 | 1 | Builder.CreatePHI(MaxLoadType, NumLoadsNonOneByte, "phi.src1"); |
609 | 1 | ResBlock.PhiSrc2 = |
610 | 1 | Builder.CreatePHI(MaxLoadType, NumLoadsNonOneByte, "phi.src2"); |
611 | 1 | } |
612 | | |
613 | 1 | void MemCmpExpansion::setupEndBlockPHINodes() { |
614 | 1 | Builder.SetInsertPoint(EndBlock, EndBlock->begin()); |
615 | 1 | PhiRes = Builder.CreatePHI(Type::getInt32Ty(CI->getContext()), 2, "phi.res"); |
616 | 1 | } |
617 | | |
618 | 0 | Value *MemCmpExpansion::getMemCmpExpansionZeroCase() { |
619 | 0 | unsigned LoadIndex = 0; |
620 | | // This loop populates each of the LoadCmpBlocks with the IR sequence to |
621 | | // handle multiple loads per block. |
622 | 0 | for (unsigned I = 0; I < getNumBlocks(); ++I) { |
623 | 0 | emitLoadCompareBlockMultipleLoads(I, LoadIndex); |
624 | 0 | } |
625 | |
|
626 | 0 | emitMemCmpResultBlock(); |
627 | 0 | return PhiRes; |
628 | 0 | } |
629 | | |
630 | | /// A memcmp expansion that compares equality with 0 and only has one block of |
631 | | /// load and compare can bypass the compare, branch, and phi IR that is required |
632 | | /// in the general case. |
633 | 19 | Value *MemCmpExpansion::getMemCmpEqZeroOneBlock() { |
634 | 19 | unsigned LoadIndex = 0; |
635 | 19 | Value *Cmp = getCompareLoadPairs(0, LoadIndex); |
636 | 19 | assert(LoadIndex == getNumLoads() && "some entries were not consumed"); |
637 | 0 | return Builder.CreateZExt(Cmp, Type::getInt32Ty(CI->getContext())); |
638 | 19 | } |
639 | | |
640 | | /// A memcmp expansion that only has one block of load and compare can bypass |
641 | | /// the compare, branch, and phi IR that is required in the general case. |
642 | | /// This function also analyses users of memcmp, and if there is only one user |
643 | | /// from which we can conclude that only 2 out of 3 memcmp outcomes really |
644 | | /// matter, then it generates more efficient code with only one comparison. |
645 | 167 | Value *MemCmpExpansion::getMemCmpOneBlock() { |
646 | 167 | bool NeedsBSwap = DL.isLittleEndian() && Size != 1; |
647 | 167 | Type *LoadSizeType = IntegerType::get(CI->getContext(), Size * 8); |
648 | 167 | Type *BSwapSizeType = |
649 | 167 | NeedsBSwap ? IntegerType::get(CI->getContext(), PowerOf2Ceil(Size * 8)) |
650 | 167 | : nullptr; |
651 | 167 | Type *MaxLoadType = |
652 | 167 | IntegerType::get(CI->getContext(), |
653 | 167 | std::max(MaxLoadSize, (unsigned)PowerOf2Ceil(Size)) * 8); |
654 | | |
655 | | // The i8 and i16 cases don't need compares. We zext the loaded values and |
656 | | // subtract them to get the suitable negative, zero, or positive i32 result. |
657 | 167 | if (Size == 1 || Size == 2) { |
658 | 0 | const LoadPair Loads = getLoadPair(LoadSizeType, BSwapSizeType, |
659 | 0 | Builder.getInt32Ty(), /*Offset*/ 0); |
660 | 0 | return Builder.CreateSub(Loads.Lhs, Loads.Rhs); |
661 | 0 | } |
662 | | |
663 | 167 | const LoadPair Loads = getLoadPair(LoadSizeType, BSwapSizeType, MaxLoadType, |
664 | 167 | /*Offset*/ 0); |
665 | | |
666 | | // If a user of memcmp cares only about two outcomes, for example: |
667 | | // bool result = memcmp(a, b, NBYTES) > 0; |
668 | | // We can generate more optimal code with a smaller number of operations |
669 | 167 | if (CI->hasOneUser()) { |
670 | 14 | auto *UI = cast<Instruction>(*CI->user_begin()); |
671 | 14 | ICmpInst::Predicate Pred = ICmpInst::Predicate::BAD_ICMP_PREDICATE; |
672 | 14 | uint64_t Shift; |
673 | 14 | bool NeedsZExt = false; |
674 | | // This is a special case because instead of checking if the result is less |
675 | | // than zero: |
676 | | // bool result = memcmp(a, b, NBYTES) < 0; |
677 | | // Compiler is clever enough to generate the following code: |
678 | | // bool result = memcmp(a, b, NBYTES) >> 31; |
679 | 14 | if (match(UI, m_LShr(m_Value(), m_ConstantInt(Shift))) && |
680 | 14 | Shift == (CI->getType()->getIntegerBitWidth() - 1)) { |
681 | 0 | Pred = ICmpInst::ICMP_SLT; |
682 | 0 | NeedsZExt = true; |
683 | 14 | } else { |
684 | | // In case of a successful match this call will set `Pred` variable |
685 | 14 | match(UI, m_ICmp(Pred, m_Specific(CI), m_Zero())); |
686 | 14 | } |
687 | | // Generate new code and remove the original memcmp call and the user |
688 | 14 | if (ICmpInst::isSigned(Pred)) { |
689 | 1 | Value *Cmp = Builder.CreateICmp(CmpInst::getUnsignedPredicate(Pred), |
690 | 1 | Loads.Lhs, Loads.Rhs); |
691 | 1 | auto *Result = NeedsZExt ? Builder.CreateZExt(Cmp, UI->getType()) : Cmp; |
692 | 1 | UI->replaceAllUsesWith(Result); |
693 | 1 | UI->eraseFromParent(); |
694 | 1 | CI->eraseFromParent(); |
695 | 1 | return nullptr; |
696 | 1 | } |
697 | 14 | } |
698 | | |
699 | | // The result of memcmp is negative, zero, or positive, so produce that by |
700 | | // subtracting 2 extended compare bits: sub (ugt, ult). |
701 | | // If a target prefers to use selects to get -1/0/1, they should be able |
702 | | // to transform this later. The inverse transform (going from selects to math) |
703 | | // may not be possible in the DAG because the selects got converted into |
704 | | // branches before we got there. |
705 | 166 | Value *CmpUGT = Builder.CreateICmpUGT(Loads.Lhs, Loads.Rhs); |
706 | 166 | Value *CmpULT = Builder.CreateICmpULT(Loads.Lhs, Loads.Rhs); |
707 | 166 | Value *ZextUGT = Builder.CreateZExt(CmpUGT, Builder.getInt32Ty()); |
708 | 166 | Value *ZextULT = Builder.CreateZExt(CmpULT, Builder.getInt32Ty()); |
709 | 166 | return Builder.CreateSub(ZextUGT, ZextULT); |
710 | 167 | } |
711 | | |
712 | | // This function expands the memcmp call into an inline expansion and returns |
713 | | // the memcmp result. Returns nullptr if the memcmp is already replaced. |
714 | 187 | Value *MemCmpExpansion::getMemCmpExpansion() { |
715 | | // Create the basic block framework for a multi-block expansion. |
716 | 187 | if (getNumBlocks() != 1) { |
717 | 1 | BasicBlock *StartBlock = CI->getParent(); |
718 | 1 | EndBlock = SplitBlock(StartBlock, CI, DTU, /*LI=*/nullptr, |
719 | 1 | /*MSSAU=*/nullptr, "endblock"); |
720 | 1 | setupEndBlockPHINodes(); |
721 | 1 | createResultBlock(); |
722 | | |
723 | | // If return value of memcmp is not used in a zero equality, we need to |
724 | | // calculate which source was larger. The calculation requires the |
725 | | // two loaded source values of each load compare block. |
726 | | // These will be saved in the phi nodes created by setupResultBlockPHINodes. |
727 | 1 | if (!IsUsedForZeroCmp) setupResultBlockPHINodes(); |
728 | | |
729 | | // Create the number of required load compare basic blocks. |
730 | 1 | createLoadCmpBlocks(); |
731 | | |
732 | | // Update the terminator added by SplitBlock to branch to the first |
733 | | // LoadCmpBlock. |
734 | 1 | StartBlock->getTerminator()->setSuccessor(0, LoadCmpBlocks[0]); |
735 | 1 | if (DTU) |
736 | 1 | DTU->applyUpdates({{DominatorTree::Insert, StartBlock, LoadCmpBlocks[0]}, |
737 | 1 | {DominatorTree::Delete, StartBlock, EndBlock}}); |
738 | 1 | } |
739 | | |
740 | 187 | Builder.SetCurrentDebugLocation(CI->getDebugLoc()); |
741 | | |
742 | 187 | if (IsUsedForZeroCmp) |
743 | 19 | return getNumBlocks() == 1 ? getMemCmpEqZeroOneBlock() |
744 | 19 | : getMemCmpExpansionZeroCase(); |
745 | | |
746 | 168 | if (getNumBlocks() == 1) |
747 | 167 | return getMemCmpOneBlock(); |
748 | | |
749 | 3 | for (unsigned I = 0; I < getNumBlocks(); ++I) { |
750 | 2 | emitLoadCompareBlock(I); |
751 | 2 | } |
752 | | |
753 | 1 | emitMemCmpResultBlock(); |
754 | 1 | return PhiRes; |
755 | 168 | } |
756 | | |
757 | | // This function checks to see if an expansion of memcmp can be generated. |
758 | | // It checks for constant compare size that is less than the max inline size. |
759 | | // If an expansion cannot occur, returns false to leave as a library call. |
760 | | // Otherwise, the library call is replaced with a new IR instruction sequence. |
761 | | /// We want to transform: |
762 | | /// %call = call signext i32 @memcmp(i8* %0, i8* %1, i64 15) |
763 | | /// To: |
764 | | /// loadbb: |
765 | | /// %0 = bitcast i32* %buffer2 to i8* |
766 | | /// %1 = bitcast i32* %buffer1 to i8* |
767 | | /// %2 = bitcast i8* %1 to i64* |
768 | | /// %3 = bitcast i8* %0 to i64* |
769 | | /// %4 = load i64, i64* %2 |
770 | | /// %5 = load i64, i64* %3 |
771 | | /// %6 = call i64 @llvm.bswap.i64(i64 %4) |
772 | | /// %7 = call i64 @llvm.bswap.i64(i64 %5) |
773 | | /// %8 = sub i64 %6, %7 |
774 | | /// %9 = icmp ne i64 %8, 0 |
775 | | /// br i1 %9, label %res_block, label %loadbb1 |
776 | | /// res_block: ; preds = %loadbb2, |
777 | | /// %loadbb1, %loadbb |
778 | | /// %phi.src1 = phi i64 [ %6, %loadbb ], [ %22, %loadbb1 ], [ %36, %loadbb2 ] |
779 | | /// %phi.src2 = phi i64 [ %7, %loadbb ], [ %23, %loadbb1 ], [ %37, %loadbb2 ] |
780 | | /// %10 = icmp ult i64 %phi.src1, %phi.src2 |
781 | | /// %11 = select i1 %10, i32 -1, i32 1 |
782 | | /// br label %endblock |
783 | | /// loadbb1: ; preds = %loadbb |
784 | | /// %12 = bitcast i32* %buffer2 to i8* |
785 | | /// %13 = bitcast i32* %buffer1 to i8* |
786 | | /// %14 = bitcast i8* %13 to i32* |
787 | | /// %15 = bitcast i8* %12 to i32* |
788 | | /// %16 = getelementptr i32, i32* %14, i32 2 |
789 | | /// %17 = getelementptr i32, i32* %15, i32 2 |
790 | | /// %18 = load i32, i32* %16 |
791 | | /// %19 = load i32, i32* %17 |
792 | | /// %20 = call i32 @llvm.bswap.i32(i32 %18) |
793 | | /// %21 = call i32 @llvm.bswap.i32(i32 %19) |
794 | | /// %22 = zext i32 %20 to i64 |
795 | | /// %23 = zext i32 %21 to i64 |
796 | | /// %24 = sub i64 %22, %23 |
797 | | /// %25 = icmp ne i64 %24, 0 |
798 | | /// br i1 %25, label %res_block, label %loadbb2 |
799 | | /// loadbb2: ; preds = %loadbb1 |
800 | | /// %26 = bitcast i32* %buffer2 to i8* |
801 | | /// %27 = bitcast i32* %buffer1 to i8* |
802 | | /// %28 = bitcast i8* %27 to i16* |
803 | | /// %29 = bitcast i8* %26 to i16* |
804 | | /// %30 = getelementptr i16, i16* %28, i16 6 |
805 | | /// %31 = getelementptr i16, i16* %29, i16 6 |
806 | | /// %32 = load i16, i16* %30 |
807 | | /// %33 = load i16, i16* %31 |
808 | | /// %34 = call i16 @llvm.bswap.i16(i16 %32) |
809 | | /// %35 = call i16 @llvm.bswap.i16(i16 %33) |
810 | | /// %36 = zext i16 %34 to i64 |
811 | | /// %37 = zext i16 %35 to i64 |
812 | | /// %38 = sub i64 %36, %37 |
813 | | /// %39 = icmp ne i64 %38, 0 |
814 | | /// br i1 %39, label %res_block, label %loadbb3 |
815 | | /// loadbb3: ; preds = %loadbb2 |
816 | | /// %40 = bitcast i32* %buffer2 to i8* |
817 | | /// %41 = bitcast i32* %buffer1 to i8* |
818 | | /// %42 = getelementptr i8, i8* %41, i8 14 |
819 | | /// %43 = getelementptr i8, i8* %40, i8 14 |
820 | | /// %44 = load i8, i8* %42 |
821 | | /// %45 = load i8, i8* %43 |
822 | | /// %46 = zext i8 %44 to i32 |
823 | | /// %47 = zext i8 %45 to i32 |
824 | | /// %48 = sub i32 %46, %47 |
825 | | /// br label %endblock |
826 | | /// endblock: ; preds = %res_block, |
827 | | /// %loadbb3 |
828 | | /// %phi.res = phi i32 [ %48, %loadbb3 ], [ %11, %res_block ] |
829 | | /// ret i32 %phi.res |
830 | | static bool expandMemCmp(CallInst *CI, const TargetTransformInfo *TTI, |
831 | | const TargetLowering *TLI, const DataLayout *DL, |
832 | | ProfileSummaryInfo *PSI, BlockFrequencyInfo *BFI, |
833 | 223 | DomTreeUpdater *DTU, const bool IsBCmp) { |
834 | 223 | NumMemCmpCalls++; |
835 | | |
836 | | // Early exit from expansion if -Oz. |
837 | 223 | if (CI->getFunction()->hasMinSize()) |
838 | 0 | return false; |
839 | | |
840 | | // Early exit from expansion if size is not a constant. |
841 | 223 | ConstantInt *SizeCast = dyn_cast<ConstantInt>(CI->getArgOperand(2)); |
842 | 223 | if (!SizeCast) { |
843 | 36 | NumMemCmpNotConstant++; |
844 | 36 | return false; |
845 | 36 | } |
846 | 187 | const uint64_t SizeVal = SizeCast->getZExtValue(); |
847 | | |
848 | 187 | if (SizeVal == 0) { |
849 | 0 | return false; |
850 | 0 | } |
851 | | // TTI call to check if target would like to expand memcmp. Also, get the |
852 | | // available load sizes. |
853 | 187 | const bool IsUsedForZeroCmp = |
854 | 187 | IsBCmp || isOnlyUsedInZeroEqualityComparison(CI); |
855 | 187 | bool OptForSize = CI->getFunction()->hasOptSize() || |
856 | 187 | llvm::shouldOptimizeForSize(CI->getParent(), PSI, BFI); |
857 | 187 | auto Options = TTI->enableMemCmpExpansion(OptForSize, |
858 | 187 | IsUsedForZeroCmp); |
859 | 187 | if (!Options) return false; |
860 | | |
861 | 187 | if (MemCmpEqZeroNumLoadsPerBlock.getNumOccurrences()) |
862 | 0 | Options.NumLoadsPerBlock = MemCmpEqZeroNumLoadsPerBlock; |
863 | | |
864 | 187 | if (OptForSize && |
865 | 187 | MaxLoadsPerMemcmpOptSize.getNumOccurrences()) |
866 | 0 | Options.MaxNumLoads = MaxLoadsPerMemcmpOptSize; |
867 | | |
868 | 187 | if (!OptForSize && MaxLoadsPerMemcmp.getNumOccurrences()) |
869 | 0 | Options.MaxNumLoads = MaxLoadsPerMemcmp; |
870 | | |
871 | 187 | MemCmpExpansion Expansion(CI, SizeVal, Options, IsUsedForZeroCmp, *DL, DTU); |
872 | | |
873 | | // Don't expand if this will require more loads than desired by the target. |
874 | 187 | if (Expansion.getNumLoads() == 0) { |
875 | 0 | NumMemCmpGreaterThanMax++; |
876 | 0 | return false; |
877 | 0 | } |
878 | | |
879 | 187 | NumMemCmpInlined++; |
880 | | |
881 | 187 | if (Value *Res = Expansion.getMemCmpExpansion()) { |
882 | | // Replace call with result of expansion and erase call. |
883 | 186 | CI->replaceAllUsesWith(Res); |
884 | 186 | CI->eraseFromParent(); |
885 | 186 | } |
886 | | |
887 | 187 | return true; |
888 | 187 | } |
889 | | |
890 | | // Returns true if a change was made. |
891 | | static bool runOnBlock(BasicBlock &BB, const TargetLibraryInfo *TLI, |
892 | | const TargetTransformInfo *TTI, const TargetLowering *TL, |
893 | | const DataLayout &DL, ProfileSummaryInfo *PSI, |
894 | | BlockFrequencyInfo *BFI, DomTreeUpdater *DTU); |
895 | | |
896 | | static PreservedAnalyses runImpl(Function &F, const TargetLibraryInfo *TLI, |
897 | | const TargetTransformInfo *TTI, |
898 | | const TargetLowering *TL, |
899 | | ProfileSummaryInfo *PSI, |
900 | | BlockFrequencyInfo *BFI, DominatorTree *DT); |
901 | | |
902 | | class ExpandMemCmpLegacyPass : public FunctionPass { |
903 | | public: |
904 | | static char ID; |
905 | | |
906 | 34.0k | ExpandMemCmpLegacyPass() : FunctionPass(ID) { |
907 | 34.0k | initializeExpandMemCmpLegacyPassPass(*PassRegistry::getPassRegistry()); |
908 | 34.0k | } |
909 | | |
910 | 113k | bool runOnFunction(Function &F) override { |
911 | 113k | if (skipFunction(F)) return false; |
912 | | |
913 | 113k | auto *TPC = getAnalysisIfAvailable<TargetPassConfig>(); |
914 | 113k | if (!TPC) { |
915 | 0 | return false; |
916 | 0 | } |
917 | 113k | const TargetLowering* TL = |
918 | 113k | TPC->getTM<TargetMachine>().getSubtargetImpl(F)->getTargetLowering(); |
919 | | |
920 | 113k | const TargetLibraryInfo *TLI = |
921 | 113k | &getAnalysis<TargetLibraryInfoWrapperPass>().getTLI(F); |
922 | 113k | const TargetTransformInfo *TTI = |
923 | 113k | &getAnalysis<TargetTransformInfoWrapperPass>().getTTI(F); |
924 | 113k | auto *PSI = &getAnalysis<ProfileSummaryInfoWrapperPass>().getPSI(); |
925 | 113k | auto *BFI = (PSI && PSI->hasProfileSummary()) ? |
926 | 29 | &getAnalysis<LazyBlockFrequencyInfoPass>().getBFI() : |
927 | 113k | nullptr; |
928 | 113k | DominatorTree *DT = nullptr; |
929 | 113k | if (auto *DTWP = getAnalysisIfAvailable<DominatorTreeWrapperPass>()) |
930 | 113k | DT = &DTWP->getDomTree(); |
931 | 113k | auto PA = runImpl(F, TLI, TTI, TL, PSI, BFI, DT); |
932 | 113k | return !PA.areAllPreserved(); |
933 | 113k | } |
934 | | |
935 | | private: |
936 | 34.0k | void getAnalysisUsage(AnalysisUsage &AU) const override { |
937 | 34.0k | AU.addRequired<TargetLibraryInfoWrapperPass>(); |
938 | 34.0k | AU.addRequired<TargetTransformInfoWrapperPass>(); |
939 | 34.0k | AU.addRequired<ProfileSummaryInfoWrapperPass>(); |
940 | 34.0k | AU.addPreserved<DominatorTreeWrapperPass>(); |
941 | 34.0k | LazyBlockFrequencyInfoPass::getLazyBFIAnalysisUsage(AU); |
942 | 34.0k | FunctionPass::getAnalysisUsage(AU); |
943 | 34.0k | } |
944 | | }; |
945 | | |
946 | | bool runOnBlock(BasicBlock &BB, const TargetLibraryInfo *TLI, |
947 | | const TargetTransformInfo *TTI, const TargetLowering *TL, |
948 | | const DataLayout &DL, ProfileSummaryInfo *PSI, |
949 | 228k | BlockFrequencyInfo *BFI, DomTreeUpdater *DTU) { |
950 | 2.24M | for (Instruction &I : BB) { |
951 | 2.24M | CallInst *CI = dyn_cast<CallInst>(&I); |
952 | 2.24M | if (!CI) { |
953 | 2.21M | continue; |
954 | 2.21M | } |
955 | 28.8k | LibFunc Func; |
956 | 28.8k | if (TLI->getLibFunc(*CI, Func) && |
957 | 28.8k | (Func == LibFunc_memcmp || Func == LibFunc_bcmp) && |
958 | 28.8k | expandMemCmp(CI, TTI, TL, &DL, PSI, BFI, DTU, Func == LibFunc_bcmp)) { |
959 | 187 | return true; |
960 | 187 | } |
961 | 28.8k | } |
962 | 228k | return false; |
963 | 228k | } |
964 | | |
965 | | PreservedAnalyses runImpl(Function &F, const TargetLibraryInfo *TLI, |
966 | | const TargetTransformInfo *TTI, |
967 | | const TargetLowering *TL, ProfileSummaryInfo *PSI, |
968 | 113k | BlockFrequencyInfo *BFI, DominatorTree *DT) { |
969 | 113k | std::optional<DomTreeUpdater> DTU; |
970 | 113k | if (DT) |
971 | 113k | DTU.emplace(DT, DomTreeUpdater::UpdateStrategy::Lazy); |
972 | | |
973 | 113k | const DataLayout& DL = F.getParent()->getDataLayout(); |
974 | 113k | bool MadeChanges = false; |
975 | 342k | for (auto BBIt = F.begin(); BBIt != F.end();) { |
976 | 228k | if (runOnBlock(*BBIt, TLI, TTI, TL, DL, PSI, BFI, DTU ? &*DTU : nullptr)) { |
977 | 187 | MadeChanges = true; |
978 | | // If changes were made, restart the function from the beginning, since |
979 | | // the structure of the function was changed. |
980 | 187 | BBIt = F.begin(); |
981 | 228k | } else { |
982 | 228k | ++BBIt; |
983 | 228k | } |
984 | 228k | } |
985 | 113k | if (MadeChanges) |
986 | 187 | for (BasicBlock &BB : F) |
987 | 192 | SimplifyInstructionsInBlock(&BB); |
988 | 113k | if (!MadeChanges) |
989 | 113k | return PreservedAnalyses::all(); |
990 | 187 | PreservedAnalyses PA; |
991 | 187 | PA.preserve<DominatorTreeAnalysis>(); |
992 | 187 | return PA; |
993 | 113k | } |
994 | | |
995 | | } // namespace |
996 | | |
997 | | PreservedAnalyses ExpandMemCmpPass::run(Function &F, |
998 | 0 | FunctionAnalysisManager &FAM) { |
999 | 0 | const auto *TL = TM->getSubtargetImpl(F)->getTargetLowering(); |
1000 | 0 | const auto &TLI = FAM.getResult<TargetLibraryAnalysis>(F); |
1001 | 0 | const auto &TTI = FAM.getResult<TargetIRAnalysis>(F); |
1002 | 0 | auto *PSI = FAM.getResult<ModuleAnalysisManagerFunctionProxy>(F) |
1003 | 0 | .getCachedResult<ProfileSummaryAnalysis>(*F.getParent()); |
1004 | 0 | BlockFrequencyInfo *BFI = (PSI && PSI->hasProfileSummary()) |
1005 | 0 | ? &FAM.getResult<BlockFrequencyAnalysis>(F) |
1006 | 0 | : nullptr; |
1007 | 0 | auto *DT = FAM.getCachedResult<DominatorTreeAnalysis>(F); |
1008 | |
|
1009 | 0 | return runImpl(F, &TLI, &TTI, TL, PSI, BFI, DT); |
1010 | 0 | } |
1011 | | |
1012 | | char ExpandMemCmpLegacyPass::ID = 0; |
1013 | 12 | INITIALIZE_PASS_BEGIN(ExpandMemCmpLegacyPass, DEBUG_TYPE, |
1014 | 12 | "Expand memcmp() to load/stores", false, false) |
1015 | 12 | INITIALIZE_PASS_DEPENDENCY(TargetLibraryInfoWrapperPass) |
1016 | 12 | INITIALIZE_PASS_DEPENDENCY(TargetTransformInfoWrapperPass) |
1017 | 12 | INITIALIZE_PASS_DEPENDENCY(LazyBlockFrequencyInfoPass) |
1018 | 12 | INITIALIZE_PASS_DEPENDENCY(ProfileSummaryInfoWrapperPass) |
1019 | 12 | INITIALIZE_PASS_DEPENDENCY(DominatorTreeWrapperPass) |
1020 | 12 | INITIALIZE_PASS_END(ExpandMemCmpLegacyPass, DEBUG_TYPE, |
1021 | | "Expand memcmp() to load/stores", false, false) |
1022 | | |
1023 | 34.0k | FunctionPass *llvm::createExpandMemCmpLegacyPass() { |
1024 | 34.0k | return new ExpandMemCmpLegacyPass(); |
1025 | 34.0k | } |