Coverage Report

Created: 2024-01-17 10:31

/src/llvm-project/llvm/lib/Target/ARM/MVETailPredication.cpp
Line
Count
Source (jump to first uncovered line)
1
//===- MVETailPredication.cpp - MVE Tail Predication ------------*- C++ -*-===//
2
//
3
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4
// See https://llvm.org/LICENSE.txt for license information.
5
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6
//
7
//===----------------------------------------------------------------------===//
8
//
9
/// \file
10
/// Armv8.1m introduced MVE, M-Profile Vector Extension, and low-overhead
11
/// branches to help accelerate DSP applications. These two extensions,
12
/// combined with a new form of predication called tail-predication, can be used
13
/// to provide implicit vector predication within a low-overhead loop.
14
/// This is implicit because the predicate of active/inactive lanes is
15
/// calculated by hardware, and thus does not need to be explicitly passed
16
/// to vector instructions. The instructions responsible for this are the
17
/// DLSTP and WLSTP instructions, which setup a tail-predicated loop and the
18
/// the total number of data elements processed by the loop. The loop-end
19
/// LETP instruction is responsible for decrementing and setting the remaining
20
/// elements to be processed and generating the mask of active lanes.
21
///
22
/// The HardwareLoops pass inserts intrinsics identifying loops that the
23
/// backend will attempt to convert into a low-overhead loop. The vectorizer is
24
/// responsible for generating a vectorized loop in which the lanes are
25
/// predicated upon an get.active.lane.mask intrinsic. This pass looks at these
26
/// get.active.lane.mask intrinsic and attempts to convert them to VCTP
27
/// instructions. This will be picked up by the ARM Low-overhead loop pass later
28
/// in the backend, which performs the final transformation to a DLSTP or WLSTP
29
/// tail-predicated loop.
30
//
31
//===----------------------------------------------------------------------===//
32
33
#include "ARM.h"
34
#include "ARMSubtarget.h"
35
#include "ARMTargetTransformInfo.h"
36
#include "llvm/Analysis/LoopInfo.h"
37
#include "llvm/Analysis/LoopPass.h"
38
#include "llvm/Analysis/ScalarEvolution.h"
39
#include "llvm/Analysis/ScalarEvolutionExpressions.h"
40
#include "llvm/Analysis/TargetLibraryInfo.h"
41
#include "llvm/Analysis/TargetTransformInfo.h"
42
#include "llvm/Analysis/ValueTracking.h"
43
#include "llvm/CodeGen/TargetPassConfig.h"
44
#include "llvm/IR/IRBuilder.h"
45
#include "llvm/IR/Instructions.h"
46
#include "llvm/IR/IntrinsicsARM.h"
47
#include "llvm/IR/PatternMatch.h"
48
#include "llvm/InitializePasses.h"
49
#include "llvm/Support/Debug.h"
50
#include "llvm/Transforms/Utils/BasicBlockUtils.h"
51
#include "llvm/Transforms/Utils/Local.h"
52
#include "llvm/Transforms/Utils/LoopUtils.h"
53
#include "llvm/Transforms/Utils/ScalarEvolutionExpander.h"
54
55
using namespace llvm;
56
57
#define DEBUG_TYPE "mve-tail-predication"
58
#define DESC "Transform predicated vector loops to use MVE tail predication"
59
60
cl::opt<TailPredication::Mode> EnableTailPredication(
61
   "tail-predication", cl::desc("MVE tail-predication pass options"),
62
   cl::init(TailPredication::Enabled),
63
   cl::values(clEnumValN(TailPredication::Disabled, "disabled",
64
                         "Don't tail-predicate loops"),
65
              clEnumValN(TailPredication::EnabledNoReductions,
66
                         "enabled-no-reductions",
67
                         "Enable tail-predication, but not for reduction loops"),
68
              clEnumValN(TailPredication::Enabled,
69
                         "enabled",
70
                         "Enable tail-predication, including reduction loops"),
71
              clEnumValN(TailPredication::ForceEnabledNoReductions,
72
                         "force-enabled-no-reductions",
73
                         "Enable tail-predication, but not for reduction loops, "
74
                         "and force this which might be unsafe"),
75
              clEnumValN(TailPredication::ForceEnabled,
76
                         "force-enabled",
77
                         "Enable tail-predication, including reduction loops, "
78
                         "and force this which might be unsafe")));
79
80
81
namespace {
82
83
class MVETailPredication : public LoopPass {
84
  SmallVector<IntrinsicInst*, 4> MaskedInsts;
85
  Loop *L = nullptr;
86
  ScalarEvolution *SE = nullptr;
87
  TargetTransformInfo *TTI = nullptr;
88
  const ARMSubtarget *ST = nullptr;
89
90
public:
91
  static char ID;
92
93
2.47k
  MVETailPredication() : LoopPass(ID) { }
94
95
2.47k
  void getAnalysisUsage(AnalysisUsage &AU) const override {
96
2.47k
    AU.addRequired<ScalarEvolutionWrapperPass>();
97
2.47k
    AU.addRequired<LoopInfoWrapperPass>();
98
2.47k
    AU.addRequired<TargetPassConfig>();
99
2.47k
    AU.addRequired<TargetTransformInfoWrapperPass>();
100
2.47k
    AU.addPreserved<LoopInfoWrapperPass>();
101
2.47k
    AU.setPreservesCFG();
102
2.47k
  }
103
104
  bool runOnLoop(Loop *L, LPPassManager&) override;
105
106
private:
107
  /// Perform the relevant checks on the loop and convert active lane masks if
108
  /// possible.
109
  bool TryConvertActiveLaneMask(Value *TripCount);
110
111
  /// Perform several checks on the arguments of @llvm.get.active.lane.mask
112
  /// intrinsic. E.g., check that the loop induction variable and the element
113
  /// count are of the form we expect, and also perform overflow checks for
114
  /// the new expressions that are created.
115
  const SCEV *IsSafeActiveMask(IntrinsicInst *ActiveLaneMask, Value *TripCount);
116
117
  /// Insert the intrinsic to represent the effect of tail predication.
118
  void InsertVCTPIntrinsic(IntrinsicInst *ActiveLaneMask, Value *Start);
119
};
120
121
} // end namespace
122
123
143
bool MVETailPredication::runOnLoop(Loop *L, LPPassManager&) {
124
143
  if (skipLoop(L) || !EnableTailPredication)
125
0
    return false;
126
127
143
  MaskedInsts.clear();
128
143
  Function &F = *L->getHeader()->getParent();
129
143
  auto &TPC = getAnalysis<TargetPassConfig>();
130
143
  auto &TM = TPC.getTM<TargetMachine>();
131
143
  ST = &TM.getSubtarget<ARMSubtarget>(F);
132
143
  TTI = &getAnalysis<TargetTransformInfoWrapperPass>().getTTI(F);
133
143
  SE = &getAnalysis<ScalarEvolutionWrapperPass>().getSE();
134
143
  this->L = L;
135
136
  // The MVE and LOB extensions are combined to enable tail-predication, but
137
  // there's nothing preventing us from generating VCTP instructions for v8.1m.
138
143
  if (!ST->hasMVEIntegerOps() || !ST->hasV8_1MMainlineOps()) {
139
143
    LLVM_DEBUG(dbgs() << "ARM TP: Not a v8.1m.main+mve target.\n");
140
143
    return false;
141
143
  }
142
143
0
  BasicBlock *Preheader = L->getLoopPreheader();
144
0
  if (!Preheader)
145
0
    return false;
146
147
0
  auto FindLoopIterations = [](BasicBlock *BB) -> IntrinsicInst* {
148
0
    for (auto &I : *BB) {
149
0
      auto *Call = dyn_cast<IntrinsicInst>(&I);
150
0
      if (!Call)
151
0
        continue;
152
153
0
      Intrinsic::ID ID = Call->getIntrinsicID();
154
0
      if (ID == Intrinsic::start_loop_iterations ||
155
0
          ID == Intrinsic::test_start_loop_iterations)
156
0
        return cast<IntrinsicInst>(&I);
157
0
    }
158
0
    return nullptr;
159
0
  };
160
161
  // Look for the hardware loop intrinsic that sets the iteration count.
162
0
  IntrinsicInst *Setup = FindLoopIterations(Preheader);
163
164
  // The test.set iteration could live in the pre-preheader.
165
0
  if (!Setup) {
166
0
    if (!Preheader->getSinglePredecessor())
167
0
      return false;
168
0
    Setup = FindLoopIterations(Preheader->getSinglePredecessor());
169
0
    if (!Setup)
170
0
      return false;
171
0
  }
172
173
0
  LLVM_DEBUG(dbgs() << "ARM TP: Running on Loop: " << *L << *Setup << "\n");
174
175
0
  bool Changed = TryConvertActiveLaneMask(Setup->getArgOperand(0));
176
177
0
  return Changed;
178
0
}
179
180
// The active lane intrinsic has this form:
181
//
182
//    @llvm.get.active.lane.mask(IV, TC)
183
//
184
// Here we perform checks that this intrinsic behaves as expected,
185
// which means:
186
//
187
// 1) Check that the TripCount (TC) belongs to this loop (originally).
188
// 2) The element count (TC) needs to be sufficiently large that the decrement
189
//    of element counter doesn't overflow, which means that we need to prove:
190
//        ceil(ElementCount / VectorWidth) >= TripCount
191
//    by rounding up ElementCount up:
192
//        ((ElementCount + (VectorWidth - 1)) / VectorWidth
193
//    and evaluate if expression isKnownNonNegative:
194
//        (((ElementCount + (VectorWidth - 1)) / VectorWidth) - TripCount
195
// 3) The IV must be an induction phi with an increment equal to the
196
//    vector width.
197
const SCEV *MVETailPredication::IsSafeActiveMask(IntrinsicInst *ActiveLaneMask,
198
0
                                                 Value *TripCount) {
199
0
  bool ForceTailPredication =
200
0
    EnableTailPredication == TailPredication::ForceEnabledNoReductions ||
201
0
    EnableTailPredication == TailPredication::ForceEnabled;
202
203
0
  Value *ElemCount = ActiveLaneMask->getOperand(1);
204
0
  bool Changed = false;
205
0
  if (!L->makeLoopInvariant(ElemCount, Changed))
206
0
    return nullptr;
207
208
0
  auto *EC= SE->getSCEV(ElemCount);
209
0
  auto *TC = SE->getSCEV(TripCount);
210
0
  int VectorWidth =
211
0
      cast<FixedVectorType>(ActiveLaneMask->getType())->getNumElements();
212
0
  if (VectorWidth != 2 && VectorWidth != 4 && VectorWidth != 8 &&
213
0
      VectorWidth != 16)
214
0
    return nullptr;
215
0
  ConstantInt *ConstElemCount = nullptr;
216
217
  // 1) Smoke tests that the original scalar loop TripCount (TC) belongs to
218
  // this loop.  The scalar tripcount corresponds the number of elements
219
  // processed by the loop, so we will refer to that from this point on.
220
0
  if (!SE->isLoopInvariant(EC, L)) {
221
0
    LLVM_DEBUG(dbgs() << "ARM TP: element count must be loop invariant.\n");
222
0
    return nullptr;
223
0
  }
224
225
  // 2) Find out if IV is an induction phi. Note that we can't use Loop
226
  // helpers here to get the induction variable, because the hardware loop is
227
  // no longer in loopsimplify form, and also the hwloop intrinsic uses a
228
  // different counter. Using SCEV, we check that the induction is of the
229
  // form i = i + 4, where the increment must be equal to the VectorWidth.
230
0
  auto *IV = ActiveLaneMask->getOperand(0);
231
0
  auto *IVExpr = SE->getSCEV(IV);
232
0
  auto *AddExpr = dyn_cast<SCEVAddRecExpr>(IVExpr);
233
234
0
  if (!AddExpr) {
235
0
    LLVM_DEBUG(dbgs() << "ARM TP: induction not an add expr: "; IVExpr->dump());
236
0
    return nullptr;
237
0
  }
238
  // Check that this AddRec is associated with this loop.
239
0
  if (AddExpr->getLoop() != L) {
240
0
    LLVM_DEBUG(dbgs() << "ARM TP: phi not part of this loop\n");
241
0
    return nullptr;
242
0
  }
243
0
  auto *Step = dyn_cast<SCEVConstant>(AddExpr->getOperand(1));
244
0
  if (!Step) {
245
0
    LLVM_DEBUG(dbgs() << "ARM TP: induction step is not a constant: ";
246
0
               AddExpr->getOperand(1)->dump());
247
0
    return nullptr;
248
0
  }
249
0
  auto StepValue = Step->getValue()->getSExtValue();
250
0
  if (VectorWidth != StepValue) {
251
0
    LLVM_DEBUG(dbgs() << "ARM TP: Step value " << StepValue
252
0
                      << " doesn't match vector width " << VectorWidth << "\n");
253
0
    return nullptr;
254
0
  }
255
256
0
  if ((ConstElemCount = dyn_cast<ConstantInt>(ElemCount))) {
257
0
    ConstantInt *TC = dyn_cast<ConstantInt>(TripCount);
258
0
    if (!TC) {
259
0
      LLVM_DEBUG(dbgs() << "ARM TP: Constant tripcount expected in "
260
0
                           "set.loop.iterations\n");
261
0
      return nullptr;
262
0
    }
263
264
    // Calculate 2 tripcount values and check that they are consistent with
265
    // each other. The TripCount for a predicated vector loop body is
266
    // ceil(ElementCount/Width), or floor((ElementCount+Width-1)/Width) as we
267
    // work it out here.
268
0
    uint64_t TC1 = TC->getZExtValue();
269
0
    uint64_t TC2 =
270
0
        (ConstElemCount->getZExtValue() + VectorWidth - 1) / VectorWidth;
271
272
    // If the tripcount values are inconsistent, we can't insert the VCTP and
273
    // trigger tail-predication; keep the intrinsic as a get.active.lane.mask
274
    // and legalize this.
275
0
    if (TC1 != TC2) {
276
0
      LLVM_DEBUG(dbgs() << "ARM TP: inconsistent constant tripcount values: "
277
0
                 << TC1 << " from set.loop.iterations, and "
278
0
                 << TC2 << " from get.active.lane.mask\n");
279
0
      return nullptr;
280
0
    }
281
0
  } else if (!ForceTailPredication) {
282
    // 3) We need to prove that the sub expression that we create in the
283
    // tail-predicated loop body, which calculates the remaining elements to be
284
    // processed, is non-negative, i.e. it doesn't overflow:
285
    //
286
    //   ((ElementCount + VectorWidth - 1) / VectorWidth) - TripCount >= 0
287
    //
288
    // This is true if:
289
    //
290
    //    TripCount == (ElementCount + VectorWidth - 1) / VectorWidth
291
    //
292
    // which what we will be using here.
293
    //
294
0
    auto *VW = SE->getSCEV(ConstantInt::get(TripCount->getType(), VectorWidth));
295
    // ElementCount + (VW-1):
296
0
    auto *Start = AddExpr->getStart();
297
0
    auto *ECPlusVWMinus1 = SE->getAddExpr(EC,
298
0
        SE->getSCEV(ConstantInt::get(TripCount->getType(), VectorWidth - 1)));
299
300
    // Ceil = ElementCount + (VW-1) / VW
301
0
    auto *Ceil = SE->getUDivExpr(ECPlusVWMinus1, VW);
302
303
    // Prevent unused variable warnings with TC
304
0
    (void)TC;
305
0
    LLVM_DEBUG({
306
0
      dbgs() << "ARM TP: Analysing overflow behaviour for:\n";
307
0
      dbgs() << "ARM TP: - TripCount = " << *TC << "\n";
308
0
      dbgs() << "ARM TP: - ElemCount = " << *EC << "\n";
309
0
      dbgs() << "ARM TP: - Start = " << *Start << "\n";
310
0
      dbgs() << "ARM TP: - BETC = " << *SE->getBackedgeTakenCount(L) << "\n";
311
0
      dbgs() << "ARM TP: - VecWidth =  " << VectorWidth << "\n";
312
0
      dbgs() << "ARM TP: - (ElemCount+VW-1) / VW = " << *Ceil << "\n";
313
0
    });
314
315
    // As an example, almost all the tripcount expressions (produced by the
316
    // vectoriser) look like this:
317
    //
318
    //   TC = ((-4 + (4 * ((3 + %N) /u 4))<nuw> - start) /u 4)
319
    //
320
    // and "ElementCount + (VW-1) / VW":
321
    //
322
    //   Ceil = ((3 + %N) /u 4)
323
    //
324
    // Check for equality of TC and Ceil by calculating SCEV expression
325
    // TC - Ceil and test it for zero.
326
    //
327
0
    const SCEV *Div = SE->getUDivExpr(
328
0
        SE->getAddExpr(SE->getMulExpr(Ceil, VW), SE->getNegativeSCEV(VW),
329
0
                       SE->getNegativeSCEV(Start)),
330
0
        VW);
331
0
    const SCEV *Sub = SE->getMinusSCEV(SE->getBackedgeTakenCount(L), Div);
332
0
    LLVM_DEBUG(dbgs() << "ARM TP: - Sub       = "; Sub->dump());
333
334
    // Use context sensitive facts about the path to the loop to refine.  This
335
    // comes up as the backedge taken count can incorporate context sensitive
336
    // reasoning, and our RHS just above doesn't.
337
0
    Sub = SE->applyLoopGuards(Sub, L);
338
0
    LLVM_DEBUG(dbgs() << "ARM TP: - (Guarded) = "; Sub->dump());
339
340
0
    if (!Sub->isZero()) {
341
0
      LLVM_DEBUG(dbgs() << "ARM TP: possible overflow in sub expression.\n");
342
0
      return nullptr;
343
0
    }
344
0
  }
345
346
  // Check that the start value is a multiple of the VectorWidth.
347
  // TODO: This could do with a method to check if the scev is a multiple of
348
  // VectorWidth. For the moment we just check for constants, muls and unknowns
349
  // (which use MaskedValueIsZero and seems to be the most common).
350
0
  if (auto *BaseC = dyn_cast<SCEVConstant>(AddExpr->getStart())) {
351
0
    if (BaseC->getAPInt().urem(VectorWidth) == 0)
352
0
      return SE->getMinusSCEV(EC, BaseC);
353
0
  } else if (auto *BaseV = dyn_cast<SCEVUnknown>(AddExpr->getStart())) {
354
0
    Type *Ty = BaseV->getType();
355
0
    APInt Mask = APInt::getLowBitsSet(Ty->getPrimitiveSizeInBits(),
356
0
                                      Log2_64(VectorWidth));
357
0
    if (MaskedValueIsZero(BaseV->getValue(), Mask,
358
0
                          L->getHeader()->getModule()->getDataLayout()))
359
0
      return SE->getMinusSCEV(EC, BaseV);
360
0
  } else if (auto *BaseMul = dyn_cast<SCEVMulExpr>(AddExpr->getStart())) {
361
0
    if (auto *BaseC = dyn_cast<SCEVConstant>(BaseMul->getOperand(0)))
362
0
      if (BaseC->getAPInt().urem(VectorWidth) == 0)
363
0
        return SE->getMinusSCEV(EC, BaseC);
364
0
    if (auto *BaseC = dyn_cast<SCEVConstant>(BaseMul->getOperand(1)))
365
0
      if (BaseC->getAPInt().urem(VectorWidth) == 0)
366
0
        return SE->getMinusSCEV(EC, BaseC);
367
0
  }
368
369
0
  LLVM_DEBUG(
370
0
      dbgs() << "ARM TP: induction base is not know to be a multiple of VF: "
371
0
             << *AddExpr->getOperand(0) << "\n");
372
0
  return nullptr;
373
0
}
374
375
void MVETailPredication::InsertVCTPIntrinsic(IntrinsicInst *ActiveLaneMask,
376
0
                                             Value *Start) {
377
0
  IRBuilder<> Builder(L->getLoopPreheader()->getTerminator());
378
0
  Module *M = L->getHeader()->getModule();
379
0
  Type *Ty = IntegerType::get(M->getContext(), 32);
380
0
  unsigned VectorWidth =
381
0
      cast<FixedVectorType>(ActiveLaneMask->getType())->getNumElements();
382
383
  // Insert a phi to count the number of elements processed by the loop.
384
0
  Builder.SetInsertPoint(L->getHeader(), L->getHeader()->getFirstNonPHIIt());
385
0
  PHINode *Processed = Builder.CreatePHI(Ty, 2);
386
0
  Processed->addIncoming(Start, L->getLoopPreheader());
387
388
  // Replace @llvm.get.active.mask() with the ARM specific VCTP intrinic, and
389
  // thus represent the effect of tail predication.
390
0
  Builder.SetInsertPoint(ActiveLaneMask);
391
0
  ConstantInt *Factor = ConstantInt::get(cast<IntegerType>(Ty), VectorWidth);
392
393
0
  Intrinsic::ID VCTPID;
394
0
  switch (VectorWidth) {
395
0
  default:
396
0
    llvm_unreachable("unexpected number of lanes");
397
0
  case 2:  VCTPID = Intrinsic::arm_mve_vctp64; break;
398
0
  case 4:  VCTPID = Intrinsic::arm_mve_vctp32; break;
399
0
  case 8:  VCTPID = Intrinsic::arm_mve_vctp16; break;
400
0
  case 16: VCTPID = Intrinsic::arm_mve_vctp8; break;
401
0
  }
402
0
  Function *VCTP = Intrinsic::getDeclaration(M, VCTPID);
403
0
  Value *VCTPCall = Builder.CreateCall(VCTP, Processed);
404
0
  ActiveLaneMask->replaceAllUsesWith(VCTPCall);
405
406
  // Add the incoming value to the new phi.
407
  // TODO: This add likely already exists in the loop.
408
0
  Value *Remaining = Builder.CreateSub(Processed, Factor);
409
0
  Processed->addIncoming(Remaining, L->getLoopLatch());
410
0
  LLVM_DEBUG(dbgs() << "ARM TP: Insert processed elements phi: "
411
0
             << *Processed << "\n"
412
0
             << "ARM TP: Inserted VCTP: " << *VCTPCall << "\n");
413
0
}
414
415
0
bool MVETailPredication::TryConvertActiveLaneMask(Value *TripCount) {
416
0
  SmallVector<IntrinsicInst *, 4> ActiveLaneMasks;
417
0
  for (auto *BB : L->getBlocks())
418
0
    for (auto &I : *BB)
419
0
      if (auto *Int = dyn_cast<IntrinsicInst>(&I))
420
0
        if (Int->getIntrinsicID() == Intrinsic::get_active_lane_mask)
421
0
          ActiveLaneMasks.push_back(Int);
422
423
0
  if (ActiveLaneMasks.empty())
424
0
    return false;
425
426
0
  LLVM_DEBUG(dbgs() << "ARM TP: Found predicated vector loop.\n");
427
428
0
  for (auto *ActiveLaneMask : ActiveLaneMasks) {
429
0
    LLVM_DEBUG(dbgs() << "ARM TP: Found active lane mask: "
430
0
                      << *ActiveLaneMask << "\n");
431
432
0
    const SCEV *StartSCEV = IsSafeActiveMask(ActiveLaneMask, TripCount);
433
0
    if (!StartSCEV) {
434
0
      LLVM_DEBUG(dbgs() << "ARM TP: Not safe to insert VCTP.\n");
435
0
      return false;
436
0
    }
437
0
    LLVM_DEBUG(dbgs() << "ARM TP: Safe to insert VCTP. Start is " << *StartSCEV
438
0
                      << "\n");
439
0
    SCEVExpander Expander(*SE, L->getHeader()->getModule()->getDataLayout(),
440
0
                          "start");
441
0
    Instruction *Ins = L->getLoopPreheader()->getTerminator();
442
0
    Value *Start = Expander.expandCodeFor(StartSCEV, StartSCEV->getType(), Ins);
443
0
    LLVM_DEBUG(dbgs() << "ARM TP: Created start value " << *Start << "\n");
444
0
    InsertVCTPIntrinsic(ActiveLaneMask, Start);
445
0
  }
446
447
  // Remove dead instructions and now dead phis.
448
0
  for (auto *II : ActiveLaneMasks)
449
0
    RecursivelyDeleteTriviallyDeadInstructions(II);
450
0
  for (auto *I : L->blocks())
451
0
    DeleteDeadPHIs(I);
452
0
  return true;
453
0
}
454
455
2.47k
Pass *llvm::createMVETailPredicationPass() {
456
2.47k
  return new MVETailPredication();
457
2.47k
}
458
459
char MVETailPredication::ID = 0;
460
461
62
INITIALIZE_PASS_BEGIN(MVETailPredication, DEBUG_TYPE, DESC, false, false)
462
62
INITIALIZE_PASS_END(MVETailPredication, DEBUG_TYPE, DESC, false, false)